diff --git a/whisperx/asr.py b/whisperx/asr.py index 1734fb9..3b86634 100644 --- a/whisperx/asr.py +++ b/whisperx/asr.py @@ -29,7 +29,7 @@ def load_model(whisper_arch, asr_options=None, language : Optional[str] = None, vad_options=None, - model=None, + model : Optional[WhisperModel] = None, task="transcribe", download_root=None, threads=4): @@ -40,6 +40,7 @@ def load_model(whisper_arch, compute_type: str - The compute type to use for the model. options: dict - A dictionary of options to use for the model. language: str - The language of the model. (use English for now) + model: Optional[WhisperModel] - The WhisperModel instance to use. download_root: Optional[str] - The root directory to download the model to. threads: int - The number of cpu threads to use per worker, e.g. will be multiplied by num workers. Returns: @@ -49,7 +50,7 @@ def load_model(whisper_arch, if whisper_arch.endswith(".en"): language = "en" - model = WhisperModel(whisper_arch, + model = model or WhisperModel(whisper_arch, device=device, device_index=device_index, compute_type=compute_type,