diff --git a/whisperx/asr.py b/whisperx/asr.py index ecc2765..49e8efd 100644 --- a/whisperx/asr.py +++ b/whisperx/asr.py @@ -31,7 +31,8 @@ def load_model(whisper_arch, vad_options=None, model=None, task="transcribe", - download_root=None): + download_root=None, + threads=4): '''Load a Whisper model for inference. Args: whisper_arch: str - The name of the Whisper model to load. @@ -40,6 +41,7 @@ def load_model(whisper_arch, options: dict - A dictionary of options to use for the model. language: str - The language of the model. (use English for now) download_root: Optional[str] - The root directory to download the model to. + threads: int - The number of cpu threads to use per worker, e.g. will be multiplied by num workers. Returns: A Whisper pipeline. ''' @@ -51,7 +53,8 @@ def load_model(whisper_arch, device=device, device_index=device_index, compute_type=compute_type, - download_root=download_root) + download_root=download_root, + cpu_threads=threads) if language is not None: tokenizer = faster_whisper.tokenizer.Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task=task, language=language) else: diff --git a/whisperx/transcribe.py b/whisperx/transcribe.py index 1cc144e..e0101c1 100644 --- a/whisperx/transcribe.py +++ b/whisperx/transcribe.py @@ -119,8 +119,10 @@ def cli(): else: temperature = [temperature] + faster_whisper_threads = 4 if (threads := args.pop("threads")) > 0: torch.set_num_threads(threads) + faster_whisper_threads = threads asr_options = { "beam_size": args.pop("beam_size"), @@ -150,7 +152,7 @@ def cli(): results = [] tmp_results = [] # model = load_model(model_name, device=device, download_root=model_dir) - model = load_model(model_name, device=device, device_index=device_index, compute_type=compute_type, language=args['language'], asr_options=asr_options, vad_options={"vad_onset": vad_onset, "vad_offset": vad_offset}, task=task) + model = load_model(model_name, device=device, device_index=device_index, compute_type=compute_type, language=args['language'], asr_options=asr_options, vad_options={"vad_onset": vad_onset, "vad_offset": vad_offset}, task=task, threads=faster_whisper_threads) for audio_path in args.pop("audio"): audio = load_audio(audio_path)