mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
Merge pull request #584 from DougTrajano/patch-1
Move load_model after WhisperModel
This commit is contained in:
186
whisperx/asr.py
186
whisperx/asr.py
@ -22,99 +22,6 @@ def find_numeral_symbol_tokens(tokenizer):
|
||||
numeral_symbol_tokens.append(i)
|
||||
return numeral_symbol_tokens
|
||||
|
||||
def load_model(whisper_arch,
|
||||
device,
|
||||
device_index=0,
|
||||
compute_type="float16",
|
||||
asr_options=None,
|
||||
language : Optional[str] = None,
|
||||
vad_options=None,
|
||||
model : Optional[WhisperModel] = None,
|
||||
task="transcribe",
|
||||
download_root=None,
|
||||
threads=4):
|
||||
'''Load a Whisper model for inference.
|
||||
Args:
|
||||
whisper_arch: str - The name of the Whisper model to load.
|
||||
device: str - The device to load the model on.
|
||||
compute_type: str - The compute type to use for the model.
|
||||
options: dict - A dictionary of options to use for the model.
|
||||
language: str - The language of the model. (use English for now)
|
||||
model: Optional[WhisperModel] - The WhisperModel instance to use.
|
||||
download_root: Optional[str] - The root directory to download the model to.
|
||||
threads: int - The number of cpu threads to use per worker, e.g. will be multiplied by num workers.
|
||||
Returns:
|
||||
A Whisper pipeline.
|
||||
'''
|
||||
|
||||
if whisper_arch.endswith(".en"):
|
||||
language = "en"
|
||||
|
||||
model = model or WhisperModel(whisper_arch,
|
||||
device=device,
|
||||
device_index=device_index,
|
||||
compute_type=compute_type,
|
||||
download_root=download_root,
|
||||
cpu_threads=threads)
|
||||
if language is not None:
|
||||
tokenizer = faster_whisper.tokenizer.Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task=task, language=language)
|
||||
else:
|
||||
print("No language specified, language will be first be detected for each audio file (increases inference time).")
|
||||
tokenizer = None
|
||||
|
||||
default_asr_options = {
|
||||
"beam_size": 5,
|
||||
"best_of": 5,
|
||||
"patience": 1,
|
||||
"length_penalty": 1,
|
||||
"repetition_penalty": 1,
|
||||
"no_repeat_ngram_size": 0,
|
||||
"temperatures": [0.0, 0.2, 0.4, 0.6, 0.8, 1.0],
|
||||
"compression_ratio_threshold": 2.4,
|
||||
"log_prob_threshold": -1.0,
|
||||
"no_speech_threshold": 0.6,
|
||||
"condition_on_previous_text": False,
|
||||
"prompt_reset_on_temperature": 0.5,
|
||||
"initial_prompt": None,
|
||||
"prefix": None,
|
||||
"suppress_blank": True,
|
||||
"suppress_tokens": [-1],
|
||||
"without_timestamps": True,
|
||||
"max_initial_timestamp": 0.0,
|
||||
"word_timestamps": False,
|
||||
"prepend_punctuations": "\"'“¿([{-",
|
||||
"append_punctuations": "\"'.。,,!!??::”)]}、",
|
||||
"suppress_numerals": False,
|
||||
}
|
||||
|
||||
if asr_options is not None:
|
||||
default_asr_options.update(asr_options)
|
||||
|
||||
suppress_numerals = default_asr_options["suppress_numerals"]
|
||||
del default_asr_options["suppress_numerals"]
|
||||
|
||||
default_asr_options = faster_whisper.transcribe.TranscriptionOptions(**default_asr_options)
|
||||
|
||||
default_vad_options = {
|
||||
"vad_onset": 0.500,
|
||||
"vad_offset": 0.363
|
||||
}
|
||||
|
||||
if vad_options is not None:
|
||||
default_vad_options.update(vad_options)
|
||||
|
||||
vad_model = load_vad_model(torch.device(device), use_auth_token=None, **default_vad_options)
|
||||
|
||||
return FasterWhisperPipeline(
|
||||
model=model,
|
||||
vad=vad_model,
|
||||
options=default_asr_options,
|
||||
tokenizer=tokenizer,
|
||||
language=language,
|
||||
suppress_numerals=suppress_numerals,
|
||||
vad_params=default_vad_options,
|
||||
)
|
||||
|
||||
class WhisperModel(faster_whisper.WhisperModel):
|
||||
'''
|
||||
FasterWhisperModel provides batched inference for faster-whisper.
|
||||
@ -341,3 +248,96 @@ class FasterWhisperPipeline(Pipeline):
|
||||
language = language_token[2:-2]
|
||||
print(f"Detected language: {language} ({language_probability:.2f}) in first 30s of audio...")
|
||||
return language
|
||||
|
||||
def load_model(whisper_arch,
|
||||
device,
|
||||
device_index=0,
|
||||
compute_type="float16",
|
||||
asr_options=None,
|
||||
language : Optional[str] = None,
|
||||
vad_options=None,
|
||||
model : Optional[WhisperModel] = None,
|
||||
task="transcribe",
|
||||
download_root=None,
|
||||
threads=4):
|
||||
'''Load a Whisper model for inference.
|
||||
Args:
|
||||
whisper_arch: str - The name of the Whisper model to load.
|
||||
device: str - The device to load the model on.
|
||||
compute_type: str - The compute type to use for the model.
|
||||
options: dict - A dictionary of options to use for the model.
|
||||
language: str - The language of the model. (use English for now)
|
||||
model: Optional[WhisperModel] - The WhisperModel instance to use.
|
||||
download_root: Optional[str] - The root directory to download the model to.
|
||||
threads: int - The number of cpu threads to use per worker, e.g. will be multiplied by num workers.
|
||||
Returns:
|
||||
A Whisper pipeline.
|
||||
'''
|
||||
|
||||
if whisper_arch.endswith(".en"):
|
||||
language = "en"
|
||||
|
||||
model = model or WhisperModel(whisper_arch,
|
||||
device=device,
|
||||
device_index=device_index,
|
||||
compute_type=compute_type,
|
||||
download_root=download_root,
|
||||
cpu_threads=threads)
|
||||
if language is not None:
|
||||
tokenizer = faster_whisper.tokenizer.Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task=task, language=language)
|
||||
else:
|
||||
print("No language specified, language will be first be detected for each audio file (increases inference time).")
|
||||
tokenizer = None
|
||||
|
||||
default_asr_options = {
|
||||
"beam_size": 5,
|
||||
"best_of": 5,
|
||||
"patience": 1,
|
||||
"length_penalty": 1,
|
||||
"repetition_penalty": 1,
|
||||
"no_repeat_ngram_size": 0,
|
||||
"temperatures": [0.0, 0.2, 0.4, 0.6, 0.8, 1.0],
|
||||
"compression_ratio_threshold": 2.4,
|
||||
"log_prob_threshold": -1.0,
|
||||
"no_speech_threshold": 0.6,
|
||||
"condition_on_previous_text": False,
|
||||
"prompt_reset_on_temperature": 0.5,
|
||||
"initial_prompt": None,
|
||||
"prefix": None,
|
||||
"suppress_blank": True,
|
||||
"suppress_tokens": [-1],
|
||||
"without_timestamps": True,
|
||||
"max_initial_timestamp": 0.0,
|
||||
"word_timestamps": False,
|
||||
"prepend_punctuations": "\"'“¿([{-",
|
||||
"append_punctuations": "\"'.。,,!!??::”)]}、",
|
||||
"suppress_numerals": False,
|
||||
}
|
||||
|
||||
if asr_options is not None:
|
||||
default_asr_options.update(asr_options)
|
||||
|
||||
suppress_numerals = default_asr_options["suppress_numerals"]
|
||||
del default_asr_options["suppress_numerals"]
|
||||
|
||||
default_asr_options = faster_whisper.transcribe.TranscriptionOptions(**default_asr_options)
|
||||
|
||||
default_vad_options = {
|
||||
"vad_onset": 0.500,
|
||||
"vad_offset": 0.363
|
||||
}
|
||||
|
||||
if vad_options is not None:
|
||||
default_vad_options.update(vad_options)
|
||||
|
||||
vad_model = load_vad_model(torch.device(device), use_auth_token=None, **default_vad_options)
|
||||
|
||||
return FasterWhisperPipeline(
|
||||
model=model,
|
||||
vad=vad_model,
|
||||
options=default_asr_options,
|
||||
tokenizer=tokenizer,
|
||||
language=language,
|
||||
suppress_numerals=suppress_numerals,
|
||||
vad_params=default_vad_options,
|
||||
)
|
||||
|
Reference in New Issue
Block a user