From e909f2f766b23b2000f2d95df41f9b844ac53e49 Mon Sep 17 00:00:00 2001 From: Max Bain Date: Tue, 20 Dec 2022 19:54:55 +0000 Subject: [PATCH] support huggingface + model select based on lang. --- README.md | 3 +- whisperx/transcribe.py | 93 +++++++++++++++++++++++++++++++----------- 2 files changed, 71 insertions(+), 25 deletions(-) diff --git a/README.md b/README.md index adfff48..b170c94 100644 --- a/README.md +++ b/README.md @@ -122,8 +122,7 @@ https://user-images.githubusercontent.com/19920981/208731743-311f2360-b73b-4c60- [x] Subtitle .ass output -[ ] Automatic align model selection based on language detection - +[x] Automatic align model selection based on language detection [ ] Incorporating word-level speaker diarization diff --git a/whisperx/transcribe.py b/whisperx/transcribe.py index 7a0e401..f847b25 100644 --- a/whisperx/transcribe.py +++ b/whisperx/transcribe.py @@ -17,8 +17,19 @@ from .utils import exact_div, format_timestamp, optional_int, optional_float, st if TYPE_CHECKING: from .model import Whisper -hugginface_models = ["jonatasgrosman/wav2vec2-large-xlsr-53-japanese"] -asian_languages = ["ja"] +LANGUAGES_WITHOUT_SPACES = ["ja", "zh"] + +DEFAULT_ALIGN_MODELS_TORCH = { + "en": "WAV2VEC2_ASR_BASE_960H", + "fr": "VOXPOPULI_ASR_BASE_10K_FR", + "de": "VOXPOPULI_ASR_BASE_10K_DE", + "es": "VOXPOPULI_ASR_BASE_10K_ES", + "it": "VOXPOPULI_ASR_BASE_10K_IT", +} + +DEFAULT_ALIGN_MODELS_HF = { + "ja": "jonatasgrosman/wav2vec2-large-xlsr-53-japanese", +} def transcribe( @@ -255,7 +266,7 @@ def align( transcript: Iterator[dict], language: str, model: torch.nn.Module, - model_dictionary: dict, + align_model_metadata: dict, audio: Union[str, np.ndarray, torch.Tensor], device: str, extend_duration: float = 0.0, @@ -272,6 +283,10 @@ def align( MAX_DURATION = audio.shape[1] / SAMPLE_RATE + model_dictionary = align_model_metadata['dictionary'] + model_lang = align_model_metadata['language'] + model_type = align_model_metadata['type'] + prev_t2 = 0 word_segments_list = [] for idx, segment in enumerate(transcript): @@ -285,14 +300,16 @@ def align( waveform_segment = audio[:, f1:f2] with torch.inference_mode(): - if language not in asian_languages: + if model_type == "torchaudio": emissions, _ = model(waveform_segment.to(device)) - else: + elif model_type == "huggingface": emissions = model(waveform_segment.to(device)).logits + else: + raise NotImplementedError(f"Align model of type {model_type} not supported.") emissions = torch.log_softmax(emissions, dim=-1) emission = emissions[0].cpu().detach() transcription = segment['text'].strip() - if language not in asian_languages: + if language not in LANGUAGES_WITHOUT_SPACES: t_words = transcription.split(' ') else: t_words = [c for c in transcription] @@ -359,6 +376,41 @@ def align( return {"segments": transcript, "word_segments": word_segments_list} +def load_align_model(language_code, device, model_name=None): + if model_name is None: + # use default model + if language_code in DEFAULT_ALIGN_MODELS_TORCH: + model_name = DEFAULT_ALIGN_MODELS_TORCH[language_code] + elif language_code in DEFAULT_ALIGN_MODELS_HF: + model_name = DEFAULT_ALIGN_MODELS_HF[language_code] + else: + print(f"There is no default alignment model set for this language ({language_code}).\ + Please find a wav2vec2.0 model finetuned on this language in https://huggingface.co/models, then pass the model name in --align_model [MODEL_NAME]") + raise ValueError(f"No default align-model for language: {language_code}") + + if model_name in torchaudio.pipelines.__all__: + pipeline_type = "torchaudio" + bundle = torchaudio.pipelines.__dict__[model_name] + align_model = bundle.get_model().to(device) + labels = bundle.get_labels() + align_dictionary = {c.lower(): i for i, c in enumerate(labels)} + else: + try: + processor = AutoProcessor.from_pretrained(model_name) + align_model = Wav2Vec2ForCTC.from_pretrained(model_name) + except Exception as e: + print(e) + print(f"Error loading model from huggingface, check https://huggingface.co/models for finetuned wav2vec2.0 models") + raise ValueError(f'The chosen align_model "{model_name}" could not be found in huggingface (https://huggingface.co/models) or torchaudio (https://pytorch.org/audio/stable/pipelines.html#id14)') + pipeline_type = "huggingface" + align_model = align_model.to(device) + labels = processor.tokenizer.get_vocab() + align_dictionary = {char.lower(): code for char,code in processor.tokenizer.get_vocab().items()} + + align_metadata = {"language": language_code, "dictionary": align_dictionary, "type": pipeline_type} + + return align_model, align_metadata + def cli(): from . import available_models @@ -368,7 +420,7 @@ def cli(): parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default") parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference") # alignment params - parser.add_argument("--align_model", default="WAV2VEC2_ASR_BASE_960H", help="Name of phoneme-level ASR model to do alignment") + parser.add_argument("--align_model", default=None, help="Name of phoneme-level ASR model to do alignment") parser.add_argument("--align_extend", default=2, type=float, help="Seconds before and after to extend the whisper segments for alignment") parser.add_argument("--align_from_prev", default=True, type=bool, help="Whether to clip the alignment start time of current segment to the end time of the last aligned word of the previous segment") parser.add_argument("--drop_non_aligned", action="store_true", help="For word .srt, whether to drop non aliged words, or merge them into neighbouring.") @@ -430,24 +482,19 @@ def cli(): from . import load_model model = load_model(model_name, device=device, download_root=model_dir) - if align_model in torchaudio.pipelines.__all__: - bundle = torchaudio.pipelines.__dict__[align_model] - align_model = bundle.get_model().to(device) - labels = bundle.get_labels() - align_dictionary = {c.lower(): i for i, c in enumerate(labels)} - elif align_model in hugginface_models: - processor = AutoProcessor.from_pretrained(align_model) - align_model = Wav2Vec2ForCTC.from_pretrained(align_model).to(device) - align_model.to(device) - labels = processor.tokenizer.get_vocab() - align_dictionary = processor.tokenizer.get_vocab() - else: - print(f'Align model "{align_model}" is not supported, choose from:\n {torchaudio.pipelines.__all__ + wa2vec2_models_on_hugginface} \n\ - See details here https://pytorch.org/audio/stable/pipelines.html#id14') - raise ValueError(f'Align model "{align_model}" not supported') + + align_language = args["language"] if args["language"] is not None else "en" # default to loading english if not specified + align_model, align_metadata = load_align_model(align_language, device, model_name=align_model) + for audio_path in args.pop("audio"): result = transcribe(model, audio_path, temperature=temperature, **args) - result_aligned = align(result["segments"], result["language"], align_model, align_dictionary, audio_path, device, + + if result["language"] != align_metadata["language"]: + # load new language + print(f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language...") + align_model, align_metadata = load_align_model(result["language"], device) + + result_aligned = align(result["segments"], result["language"], align_model, align_metadata, audio_path, device, extend_duration=align_extend, start_from_previous=align_from_prev, drop_non_aligned_words=drop_non_aligned) audio_basename = os.path.basename(audio_path)