import os from dataclasses import replace import warnings from typing import List, Union, Optional, NamedTuple, Callable from enum import Enum import ctranslate2 import faster_whisper import numpy as np import torch from faster_whisper.tokenizer import Tokenizer from faster_whisper.transcribe import TranscriptionOptions, get_ctranslate2_storage from transformers import Pipeline from transformers.pipelines.pt_utils import PipelineIterator from whisperx.audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram from whisperx.types import SingleSegment, TranscriptionResult from whisperx.vads import Vad, Silero, Pyannote def find_numeral_symbol_tokens(tokenizer): numeral_symbol_tokens = [] for i in range(tokenizer.eot): token = tokenizer.decode([i]).removeprefix(" ") has_numeral_symbol = any(c in "0123456789%$£" for c in token) if has_numeral_symbol: numeral_symbol_tokens.append(i) return numeral_symbol_tokens class WhisperModel(faster_whisper.WhisperModel): ''' FasterWhisperModel provides batched inference for faster-whisper. Currently only works in non-timestamp mode and fixed prompt for all samples in batch. ''' def generate_segment_batched( self, features: np.ndarray, tokenizer: Tokenizer, options: TranscriptionOptions, encoder_output=None, ): batch_size = features.shape[0] all_tokens = [] prompt_reset_since = 0 if options.initial_prompt is not None: initial_prompt = " " + options.initial_prompt.strip() initial_prompt_tokens = tokenizer.encode(initial_prompt) all_tokens.extend(initial_prompt_tokens) previous_tokens = all_tokens[prompt_reset_since:] prompt = self.get_prompt( tokenizer, previous_tokens, without_timestamps=options.without_timestamps, prefix=options.prefix, hotwords=options.hotwords ) encoder_output = self.encode(features) max_initial_timestamp_index = int( round(options.max_initial_timestamp / self.time_precision) ) result = self.model.generate( encoder_output, [prompt] * batch_size, beam_size=options.beam_size, patience=options.patience, length_penalty=options.length_penalty, max_length=self.max_length, suppress_blank=options.suppress_blank, suppress_tokens=options.suppress_tokens, ) tokens_batch = [x.sequences_ids[0] for x in result] def decode_batch(tokens: List[List[int]]) -> str: res = [] for tk in tokens: res.append([token for token in tk if token < tokenizer.eot]) # text_tokens = [token for token in tokens if token < self.eot] return tokenizer.tokenizer.decode_batch(res) text = decode_batch(tokens_batch) return text def encode(self, features: np.ndarray) -> ctranslate2.StorageView: # When the model is running on multiple GPUs, the encoder output should be moved # to the CPU since we don't know which GPU will handle the next job. to_cpu = self.model.device == "cuda" and len(self.model.device_index) > 1 # unsqueeze if batch size = 1 if len(features.shape) == 2: features = np.expand_dims(features, 0) features = get_ctranslate2_storage(features) return self.model.encode(features, to_cpu=to_cpu) class FasterWhisperPipeline(Pipeline): """ Huggingface Pipeline wrapper for FasterWhisperModel. """ # TODO: # - add support for timestamp mode # - add support for custom inference kwargs class TranscriptionState(Enum): LOADING_AUDIO = "loading_audio" GENERATING_VAD_SEGMENTS = "generating_vad_segments" TRANSCRIBING = "transcribing" FINISHED = "finished" def __init__( self, model: WhisperModel, vad, vad_params: dict, options: TranscriptionOptions, tokenizer: Optional[Tokenizer] = None, device: Union[int, str, "torch.device"] = -1, framework="pt", language: Optional[str] = None, suppress_numerals: bool = False, **kwargs, ): self.model = model self.tokenizer = tokenizer self.options = options self.preset_language = language self.suppress_numerals = suppress_numerals self._batch_size = kwargs.pop("batch_size", None) self._num_workers = 1 self._preprocess_params, self._forward_params, self._postprocess_params = self._sanitize_parameters(**kwargs) self.call_count = 0 self.framework = framework if self.framework == "pt": if isinstance(device, torch.device): self.device = device elif isinstance(device, str): self.device = torch.device(device) elif device < 0: self.device = torch.device("cpu") else: self.device = torch.device(f"cuda:{device}") else: self.device = device super(Pipeline, self).__init__() self.vad_model = vad self._vad_params = vad_params def _sanitize_parameters(self, **kwargs): preprocess_kwargs = {} if "tokenizer" in kwargs: preprocess_kwargs["maybe_arg"] = kwargs["maybe_arg"] return preprocess_kwargs, {}, {} def preprocess(self, audio): audio = audio['inputs'] model_n_mels = self.model.feat_kwargs.get("feature_size") features = log_mel_spectrogram( audio, n_mels=model_n_mels if model_n_mels is not None else 80, padding=N_SAMPLES - audio.shape[0], ) return {'inputs': features} def _forward(self, model_inputs): outputs = self.model.generate_segment_batched(model_inputs['inputs'], self.tokenizer, self.options) return {'text': outputs} def postprocess(self, model_outputs): return model_outputs def get_iterator( self, inputs, num_workers: int, batch_size: int, preprocess_params: dict, forward_params: dict, postprocess_params: dict, ): dataset = PipelineIterator(inputs, self.preprocess, preprocess_params) if "TOKENIZERS_PARALLELISM" not in os.environ: os.environ["TOKENIZERS_PARALLELISM"] = "false" # TODO hack by collating feature_extractor and image_processor def stack(items): return {'inputs': torch.stack([x['inputs'] for x in items])} dataloader = torch.utils.data.DataLoader(dataset, num_workers=num_workers, batch_size=batch_size, collate_fn=stack) model_iterator = PipelineIterator(dataloader, self.forward, forward_params, loader_batch_size=batch_size) final_iterator = PipelineIterator(model_iterator, self.postprocess, postprocess_params) return final_iterator def transcribe( self, audio: Union[str, np.ndarray], batch_size: Optional[int] = None, num_workers=0, language: Optional[str] = None, task: Optional[str] = None, chunk_size=30, print_progress=False, combined_progress=False, verbose=False, on_progress: Callable[[TranscriptionState, Optional[int], Optional[int]], None] = None, ) -> TranscriptionResult: if isinstance(audio, str): if on_progress: on_progress(self.__class__.TranscriptionState.LOADING_AUDIO) audio = load_audio(audio) def data(audio, segments): for seg in segments: f1 = int(seg['start'] * SAMPLE_RATE) f2 = int(seg['end'] * SAMPLE_RATE) # print(f2-f1) yield {'inputs': audio[f1:f2]} # Pre-process audio and merge chunks as defined by the respective VAD child class # In case vad_model is manually assigned (see 'load_model') follow the functionality of pyannote toolkit if issubclass(type(self.vad_model), Vad): waveform = self.vad_model.preprocess_audio(audio) merge_chunks = self.vad_model.merge_chunks else: waveform = Pyannote.preprocess_audio(audio) merge_chunks = Pyannote.merge_chunks if on_progress: on_progress(self.__class__.TranscriptionState.GENERATING_VAD_SEGMENTS) vad_segments = self.vad_model({"waveform": waveform, "sample_rate": SAMPLE_RATE}) vad_segments = merge_chunks( vad_segments, chunk_size, onset=self._vad_params["vad_onset"], offset=self._vad_params["vad_offset"], ) if self.tokenizer is None: language = language or self.detect_language(audio) task = task or "transcribe" self.tokenizer = Tokenizer( self.model.hf_tokenizer, self.model.model.is_multilingual, task=task, language=language, ) else: language = language or self.tokenizer.language_code task = task or self.tokenizer.task if task != self.tokenizer.task or language != self.tokenizer.language_code: self.tokenizer = Tokenizer( self.model.hf_tokenizer, self.model.model.is_multilingual, task=task, language=language, ) if self.suppress_numerals: previous_suppress_tokens = self.options.suppress_tokens numeral_symbol_tokens = find_numeral_symbol_tokens(self.tokenizer) print(f"Suppressing numeral and symbol tokens") new_suppressed_tokens = numeral_symbol_tokens + self.options.suppress_tokens new_suppressed_tokens = list(set(new_suppressed_tokens)) self.options = replace(self.options, suppress_tokens=new_suppressed_tokens) segments: List[SingleSegment] = [] batch_size = batch_size or self._batch_size total_segments = len(vad_segments) if on_progress: on_progress(self.__class__.TranscriptionState.TRANSCRIBING, 0, total_segments) for idx, out in enumerate(self.__call__(data(audio, vad_segments), batch_size=batch_size, num_workers=num_workers)): if print_progress: base_progress = ((idx + 1) / total_segments) * 100 percent_complete = base_progress / 2 if combined_progress else base_progress print(f"Progress: {percent_complete:.2f}%...") if on_progress: on_progress(self.__class__.TranscriptionState.TRANSCRIBING, idx + 1, total_segments) text = out['text'] if batch_size in [0, 1, None]: text = text[0] segments.append( { "text": text, "start": round(vad_segments[idx]['start'], 3), "end": round(vad_segments[idx]['end'], 3) } ) if on_progress: on_progress(self.__class__.TranscriptionState.FINISHED) # revert the tokenizer if multilingual inference is enabled if self.preset_language is None: self.tokenizer = None # revert suppressed tokens if suppress_numerals is enabled if self.suppress_numerals: self.options = replace(self.options, suppress_tokens=previous_suppress_tokens) return {"segments": segments, "language": language} def detect_language(self, audio: np.ndarray) -> str: if audio.shape[0] < N_SAMPLES: print("Warning: audio is shorter than 30s, language detection may be inaccurate.") model_n_mels = self.model.feat_kwargs.get("feature_size") segment = log_mel_spectrogram(audio[: N_SAMPLES], n_mels=model_n_mels if model_n_mels is not None else 80, padding=0 if audio.shape[0] >= N_SAMPLES else N_SAMPLES - audio.shape[0]) encoder_output = self.model.encode(segment) results = self.model.model.detect_language(encoder_output) language_token, language_probability = results[0][0] language = language_token[2:-2] print(f"Detected language: {language} ({language_probability:.2f}) in first 30s of audio...") return language def load_model( whisper_arch: str, device: str, device_index=0, compute_type="float16", asr_options: Optional[dict] = None, language: Optional[str] = None, vad_model: Optional[Vad]= None, vad_method: Optional[str] = "pyannote", vad_options: Optional[dict] = None, model: Optional[WhisperModel] = None, task="transcribe", download_root: Optional[str] = None, local_files_only=False, threads=4, ) -> FasterWhisperPipeline: """Load a Whisper model for inference. Args: whisper_arch - The name of the Whisper model to load. device - The device to load the model on. compute_type - The compute type to use for the model. vad_method - The vad method to use. vad_model has higher priority if is not None. options - A dictionary of options to use for the model. language - The language of the model. (use English for now) model - The WhisperModel instance to use. download_root - The root directory to download the model to. local_files_only - If `True`, avoid downloading the file and return the path to the local cached file if it exists. threads - The number of cpu threads to use per worker, e.g. will be multiplied by num workers. Returns: A Whisper pipeline. """ if whisper_arch.endswith(".en"): language = "en" model = model or WhisperModel(whisper_arch, device=device, device_index=device_index, compute_type=compute_type, download_root=download_root, local_files_only=local_files_only, cpu_threads=threads) if language is not None: tokenizer = Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task=task, language=language) else: print("No language specified, language will be first be detected for each audio file (increases inference time).") tokenizer = None default_asr_options = { "beam_size": 5, "best_of": 5, "patience": 1, "length_penalty": 1, "repetition_penalty": 1, "no_repeat_ngram_size": 0, "temperatures": [0.0, 0.2, 0.4, 0.6, 0.8, 1.0], "compression_ratio_threshold": 2.4, "log_prob_threshold": -1.0, "no_speech_threshold": 0.6, "condition_on_previous_text": False, "prompt_reset_on_temperature": 0.5, "initial_prompt": None, "prefix": None, "suppress_blank": True, "suppress_tokens": [-1], "without_timestamps": True, "max_initial_timestamp": 0.0, "word_timestamps": False, "prepend_punctuations": "\"'“¿([{-", "append_punctuations": "\"'.。,,!!??::”)]}、", "multilingual": model.model.is_multilingual, "suppress_numerals": False, "max_new_tokens": None, "clip_timestamps": None, "hallucination_silence_threshold": None, "hotwords": None, } if asr_options is not None: default_asr_options.update(asr_options) suppress_numerals = default_asr_options["suppress_numerals"] del default_asr_options["suppress_numerals"] default_asr_options = TranscriptionOptions(**default_asr_options) default_vad_options = { "chunk_size": 30, # needed by silero since binarization happens before merge_chunks "vad_onset": 0.500, "vad_offset": 0.363 } if vad_options is not None: default_vad_options.update(vad_options) # Note: manually assigned vad_model has higher priority than vad_method! if vad_model is not None: print("Use manually assigned vad_model. vad_method is ignored.") vad_model = vad_model else: if vad_method == "silero": vad_model = Silero(**default_vad_options) elif vad_method == "pyannote": vad_model = Pyannote(torch.device(device), use_auth_token=None, **default_vad_options) else: raise ValueError(f"Invalid vad_method: {vad_method}") return FasterWhisperPipeline( model=model, vad=vad_model, options=default_asr_options, tokenizer=tokenizer, language=language, suppress_numerals=suppress_numerals, vad_params=default_vad_options, )