diff --git a/README.md b/README.md index 2bffa43..b5b95c5 100644 --- a/README.md +++ b/README.md @@ -32,7 +32,7 @@ whisperx-arch -

Whisper-Based Automatic Speech Recognition (ASR) with improved timestamp accuracy using forced alignment. +

Whisper-Based Automatic Speech Recognition (ASR) with improved timestamp accuracy + quality via forced phoneme alignment and speech-activity batching.

@@ -52,6 +52,7 @@ This repository refines the timestamps of openAI's Whisper model via forced alig

New🚨

+- v3 released, 70x speed-up open-sourced. Using batched whisper with [faster-whisper](https://github.com/guillaumekln/faster-whisper) backend! - v2 released, code cleanup, imports whisper library, batched inference from paper not included (contact for licensing / batched model API). VAD filtering is now turned on by default, as in the paper. - Paper drop🎓👨‍🏫! Please see our [ArxiV preprint](https://arxiv.org/abs/2303.00747) for benchmarking and details of WhisperX. We also introduce more efficient batch inference resulting in large-v2 with *60-70x REAL TIME speed (not provided in this repo). - VAD filtering: Voice Activity Detection (VAD) from [Pyannote.audio](https://huggingface.co/pyannote/voice-activity-detection) is used as a preprocessing step to remove reliance on whisper timestamps and only transcribe audio segments containing speech. add `--vad_filter True` flag, increases timestamp accuracy and robustness (requires more GPU mem due to 30s inputs in wav2vec2) @@ -60,7 +61,25 @@ This repository refines the timestamps of openAI's Whisper model via forced alig

Setup ⚙️

-Install this package using +Tested for PyTorch 0.11, Python 3.8 (use other versions at your own risk!) + +GPU execution requires the NVIDIA libraries cuBLAS 11.x and cuDNN 8.x to be installed on the system. Please refer to the [CTranslate2 documentation](https://opennmt.net/CTranslate2/installation.html). + + +### 1. Create Python3.8 environment + +`conda create --name whisperx python=3.8` + +`conda activate whisperx` + + +### 2. Install PyTorch 0.11.0, e.g. for Linux and Windows: + +`conda install pytorch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 cudatoolkit=11.3 -c pytorch` + +See other methods [here.](https://pytorch.org/get-started/previous-versions/#wheel-4) + +### 3. Install this repo `pip install git+https://github.com/m-bain/whisperx.git` @@ -78,13 +97,6 @@ $ pip install -e . You may also need to install ffmpeg, rust etc. Follow openAI instructions here https://github.com/openai/whisper#setup. -### Setup not working??? -Safest to use install pytorch as follows (for gpu) - -`conda install pytorch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 -c pytorch -` - - ### Speaker Diarization To **enable Speaker. Diarization**, include your Hugging Face access token that you can generate from [Here](https://huggingface.co/settings/tokens) after the `--hf_token` argument and accept the user agreement for the following models: [Segmentation](https://huggingface.co/pyannote/segmentation) , [Voice Activity Detection (VAD)](https://huggingface.co/pyannote/voice-activity-detection) , and [Speaker Diarization](https://huggingface.co/pyannote/speaker-diarization) @@ -130,14 +142,15 @@ See more examples in other languages [here](EXAMPLES.md). ```python import whisperx -import whisper device = "cuda" audio_file = "audio.mp3" # transcribe with original whisper -model = whisper.load_model("large", device) -result = model.transcribe(audio_file) +model = whisperx.load_model("large-v2", device) + +audio = whisperx.load_audio(audio_file) +result = model.transcribe(audio, batch_size=8) print(result["segments"]) # before alignment @@ -145,7 +158,7 @@ print(result["segments"]) # before alignment model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device) # align whisper output -result_aligned = whisperx.align(result["segments"], model_a, metadata, audio_file, device) +result_aligned = whisperx.align(result["segments"], model_a, metadata, audio, device) print(result_aligned["segments"]) # after alignment print(result_aligned["word_segments"]) # after alignment @@ -186,9 +199,15 @@ The next major upgrade we are working on is whisper with speaker diarization, so * [x] Incorporating speaker diarization -* [ ] Automatic .wav conversion to make VAD compatible +* [x] Model flush, for low gpu mem resources -* [ ] Model flush, for low gpu mem resources +* [x] Faster-whisper backend + +* [ ] Add benchmarking code (TEDLIUM for spd/WER & word segmentation) + +* [ ] Allow silero-vad as alternative VAD option + +* [ ] Add max-line etc. see (openai's whisper utils.py) * [ ] Improve diarization (word level). *Harder than first thought...* @@ -205,10 +224,13 @@ Contact maxhbain@gmail.com for queries and licensing / early access to a model A This work, and my PhD, is supported by the [VGG (Visual Geometry Group)](https://www.robots.ox.ac.uk/~vgg/) and the University of Oxford. - Of course, this is builds on [openAI's whisper](https://github.com/openai/whisper). And borrows important alignment code from [PyTorch tutorial on forced alignment](https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html) +Valuable VAD & Diarization Models from (pyannote.audio)[https://github.com/pyannote/pyannote-audio] + +Great backend from (faster-whisper)[https://github.com/guillaumekln/faster-whisper] and (CTranslate2)[https://github.com/OpenNMT/CTranslate2] +

Citation

If you use this in your research, please cite the paper: @@ -220,37 +242,4 @@ If you use this in your research, please cite the paper: journal={arXiv preprint, arXiv:2303.00747}, year={2023} } -``` - -as well the following works, used in each stage of the pipeline: - -```bibtex -@article{radford2022robust, - title={Robust speech recognition via large-scale weak supervision}, - author={Radford, Alec and Kim, Jong Wook and Xu, Tao and Brockman, Greg and McLeavey, Christine and Sutskever, Ilya}, - journal={arXiv preprint arXiv:2212.04356}, - year={2022} -} -``` - -```bibtex -@article{baevski2020wav2vec, - title={wav2vec 2.0: A framework for self-supervised learning of speech representations}, - author={Baevski, Alexei and Zhou, Yuhao and Mohamed, Abdelrahman and Auli, Michael}, - journal={Advances in neural information processing systems}, - volume={33}, - pages={12449--12460}, - year={2020} -} -``` - -```bibtex -@inproceedings{bredin2020pyannote, - title={Pyannote. audio: neural building blocks for speaker diarization}, - author={Bredin, Herv{\'e} and Yin, Ruiqing and Coria, Juan Manuel and Gelly, Gregory and Korshunov, Pavel and Lavechin, Marvin and Fustes, Diego and Titeux, Hadrien and Bouaziz, Wassim and Gill, Marie-Philippe}, - booktitle={ICASSP 2020-2020 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)}, - pages={7124--7128}, - year={2020}, - organization={IEEE} -} -``` +``` \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 139ee56..2747e2d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,8 @@ -numpy -pandas -torch >=1.9 -torchaudio >=0.10,<1.0 -tqdm -more-itertools -transformers>=4.19.0 -ffmpeg-python==0.2.0 +torch==1.11.0 +torchaudio==0.11.0 pyannote.audio -openai-whisper==20230314 +faster-whisper +transformers +ffmpeg-python==0.2.0 +pandas +setuptools==65.6.3 \ No newline at end of file diff --git a/setup.py b/setup.py index d6472e1..2060d42 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ from setuptools import setup, find_packages setup( name="whisperx", py_modules=["whisperx"], - version="2.0", + version="3.0.0", description="Time-Accurate Automatic Speech Recognition using Whisper.", readme="README.md", python_requires=">=3.8", diff --git a/whisperx/__init__.py b/whisperx/__init__.py index 985ed32..d0294b9 100644 --- a/whisperx/__init__.py +++ b/whisperx/__init__.py @@ -1,3 +1,3 @@ -from .transcribe import transcribe, transcribe_with_vad +from .transcribe import load_model from .alignment import load_align_model, align -from .vad import load_vad_model \ No newline at end of file +from .audio import load_audio \ No newline at end of file diff --git a/whisperx/alignment.py b/whisperx/alignment.py index c15310b..09f044f 100644 --- a/whisperx/alignment.py +++ b/whisperx/alignment.py @@ -2,16 +2,17 @@ Forced Alignment with Whisper C. Max Bain """ +from dataclasses import dataclass +from typing import Iterator, Union + import numpy as np import pandas as pd -from typing import List, Union, Iterator, TYPE_CHECKING -from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor -import torchaudio import torch -from dataclasses import dataclass -from whisper.audio import SAMPLE_RATE, load_audio -from .utils import interpolate_nans +import torchaudio +from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor +from .audio import SAMPLE_RATE, load_audio +from .utils import interpolate_nans LANGUAGES_WITHOUT_SPACES = ["ja", "zh"] @@ -391,34 +392,42 @@ def align( if 'level_1' in cseg: del cseg['level_1'] if 'level_0' in cseg: del cseg['level_0'] cseg.reset_index(inplace=True) - aligned_segments.append( - { - "start": srow["start"], - "end": srow["end"], - "text": text, - "word-segments": wseg, - "char-segments": cseg - } - ) def get_raw_text(word_row): return seg["seg-text"][word_row.name][int(word_row["segment-text-start"]):int(word_row["segment-text-end"])+1] + word_list = [] wdx = 0 curr_text = get_raw_text(wseg.iloc[wdx]) + if not curr_text.startswith(" "): + curr_text = " " + curr_text + if len(wseg) > 1: for _, wrow in wseg.iloc[1:].iterrows(): if wrow['start'] != wseg.iloc[wdx]['start']: + word_start = wseg.iloc[wdx]['start'] + word_end = wseg.iloc[wdx]['end'] + aligned_segments_word.append( { "text": curr_text.strip(), - "start": wseg.iloc[wdx]["start"], - "end": wseg.iloc[wdx]["end"], + "start": word_start, + "end": word_end } ) - curr_text = "" - curr_text += " " + get_raw_text(wrow) + + word_list.append( + { + "word": curr_text.rstrip(), + "start": word_start, + "end": word_end, + } + ) + + curr_text = " " + curr_text += get_raw_text(wrow) + " " wdx += 1 + aligned_segments_word.append( { "text": curr_text.strip(), @@ -427,6 +436,25 @@ def align( } ) + word_list.append( + { + "word": curr_text.rstrip(), + "start": word_start, + "end": word_end, + } + ) + + aligned_segments.append( + { + "start": srow["start"], + "end": srow["end"], + "text": text, + "words": word_list, + # "word-segments": wseg, + # "char-segments": cseg + } + ) + return {"segments": aligned_segments, "word_segments": aligned_segments_word} diff --git a/whisperx/asr.py b/whisperx/asr.py index e78d77c..d9cbff0 100644 --- a/whisperx/asr.py +++ b/whisperx/asr.py @@ -1,433 +1,406 @@ +import os import warnings -from typing import TYPE_CHECKING, Optional, Tuple, Union +from typing import List, Union + +import ctranslate2 +import faster_whisper import numpy as np import torch -import tqdm -import ffmpeg -from whisper.audio import ( - FRAMES_PER_SECOND, - HOP_LENGTH, - N_FRAMES, - N_SAMPLES, - SAMPLE_RATE, - CHUNK_LENGTH, - log_mel_spectrogram, - pad_or_trim, - load_audio -) -from whisper.decoding import DecodingOptions, DecodingResult -from whisper.timing import add_word_timestamps -from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer -from whisper.utils import ( - exact_div, - format_timestamp, - make_safe, -) +from transformers import Pipeline +from transformers.pipelines.pt_utils import PipelineIterator -if TYPE_CHECKING: - from whisper.model import Whisper +from .audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram +from .vad import load_vad_model, merge_chunks -from .vad import merge_chunks -def transcribe( - model: "Whisper", - audio: Union[str, np.ndarray, torch.Tensor] = None, - mel: np.ndarray = None, - verbose: Optional[bool] = None, - temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0), - compression_ratio_threshold: Optional[float] = 2.4, - logprob_threshold: Optional[float] = -1.0, - no_speech_threshold: Optional[float] = 0.6, - condition_on_previous_text: bool = True, - initial_prompt: Optional[str] = None, - word_timestamps: bool = False, - prepend_punctuations: str = "\"'“¿([{-", - append_punctuations: str = "\"'.。,,!!??::”)]}、", - **decode_options, -): - """ - Transcribe an audio file using Whisper. - We redefine the Whisper transcribe function to allow mel input (for sequential slicing of audio) +def load_model(whisper_arch, device, compute_type="float16", asr_options=None, language=None, + vad_options=None, model=None): + '''Load a Whisper model for inference. + Args: + whisper_arch: str - The name of the Whisper model to load. + device: str - The device to load the model on. + compute_type: str - The compute type to use for the model. + options: dict - A dictionary of options to use for the model. + language: str - The language of the model. (use English for now) + Returns: + A Whisper pipeline. + ''' - Parameters - ---------- - model: Whisper - The Whisper model instance + if whisper_arch.endswith(".en"): + language = "en" - audio: Union[str, np.ndarray, torch.Tensor] - The path to the audio file to open, or the audio waveform - - mel: np.ndarray - Mel spectrogram of audio segment. - - verbose: bool - Whether to display the text being decoded to the console. If True, displays all the details, - If False, displays minimal details. If None, does not display anything - - temperature: Union[float, Tuple[float, ...]] - Temperature for sampling. It can be a tuple of temperatures, which will be successively used - upon failures according to either `compression_ratio_threshold` or `logprob_threshold`. - - compression_ratio_threshold: float - If the gzip compression ratio is above this value, treat as failed - - logprob_threshold: float - If the average log probability over sampled tokens is below this value, treat as failed - - no_speech_threshold: float - If the no_speech probability is higher than this value AND the average log probability - over sampled tokens is below `logprob_threshold`, consider the segment as silent - - condition_on_previous_text: bool - if True, the previous output of the model is provided as a prompt for the next window; - disabling may make the text inconsistent across windows, but the model becomes less prone to - getting stuck in a failure loop, such as repetition looping or timestamps going out of sync. - - word_timestamps: bool - Extract word-level timestamps using the cross-attention pattern and dynamic time warping, - and include the timestamps for each word in each segment. - - prepend_punctuations: str - If word_timestamps is True, merge these punctuation symbols with the next word - - append_punctuations: str - If word_timestamps is True, merge these punctuation symbols with the previous word - - initial_prompt: Optional[str] - Optional text to provide as a prompt for the first window. This can be used to provide, or - "prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns - to make it more likely to predict those word correctly. - - decode_options: dict - Keyword arguments to construct `DecodingOptions` instances - - Returns - ------- - A dictionary containing the resulting text ("text") and segment-level details ("segments"), and - the spoken language ("language"), which is detected when `decode_options["language"]` is None. - """ - dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32 - if model.device == torch.device("cpu"): - if torch.cuda.is_available(): - warnings.warn("Performing inference on CPU when CUDA is available") - if dtype == torch.float16: - warnings.warn("FP16 is not supported on CPU; using FP32 instead") - dtype = torch.float32 - - if dtype == torch.float32: - decode_options["fp16"] = False - - # Pad 30-seconds of silence to the input audio, for slicing - if mel is None: - if audio is None: - raise ValueError("Transcribe needs either audio or mel as input, currently both are none.") - mel = log_mel_spectrogram(audio, padding=N_SAMPLES) - content_frames = mel.shape[-1] - N_FRAMES - - if decode_options.get("language", None) is None: - if not model.is_multilingual: - decode_options["language"] = "en" - else: - if verbose: - print( - "Detecting language using up to the first 30 seconds. Use `--language` to specify the language" - ) - mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype) - _, probs = model.detect_language(mel_segment) - decode_options["language"] = max(probs, key=probs.get) - if verbose is not None: - print( - f"Detected language: {LANGUAGES[decode_options['language']].title()}" - ) - - language: str = decode_options["language"] - task: str = decode_options.get("task", "transcribe") - tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task) - - if word_timestamps and task == "translate": - warnings.warn("Word-level timestamps on translations may not be reliable.") - - def decode_with_fallback(segment: torch.Tensor) -> DecodingResult: - temperatures = ( - [temperature] if isinstance(temperature, (int, float)) else temperature - ) - decode_result = None - - for t in temperatures: - kwargs = {**decode_options} - if t > 0: - # disable beam_size and patience when t > 0 - kwargs.pop("beam_size", None) - kwargs.pop("patience", None) - else: - # disable best_of when t == 0 - kwargs.pop("best_of", None) - - options = DecodingOptions(**kwargs, temperature=t) - decode_result = model.decode(segment, options) - - needs_fallback = False - if ( - compression_ratio_threshold is not None - and decode_result.compression_ratio > compression_ratio_threshold - ): - needs_fallback = True # too repetitive - if ( - logprob_threshold is not None - and decode_result.avg_logprob < logprob_threshold - ): - needs_fallback = True # average log probability is too low - - if not needs_fallback: - break - - return decode_result - - seek = 0 - input_stride = exact_div( - N_FRAMES, model.dims.n_audio_ctx - ) # mel frames per output token: 2 - time_precision = ( - input_stride * HOP_LENGTH / SAMPLE_RATE - ) # time per output token: 0.02 (seconds) - all_tokens = [] - all_segments = [] - prompt_reset_since = 0 - - if initial_prompt is not None: - initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip()) - all_tokens.extend(initial_prompt_tokens) + model = WhisperModel(whisper_arch, device=device, compute_type=compute_type) + if language is not None: + tokenizer = faster_whisper.tokenizer.Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task="transcribe", language=language) else: - initial_prompt_tokens = [] + print("No language specified, language will be first be detected for each audio file (increases inference time).") + tokenizer = None - def new_segment( - *, start: float, end: float, tokens: torch.Tensor, result: DecodingResult + default_asr_options = { + "beam_size": 5, + "best_of": 5, + "patience": 1, + "length_penalty": 1, + "temperatures": [0.0, 0.2, 0.4, 0.6, 0.8, 1.0], + "compression_ratio_threshold": 2.4, + "log_prob_threshold": -1.0, + "no_speech_threshold": 0.6, + "condition_on_previous_text": False, + "initial_prompt": None, + "prefix": None, + "suppress_blank": True, + "suppress_tokens": [-1], + "without_timestamps": True, + "max_initial_timestamp": 0.0, + "word_timestamps": False, + "prepend_punctuations": "\"'“¿([{-", + "append_punctuations": "\"'.。,,!!??::”)]}、" + } + + if asr_options is not None: + default_asr_options.update(asr_options) + default_asr_options = faster_whisper.transcribe.TranscriptionOptions(**default_asr_options) + + default_vad_options = { + "vad_onset": 0.500, + "vad_offset": 0.363 + } + + if vad_options is not None: + default_vad_options.update(vad_options) + + vad_model = load_vad_model(torch.device(device), use_auth_token=None, **default_vad_options) + + return FasterWhisperPipeline(model, vad_model, default_asr_options, tokenizer) + + + +class WhisperModel(faster_whisper.WhisperModel): + ''' + FasterWhisperModel provides batched inference for faster-whisper. + Currently only works in non-timestamp mode. + ''' + + def generate_segment_batched(self, features: np.ndarray, tokenizer: faster_whisper.tokenizer.Tokenizer, options: faster_whisper.transcribe.TranscriptionOptions, encoder_output = None): + batch_size = features.shape[0] + all_tokens = [] + prompt_reset_since = 0 + if options.initial_prompt is not None: + initial_prompt = " " + options.initial_prompt.strip() + initial_prompt_tokens = tokenizer.encode(initial_prompt) + all_tokens.extend(initial_prompt_tokens) + previous_tokens = all_tokens[prompt_reset_since:] + prompt = self.get_prompt( + tokenizer, + previous_tokens, + without_timestamps=options.without_timestamps, + prefix=options.prefix, + ) + + encoder_output = self.encode(features) + + max_initial_timestamp_index = int( + round(options.max_initial_timestamp / self.time_precision) + ) + + result = self.model.generate( + encoder_output, + [prompt] * batch_size, + # length_penalty=options.length_penalty, + # max_length=self.max_length, + # return_scores=True, + # return_no_speech_prob=True, + # suppress_blank=options.suppress_blank, + # suppress_tokens=options.suppress_tokens, + # max_initial_timestamp_index=max_initial_timestamp_index, + ) + + tokens_batch = [x.sequences_ids[0] for x in result] + + def decode_batch(tokens: List[List[int]]) -> str: + res = [] + for tk in tokens: + res.append([token for token in tk if token < tokenizer.eot]) + # text_tokens = [token for token in tokens if token < self.eot] + return tokenizer.tokenizer.decode_batch(res) + + text = decode_batch(tokens_batch) + + return text + + def encode(self, features: np.ndarray) -> ctranslate2.StorageView: + # When the model is running on multiple GPUs, the encoder output should be moved + # to the CPU since we don't know which GPU will handle the next job. + to_cpu = self.model.device == "cuda" and len(self.model.device_index) > 1 + # unsqueeze if batch size = 1 + if len(features.shape) == 2: + features = np.expand_dims(features, 0) + features = faster_whisper.transcribe.get_ctranslate2_storage(features) + + return self.model.encode(features, to_cpu=to_cpu) + +class FasterWhisperPipeline(Pipeline): + def __init__( + self, + model, + vad, + options, + tokenizer=None, + device: Union[int, str, "torch.device"] = -1, + framework = "pt", + **kwargs ): - tokens = tokens.tolist() - text_tokens = [token for token in tokens if token < tokenizer.eot] - return { - "seek": seek, - "start": start, - "end": end, - "text": tokenizer.decode(text_tokens), - "tokens": tokens, - "temperature": result.temperature, - "avg_logprob": result.avg_logprob, - "compression_ratio": result.compression_ratio, - "no_speech_prob": result.no_speech_prob, - } - - - # show the progress bar when verbose is False (if True, transcribed text will be printed) - with tqdm.tqdm( - total=content_frames, unit="frames", disable=verbose is not False - ) as pbar: - while seek < content_frames: - time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE) - mel_segment = mel[:, seek : seek + N_FRAMES] - segment_size = min(N_FRAMES, content_frames - seek) - segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE - mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype) - - decode_options["prompt"] = all_tokens[prompt_reset_since:] - result: DecodingResult = decode_with_fallback(mel_segment) - tokens = torch.tensor(result.tokens) - if no_speech_threshold is not None: - # no voice activity check - should_skip = result.no_speech_prob > no_speech_threshold - if ( - logprob_threshold is not None - and result.avg_logprob > logprob_threshold - ): - # don't skip if the logprob is high enough, despite the no_speech_prob - should_skip = False - - if should_skip: - seek += segment_size # fast-forward to the next segment boundary - continue - - previous_seek = seek - current_segments = [] - - timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin) - single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True] - - consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] - consecutive.add_(1) - if len(consecutive) > 0: - # if the output contains two consecutive timestamp tokens - slices = consecutive.tolist() - if single_timestamp_ending: - slices.append(len(tokens)) - - last_slice = 0 - for current_slice in slices: - sliced_tokens = tokens[last_slice:current_slice] - start_timestamp_pos = ( - sliced_tokens[0].item() - tokenizer.timestamp_begin - ) - end_timestamp_pos = ( - sliced_tokens[-1].item() - tokenizer.timestamp_begin - ) - - # clamp end-time to at least be 1 frame after start-time - end_timestamp_pos = max(end_timestamp_pos, start_timestamp_pos + time_precision) - - current_segments.append( - new_segment( - start=time_offset + start_timestamp_pos * time_precision, - end=time_offset + end_timestamp_pos * time_precision, - tokens=sliced_tokens, - result=result, - ) - ) - last_slice = current_slice - - if single_timestamp_ending: - # single timestamp at the end means no speech after the last timestamp. - seek += segment_size - else: - # otherwise, ignore the unfinished segment and seek to the last timestamp - last_timestamp_pos = ( - tokens[last_slice - 1].item() - tokenizer.timestamp_begin - ) - seek += last_timestamp_pos * input_stride + self.model = model + self.tokenizer = tokenizer + self.options = options + self._batch_size = kwargs.pop("batch_size", None) + self._num_workers = 1 + self._preprocess_params, self._forward_params, self._postprocess_params = self._sanitize_parameters(**kwargs) + self.call_count = 0 + self.framework = framework + if self.framework == "pt": + if isinstance(device, torch.device): + self.device = device + elif isinstance(device, str): + self.device = torch.device(device) + elif device < 0: + self.device = torch.device("cpu") else: - duration = segment_duration - timestamps = tokens[timestamp_tokens.nonzero().flatten()] - if ( - len(timestamps) > 0 - and timestamps[-1].item() != tokenizer.timestamp_begin - ): - # no consecutive timestamps but it has a timestamp; use the last one. - last_timestamp_pos = ( - timestamps[-1].item() - tokenizer.timestamp_begin - ) - duration = last_timestamp_pos * time_precision + self.device = torch.device(f"cuda:{device}") + else: + self.device = device + + super(Pipeline, self).__init__() + self.vad_model = vad - current_segments.append( - new_segment( - start=time_offset, - end=time_offset + duration, - tokens=tokens, - result=result, - ) - ) - seek += segment_size + def _sanitize_parameters(self, **kwargs): + preprocess_kwargs = {} + if "tokenizer" in kwargs: + preprocess_kwargs["maybe_arg"] = kwargs["maybe_arg"] + return preprocess_kwargs, {}, {} - if not condition_on_previous_text or result.temperature > 0.5: - # do not feed the prompt tokens if a high temperature was used - prompt_reset_since = len(all_tokens) + def preprocess(self, audio): + audio = audio['inputs'] + features = log_mel_spectrogram(audio, padding=N_SAMPLES - audio.shape[0]) + return {'inputs': features} - if word_timestamps: - add_word_timestamps( - segments=current_segments, - model=model, - tokenizer=tokenizer, - mel=mel_segment, - num_frames=segment_size, - prepend_punctuations=prepend_punctuations, - append_punctuations=append_punctuations, - ) - word_end_timestamps = [ - w["end"] for s in current_segments for w in s["words"] - ] - if not single_timestamp_ending and len(word_end_timestamps) > 0: - seek_shift = round( - (word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND - ) - if seek_shift > 0: - seek = previous_seek + seek_shift + def _forward(self, model_inputs): + outputs = self.model.generate_segment_batched(model_inputs['inputs'], self.tokenizer, self.options) + return {'text': outputs} + + def postprocess(self, model_outputs): + return model_outputs - if verbose: - for segment in current_segments: - start, end, text = segment["start"], segment["end"], segment["text"] - line = f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}" - print(make_safe(line)) + def get_iterator( + self, inputs, num_workers: int, batch_size: int, preprocess_params, forward_params, postprocess_params + ): + dataset = PipelineIterator(inputs, self.preprocess, preprocess_params) + if "TOKENIZERS_PARALLELISM" not in os.environ: + os.environ["TOKENIZERS_PARALLELISM"] = "false" + # TODO hack by collating feature_extractor and image_processor - # if a segment is instantaneous or does not contain text, clear it - for i, segment in enumerate(current_segments): - if segment["start"] == segment["end"] or segment["text"].strip() == "": - segment["text"] = "" - segment["tokens"] = [] - segment["words"] = [] + def stack(items): + return {'inputs': torch.stack([x['inputs'] for x in items])} + dataloader = torch.utils.data.DataLoader(dataset, num_workers=num_workers, batch_size=batch_size, collate_fn=stack) + model_iterator = PipelineIterator(dataloader, self.forward, forward_params, loader_batch_size=batch_size) + final_iterator = PipelineIterator(model_iterator, self.postprocess, postprocess_params) + return final_iterator - all_segments.extend( - [ - {"id": i, **segment} - for i, segment in enumerate( - current_segments, start=len(all_segments) - ) - ] - ) - all_tokens.extend( - [token for segment in current_segments for token in segment["tokens"]] - ) + def transcribe( + self, audio: Union[str, np.ndarray], batch_size=None + ): + if isinstance(audio, str): + audio = load_audio(audio) + + def data(audio, segments): + for seg in segments: + f1 = int(seg['start'] * SAMPLE_RATE) + f2 = int(seg['end'] * SAMPLE_RATE) + # print(f2-f1) + yield {'inputs': audio[f1:f2]} - # update progress bar - pbar.update(min(content_frames, seek) - previous_seek) + vad_segments = self.vad_model({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": SAMPLE_RATE}) + vad_segments = merge_chunks(vad_segments, 30) + del_tokenizer = False + if self.tokenizer is None: + language = self.detect_language(audio) + self.tokenizer = faster_whisper.tokenizer.Tokenizer(self.model.hf_tokenizer, self.model.model.is_multilingual, task="transcribe", language=language) + del_tokenizer = True + else: + language = self.tokenizer.language_code - return dict( - text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]), - segments=all_segments, - language=language, - ) - - -def transcribe_with_vad( - model: "Whisper", - audio: str, - vad_pipeline, - mel = None, - verbose: Optional[bool] = None, - **kwargs -): - """ - Transcribe per VAD segment - """ - - vad_segments = vad_pipeline(audio) - - # if not torch.is_tensor(audio): - # if isinstance(audio, str): - audio = load_audio(audio) - audio = torch.from_numpy(audio) - - prev = 0 - output = {"segments": []} - - # merge segments to approx 30s inputs to make whisper most appropraite - vad_segments = merge_chunks(vad_segments, chunk_size=CHUNK_LENGTH) - if len(vad_segments) == 0: - return output - - print(">>Performing transcription...") - for sdx, seg_t in enumerate(vad_segments): - if verbose: - print(f"~~ Transcribing VAD chunk: ({format_timestamp(seg_t['start'])} --> {format_timestamp(seg_t['end'])}) ~~") - seg_f_start, seg_f_end = int(seg_t["start"] * SAMPLE_RATE), int(seg_t["end"] * SAMPLE_RATE) - local_f_start, local_f_end = seg_f_start - prev, seg_f_end - prev - audio = audio[local_f_start:] # seek forward - seg_audio = audio[:local_f_end-local_f_start] # seek forward - prev = seg_f_start - local_mel = log_mel_spectrogram(seg_audio, padding=N_SAMPLES) - # need to pad - - result = transcribe(model, audio, mel=local_mel, verbose=verbose, **kwargs) - seg_t["text"] = result["text"] - output["segments"].append( - { - "start": seg_t["start"], - "end": seg_t["end"], - "language": result["language"], - "text": result["text"], - "seg-text": [x["text"] for x in result["segments"]], - "seg-start": [x["start"] for x in result["segments"]], - "seg-end": [x["end"] for x in result["segments"]], + segments = [] + batch_size = batch_size or self._batch_size + for idx, out in enumerate(self.__call__(data(audio, vad_segments), batch_size=batch_size)): + text = out['text'] + if batch_size in [0, 1, None]: + text = text[0] + segments.append( + { + "text": out['text'], + "start": round(vad_segments[idx]['start'], 3), + "end": round(vad_segments[idx]['end'], 3) } ) + + if del_tokenizer: + self.tokenizer = None - output["language"] = output["segments"][0]["language"] + return {"segments": segments, "language": language} - return output + + def detect_language(self, audio: np.ndarray): + segment = log_mel_spectrogram(audio[: N_SAMPLES], padding=0) + encoder_output = self.model.encode(segment) + results = self.model.model.detect_language(encoder_output) + language_token, language_probability = results[0][0] + language = language_token[2:-2] + print(f"Detected language: {language} ({language_probability:.2f}) in first 30s of audio...") + return language + +if __name__ == "__main__": + main_type = "simple" + import time + + import jiwer + from tqdm import tqdm + from whisper.normalizers import EnglishTextNormalizer + + from benchmark.tedlium import parse_tedlium_annos + + if main_type == "complex": + from faster_whisper.tokenizer import Tokenizer + from faster_whisper.transcribe import TranscriptionOptions + from faster_whisper.vad import (SpeechTimestampsMap, + get_speech_timestamps) + + from whisperx.vad import load_vad_model, merge_chunks + + from .audio import SAMPLE_RATE, load_audio, log_mel_spectrogram + faster_t_options = TranscriptionOptions( + beam_size=5, + best_of=5, + patience=1, + length_penalty=1, + temperatures=[0.0, 0.2, 0.4, 0.6, 0.8, 1.0], + compression_ratio_threshold=2.4, + log_prob_threshold=-1.0, + no_speech_threshold=0.6, + condition_on_previous_text=False, + initial_prompt=None, + prefix=None, + suppress_blank=True, + suppress_tokens=[-1], + without_timestamps=True, + max_initial_timestamp=0.0, + word_timestamps=False, + prepend_punctuations="\"'“¿([{-", + append_punctuations="\"'.。,,!!??::”)]}、" + ) + whisper_arch = "large-v2" + device = "cuda" + batch_size = 16 + model = WhisperModel(whisper_arch, device="cuda", compute_type="float16",) + tokenizer = Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task="transcribe", language="en") + model = FasterWhisperPipeline(model, tokenizer, faster_t_options, device=-1) + fn = "DanielKahneman_2010.wav" + wav_dir = f"/tmp/test/wav/" + vad_model = load_vad_model("cuda", 0.6, 0.3) + audio = load_audio(os.path.join(wav_dir, fn)) + vad_segments = vad_model({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": SAMPLE_RATE}) + vad_segments = merge_chunks(vad_segments, 30) + + def data(audio, segments): + for seg in segments: + f1 = int(seg['start'] * SAMPLE_RATE) + f2 = int(seg['end'] * SAMPLE_RATE) + # print(f2-f1) + yield {'inputs': audio[f1:f2]} + vad_method="pyannote" + + wav_dir = f"/tmp/test/wav/" + wer_li = [] + time_li = [] + for fn in os.listdir(wav_dir): + if fn == "RobertGupta_2010U.wav": + continue + base_fn = fn.split('.')[0] + audio_fp = os.path.join(wav_dir, fn) + + audio = load_audio(audio_fp) + t1 = time.time() + if vad_method == "pyannote": + vad_segments = vad_model({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": SAMPLE_RATE}) + vad_segments = merge_chunks(vad_segments, 30) + elif vad_method == "silero": + vad_segments = get_speech_timestamps(audio, threshold=0.5, max_speech_duration_s=30) + vad_segments = [{"start": x["start"] / SAMPLE_RATE, "end": x["end"] / SAMPLE_RATE} for x in vad_segments] + new_segs = [] + curr_start = vad_segments[0]['start'] + curr_end = vad_segments[0]['end'] + for seg in vad_segments[1:]: + if seg['end'] - curr_start > 30: + new_segs.append({"start": curr_start, "end": curr_end}) + curr_start = seg['start'] + curr_end = seg['end'] + else: + curr_end = seg['end'] + new_segs.append({"start": curr_start, "end": curr_end}) + vad_segments = new_segs + text = [] + # for idx, out in tqdm(enumerate(model(data(audio_fp, vad_segments), batch_size=batch_size)), total=len(vad_segments)): + for idx, out in enumerate(model(data(audio, vad_segments), batch_size=batch_size)): + text.append(out['text']) + t2 = time.time() + if batch_size == 1: + text = [x[0] for x in text] + text = " ".join(text) + + normalizer = EnglishTextNormalizer() + text = normalizer(text) + gt_corpus = normalizer(parse_tedlium_annos(base_fn, "/tmp/test/")) + + wer_result = jiwer.wer(gt_corpus, text) + print("WER: %.2f \t time: %.2f \t [%s]" % (wer_result * 100, t2-t1, fn)) + + wer_li.append(wer_result) + time_li.append(t2-t1) + print("# Avg Mean...") + print("WER: %.2f" % (sum(wer_li) * 100/len(wer_li))) + print("Time: %.2f" % (sum(time_li)/len(time_li))) + elif main_type == "simple": + model = load_model( + "large-v2", + device="cuda", + language="en", + ) + + wav_dir = f"/tmp/test/wav/" + wer_li = [] + time_li = [] + for fn in os.listdir(wav_dir): + if fn == "RobertGupta_2010U.wav": + continue + # fn = "DanielKahneman_2010.wav" + base_fn = fn.split('.')[0] + audio_fp = os.path.join(wav_dir, fn) + + audio = load_audio(audio_fp) + t1 = time.time() + out = model.transcribe(audio_fp, batch_size=8)["segments"] + t2 = time.time() + + text = " ".join([x['text'] for x in out]) + normalizer = EnglishTextNormalizer() + text = normalizer(text) + gt_corpus = normalizer(parse_tedlium_annos(base_fn, "/tmp/test/")) + + wer_result = jiwer.wer(gt_corpus, text) + print("WER: %.2f \t time: %.2f \t [%s]" % (wer_result * 100, t2-t1, fn)) + + wer_li.append(wer_result) + time_li.append(t2-t1) + print("# Avg Mean...") + print("WER: %.2f" % (sum(wer_li) * 100/len(wer_li))) + print("Time: %.2f" % (sum(time_li)/len(time_li))) diff --git a/whisperx/assets/mel_filters.npz b/whisperx/assets/mel_filters.npz new file mode 100644 index 0000000..1a78392 Binary files /dev/null and b/whisperx/assets/mel_filters.npz differ diff --git a/whisperx/audio.py b/whisperx/audio.py new file mode 100644 index 0000000..513ab7c --- /dev/null +++ b/whisperx/audio.py @@ -0,0 +1,147 @@ +import os +from functools import lru_cache +from typing import Optional, Union + +import ffmpeg +import numpy as np +import torch +import torch.nn.functional as F + +from .utils import exact_div + +# hard-coded audio hyperparameters +SAMPLE_RATE = 16000 +N_FFT = 400 +N_MELS = 80 +HOP_LENGTH = 160 +CHUNK_LENGTH = 30 +N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk +N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000 frames in a mel spectrogram input + +N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2 +FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame +TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token + + +def load_audio(file: str, sr: int = SAMPLE_RATE): + """ + Open an audio file and read as mono waveform, resampling as necessary + + Parameters + ---------- + file: str + The audio file to open + + sr: int + The sample rate to resample the audio if necessary + + Returns + ------- + A NumPy array containing the audio waveform, in float32 dtype. + """ + try: + # This launches a subprocess to decode audio while down-mixing and resampling as necessary. + # Requires the ffmpeg CLI and `ffmpeg-python` package to be installed. + out, _ = ( + ffmpeg.input(file, threads=0) + .output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=sr) + .run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True) + ) + except ffmpeg.Error as e: + raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e + + return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0 + + +def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1): + """ + Pad or trim the audio array to N_SAMPLES, as expected by the encoder. + """ + if torch.is_tensor(array): + if array.shape[axis] > length: + array = array.index_select( + dim=axis, index=torch.arange(length, device=array.device) + ) + + if array.shape[axis] < length: + pad_widths = [(0, 0)] * array.ndim + pad_widths[axis] = (0, length - array.shape[axis]) + array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes]) + else: + if array.shape[axis] > length: + array = array.take(indices=range(length), axis=axis) + + if array.shape[axis] < length: + pad_widths = [(0, 0)] * array.ndim + pad_widths[axis] = (0, length - array.shape[axis]) + array = np.pad(array, pad_widths) + + return array + + +@lru_cache(maxsize=None) +def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor: + """ + load the mel filterbank matrix for projecting STFT into a Mel spectrogram. + Allows decoupling librosa dependency; saved using: + + np.savez_compressed( + "mel_filters.npz", + mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80), + ) + """ + assert n_mels == 80, f"Unsupported n_mels: {n_mels}" + with np.load( + os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz") + ) as f: + return torch.from_numpy(f[f"mel_{n_mels}"]).to(device) + + +def log_mel_spectrogram( + audio: Union[str, np.ndarray, torch.Tensor], + n_mels: int = N_MELS, + padding: int = 0, + device: Optional[Union[str, torch.device]] = None, +): + """ + Compute the log-Mel spectrogram of + + Parameters + ---------- + audio: Union[str, np.ndarray, torch.Tensor], shape = (*) + The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz + + n_mels: int + The number of Mel-frequency filters, only 80 is supported + + padding: int + Number of zero samples to pad to the right + + device: Optional[Union[str, torch.device]] + If given, the audio tensor is moved to this device before STFT + + Returns + ------- + torch.Tensor, shape = (80, n_frames) + A Tensor that contains the Mel spectrogram + """ + if not torch.is_tensor(audio): + if isinstance(audio, str): + audio = load_audio(audio) + audio = torch.from_numpy(audio) + + if device is not None: + audio = audio.to(device) + if padding > 0: + audio = F.pad(audio, (0, padding)) + window = torch.hann_window(N_FFT).to(audio.device) + stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True) + magnitudes = stft[..., :-1].abs() ** 2 + + filters = mel_filters(audio.device, n_mels) + mel_spec = filters @ magnitudes + + log_spec = torch.clamp(mel_spec, min=1e-10).log10() + log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) + log_spec = (log_spec + 4.0) / 4.0 + return log_spec diff --git a/whisperx/transcribe.py b/whisperx/transcribe.py index ed918e0..c3719ba 100644 --- a/whisperx/transcribe.py +++ b/whisperx/transcribe.py @@ -1,37 +1,30 @@ import argparse -import os import gc +import os import warnings -from typing import TYPE_CHECKING, Optional, Tuple, Union + import numpy as np import torch -import tempfile -import ffmpeg -from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE -from whisper.audio import SAMPLE_RATE -from whisper.utils import ( - optional_float, - optional_int, - str2bool, -) -from .alignment import load_align_model, align -from .asr import transcribe, transcribe_with_vad +from .alignment import align, load_align_model +from .asr import load_model +from .audio import load_audio from .diarize import DiarizationPipeline, assign_word_speakers -from .utils import get_writer -from .vad import load_vad_model +from .utils import (LANGUAGES, TO_LANGUAGE_CODE, get_writer, optional_float, + optional_int, str2bool) + def cli(): - from whisper import available_models - # fmt: off parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe") - parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use") + parser.add_argument("--model", default="small", help="name of the Whisper model to use") parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default") parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference") + parser.add_argument("--batch_size", default=8, type=int, help="device to use for PyTorch inference") + parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs") - parser.add_argument("--output_format", "-f", type=str, default="all", choices=["all", "srt", "srt-word", "vtt", "txt", "tsv", "ass", "ass-char", "pickle", "vad"], help="format of the output file; if not specified, all available formats will be produced") + parser.add_argument("--output_format", "-f", type=str, default="all", choices=["all", "srt", "vtt", "txt", "tsv", "json"], help="format of the output file; if not specified, all available formats will be produced") parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages") parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')") @@ -39,13 +32,10 @@ def cli(): # alignment params parser.add_argument("--align_model", default=None, help="Name of phoneme-level ASR model to do alignment") - parser.add_argument("--align_extend", default=2, type=float, help="Seconds before and after to extend the whisper segments for alignment (if not using VAD).") - parser.add_argument("--align_from_prev", default=True, type=bool, help="Whether to clip the alignment start time of current segment to the end time of the last aligned word of the previous segment (if not using VAD)") parser.add_argument("--interpolate_method", default="nearest", choices=["nearest", "linear", "ignore"], help="For word .srt, method to assign timestamps to non-aligned words, or merge them into neighbouring.") parser.add_argument("--no_align", action='store_true', help="Do not perform phoneme alignment") # vad params - parser.add_argument("--vad_filter", type=str2bool, default=True, help="Whether to pre-segment audio with VAD, highly recommended! Produces more accurate alignment + timestamp see WhisperX paper https://arxiv.org/abs/2303.00747") parser.add_argument("--vad_onset", type=float, default=0.500, help="Onset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected") parser.add_argument("--vad_offset", type=float, default=0.363, help="Offset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected.") @@ -69,9 +59,14 @@ def cli(): parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed") parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed") parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence") - parser.add_argument("--word_timestamps", type=str2bool, default=False, help="(experimental) extract word-level timestamps and refine the results based on them") - parser.add_argument("--prepend_punctuations", type=str, default="\"\'“¿([{-", help="if word_timestamps is True, merge these punctuation symbols with the next word") - parser.add_argument("--append_punctuations", type=str, default="\"\'.。,,!!??::”)]}、", help="if word_timestamps is True, merge these punctuation symbols with the previous word") + + parser.add_argument("--max_line_width", type=optional_int, default=None, help="(not possible with --no_align) the maximum number of characters in a line before breaking the line") + parser.add_argument("--max_line_count", type=optional_int, default=None, help="(requires --no_align) the maximum number of lines in a segment") + parser.add_argument("--highlight_words", type=str2bool, default=False, help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt") + + # parser.add_argument("--word_timestamps", type=str2bool, default=False, help="(experimental) extract word-level timestamps and refine the results based on them") + # parser.add_argument("--prepend_punctuations", type=str, default="\"\'“¿([{-", help="if word_timestamps is True, merge these punctuation symbols with the next word") + # parser.add_argument("--append_punctuations", type=str, default="\"\'.。,,!!??::”)]}、", help="if word_timestamps is True, merge these punctuation symbols with the previous word") parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS") parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face Access Token to access PyAnnote gated models") @@ -81,7 +76,7 @@ def cli(): args = parser.parse_args().__dict__ model_name: str = args.pop("model") - model_dir: str = args.pop("model_dir") + batch_size: int = args.pop("batch_size") output_dir: str = args.pop("output_dir") output_format: str = args.pop("output_format") device: str = args.pop("device") @@ -93,13 +88,10 @@ def cli(): os.makedirs(tmp_dir, exist_ok=True) align_model: str = args.pop("align_model") - align_extend: float = args.pop("align_extend") - align_from_prev: bool = args.pop("align_from_prev") interpolate_method: str = args.pop("interpolate_method") no_align: bool = args.pop("no_align") hf_token: str = args.pop("hf_token") - vad_filter: bool = args.pop("vad_filter") vad_onset: float = args.pop("vad_onset") vad_offset: float = args.pop("vad_offset") @@ -107,18 +99,7 @@ def cli(): min_speakers: int = args.pop("min_speakers") max_speakers: int = args.pop("max_speakers") - if vad_filter: - from pyannote.audio import Pipeline - from pyannote.audio import Model, Pipeline - vad_model = load_vad_model(torch.device(device), vad_onset, vad_offset, use_auth_token=hf_token) - else: - vad_model = None - - # if model_flush: - # print(">>Model flushing activated... Only loading model after ASR stage") - # del align_model - # align_model = "" - + # TODO: check model loading works. if model_name.endswith(".en") and args["language"] not in {"en", "English"}: if args["language"] is not None: @@ -136,39 +117,43 @@ def cli(): if (threads := args.pop("threads")) > 0: torch.set_num_threads(threads) - from whisper import load_model + asr_options = { + "beam_size": args.pop("beam_size"), + "patience": args.pop("patience"), + "length_penalty": args.pop("length_penalty"), + "temperatures": temperature, + "compression_ratio_threshold": args.pop("compression_ratio_threshold"), + "log_prob_threshold": args.pop("logprob_threshold"), + "no_speech_threshold": args.pop("no_speech_threshold"), + "condition_on_previous_text": False, + "initial_prompt": args.pop("initial_prompt"), + } writer = get_writer(output_format, output_dir) - + word_options = ["highlight_words", "max_line_count", "max_line_width"] + if no_align: + for option in word_options: + if args[option]: + parser.error(f"--{option} requires --word_timestamps True") + if args["max_line_count"] and not args["max_line_width"]: + warnings.warn("--max_line_count has no effect without --max_line_width") + writer_args = {arg: args.pop(arg) for arg in word_options} + # Part 1: VAD & ASR Loop results = [] tmp_results = [] - model = load_model(model_name, device=device, download_root=model_dir) - for audio_path in args.pop("audio"): - input_audio_path = audio_path - tfile = None + # model = load_model(model_name, device=device, download_root=model_dir) + model = load_model(model_name, device=device, language=args['language'], asr_options=asr_options, vad_options={"vad_onset": vad_onset, "vad_offset": vad_offset},) + for audio_path in args.pop("audio"): + audio = load_audio(audio_path) # >> VAD & ASR - if vad_model is not None: - if not audio_path.endswith(".wav"): - print(">>VAD requires .wav format, converting to wav as a tempfile...") - audio_basename = os.path.splitext(os.path.basename(audio_path))[0] - if tmp_dir is not None: - input_audio_path = os.path.join(tmp_dir, audio_basename + ".wav") - else: - input_audio_path = os.path.join(os.path.dirname(audio_path), audio_basename + ".wav") - ffmpeg.input(audio_path, threads=0).output(input_audio_path, ac=1, ar=SAMPLE_RATE).run(cmd=["ffmpeg"]) - print(">>Performing VAD...") - result = transcribe_with_vad(model, input_audio_path, vad_model, temperature=temperature, **args) - else: - print(">>Performing transcription...") - result = transcribe(model, input_audio_path, temperature=temperature, **args) - - results.append((result, input_audio_path)) + print(">>Performing transcription...") + result = model.transcribe(audio, batch_size=batch_size) + results.append((result, audio_path)) # Unload Whisper and VAD del model - del vad_model gc.collect() torch.cuda.empty_cache() @@ -178,17 +163,22 @@ def cli(): results = [] align_language = args["language"] if args["language"] is not None else "en" # default to loading english if not specified align_model, align_metadata = load_align_model(align_language, device, model_name=align_model) - for result, input_audio_path in tmp_results: + for result, audio_path in tmp_results: # >> Align + if len(tmp_results) > 1: + input_audio = audio_path + else: + # lazily load audio from part 1 + input_audio = audio + if align_model is not None and len(result["segments"]) > 0: if result.get("language", "en") != align_metadata["language"]: # load new language print(f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language...") align_model, align_metadata = load_align_model(result["language"], device) print(">>Performing alignment...") - result = align(result["segments"], align_model, align_metadata, input_audio_path, device, - extend_duration=align_extend, start_from_previous=align_from_prev, interpolate_method=interpolate_method) - results.append((result, input_audio_path)) + result = align(result["segments"], align_model, align_metadata, input_audio, device, interpolate_method=interpolate_method) + results.append((result, audio_path)) # Unload align model del align_model @@ -210,11 +200,7 @@ def cli(): # >> Write for result, audio_path in results: - writer(result, audio_path) - - # cleanup - if input_audio_path != audio_path: - os.remove(input_audio_path) + writer(result, audio_path, writer_args) if __name__ == "__main__": cli() \ No newline at end of file diff --git a/whisperx/utils.py b/whisperx/utils.py index 14e298b..3401a84 100644 --- a/whisperx/utils.py +++ b/whisperx/utils.py @@ -1,280 +1,301 @@ +import json import os +import re +import sys import zlib -from typing import Callable, TextIO, Iterator, Tuple -import pandas as pd -import numpy as np +from typing import Callable, Optional, TextIO -def interpolate_nans(x, method='nearest'): - if x.notnull().sum() > 1: - return x.interpolate(method=method).ffill().bfill() +LANGUAGES = { + "en": "english", + "zh": "chinese", + "de": "german", + "es": "spanish", + "ru": "russian", + "ko": "korean", + "fr": "french", + "ja": "japanese", + "pt": "portuguese", + "tr": "turkish", + "pl": "polish", + "ca": "catalan", + "nl": "dutch", + "ar": "arabic", + "sv": "swedish", + "it": "italian", + "id": "indonesian", + "hi": "hindi", + "fi": "finnish", + "vi": "vietnamese", + "he": "hebrew", + "uk": "ukrainian", + "el": "greek", + "ms": "malay", + "cs": "czech", + "ro": "romanian", + "da": "danish", + "hu": "hungarian", + "ta": "tamil", + "no": "norwegian", + "th": "thai", + "ur": "urdu", + "hr": "croatian", + "bg": "bulgarian", + "lt": "lithuanian", + "la": "latin", + "mi": "maori", + "ml": "malayalam", + "cy": "welsh", + "sk": "slovak", + "te": "telugu", + "fa": "persian", + "lv": "latvian", + "bn": "bengali", + "sr": "serbian", + "az": "azerbaijani", + "sl": "slovenian", + "kn": "kannada", + "et": "estonian", + "mk": "macedonian", + "br": "breton", + "eu": "basque", + "is": "icelandic", + "hy": "armenian", + "ne": "nepali", + "mn": "mongolian", + "bs": "bosnian", + "kk": "kazakh", + "sq": "albanian", + "sw": "swahili", + "gl": "galician", + "mr": "marathi", + "pa": "punjabi", + "si": "sinhala", + "km": "khmer", + "sn": "shona", + "yo": "yoruba", + "so": "somali", + "af": "afrikaans", + "oc": "occitan", + "ka": "georgian", + "be": "belarusian", + "tg": "tajik", + "sd": "sindhi", + "gu": "gujarati", + "am": "amharic", + "yi": "yiddish", + "lo": "lao", + "uz": "uzbek", + "fo": "faroese", + "ht": "haitian creole", + "ps": "pashto", + "tk": "turkmen", + "nn": "nynorsk", + "mt": "maltese", + "sa": "sanskrit", + "lb": "luxembourgish", + "my": "myanmar", + "bo": "tibetan", + "tl": "tagalog", + "mg": "malagasy", + "as": "assamese", + "tt": "tatar", + "haw": "hawaiian", + "ln": "lingala", + "ha": "hausa", + "ba": "bashkir", + "jw": "javanese", + "su": "sundanese", +} + +# language code lookup by name, with a few language aliases +TO_LANGUAGE_CODE = { + **{language: code for code, language in LANGUAGES.items()}, + "burmese": "my", + "valencian": "ca", + "flemish": "nl", + "haitian": "ht", + "letzeburgesch": "lb", + "pushto": "ps", + "panjabi": "pa", + "moldavian": "ro", + "moldovan": "ro", + "sinhalese": "si", + "castilian": "es", +} + + +system_encoding = sys.getdefaultencoding() + +if system_encoding != "utf-8": + + def make_safe(string): + # replaces any character not representable using the system default encoding with an '?', + # avoiding UnicodeEncodeError (https://github.com/openai/whisper/discussions/729). + return string.encode(system_encoding, errors="replace").decode(system_encoding) + +else: + + def make_safe(string): + # utf-8 can encode any Unicode code point, so no need to do the round-trip encoding + return string + + +def exact_div(x, y): + assert x % y == 0 + return x // y + + +def str2bool(string): + str2val = {"True": True, "False": False} + if string in str2val: + return str2val[string] else: - return x.ffill().bfill() - - -def write_txt(transcript: Iterator[dict], file: TextIO): - for segment in transcript: - print(segment['text'].strip(), file=file, flush=True) + raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}") -def write_vtt(transcript: Iterator[dict], file: TextIO): - print("WEBVTT\n", file=file) - for segment in transcript: - print( - f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n" - f"{segment['text'].strip().replace('-->', '->')}\n", - file=file, - flush=True, +def optional_int(string): + return None if string == "None" else int(string) + + +def optional_float(string): + return None if string == "None" else float(string) + + +def compression_ratio(text) -> float: + text_bytes = text.encode("utf-8") + return len(text_bytes) / len(zlib.compress(text_bytes)) + + +def format_timestamp( + seconds: float, always_include_hours: bool = False, decimal_marker: str = "." +): + assert seconds >= 0, "non-negative timestamp expected" + milliseconds = round(seconds * 1000.0) + + hours = milliseconds // 3_600_000 + milliseconds -= hours * 3_600_000 + + minutes = milliseconds // 60_000 + milliseconds -= minutes * 60_000 + + seconds = milliseconds // 1_000 + milliseconds -= seconds * 1_000 + + hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else "" + return ( + f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}" + ) + + +class ResultWriter: + extension: str + + def __init__(self, output_dir: str): + self.output_dir = output_dir + + def __call__(self, result: dict, audio_path: str, options: dict): + audio_basename = os.path.basename(audio_path) + audio_basename = os.path.splitext(audio_basename)[0] + output_path = os.path.join( + self.output_dir, audio_basename + "." + self.extension ) -def write_tsv(transcript: Iterator[dict], file: TextIO): - print("start", "end", "text", sep="\t", file=file) - for segment in transcript: - print(segment['start'], file=file, end="\t") - print(segment['end'], file=file, end="\t") - print(segment['text'].strip().replace("\t", " "), file=file, flush=True) + with open(output_path, "w", encoding="utf-8") as f: + self.write_result(result, file=f, options=options) + + def write_result(self, result: dict, file: TextIO, options: dict): + raise NotImplementedError -def write_srt(transcript: Iterator[dict], file: TextIO): - """ - Write a transcript to a file in SRT format. +class WriteTXT(ResultWriter): + extension: str = "txt" - Example usage: - from pathlib import Path - from whisper.utils import write_srt - - result = transcribe(model, audio_path, temperature=temperature, **args) - - # save SRT - audio_basename = Path(audio_path).stem - with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt: - write_srt(result["segments"], file=srt) - """ - for i, segment in enumerate(transcript, start=1): - # write srt lines - print( - f"{i}\n" - f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> " - f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n" - f"{segment['text'].strip().replace('-->', '->')}\n", - file=file, - flush=True, - ) + def write_result(self, result: dict, file: TextIO, options: dict): + for segment in result["segments"]: + print(segment["text"].strip(), file=file, flush=True) -def write_ass(transcript: Iterator[dict], - file: TextIO, - resolution: str = "word", - color: str = None, underline=True, - prefmt: str = None, suffmt: str = None, - font: str = None, font_size: int = 24, - strip=True, **kwargs): - """ - Credit: https://github.com/jianfch/stable-ts/blob/ff79549bd01f764427879f07ecd626c46a9a430a/stable_whisper/text_output.py - Generate Advanced SubStation Alpha (ass) file from results to - display both phrase-level & word-level timestamp simultaneously by: - -using segment-level timestamps display phrases as usual - -using word-level timestamps change formats (e.g. color/underline) of the word in the displayed segment - Note: ass file is used in the same way as srt, vtt, etc. - Parameters - ---------- - transcript: dict - results from modified model - file: TextIO - file object to write to - resolution: str - "word" or "char", timestamp resolution to highlight. - color: str - color code for a word at its corresponding timestamp - reverse order hexadecimal RGB value (e.g. FF0000 is full intensity blue. Default: 00FF00) - underline: bool - whether to underline a word at its corresponding timestamp - prefmt: str - used to specify format for word-level timestamps (must be use with 'suffmt' and overrides 'color'&'underline') - appears as such in the .ass file: - Hi, {}how{} are you? - reference [Appendix A: Style override codes] in http://www.tcax.org/docs/ass-specs.htm - suffmt: str - used to specify format for word-level timestamps (must be use with 'prefmt' and overrides 'color'&'underline') - appears as such in the .ass file: - Hi, {}how{} are you? - reference [Appendix A: Style override codes] in http://www.tcax.org/docs/ass-specs.htm - font: str - word font (default: Arial) - font_size: int - word font size (default: 48) - kwargs: - used for format styles: - 'Name', 'Fontname', 'Fontsize', 'PrimaryColour', 'SecondaryColour', 'OutlineColour', 'BackColour', 'Bold', - 'Italic', 'Underline', 'StrikeOut', 'ScaleX', 'ScaleY', 'Spacing', 'Angle', 'BorderStyle', 'Outline', - 'Shadow', 'Alignment', 'MarginL', 'MarginR', 'MarginV', 'Encoding' +class SubtitlesWriter(ResultWriter): + always_include_hours: bool + decimal_marker: str - """ + def iterate_result(self, result: dict, options: dict): + raw_max_line_width: Optional[int] = options["max_line_width"] + max_line_count: Optional[int] = options["max_line_count"] + highlight_words: bool = options["highlight_words"] + max_line_width = 1000 if raw_max_line_width is None else raw_max_line_width + preserve_segments = max_line_count is None or raw_max_line_width is None - fmt_style_dict = {'Name': 'Default', 'Fontname': 'Arial', 'Fontsize': '48', 'PrimaryColour': '&Hffffff', - 'SecondaryColour': '&Hffffff', 'OutlineColour': '&H0', 'BackColour': '&H0', 'Bold': '0', - 'Italic': '0', 'Underline': '0', 'StrikeOut': '0', 'ScaleX': '100', 'ScaleY': '100', - 'Spacing': '0', 'Angle': '0', 'BorderStyle': '1', 'Outline': '1', 'Shadow': '0', - 'Alignment': '2', 'MarginL': '10', 'MarginR': '10', 'MarginV': '10', 'Encoding': '0'} + def iterate_subtitles(): + line_len = 0 + line_count = 1 + # the next subtitle to yield (a list of word timings with whitespace) + subtitle: list[dict] = [] + last = result["segments"][0]["words"][0]["start"] + for segment in result["segments"]: + for i, original_timing in enumerate(segment["words"]): + timing = original_timing.copy() + long_pause = not preserve_segments and timing["start"] - last > 3.0 + has_room = line_len + len(timing["word"]) <= max_line_width + seg_break = i == 0 and len(subtitle) > 0 and preserve_segments + if line_len > 0 and has_room and not long_pause and not seg_break: + # line continuation + line_len += len(timing["word"]) + else: + # new line + timing["word"] = timing["word"].strip() + if ( + len(subtitle) > 0 + and max_line_count is not None + and (long_pause or line_count >= max_line_count) + or seg_break + ): + # subtitle break + yield subtitle + subtitle = [] + line_count = 1 + elif line_len > 0: + # line break + line_count += 1 + timing["word"] = "\n" + timing["word"] + line_len = len(timing["word"].strip()) + subtitle.append(timing) + last = timing["start"] + if len(subtitle) > 0: + yield subtitle - for k, v in filter(lambda x: 'colour' in x[0].lower() and not str(x[1]).startswith('&H'), kwargs.items()): - kwargs[k] = f'&H{kwargs[k]}' + if "words" in result["segments"][0]: + for subtitle in iterate_subtitles(): + subtitle_start = self.format_timestamp(subtitle[0]["start"]) + subtitle_end = self.format_timestamp(subtitle[-1]["end"]) + subtitle_text = "".join([word["word"] for word in subtitle]) + if highlight_words: + last = subtitle_start + all_words = [timing["word"] for timing in subtitle] + for i, this_word in enumerate(subtitle): + start = self.format_timestamp(this_word["start"]) + end = self.format_timestamp(this_word["end"]) + if last != start: + yield last, start, subtitle_text - fmt_style_dict.update((k, v) for k, v in kwargs.items() if k in fmt_style_dict) - - if font: - fmt_style_dict.update(Fontname=font) - if font_size: - fmt_style_dict.update(Fontsize=font_size) - - fmts = f'Format: {", ".join(map(str, fmt_style_dict.keys()))}' - - styles = f'Style: {",".join(map(str, fmt_style_dict.values()))}' - - ass_str = f'[Script Info]\nScriptType: v4.00+\nPlayResX: 384\nPlayResY: 288\nScaledBorderAndShadow: yes\n\n' \ - f'[V4+ Styles]\n{fmts}\n{styles}\n\n' \ - f'[Events]\nFormat: Layer, Start, End, Style, Name, MarginL, MarginR, MarginV, Effect, Text\n\n' - - if prefmt or suffmt: - if suffmt: - assert prefmt, 'prefmt must be used along with suffmt' + yield start, end, "".join( + [ + re.sub(r"^(\s*)(.*)$", r"\1\2", word) + if j == i + else word + for j, word in enumerate(all_words) + ] + ) + last = end + else: + yield subtitle_start, subtitle_end, subtitle_text else: - suffmt = r'\r' - else: - if not color: - color = 'HFF00' - underline_code = r'\u1' if underline else '' - - prefmt = r'{\1c&' + f'{color.upper()}&{underline_code}' + '}' - suffmt = r'{\r}' - - def secs_to_hhmmss(secs: Tuple[float, int]): - mm, ss = divmod(secs, 60) - hh, mm = divmod(mm, 60) - return f'{hh:0>1.0f}:{mm:0>2.0f}:{ss:0>2.2f}' - - - def dialogue(chars: str, start: float, end: float, idx_0: int, idx_1: int) -> str: - if idx_0 == -1: - text = chars - else: - text = f'{chars[:idx_0]}{prefmt}{chars[idx_0:idx_1]}{suffmt}{chars[idx_1:]}' - return f"Dialogue: 0,{secs_to_hhmmss(start)},{secs_to_hhmmss(end)}," \ - f"Default,,0,0,0,,{text.strip() if strip else text}" - - if resolution == "word": - resolution_key = "word-segments" - elif resolution == "char": - resolution_key = "char-segments" - else: - raise ValueError(".ass resolution should be 'word' or 'char', not ", resolution) - - ass_arr = [] - - for segment in transcript: - # if "12" in segment['text']: - # import pdb; pdb.set_trace() - if resolution_key in segment: - res_segs = pd.DataFrame(segment[resolution_key]) - prev = segment['start'] - if "speaker" in segment: - speaker_str = f"[{segment['speaker']}]: " - else: - speaker_str = "" - for cdx, crow in res_segs.iterrows(): - if not np.isnan(crow['start']): - if resolution == "char": - idx_0 = cdx - idx_1 = cdx + 1 - elif resolution == "word": - idx_0 = int(crow["segment-text-start"]) - idx_1 = int(crow["segment-text-end"]) - # fill gap - if crow['start'] > prev: - filler_ts = { - "chars": speaker_str + segment['text'], - "start": prev, - "end": crow['start'], - "idx_0": -1, - "idx_1": -1 - } - - ass_arr.append(filler_ts) - # highlight current word - f_word_ts = { - "chars": speaker_str + segment['text'], - "start": crow['start'], - "end": crow['end'], - "idx_0": idx_0 + len(speaker_str), - "idx_1": idx_1 + len(speaker_str) - } - ass_arr.append(f_word_ts) - prev = crow['end'] - - ass_str += '\n'.join(map(lambda x: dialogue(**x), ass_arr)) - - file.write(ass_str) - - -from whisper.utils import SubtitlesWriter, ResultWriter, WriteTXT, WriteVTT, WriteSRT, WriteTSV, WriteJSON, format_timestamp - -class WriteASS(ResultWriter): - extension: str = "ass" - - def write_result(self, result: dict, file: TextIO): - write_ass(result["segments"], file, resolution="word") - -class WriteASSchar(ResultWriter): - extension: str = "ass" - - def write_result(self, result: dict, file: TextIO): - write_ass(result["segments"], file, resolution="char") - -class WritePickle(ResultWriter): - extension: str = "ass" - - def write_result(self, result: dict, file: TextIO): - pd.DataFrame(result["segments"]).to_pickle(file) - -class WriteSRTWord(ResultWriter): - extension: str = "word.srt" - always_include_hours: bool = True - decimal_marker: str = "," - - def iterate_result(self, result: dict): - for segment in result["word_segments"]: - segment_start = self.format_timestamp(segment["start"]) - segment_end = self.format_timestamp(segment["end"]) - segment_text = segment["text"].strip().replace("-->", "->") - - if word_timings := segment.get("words", None): - all_words = [timing["word"] for timing in word_timings] - all_words[0] = all_words[0].strip() # remove the leading space, if any - last = segment_start - for i, this_word in enumerate(word_timings): - start = self.format_timestamp(this_word["start"]) - end = self.format_timestamp(this_word["end"]) - if last != start: - yield last, start, segment_text - - yield start, end, "".join( - [ - f"{word}" if j == i else word - for j, word in enumerate(all_words) - ] - ) - last = end - - if last != segment_end: - yield last, segment_end, segment_text - else: + for segment in result["segments"]: + segment_start = self.format_timestamp(segment["start"]) + segment_end = self.format_timestamp(segment["end"]) + segment_text = segment["text"].strip().replace("-->", "->") yield segment_start, segment_end, segment_text - def write_result(self, result: dict, file: TextIO): - if "word_segments" not in result: - return - for i, (start, end, text) in enumerate(self.iterate_result(result), start=1): - print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True) - def format_timestamp(self, seconds: float): return format_timestamp( seconds=seconds, @@ -282,36 +303,81 @@ class WriteSRTWord(ResultWriter): decimal_marker=self.decimal_marker, ) -def get_writer(output_format: str, output_dir: str) -> Callable[[dict, TextIO], None]: + +class WriteVTT(SubtitlesWriter): + extension: str = "vtt" + always_include_hours: bool = False + decimal_marker: str = "." + + def write_result(self, result: dict, file: TextIO, options: dict): + print("WEBVTT\n", file=file) + for start, end, text in self.iterate_result(result, options): + print(f"{start} --> {end}\n{text}\n", file=file, flush=True) + + +class WriteSRT(SubtitlesWriter): + extension: str = "srt" + always_include_hours: bool = True + decimal_marker: str = "," + + def write_result(self, result: dict, file: TextIO, options: dict): + for i, (start, end, text) in enumerate( + self.iterate_result(result, options), start=1 + ): + print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True) + + +class WriteTSV(ResultWriter): + """ + Write a transcript to a file in TSV (tab-separated values) format containing lines like: + \t\t + + Using integer milliseconds as start and end times means there's no chance of interference from + an environment setting a language encoding that causes the decimal in a floating point number + to appear as a comma; also is faster and more efficient to parse & store, e.g., in C++. + """ + + extension: str = "tsv" + + def write_result(self, result: dict, file: TextIO, options: dict): + print("start", "end", "text", sep="\t", file=file) + for segment in result["segments"]: + print(round(1000 * segment["start"]), file=file, end="\t") + print(round(1000 * segment["end"]), file=file, end="\t") + print(segment["text"].strip().replace("\t", " "), file=file, flush=True) + + +class WriteJSON(ResultWriter): + extension: str = "json" + + def write_result(self, result: dict, file: TextIO, options: dict): + json.dump(result, file) + + +def get_writer( + output_format: str, output_dir: str +) -> Callable[[dict, TextIO, dict], None]: writers = { "txt": WriteTXT, "vtt": WriteVTT, "srt": WriteSRT, "tsv": WriteTSV, - "ass": WriteASS, - "srt-word": WriteSRTWord, - # "ass-char": WriteASSchar, - # "pickle": WritePickle, - # "json": WriteJSON, - } - - writers_other = { - "pkl": WritePickle, - "ass-char": WriteASSchar + "json": WriteJSON, } if output_format == "all": all_writers = [writer(output_dir) for writer in writers.values()] - def write_all(result: dict, file: TextIO): + def write_all(result: dict, file: TextIO, options: dict): for writer in all_writers: - writer(result, file) + writer(result, file, options) return write_all - if output_format in writers: - return writers[output_format](output_dir) - elif output_format in writers_other: - return writers_other[output_format](output_dir) + return writers[output_format](output_dir) + +def interpolate_nans(x, method='nearest'): + if x.notnull().sum() > 1: + return x.interpolate(method=method).ffill().bfill() else: - raise ValueError(f"Output format '{output_format}' not supported, choose from {writers.keys()} and {writers_other.keys()}") + return x.ffill().bfill() \ No newline at end of file diff --git a/whisperx/vad.py b/whisperx/vad.py index 933d270..42b0bfb 100644 --- a/whisperx/vad.py +++ b/whisperx/vad.py @@ -1,22 +1,23 @@ +import hashlib import os import urllib -import pandas as pd +from typing import Callable, Optional, Text, Union + import numpy as np +import pandas as pd import torch -import hashlib -from tqdm import tqdm -from typing import Optional, Callable, Union, Text -from pyannote.audio.core.io import AudioFile -from pyannote.core import Annotation, Segment, SlidingWindowFeature -from pyannote.audio.pipelines.utils import PipelineModel from pyannote.audio import Model +from pyannote.audio.core.io import AudioFile from pyannote.audio.pipelines import VoiceActivityDetection +from pyannote.audio.pipelines.utils import PipelineModel +from pyannote.core import Annotation, Segment, SlidingWindowFeature +from tqdm import tqdm + from .diarize import Segment as SegmentX -from typing import List, Tuple, Optional VAD_SEGMENTATION_URL = "https://whisperx.s3.eu-west-2.amazonaws.com/model_weights/segmentation/0b5b3216d60a2d32fc086b47ea8c67589aaeb26b7e07fcbe620d6d0b83e209ea/pytorch_model.bin" -def load_vad_model(device, vad_onset, vad_offset, use_auth_token=None, model_fp=None): +def load_vad_model(device, vad_onset=0.500, vad_offset=0.363, use_auth_token=None, model_fp=None): model_dir = torch.hub._get_torch_home() os.makedirs(model_dir, exist_ok = True) if model_fp is None: