From bd3aa03b6f701584d360d3978a1fa57b2cd63d48 Mon Sep 17 00:00:00 2001 From: Douglas Trajano Date: Thu, 16 Nov 2023 08:59:28 -0300 Subject: [PATCH] Move load_model after WhisperModel --- whisperx/asr.py | 186 ++++++++++++++++++++++++------------------------ 1 file changed, 93 insertions(+), 93 deletions(-) diff --git a/whisperx/asr.py b/whisperx/asr.py index 3b86634..94e0311 100644 --- a/whisperx/asr.py +++ b/whisperx/asr.py @@ -22,99 +22,6 @@ def find_numeral_symbol_tokens(tokenizer): numeral_symbol_tokens.append(i) return numeral_symbol_tokens -def load_model(whisper_arch, - device, - device_index=0, - compute_type="float16", - asr_options=None, - language : Optional[str] = None, - vad_options=None, - model : Optional[WhisperModel] = None, - task="transcribe", - download_root=None, - threads=4): - '''Load a Whisper model for inference. - Args: - whisper_arch: str - The name of the Whisper model to load. - device: str - The device to load the model on. - compute_type: str - The compute type to use for the model. - options: dict - A dictionary of options to use for the model. - language: str - The language of the model. (use English for now) - model: Optional[WhisperModel] - The WhisperModel instance to use. - download_root: Optional[str] - The root directory to download the model to. - threads: int - The number of cpu threads to use per worker, e.g. will be multiplied by num workers. - Returns: - A Whisper pipeline. - ''' - - if whisper_arch.endswith(".en"): - language = "en" - - model = model or WhisperModel(whisper_arch, - device=device, - device_index=device_index, - compute_type=compute_type, - download_root=download_root, - cpu_threads=threads) - if language is not None: - tokenizer = faster_whisper.tokenizer.Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task=task, language=language) - else: - print("No language specified, language will be first be detected for each audio file (increases inference time).") - tokenizer = None - - default_asr_options = { - "beam_size": 5, - "best_of": 5, - "patience": 1, - "length_penalty": 1, - "repetition_penalty": 1, - "no_repeat_ngram_size": 0, - "temperatures": [0.0, 0.2, 0.4, 0.6, 0.8, 1.0], - "compression_ratio_threshold": 2.4, - "log_prob_threshold": -1.0, - "no_speech_threshold": 0.6, - "condition_on_previous_text": False, - "prompt_reset_on_temperature": 0.5, - "initial_prompt": None, - "prefix": None, - "suppress_blank": True, - "suppress_tokens": [-1], - "without_timestamps": True, - "max_initial_timestamp": 0.0, - "word_timestamps": False, - "prepend_punctuations": "\"'“¿([{-", - "append_punctuations": "\"'.。,,!!??::”)]}、", - "suppress_numerals": False, - } - - if asr_options is not None: - default_asr_options.update(asr_options) - - suppress_numerals = default_asr_options["suppress_numerals"] - del default_asr_options["suppress_numerals"] - - default_asr_options = faster_whisper.transcribe.TranscriptionOptions(**default_asr_options) - - default_vad_options = { - "vad_onset": 0.500, - "vad_offset": 0.363 - } - - if vad_options is not None: - default_vad_options.update(vad_options) - - vad_model = load_vad_model(torch.device(device), use_auth_token=None, **default_vad_options) - - return FasterWhisperPipeline( - model=model, - vad=vad_model, - options=default_asr_options, - tokenizer=tokenizer, - language=language, - suppress_numerals=suppress_numerals, - vad_params=default_vad_options, - ) - class WhisperModel(faster_whisper.WhisperModel): ''' FasterWhisperModel provides batched inference for faster-whisper. @@ -341,3 +248,96 @@ class FasterWhisperPipeline(Pipeline): language = language_token[2:-2] print(f"Detected language: {language} ({language_probability:.2f}) in first 30s of audio...") return language + +def load_model(whisper_arch, + device, + device_index=0, + compute_type="float16", + asr_options=None, + language : Optional[str] = None, + vad_options=None, + model : Optional[WhisperModel] = None, + task="transcribe", + download_root=None, + threads=4): + '''Load a Whisper model for inference. + Args: + whisper_arch: str - The name of the Whisper model to load. + device: str - The device to load the model on. + compute_type: str - The compute type to use for the model. + options: dict - A dictionary of options to use for the model. + language: str - The language of the model. (use English for now) + model: Optional[WhisperModel] - The WhisperModel instance to use. + download_root: Optional[str] - The root directory to download the model to. + threads: int - The number of cpu threads to use per worker, e.g. will be multiplied by num workers. + Returns: + A Whisper pipeline. + ''' + + if whisper_arch.endswith(".en"): + language = "en" + + model = model or WhisperModel(whisper_arch, + device=device, + device_index=device_index, + compute_type=compute_type, + download_root=download_root, + cpu_threads=threads) + if language is not None: + tokenizer = faster_whisper.tokenizer.Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task=task, language=language) + else: + print("No language specified, language will be first be detected for each audio file (increases inference time).") + tokenizer = None + + default_asr_options = { + "beam_size": 5, + "best_of": 5, + "patience": 1, + "length_penalty": 1, + "repetition_penalty": 1, + "no_repeat_ngram_size": 0, + "temperatures": [0.0, 0.2, 0.4, 0.6, 0.8, 1.0], + "compression_ratio_threshold": 2.4, + "log_prob_threshold": -1.0, + "no_speech_threshold": 0.6, + "condition_on_previous_text": False, + "prompt_reset_on_temperature": 0.5, + "initial_prompt": None, + "prefix": None, + "suppress_blank": True, + "suppress_tokens": [-1], + "without_timestamps": True, + "max_initial_timestamp": 0.0, + "word_timestamps": False, + "prepend_punctuations": "\"'“¿([{-", + "append_punctuations": "\"'.。,,!!??::”)]}、", + "suppress_numerals": False, + } + + if asr_options is not None: + default_asr_options.update(asr_options) + + suppress_numerals = default_asr_options["suppress_numerals"] + del default_asr_options["suppress_numerals"] + + default_asr_options = faster_whisper.transcribe.TranscriptionOptions(**default_asr_options) + + default_vad_options = { + "vad_onset": 0.500, + "vad_offset": 0.363 + } + + if vad_options is not None: + default_vad_options.update(vad_options) + + vad_model = load_vad_model(torch.device(device), use_auth_token=None, **default_vad_options) + + return FasterWhisperPipeline( + model=model, + vad=vad_model, + options=default_asr_options, + tokenizer=tokenizer, + language=language, + suppress_numerals=suppress_numerals, + vad_params=default_vad_options, + )