diff --git a/whisperx/transcribe.py b/whisperx/transcribe.py index 1d7d1ac..ae61ed1 100644 --- a/whisperx/transcribe.py +++ b/whisperx/transcribe.py @@ -26,6 +26,7 @@ def cli(): parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe") parser.add_argument("--model", default="small", help="name of the Whisper model to use") + parser.add_argument("--model_cache_only", type=str2bool, default=False, help="If True, will not attempt to download models, instead using cached models from --model_dir") parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default") parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference") parser.add_argument("--device_index", default=0, type=int, help="device index to use for FasterWhisper inference") @@ -90,6 +91,7 @@ def cli(): model_name: str = args.pop("model") batch_size: int = args.pop("batch_size") model_dir: str = args.pop("model_dir") + model_cache_only: bool = args.pop("model_cache_only") output_dir: str = args.pop("output_dir") output_format: str = args.pop("output_format") device: str = args.pop("device") @@ -177,7 +179,7 @@ def cli(): results = [] tmp_results = [] # model = load_model(model_name, device=device, download_root=model_dir) - model = load_model(model_name, device=device, device_index=device_index, download_root=model_dir, compute_type=compute_type, language=args['language'], asr_options=asr_options, vad_method=vad_method, vad_options={"chunk_size":chunk_size, "vad_onset": vad_onset, "vad_offset": vad_offset}, task=task, threads=faster_whisper_threads) + model = load_model(model_name, device=device, device_index=device_index, download_root=model_dir, compute_type=compute_type, language=args['language'], asr_options=asr_options, vad_method=vad_method, vad_options={"chunk_size":chunk_size, "vad_onset": vad_onset, "vad_offset": vad_offset}, task=task, local_files_only=model_cache_only, threads=faster_whisper_threads) for audio_path in args.pop("audio"): audio = load_audio(audio_path)