diff --git a/whisperx/asr.py b/whisperx/asr.py index 713531c..d0e6962 100644 --- a/whisperx/asr.py +++ b/whisperx/asr.py @@ -13,8 +13,16 @@ from .audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram from .vad import load_vad_model, merge_chunks from .types import TranscriptionResult, SingleSegment -def load_model(whisper_arch, device, device_index=0, compute_type="float16", asr_options=None, language=None, - vad_options=None, model=None, task="transcribe"): +def load_model(whisper_arch, + device, + device_index=0, + compute_type="float16", + asr_options=None, + language=None, + vad_options=None, + model=None, + task="transcribe", + download_root=None): '''Load a Whisper model for inference. Args: whisper_arch: str - The name of the Whisper model to load. @@ -22,14 +30,19 @@ def load_model(whisper_arch, device, device_index=0, compute_type="float16", asr compute_type: str - The compute type to use for the model. options: dict - A dictionary of options to use for the model. language: str - The language of the model. (use English for now) + download_root: Optional[str] - The root directory to download the model to. Returns: A Whisper pipeline. - ''' + ''' if whisper_arch.endswith(".en"): language = "en" - model = WhisperModel(whisper_arch, device=device, device_index=device_index, compute_type=compute_type) + model = WhisperModel(whisper_arch, + device=device, + device_index=device_index, + compute_type=compute_type, + download_root=download_root) if language is not None: tokenizer = faster_whisper.tokenizer.Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task=task, language=language) else: @@ -114,7 +127,7 @@ class WhisperModel(faster_whisper.WhisperModel): # suppress_tokens=options.suppress_tokens, # max_initial_timestamp_index=max_initial_timestamp_index, ) - + tokens_batch = [x.sequences_ids[0] for x in result] def decode_batch(tokens: List[List[int]]) -> str: @@ -127,7 +140,7 @@ class WhisperModel(faster_whisper.WhisperModel): text = decode_batch(tokens_batch) return text - + def encode(self, features: np.ndarray) -> ctranslate2.StorageView: # When the model is running on multiple GPUs, the encoder output should be moved # to the CPU since we don't know which GPU will handle the next job. @@ -136,9 +149,9 @@ class WhisperModel(faster_whisper.WhisperModel): if len(features.shape) == 2: features = np.expand_dims(features, 0) features = faster_whisper.transcribe.get_ctranslate2_storage(features) - + return self.model.encode(features, to_cpu=to_cpu) - + class FasterWhisperPipeline(Pipeline): """ Huggingface Pipeline wrapper for FasterWhisperModel. @@ -176,7 +189,7 @@ class FasterWhisperPipeline(Pipeline): self.device = torch.device(f"cuda:{device}") else: self.device = device - + super(Pipeline, self).__init__() self.vad_model = vad @@ -194,7 +207,7 @@ class FasterWhisperPipeline(Pipeline): def _forward(self, model_inputs): outputs = self.model.generate_segment_batched(model_inputs['inputs'], self.tokenizer, self.options) return {'text': outputs} - + def postprocess(self, model_outputs): return model_outputs @@ -218,7 +231,7 @@ class FasterWhisperPipeline(Pipeline): ) -> TranscriptionResult: if isinstance(audio, str): audio = load_audio(audio) - + def data(audio, segments): for seg in segments: f1 = int(seg['start'] * SAMPLE_RATE)