diff --git a/whisperx/transcribe.py b/whisperx/transcribe.py index 6fff837..edd2764 100644 --- a/whisperx/transcribe.py +++ b/whisperx/transcribe.py @@ -81,6 +81,7 @@ def cli(): args = parser.parse_args().__dict__ model_name: str = args.pop("model") batch_size: int = args.pop("batch_size") + model_dir: str = args.pop("model_dir") output_dir: str = args.pop("output_dir") output_format: str = args.pop("output_format") device: str = args.pop("device") @@ -166,7 +167,7 @@ def cli(): results = [] tmp_results = [] # model = load_model(model_name, device=device, download_root=model_dir) - model = load_model(model_name, device=device, device_index=device_index, compute_type=compute_type, language=args['language'], asr_options=asr_options, vad_options={"vad_onset": vad_onset, "vad_offset": vad_offset}, task=task, threads=faster_whisper_threads) + model = load_model(model_name, device=device, device_index=device_index, download_root=model_dir, compute_type=compute_type, language=args['language'], asr_options=asr_options, vad_options={"vad_onset": vad_onset, "vad_offset": vad_offset}, task=task, threads=faster_whisper_threads) for audio_path in args.pop("audio"): audio = load_audio(audio_path)