Merge pull request #266 from sorgfresser/main

Add device_index option
This commit is contained in:
Max Bain
2023-05-20 14:45:34 +01:00
committed by GitHub
2 changed files with 5 additions and 3 deletions

View File

@ -13,7 +13,7 @@ from .audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram
from .vad import load_vad_model, merge_chunks
from .types import TranscriptionResult, SingleSegment
def load_model(whisper_arch, device, compute_type="float16", asr_options=None, language=None,
def load_model(whisper_arch, device, device_index=0, compute_type="float16", asr_options=None, language=None,
vad_options=None, model=None, task="transcribe"):
'''Load a Whisper model for inference.
Args:
@ -29,7 +29,7 @@ def load_model(whisper_arch, device, compute_type="float16", asr_options=None, l
if whisper_arch.endswith(".en"):
language = "en"
model = WhisperModel(whisper_arch, device=device, compute_type=compute_type)
model = WhisperModel(whisper_arch, device=device, device_index=device_index, compute_type=compute_type)
if language is not None:
tokenizer = faster_whisper.tokenizer.Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task=task, language=language)
else:

View File

@ -21,6 +21,7 @@ def cli():
parser.add_argument("--model", default="small", help="name of the Whisper model to use")
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
parser.add_argument("--device_index", default=0, type=int, help="device index to use for FasterWhisper inference")
parser.add_argument("--batch_size", default=8, type=int, help="device to use for PyTorch inference")
parser.add_argument("--compute_type", default="float16", type=str, choices=["float16", "float32", "int8"], help="compute type for computation")
@ -78,6 +79,7 @@ def cli():
output_dir: str = args.pop("output_dir")
output_format: str = args.pop("output_format")
device: str = args.pop("device")
device_index: int = args.pop("device_index")
compute_type: str = args.pop("compute_type")
# model_flush: bool = args.pop("model_flush")
@ -144,7 +146,7 @@ def cli():
results = []
tmp_results = []
# model = load_model(model_name, device=device, download_root=model_dir)
model = load_model(model_name, device=device, compute_type=compute_type, language=args['language'], asr_options=asr_options, vad_options={"vad_onset": vad_onset, "vad_offset": vad_offset}, task=task)
model = load_model(model_name, device=device, device_index=device_index, compute_type=compute_type, language=args['language'], asr_options=asr_options, vad_options={"vad_onset": vad_onset, "vad_offset": vad_offset}, task=task)
for audio_path in args.pop("audio"):
audio = load_audio(audio_path)