mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
support huggingface + model select based on lang.
This commit is contained in:
@ -122,8 +122,7 @@ https://user-images.githubusercontent.com/19920981/208731743-311f2360-b73b-4c60-
|
||||
|
||||
[x] Subtitle .ass output
|
||||
|
||||
[ ] Automatic align model selection based on language detection
|
||||
|
||||
[x] Automatic align model selection based on language detection
|
||||
|
||||
[ ] Incorporating word-level speaker diarization
|
||||
|
||||
|
@ -17,8 +17,19 @@ from .utils import exact_div, format_timestamp, optional_int, optional_float, st
|
||||
if TYPE_CHECKING:
|
||||
from .model import Whisper
|
||||
|
||||
hugginface_models = ["jonatasgrosman/wav2vec2-large-xlsr-53-japanese"]
|
||||
asian_languages = ["ja"]
|
||||
LANGUAGES_WITHOUT_SPACES = ["ja", "zh"]
|
||||
|
||||
DEFAULT_ALIGN_MODELS_TORCH = {
|
||||
"en": "WAV2VEC2_ASR_BASE_960H",
|
||||
"fr": "VOXPOPULI_ASR_BASE_10K_FR",
|
||||
"de": "VOXPOPULI_ASR_BASE_10K_DE",
|
||||
"es": "VOXPOPULI_ASR_BASE_10K_ES",
|
||||
"it": "VOXPOPULI_ASR_BASE_10K_IT",
|
||||
}
|
||||
|
||||
DEFAULT_ALIGN_MODELS_HF = {
|
||||
"ja": "jonatasgrosman/wav2vec2-large-xlsr-53-japanese",
|
||||
}
|
||||
|
||||
|
||||
def transcribe(
|
||||
@ -255,7 +266,7 @@ def align(
|
||||
transcript: Iterator[dict],
|
||||
language: str,
|
||||
model: torch.nn.Module,
|
||||
model_dictionary: dict,
|
||||
align_model_metadata: dict,
|
||||
audio: Union[str, np.ndarray, torch.Tensor],
|
||||
device: str,
|
||||
extend_duration: float = 0.0,
|
||||
@ -272,6 +283,10 @@ def align(
|
||||
|
||||
MAX_DURATION = audio.shape[1] / SAMPLE_RATE
|
||||
|
||||
model_dictionary = align_model_metadata['dictionary']
|
||||
model_lang = align_model_metadata['language']
|
||||
model_type = align_model_metadata['type']
|
||||
|
||||
prev_t2 = 0
|
||||
word_segments_list = []
|
||||
for idx, segment in enumerate(transcript):
|
||||
@ -285,14 +300,16 @@ def align(
|
||||
|
||||
waveform_segment = audio[:, f1:f2]
|
||||
with torch.inference_mode():
|
||||
if language not in asian_languages:
|
||||
if model_type == "torchaudio":
|
||||
emissions, _ = model(waveform_segment.to(device))
|
||||
else:
|
||||
elif model_type == "huggingface":
|
||||
emissions = model(waveform_segment.to(device)).logits
|
||||
else:
|
||||
raise NotImplementedError(f"Align model of type {model_type} not supported.")
|
||||
emissions = torch.log_softmax(emissions, dim=-1)
|
||||
emission = emissions[0].cpu().detach()
|
||||
transcription = segment['text'].strip()
|
||||
if language not in asian_languages:
|
||||
if language not in LANGUAGES_WITHOUT_SPACES:
|
||||
t_words = transcription.split(' ')
|
||||
else:
|
||||
t_words = [c for c in transcription]
|
||||
@ -359,6 +376,41 @@ def align(
|
||||
|
||||
return {"segments": transcript, "word_segments": word_segments_list}
|
||||
|
||||
def load_align_model(language_code, device, model_name=None):
|
||||
if model_name is None:
|
||||
# use default model
|
||||
if language_code in DEFAULT_ALIGN_MODELS_TORCH:
|
||||
model_name = DEFAULT_ALIGN_MODELS_TORCH[language_code]
|
||||
elif language_code in DEFAULT_ALIGN_MODELS_HF:
|
||||
model_name = DEFAULT_ALIGN_MODELS_HF[language_code]
|
||||
else:
|
||||
print(f"There is no default alignment model set for this language ({language_code}).\
|
||||
Please find a wav2vec2.0 model finetuned on this language in https://huggingface.co/models, then pass the model name in --align_model [MODEL_NAME]")
|
||||
raise ValueError(f"No default align-model for language: {language_code}")
|
||||
|
||||
if model_name in torchaudio.pipelines.__all__:
|
||||
pipeline_type = "torchaudio"
|
||||
bundle = torchaudio.pipelines.__dict__[model_name]
|
||||
align_model = bundle.get_model().to(device)
|
||||
labels = bundle.get_labels()
|
||||
align_dictionary = {c.lower(): i for i, c in enumerate(labels)}
|
||||
else:
|
||||
try:
|
||||
processor = AutoProcessor.from_pretrained(model_name)
|
||||
align_model = Wav2Vec2ForCTC.from_pretrained(model_name)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print(f"Error loading model from huggingface, check https://huggingface.co/models for finetuned wav2vec2.0 models")
|
||||
raise ValueError(f'The chosen align_model "{model_name}" could not be found in huggingface (https://huggingface.co/models) or torchaudio (https://pytorch.org/audio/stable/pipelines.html#id14)')
|
||||
pipeline_type = "huggingface"
|
||||
align_model = align_model.to(device)
|
||||
labels = processor.tokenizer.get_vocab()
|
||||
align_dictionary = {char.lower(): code for char,code in processor.tokenizer.get_vocab().items()}
|
||||
|
||||
align_metadata = {"language": language_code, "dictionary": align_dictionary, "type": pipeline_type}
|
||||
|
||||
return align_model, align_metadata
|
||||
|
||||
def cli():
|
||||
from . import available_models
|
||||
|
||||
@ -368,7 +420,7 @@ def cli():
|
||||
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
|
||||
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
|
||||
# alignment params
|
||||
parser.add_argument("--align_model", default="WAV2VEC2_ASR_BASE_960H", help="Name of phoneme-level ASR model to do alignment")
|
||||
parser.add_argument("--align_model", default=None, help="Name of phoneme-level ASR model to do alignment")
|
||||
parser.add_argument("--align_extend", default=2, type=float, help="Seconds before and after to extend the whisper segments for alignment")
|
||||
parser.add_argument("--align_from_prev", default=True, type=bool, help="Whether to clip the alignment start time of current segment to the end time of the last aligned word of the previous segment")
|
||||
parser.add_argument("--drop_non_aligned", action="store_true", help="For word .srt, whether to drop non aliged words, or merge them into neighbouring.")
|
||||
@ -430,24 +482,19 @@ def cli():
|
||||
|
||||
from . import load_model
|
||||
model = load_model(model_name, device=device, download_root=model_dir)
|
||||
if align_model in torchaudio.pipelines.__all__:
|
||||
bundle = torchaudio.pipelines.__dict__[align_model]
|
||||
align_model = bundle.get_model().to(device)
|
||||
labels = bundle.get_labels()
|
||||
align_dictionary = {c.lower(): i for i, c in enumerate(labels)}
|
||||
elif align_model in hugginface_models:
|
||||
processor = AutoProcessor.from_pretrained(align_model)
|
||||
align_model = Wav2Vec2ForCTC.from_pretrained(align_model).to(device)
|
||||
align_model.to(device)
|
||||
labels = processor.tokenizer.get_vocab()
|
||||
align_dictionary = processor.tokenizer.get_vocab()
|
||||
else:
|
||||
print(f'Align model "{align_model}" is not supported, choose from:\n {torchaudio.pipelines.__all__ + wa2vec2_models_on_hugginface} \n\
|
||||
See details here https://pytorch.org/audio/stable/pipelines.html#id14')
|
||||
raise ValueError(f'Align model "{align_model}" not supported')
|
||||
|
||||
align_language = args["language"] if args["language"] is not None else "en" # default to loading english if not specified
|
||||
align_model, align_metadata = load_align_model(align_language, device, model_name=align_model)
|
||||
|
||||
for audio_path in args.pop("audio"):
|
||||
result = transcribe(model, audio_path, temperature=temperature, **args)
|
||||
result_aligned = align(result["segments"], result["language"], align_model, align_dictionary, audio_path, device,
|
||||
|
||||
if result["language"] != align_metadata["language"]:
|
||||
# load new language
|
||||
print(f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language...")
|
||||
align_model, align_metadata = load_align_model(result["language"], device)
|
||||
|
||||
result_aligned = align(result["segments"], result["language"], align_model, align_metadata, audio_path, device,
|
||||
extend_duration=align_extend, start_from_previous=align_from_prev, drop_non_aligned_words=drop_non_aligned)
|
||||
audio_basename = os.path.basename(audio_path)
|
||||
|
||||
|
Reference in New Issue
Block a user