From 53396adb210d1db07f4400bb29e8aa8c0ae88af5 Mon Sep 17 00:00:00 2001 From: Simon Date: Sat, 20 May 2023 13:02:46 +0200 Subject: [PATCH 1/2] add device_index --- whisperx/asr.py | 4 ++-- whisperx/transcribe.py | 4 +++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/whisperx/asr.py b/whisperx/asr.py index 88d5bf6..470e701 100644 --- a/whisperx/asr.py +++ b/whisperx/asr.py @@ -13,7 +13,7 @@ from .audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram from .vad import load_vad_model, merge_chunks from .types import TranscriptionResult, SingleSegment -def load_model(whisper_arch, device, compute_type="float16", asr_options=None, language=None, +def load_model(whisper_arch, device, device_index=0, compute_type="float16", asr_options=None, language=None, vad_options=None, model=None, task="transcribe"): '''Load a Whisper model for inference. Args: @@ -29,7 +29,7 @@ def load_model(whisper_arch, device, compute_type="float16", asr_options=None, l if whisper_arch.endswith(".en"): language = "en" - model = WhisperModel(whisper_arch, device=device, compute_type=compute_type) + model = WhisperModel(whisper_arch, device=device, device_index=device_index, compute_type=compute_type) if language is not None: tokenizer = faster_whisper.tokenizer.Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task=task, language=language) else: diff --git a/whisperx/transcribe.py b/whisperx/transcribe.py index 3edc746..4432abe 100644 --- a/whisperx/transcribe.py +++ b/whisperx/transcribe.py @@ -21,6 +21,7 @@ def cli(): parser.add_argument("--model", default="small", help="name of the Whisper model to use") parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default") parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference") + parser.add_argument("--device_index", default=None, type=int, help="device index to use for FasterWhisper inference") parser.add_argument("--batch_size", default=8, type=int, help="device to use for PyTorch inference") parser.add_argument("--compute_type", default="float16", type=str, choices=["float16", "float32", "int8"], help="compute type for computation") @@ -78,6 +79,7 @@ def cli(): output_dir: str = args.pop("output_dir") output_format: str = args.pop("output_format") device: str = args.pop("device") + device_index: int = args.pop("device_index") compute_type: str = args.pop("compute_type") # model_flush: bool = args.pop("model_flush") @@ -144,7 +146,7 @@ def cli(): results = [] tmp_results = [] # model = load_model(model_name, device=device, download_root=model_dir) - model = load_model(model_name, device=device, compute_type=compute_type, language=args['language'], asr_options=asr_options, vad_options={"vad_onset": vad_onset, "vad_offset": vad_offset}, task=task) + model = load_model(model_name, device=device, device_index=device_index, compute_type=compute_type, language=args['language'], asr_options=asr_options, vad_options={"vad_onset": vad_onset, "vad_offset": vad_offset}, task=task) for audio_path in args.pop("audio"): audio = load_audio(audio_path) From 74b98ebfaab771f4078c7ffe973117257667dda2 Mon Sep 17 00:00:00 2001 From: Simon Date: Sat, 20 May 2023 13:11:30 +0200 Subject: [PATCH 2/2] ensure device_index not None --- whisperx/transcribe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/whisperx/transcribe.py b/whisperx/transcribe.py index 4432abe..691e3f9 100644 --- a/whisperx/transcribe.py +++ b/whisperx/transcribe.py @@ -21,7 +21,7 @@ def cli(): parser.add_argument("--model", default="small", help="name of the Whisper model to use") parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default") parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference") - parser.add_argument("--device_index", default=None, type=int, help="device index to use for FasterWhisper inference") + parser.add_argument("--device_index", default=0, type=int, help="device index to use for FasterWhisper inference") parser.add_argument("--batch_size", default=8, type=int, help="device to use for PyTorch inference") parser.add_argument("--compute_type", default="float16", type=str, choices=["float16", "float32", "int8"], help="compute type for computation")