diff --git a/whisperx/asr.py b/whisperx/asr.py index 0ea03b6..95ba098 100644 --- a/whisperx/asr.py +++ b/whisperx/asr.py @@ -6,6 +6,9 @@ import ctranslate2 import faster_whisper import numpy as np import torch +from faster_whisper.tokenizer import Tokenizer +from faster_whisper.transcribe import (TranscriptionOptions, + get_ctranslate2_storage) from transformers import Pipeline from transformers.pipelines.pt_utils import PipelineIterator @@ -28,7 +31,13 @@ class WhisperModel(faster_whisper.WhisperModel): Currently only works in non-timestamp mode and fixed prompt for all samples in batch. ''' - def generate_segment_batched(self, features: np.ndarray, tokenizer: faster_whisper.tokenizer.Tokenizer, options: faster_whisper.transcribe.TranscriptionOptions, encoder_output = None): + def generate_segment_batched( + self, + features: np.ndarray, + tokenizer: Tokenizer, + options: TranscriptionOptions, + encoder_output=None, + ): batch_size = features.shape[0] all_tokens = [] prompt_reset_since = 0 @@ -81,7 +90,7 @@ class WhisperModel(faster_whisper.WhisperModel): # unsqueeze if batch size = 1 if len(features.shape) == 2: features = np.expand_dims(features, 0) - features = faster_whisper.transcribe.get_ctranslate2_storage(features) + features = get_ctranslate2_storage(features) return self.model.encode(features, to_cpu=to_cpu) @@ -193,17 +202,23 @@ class FasterWhisperPipeline(Pipeline): if self.tokenizer is None: language = language or self.detect_language(audio) task = task or "transcribe" - self.tokenizer = faster_whisper.tokenizer.Tokenizer(self.model.hf_tokenizer, - self.model.model.is_multilingual, task=task, - language=language) + self.tokenizer = Tokenizer( + self.model.hf_tokenizer, + self.model.model.is_multilingual, + task=task, + language=language, + ) else: language = language or self.tokenizer.language_code task = task or self.tokenizer.task if task != self.tokenizer.task or language != self.tokenizer.language_code: - self.tokenizer = faster_whisper.tokenizer.Tokenizer(self.model.hf_tokenizer, - self.model.model.is_multilingual, task=task, - language=language) - + self.tokenizer = Tokenizer( + self.model.hf_tokenizer, + self.model.model.is_multilingual, + task=task, + language=language, + ) + if self.suppress_numerals: previous_suppress_tokens = self.options.suppress_tokens numeral_symbol_tokens = find_numeral_symbol_tokens(self.tokenizer) @@ -297,7 +312,7 @@ def load_model(whisper_arch, local_files_only=local_files_only, cpu_threads=threads) if language is not None: - tokenizer = faster_whisper.tokenizer.Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task=task, language=language) + tokenizer = Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task=task, language=language) else: print("No language specified, language will be first be detected for each audio file (increases inference time).") tokenizer = None @@ -338,7 +353,7 @@ def load_model(whisper_arch, suppress_numerals = default_asr_options["suppress_numerals"] del default_asr_options["suppress_numerals"] - default_asr_options = faster_whisper.transcribe.TranscriptionOptions(**default_asr_options) + default_asr_options = TranscriptionOptions(**default_asr_options) default_vad_options = { "vad_onset": 0.500,