From 48d651e5eadb7b91575d6e8e212777ebd386473b Mon Sep 17 00:00:00 2001 From: kaka1909 Date: Thu, 16 Nov 2023 15:29:24 +0800 Subject: [PATCH] Update asr.py and make the model parameter be used --- whisperx/asr.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/whisperx/asr.py b/whisperx/asr.py index 1734fb9..3b86634 100644 --- a/whisperx/asr.py +++ b/whisperx/asr.py @@ -29,7 +29,7 @@ def load_model(whisper_arch, asr_options=None, language : Optional[str] = None, vad_options=None, - model=None, + model : Optional[WhisperModel] = None, task="transcribe", download_root=None, threads=4): @@ -40,6 +40,7 @@ def load_model(whisper_arch, compute_type: str - The compute type to use for the model. options: dict - A dictionary of options to use for the model. language: str - The language of the model. (use English for now) + model: Optional[WhisperModel] - The WhisperModel instance to use. download_root: Optional[str] - The root directory to download the model to. threads: int - The number of cpu threads to use per worker, e.g. will be multiplied by num workers. Returns: @@ -49,7 +50,7 @@ def load_model(whisper_arch, if whisper_arch.endswith(".en"): language = "en" - model = WhisperModel(whisper_arch, + model = model or WhisperModel(whisper_arch, device=device, device_index=device_index, compute_type=compute_type,