diff --git a/whisperx/asr.py b/whisperx/asr.py index dba8271..ac816ff 100644 --- a/whisperx/asr.py +++ b/whisperx/asr.py @@ -262,6 +262,7 @@ def load_model(whisper_arch, compute_type="float16", asr_options=None, language : Optional[str] = None, + vad_model=None, vad_options=None, model : Optional[WhisperModel] = None, task="transcribe", @@ -337,7 +338,10 @@ def load_model(whisper_arch, if vad_options is not None: default_vad_options.update(vad_options) - vad_model = load_vad_model(torch.device(device), use_auth_token=None, **default_vad_options) + if vad_model is not None: + vad_model = vad_model + else: + vad_model = load_vad_model(torch.device(device), use_auth_token=None, **default_vad_options) return FasterWhisperPipeline( model=model,