mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
Move load_model after WhisperModel
This commit is contained in:
186
whisperx/asr.py
186
whisperx/asr.py
@ -22,99 +22,6 @@ def find_numeral_symbol_tokens(tokenizer):
|
|||||||
numeral_symbol_tokens.append(i)
|
numeral_symbol_tokens.append(i)
|
||||||
return numeral_symbol_tokens
|
return numeral_symbol_tokens
|
||||||
|
|
||||||
def load_model(whisper_arch,
|
|
||||||
device,
|
|
||||||
device_index=0,
|
|
||||||
compute_type="float16",
|
|
||||||
asr_options=None,
|
|
||||||
language : Optional[str] = None,
|
|
||||||
vad_options=None,
|
|
||||||
model : Optional[WhisperModel] = None,
|
|
||||||
task="transcribe",
|
|
||||||
download_root=None,
|
|
||||||
threads=4):
|
|
||||||
'''Load a Whisper model for inference.
|
|
||||||
Args:
|
|
||||||
whisper_arch: str - The name of the Whisper model to load.
|
|
||||||
device: str - The device to load the model on.
|
|
||||||
compute_type: str - The compute type to use for the model.
|
|
||||||
options: dict - A dictionary of options to use for the model.
|
|
||||||
language: str - The language of the model. (use English for now)
|
|
||||||
model: Optional[WhisperModel] - The WhisperModel instance to use.
|
|
||||||
download_root: Optional[str] - The root directory to download the model to.
|
|
||||||
threads: int - The number of cpu threads to use per worker, e.g. will be multiplied by num workers.
|
|
||||||
Returns:
|
|
||||||
A Whisper pipeline.
|
|
||||||
'''
|
|
||||||
|
|
||||||
if whisper_arch.endswith(".en"):
|
|
||||||
language = "en"
|
|
||||||
|
|
||||||
model = model or WhisperModel(whisper_arch,
|
|
||||||
device=device,
|
|
||||||
device_index=device_index,
|
|
||||||
compute_type=compute_type,
|
|
||||||
download_root=download_root,
|
|
||||||
cpu_threads=threads)
|
|
||||||
if language is not None:
|
|
||||||
tokenizer = faster_whisper.tokenizer.Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task=task, language=language)
|
|
||||||
else:
|
|
||||||
print("No language specified, language will be first be detected for each audio file (increases inference time).")
|
|
||||||
tokenizer = None
|
|
||||||
|
|
||||||
default_asr_options = {
|
|
||||||
"beam_size": 5,
|
|
||||||
"best_of": 5,
|
|
||||||
"patience": 1,
|
|
||||||
"length_penalty": 1,
|
|
||||||
"repetition_penalty": 1,
|
|
||||||
"no_repeat_ngram_size": 0,
|
|
||||||
"temperatures": [0.0, 0.2, 0.4, 0.6, 0.8, 1.0],
|
|
||||||
"compression_ratio_threshold": 2.4,
|
|
||||||
"log_prob_threshold": -1.0,
|
|
||||||
"no_speech_threshold": 0.6,
|
|
||||||
"condition_on_previous_text": False,
|
|
||||||
"prompt_reset_on_temperature": 0.5,
|
|
||||||
"initial_prompt": None,
|
|
||||||
"prefix": None,
|
|
||||||
"suppress_blank": True,
|
|
||||||
"suppress_tokens": [-1],
|
|
||||||
"without_timestamps": True,
|
|
||||||
"max_initial_timestamp": 0.0,
|
|
||||||
"word_timestamps": False,
|
|
||||||
"prepend_punctuations": "\"'“¿([{-",
|
|
||||||
"append_punctuations": "\"'.。,,!!??::”)]}、",
|
|
||||||
"suppress_numerals": False,
|
|
||||||
}
|
|
||||||
|
|
||||||
if asr_options is not None:
|
|
||||||
default_asr_options.update(asr_options)
|
|
||||||
|
|
||||||
suppress_numerals = default_asr_options["suppress_numerals"]
|
|
||||||
del default_asr_options["suppress_numerals"]
|
|
||||||
|
|
||||||
default_asr_options = faster_whisper.transcribe.TranscriptionOptions(**default_asr_options)
|
|
||||||
|
|
||||||
default_vad_options = {
|
|
||||||
"vad_onset": 0.500,
|
|
||||||
"vad_offset": 0.363
|
|
||||||
}
|
|
||||||
|
|
||||||
if vad_options is not None:
|
|
||||||
default_vad_options.update(vad_options)
|
|
||||||
|
|
||||||
vad_model = load_vad_model(torch.device(device), use_auth_token=None, **default_vad_options)
|
|
||||||
|
|
||||||
return FasterWhisperPipeline(
|
|
||||||
model=model,
|
|
||||||
vad=vad_model,
|
|
||||||
options=default_asr_options,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
language=language,
|
|
||||||
suppress_numerals=suppress_numerals,
|
|
||||||
vad_params=default_vad_options,
|
|
||||||
)
|
|
||||||
|
|
||||||
class WhisperModel(faster_whisper.WhisperModel):
|
class WhisperModel(faster_whisper.WhisperModel):
|
||||||
'''
|
'''
|
||||||
FasterWhisperModel provides batched inference for faster-whisper.
|
FasterWhisperModel provides batched inference for faster-whisper.
|
||||||
@ -341,3 +248,96 @@ class FasterWhisperPipeline(Pipeline):
|
|||||||
language = language_token[2:-2]
|
language = language_token[2:-2]
|
||||||
print(f"Detected language: {language} ({language_probability:.2f}) in first 30s of audio...")
|
print(f"Detected language: {language} ({language_probability:.2f}) in first 30s of audio...")
|
||||||
return language
|
return language
|
||||||
|
|
||||||
|
def load_model(whisper_arch,
|
||||||
|
device,
|
||||||
|
device_index=0,
|
||||||
|
compute_type="float16",
|
||||||
|
asr_options=None,
|
||||||
|
language : Optional[str] = None,
|
||||||
|
vad_options=None,
|
||||||
|
model : Optional[WhisperModel] = None,
|
||||||
|
task="transcribe",
|
||||||
|
download_root=None,
|
||||||
|
threads=4):
|
||||||
|
'''Load a Whisper model for inference.
|
||||||
|
Args:
|
||||||
|
whisper_arch: str - The name of the Whisper model to load.
|
||||||
|
device: str - The device to load the model on.
|
||||||
|
compute_type: str - The compute type to use for the model.
|
||||||
|
options: dict - A dictionary of options to use for the model.
|
||||||
|
language: str - The language of the model. (use English for now)
|
||||||
|
model: Optional[WhisperModel] - The WhisperModel instance to use.
|
||||||
|
download_root: Optional[str] - The root directory to download the model to.
|
||||||
|
threads: int - The number of cpu threads to use per worker, e.g. will be multiplied by num workers.
|
||||||
|
Returns:
|
||||||
|
A Whisper pipeline.
|
||||||
|
'''
|
||||||
|
|
||||||
|
if whisper_arch.endswith(".en"):
|
||||||
|
language = "en"
|
||||||
|
|
||||||
|
model = model or WhisperModel(whisper_arch,
|
||||||
|
device=device,
|
||||||
|
device_index=device_index,
|
||||||
|
compute_type=compute_type,
|
||||||
|
download_root=download_root,
|
||||||
|
cpu_threads=threads)
|
||||||
|
if language is not None:
|
||||||
|
tokenizer = faster_whisper.tokenizer.Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task=task, language=language)
|
||||||
|
else:
|
||||||
|
print("No language specified, language will be first be detected for each audio file (increases inference time).")
|
||||||
|
tokenizer = None
|
||||||
|
|
||||||
|
default_asr_options = {
|
||||||
|
"beam_size": 5,
|
||||||
|
"best_of": 5,
|
||||||
|
"patience": 1,
|
||||||
|
"length_penalty": 1,
|
||||||
|
"repetition_penalty": 1,
|
||||||
|
"no_repeat_ngram_size": 0,
|
||||||
|
"temperatures": [0.0, 0.2, 0.4, 0.6, 0.8, 1.0],
|
||||||
|
"compression_ratio_threshold": 2.4,
|
||||||
|
"log_prob_threshold": -1.0,
|
||||||
|
"no_speech_threshold": 0.6,
|
||||||
|
"condition_on_previous_text": False,
|
||||||
|
"prompt_reset_on_temperature": 0.5,
|
||||||
|
"initial_prompt": None,
|
||||||
|
"prefix": None,
|
||||||
|
"suppress_blank": True,
|
||||||
|
"suppress_tokens": [-1],
|
||||||
|
"without_timestamps": True,
|
||||||
|
"max_initial_timestamp": 0.0,
|
||||||
|
"word_timestamps": False,
|
||||||
|
"prepend_punctuations": "\"'“¿([{-",
|
||||||
|
"append_punctuations": "\"'.。,,!!??::”)]}、",
|
||||||
|
"suppress_numerals": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
if asr_options is not None:
|
||||||
|
default_asr_options.update(asr_options)
|
||||||
|
|
||||||
|
suppress_numerals = default_asr_options["suppress_numerals"]
|
||||||
|
del default_asr_options["suppress_numerals"]
|
||||||
|
|
||||||
|
default_asr_options = faster_whisper.transcribe.TranscriptionOptions(**default_asr_options)
|
||||||
|
|
||||||
|
default_vad_options = {
|
||||||
|
"vad_onset": 0.500,
|
||||||
|
"vad_offset": 0.363
|
||||||
|
}
|
||||||
|
|
||||||
|
if vad_options is not None:
|
||||||
|
default_vad_options.update(vad_options)
|
||||||
|
|
||||||
|
vad_model = load_vad_model(torch.device(device), use_auth_token=None, **default_vad_options)
|
||||||
|
|
||||||
|
return FasterWhisperPipeline(
|
||||||
|
model=model,
|
||||||
|
vad=vad_model,
|
||||||
|
options=default_asr_options,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
language=language,
|
||||||
|
suppress_numerals=suppress_numerals,
|
||||||
|
vad_params=default_vad_options,
|
||||||
|
)
|
||||||
|
Reference in New Issue
Block a user