From 558d98053541d23d8ccb0ce7c1842c65087c2846 Mon Sep 17 00:00:00 2001
From: Max Bain
Date: Mon, 24 Apr 2023 21:08:43 +0100
Subject: [PATCH 01/20] v3 init
---
README.md | 89 ++--
requirements.txt | 16 +-
setup.py | 2 +-
whisperx/__init__.py | 4 +-
whisperx/alignment.py | 66 ++-
whisperx/asr.py | 789 +++++++++++++++-----------------
whisperx/assets/mel_filters.npz | Bin 0 -> 2048 bytes
whisperx/audio.py | 147 ++++++
whisperx/transcribe.py | 132 +++---
whisperx/utils.py | 616 ++++++++++++++-----------
whisperx/vad.py | 19 +-
11 files changed, 1034 insertions(+), 846 deletions(-)
create mode 100644 whisperx/assets/mel_filters.npz
create mode 100644 whisperx/audio.py
diff --git a/README.md b/README.md
index 2bffa43..b5b95c5 100644
--- a/README.md
+++ b/README.md
@@ -32,7 +32,7 @@
-Whisper-Based Automatic Speech Recognition (ASR) with improved timestamp accuracy using forced alignment.
+
Whisper-Based Automatic Speech Recognition (ASR) with improved timestamp accuracy + quality via forced phoneme alignment and speech-activity batching.
@@ -52,6 +52,7 @@ This repository refines the timestamps of openAI's Whisper model via forced alig
Newπ¨
+- v3 released, 70x speed-up open-sourced. Using batched whisper with [faster-whisper](https://github.com/guillaumekln/faster-whisper) backend!
- v2 released, code cleanup, imports whisper library, batched inference from paper not included (contact for licensing / batched model API). VAD filtering is now turned on by default, as in the paper.
- Paper dropππ¨βπ«! Please see our [ArxiV preprint](https://arxiv.org/abs/2303.00747) for benchmarking and details of WhisperX. We also introduce more efficient batch inference resulting in large-v2 with *60-70x REAL TIME speed (not provided in this repo).
- VAD filtering: Voice Activity Detection (VAD) from [Pyannote.audio](https://huggingface.co/pyannote/voice-activity-detection) is used as a preprocessing step to remove reliance on whisper timestamps and only transcribe audio segments containing speech. add `--vad_filter True` flag, increases timestamp accuracy and robustness (requires more GPU mem due to 30s inputs in wav2vec2)
@@ -60,7 +61,25 @@ This repository refines the timestamps of openAI's Whisper model via forced alig
Setup βοΈ
-Install this package using
+Tested for PyTorch 0.11, Python 3.8 (use other versions at your own risk!)
+
+GPU execution requires the NVIDIA libraries cuBLAS 11.x and cuDNN 8.x to be installed on the system. Please refer to the [CTranslate2 documentation](https://opennmt.net/CTranslate2/installation.html).
+
+
+### 1. Create Python3.8 environment
+
+`conda create --name whisperx python=3.8`
+
+`conda activate whisperx`
+
+
+### 2. Install PyTorch 0.11.0, e.g. for Linux and Windows:
+
+`conda install pytorch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 cudatoolkit=11.3 -c pytorch`
+
+See other methods [here.](https://pytorch.org/get-started/previous-versions/#wheel-4)
+
+### 3. Install this repo
`pip install git+https://github.com/m-bain/whisperx.git`
@@ -78,13 +97,6 @@ $ pip install -e .
You may also need to install ffmpeg, rust etc. Follow openAI instructions here https://github.com/openai/whisper#setup.
-### Setup not working???
-Safest to use install pytorch as follows (for gpu)
-
-`conda install pytorch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 -c pytorch
-`
-
-
### Speaker Diarization
To **enable Speaker. Diarization**, include your Hugging Face access token that you can generate from [Here](https://huggingface.co/settings/tokens) after the `--hf_token` argument and accept the user agreement for the following models: [Segmentation](https://huggingface.co/pyannote/segmentation) , [Voice Activity Detection (VAD)](https://huggingface.co/pyannote/voice-activity-detection) , and [Speaker Diarization](https://huggingface.co/pyannote/speaker-diarization)
@@ -130,14 +142,15 @@ See more examples in other languages [here](EXAMPLES.md).
```python
import whisperx
-import whisper
device = "cuda"
audio_file = "audio.mp3"
# transcribe with original whisper
-model = whisper.load_model("large", device)
-result = model.transcribe(audio_file)
+model = whisperx.load_model("large-v2", device)
+
+audio = whisperx.load_audio(audio_file)
+result = model.transcribe(audio, batch_size=8)
print(result["segments"]) # before alignment
@@ -145,7 +158,7 @@ print(result["segments"]) # before alignment
model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
# align whisper output
-result_aligned = whisperx.align(result["segments"], model_a, metadata, audio_file, device)
+result_aligned = whisperx.align(result["segments"], model_a, metadata, audio, device)
print(result_aligned["segments"]) # after alignment
print(result_aligned["word_segments"]) # after alignment
@@ -186,9 +199,15 @@ The next major upgrade we are working on is whisper with speaker diarization, so
* [x] Incorporating speaker diarization
-* [ ] Automatic .wav conversion to make VAD compatible
+* [x] Model flush, for low gpu mem resources
-* [ ] Model flush, for low gpu mem resources
+* [x] Faster-whisper backend
+
+* [ ] Add benchmarking code (TEDLIUM for spd/WER & word segmentation)
+
+* [ ] Allow silero-vad as alternative VAD option
+
+* [ ] Add max-line etc. see (openai's whisper utils.py)
* [ ] Improve diarization (word level). *Harder than first thought...*
@@ -205,10 +224,13 @@ Contact maxhbain@gmail.com for queries and licensing / early access to a model A
This work, and my PhD, is supported by the [VGG (Visual Geometry Group)](https://www.robots.ox.ac.uk/~vgg/) and the University of Oxford.
-
Of course, this is builds on [openAI's whisper](https://github.com/openai/whisper).
And borrows important alignment code from [PyTorch tutorial on forced alignment](https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html)
+Valuable VAD & Diarization Models from (pyannote.audio)[https://github.com/pyannote/pyannote-audio]
+
+Great backend from (faster-whisper)[https://github.com/guillaumekln/faster-whisper] and (CTranslate2)[https://github.com/OpenNMT/CTranslate2]
+
Citation
If you use this in your research, please cite the paper:
@@ -220,37 +242,4 @@ If you use this in your research, please cite the paper:
journal={arXiv preprint, arXiv:2303.00747},
year={2023}
}
-```
-
-as well the following works, used in each stage of the pipeline:
-
-```bibtex
-@article{radford2022robust,
- title={Robust speech recognition via large-scale weak supervision},
- author={Radford, Alec and Kim, Jong Wook and Xu, Tao and Brockman, Greg and McLeavey, Christine and Sutskever, Ilya},
- journal={arXiv preprint arXiv:2212.04356},
- year={2022}
-}
-```
-
-```bibtex
-@article{baevski2020wav2vec,
- title={wav2vec 2.0: A framework for self-supervised learning of speech representations},
- author={Baevski, Alexei and Zhou, Yuhao and Mohamed, Abdelrahman and Auli, Michael},
- journal={Advances in neural information processing systems},
- volume={33},
- pages={12449--12460},
- year={2020}
-}
-```
-
-```bibtex
-@inproceedings{bredin2020pyannote,
- title={Pyannote. audio: neural building blocks for speaker diarization},
- author={Bredin, Herv{\'e} and Yin, Ruiqing and Coria, Juan Manuel and Gelly, Gregory and Korshunov, Pavel and Lavechin, Marvin and Fustes, Diego and Titeux, Hadrien and Bouaziz, Wassim and Gill, Marie-Philippe},
- booktitle={ICASSP 2020-2020 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
- pages={7124--7128},
- year={2020},
- organization={IEEE}
-}
-```
+```
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
index 139ee56..2747e2d 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,10 +1,8 @@
-numpy
-pandas
-torch >=1.9
-torchaudio >=0.10,<1.0
-tqdm
-more-itertools
-transformers>=4.19.0
-ffmpeg-python==0.2.0
+torch==1.11.0
+torchaudio==0.11.0
pyannote.audio
-openai-whisper==20230314
+faster-whisper
+transformers
+ffmpeg-python==0.2.0
+pandas
+setuptools==65.6.3
\ No newline at end of file
diff --git a/setup.py b/setup.py
index d6472e1..2060d42 100644
--- a/setup.py
+++ b/setup.py
@@ -6,7 +6,7 @@ from setuptools import setup, find_packages
setup(
name="whisperx",
py_modules=["whisperx"],
- version="2.0",
+ version="3.0.0",
description="Time-Accurate Automatic Speech Recognition using Whisper.",
readme="README.md",
python_requires=">=3.8",
diff --git a/whisperx/__init__.py b/whisperx/__init__.py
index 985ed32..d0294b9 100644
--- a/whisperx/__init__.py
+++ b/whisperx/__init__.py
@@ -1,3 +1,3 @@
-from .transcribe import transcribe, transcribe_with_vad
+from .transcribe import load_model
from .alignment import load_align_model, align
-from .vad import load_vad_model
\ No newline at end of file
+from .audio import load_audio
\ No newline at end of file
diff --git a/whisperx/alignment.py b/whisperx/alignment.py
index c15310b..09f044f 100644
--- a/whisperx/alignment.py
+++ b/whisperx/alignment.py
@@ -2,16 +2,17 @@
Forced Alignment with Whisper
C. Max Bain
"""
+from dataclasses import dataclass
+from typing import Iterator, Union
+
import numpy as np
import pandas as pd
-from typing import List, Union, Iterator, TYPE_CHECKING
-from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
-import torchaudio
import torch
-from dataclasses import dataclass
-from whisper.audio import SAMPLE_RATE, load_audio
-from .utils import interpolate_nans
+import torchaudio
+from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
+from .audio import SAMPLE_RATE, load_audio
+from .utils import interpolate_nans
LANGUAGES_WITHOUT_SPACES = ["ja", "zh"]
@@ -391,34 +392,42 @@ def align(
if 'level_1' in cseg: del cseg['level_1']
if 'level_0' in cseg: del cseg['level_0']
cseg.reset_index(inplace=True)
- aligned_segments.append(
- {
- "start": srow["start"],
- "end": srow["end"],
- "text": text,
- "word-segments": wseg,
- "char-segments": cseg
- }
- )
def get_raw_text(word_row):
return seg["seg-text"][word_row.name][int(word_row["segment-text-start"]):int(word_row["segment-text-end"])+1]
+ word_list = []
wdx = 0
curr_text = get_raw_text(wseg.iloc[wdx])
+ if not curr_text.startswith(" "):
+ curr_text = " " + curr_text
+
if len(wseg) > 1:
for _, wrow in wseg.iloc[1:].iterrows():
if wrow['start'] != wseg.iloc[wdx]['start']:
+ word_start = wseg.iloc[wdx]['start']
+ word_end = wseg.iloc[wdx]['end']
+
aligned_segments_word.append(
{
"text": curr_text.strip(),
- "start": wseg.iloc[wdx]["start"],
- "end": wseg.iloc[wdx]["end"],
+ "start": word_start,
+ "end": word_end
}
)
- curr_text = ""
- curr_text += " " + get_raw_text(wrow)
+
+ word_list.append(
+ {
+ "word": curr_text.rstrip(),
+ "start": word_start,
+ "end": word_end,
+ }
+ )
+
+ curr_text = " "
+ curr_text += get_raw_text(wrow) + " "
wdx += 1
+
aligned_segments_word.append(
{
"text": curr_text.strip(),
@@ -427,6 +436,25 @@ def align(
}
)
+ word_list.append(
+ {
+ "word": curr_text.rstrip(),
+ "start": word_start,
+ "end": word_end,
+ }
+ )
+
+ aligned_segments.append(
+ {
+ "start": srow["start"],
+ "end": srow["end"],
+ "text": text,
+ "words": word_list,
+ # "word-segments": wseg,
+ # "char-segments": cseg
+ }
+ )
+
return {"segments": aligned_segments, "word_segments": aligned_segments_word}
diff --git a/whisperx/asr.py b/whisperx/asr.py
index e78d77c..d9cbff0 100644
--- a/whisperx/asr.py
+++ b/whisperx/asr.py
@@ -1,433 +1,406 @@
+import os
import warnings
-from typing import TYPE_CHECKING, Optional, Tuple, Union
+from typing import List, Union
+
+import ctranslate2
+import faster_whisper
import numpy as np
import torch
-import tqdm
-import ffmpeg
-from whisper.audio import (
- FRAMES_PER_SECOND,
- HOP_LENGTH,
- N_FRAMES,
- N_SAMPLES,
- SAMPLE_RATE,
- CHUNK_LENGTH,
- log_mel_spectrogram,
- pad_or_trim,
- load_audio
-)
-from whisper.decoding import DecodingOptions, DecodingResult
-from whisper.timing import add_word_timestamps
-from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
-from whisper.utils import (
- exact_div,
- format_timestamp,
- make_safe,
-)
+from transformers import Pipeline
+from transformers.pipelines.pt_utils import PipelineIterator
-if TYPE_CHECKING:
- from whisper.model import Whisper
+from .audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram
+from .vad import load_vad_model, merge_chunks
-from .vad import merge_chunks
-def transcribe(
- model: "Whisper",
- audio: Union[str, np.ndarray, torch.Tensor] = None,
- mel: np.ndarray = None,
- verbose: Optional[bool] = None,
- temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
- compression_ratio_threshold: Optional[float] = 2.4,
- logprob_threshold: Optional[float] = -1.0,
- no_speech_threshold: Optional[float] = 0.6,
- condition_on_previous_text: bool = True,
- initial_prompt: Optional[str] = None,
- word_timestamps: bool = False,
- prepend_punctuations: str = "\"'βΒΏ([{-",
- append_punctuations: str = "\"'.γ,οΌ!οΌ?οΌ:οΌβ)]}γ",
- **decode_options,
-):
- """
- Transcribe an audio file using Whisper.
- We redefine the Whisper transcribe function to allow mel input (for sequential slicing of audio)
+def load_model(whisper_arch, device, compute_type="float16", asr_options=None, language=None,
+ vad_options=None, model=None):
+ '''Load a Whisper model for inference.
+ Args:
+ whisper_arch: str - The name of the Whisper model to load.
+ device: str - The device to load the model on.
+ compute_type: str - The compute type to use for the model.
+ options: dict - A dictionary of options to use for the model.
+ language: str - The language of the model. (use English for now)
+ Returns:
+ A Whisper pipeline.
+ '''
- Parameters
- ----------
- model: Whisper
- The Whisper model instance
+ if whisper_arch.endswith(".en"):
+ language = "en"
- audio: Union[str, np.ndarray, torch.Tensor]
- The path to the audio file to open, or the audio waveform
-
- mel: np.ndarray
- Mel spectrogram of audio segment.
-
- verbose: bool
- Whether to display the text being decoded to the console. If True, displays all the details,
- If False, displays minimal details. If None, does not display anything
-
- temperature: Union[float, Tuple[float, ...]]
- Temperature for sampling. It can be a tuple of temperatures, which will be successively used
- upon failures according to either `compression_ratio_threshold` or `logprob_threshold`.
-
- compression_ratio_threshold: float
- If the gzip compression ratio is above this value, treat as failed
-
- logprob_threshold: float
- If the average log probability over sampled tokens is below this value, treat as failed
-
- no_speech_threshold: float
- If the no_speech probability is higher than this value AND the average log probability
- over sampled tokens is below `logprob_threshold`, consider the segment as silent
-
- condition_on_previous_text: bool
- if True, the previous output of the model is provided as a prompt for the next window;
- disabling may make the text inconsistent across windows, but the model becomes less prone to
- getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
-
- word_timestamps: bool
- Extract word-level timestamps using the cross-attention pattern and dynamic time warping,
- and include the timestamps for each word in each segment.
-
- prepend_punctuations: str
- If word_timestamps is True, merge these punctuation symbols with the next word
-
- append_punctuations: str
- If word_timestamps is True, merge these punctuation symbols with the previous word
-
- initial_prompt: Optional[str]
- Optional text to provide as a prompt for the first window. This can be used to provide, or
- "prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
- to make it more likely to predict those word correctly.
-
- decode_options: dict
- Keyword arguments to construct `DecodingOptions` instances
-
- Returns
- -------
- A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
- the spoken language ("language"), which is detected when `decode_options["language"]` is None.
- """
- dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32
- if model.device == torch.device("cpu"):
- if torch.cuda.is_available():
- warnings.warn("Performing inference on CPU when CUDA is available")
- if dtype == torch.float16:
- warnings.warn("FP16 is not supported on CPU; using FP32 instead")
- dtype = torch.float32
-
- if dtype == torch.float32:
- decode_options["fp16"] = False
-
- # Pad 30-seconds of silence to the input audio, for slicing
- if mel is None:
- if audio is None:
- raise ValueError("Transcribe needs either audio or mel as input, currently both are none.")
- mel = log_mel_spectrogram(audio, padding=N_SAMPLES)
- content_frames = mel.shape[-1] - N_FRAMES
-
- if decode_options.get("language", None) is None:
- if not model.is_multilingual:
- decode_options["language"] = "en"
- else:
- if verbose:
- print(
- "Detecting language using up to the first 30 seconds. Use `--language` to specify the language"
- )
- mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
- _, probs = model.detect_language(mel_segment)
- decode_options["language"] = max(probs, key=probs.get)
- if verbose is not None:
- print(
- f"Detected language: {LANGUAGES[decode_options['language']].title()}"
- )
-
- language: str = decode_options["language"]
- task: str = decode_options.get("task", "transcribe")
- tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task)
-
- if word_timestamps and task == "translate":
- warnings.warn("Word-level timestamps on translations may not be reliable.")
-
- def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
- temperatures = (
- [temperature] if isinstance(temperature, (int, float)) else temperature
- )
- decode_result = None
-
- for t in temperatures:
- kwargs = {**decode_options}
- if t > 0:
- # disable beam_size and patience when t > 0
- kwargs.pop("beam_size", None)
- kwargs.pop("patience", None)
- else:
- # disable best_of when t == 0
- kwargs.pop("best_of", None)
-
- options = DecodingOptions(**kwargs, temperature=t)
- decode_result = model.decode(segment, options)
-
- needs_fallback = False
- if (
- compression_ratio_threshold is not None
- and decode_result.compression_ratio > compression_ratio_threshold
- ):
- needs_fallback = True # too repetitive
- if (
- logprob_threshold is not None
- and decode_result.avg_logprob < logprob_threshold
- ):
- needs_fallback = True # average log probability is too low
-
- if not needs_fallback:
- break
-
- return decode_result
-
- seek = 0
- input_stride = exact_div(
- N_FRAMES, model.dims.n_audio_ctx
- ) # mel frames per output token: 2
- time_precision = (
- input_stride * HOP_LENGTH / SAMPLE_RATE
- ) # time per output token: 0.02 (seconds)
- all_tokens = []
- all_segments = []
- prompt_reset_since = 0
-
- if initial_prompt is not None:
- initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip())
- all_tokens.extend(initial_prompt_tokens)
+ model = WhisperModel(whisper_arch, device=device, compute_type=compute_type)
+ if language is not None:
+ tokenizer = faster_whisper.tokenizer.Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task="transcribe", language=language)
else:
- initial_prompt_tokens = []
+ print("No language specified, language will be first be detected for each audio file (increases inference time).")
+ tokenizer = None
- def new_segment(
- *, start: float, end: float, tokens: torch.Tensor, result: DecodingResult
+ default_asr_options = {
+ "beam_size": 5,
+ "best_of": 5,
+ "patience": 1,
+ "length_penalty": 1,
+ "temperatures": [0.0, 0.2, 0.4, 0.6, 0.8, 1.0],
+ "compression_ratio_threshold": 2.4,
+ "log_prob_threshold": -1.0,
+ "no_speech_threshold": 0.6,
+ "condition_on_previous_text": False,
+ "initial_prompt": None,
+ "prefix": None,
+ "suppress_blank": True,
+ "suppress_tokens": [-1],
+ "without_timestamps": True,
+ "max_initial_timestamp": 0.0,
+ "word_timestamps": False,
+ "prepend_punctuations": "\"'βΒΏ([{-",
+ "append_punctuations": "\"'.γ,οΌ!οΌ?οΌ:οΌβ)]}γ"
+ }
+
+ if asr_options is not None:
+ default_asr_options.update(asr_options)
+ default_asr_options = faster_whisper.transcribe.TranscriptionOptions(**default_asr_options)
+
+ default_vad_options = {
+ "vad_onset": 0.500,
+ "vad_offset": 0.363
+ }
+
+ if vad_options is not None:
+ default_vad_options.update(vad_options)
+
+ vad_model = load_vad_model(torch.device(device), use_auth_token=None, **default_vad_options)
+
+ return FasterWhisperPipeline(model, vad_model, default_asr_options, tokenizer)
+
+
+
+class WhisperModel(faster_whisper.WhisperModel):
+ '''
+ FasterWhisperModel provides batched inference for faster-whisper.
+ Currently only works in non-timestamp mode.
+ '''
+
+ def generate_segment_batched(self, features: np.ndarray, tokenizer: faster_whisper.tokenizer.Tokenizer, options: faster_whisper.transcribe.TranscriptionOptions, encoder_output = None):
+ batch_size = features.shape[0]
+ all_tokens = []
+ prompt_reset_since = 0
+ if options.initial_prompt is not None:
+ initial_prompt = " " + options.initial_prompt.strip()
+ initial_prompt_tokens = tokenizer.encode(initial_prompt)
+ all_tokens.extend(initial_prompt_tokens)
+ previous_tokens = all_tokens[prompt_reset_since:]
+ prompt = self.get_prompt(
+ tokenizer,
+ previous_tokens,
+ without_timestamps=options.without_timestamps,
+ prefix=options.prefix,
+ )
+
+ encoder_output = self.encode(features)
+
+ max_initial_timestamp_index = int(
+ round(options.max_initial_timestamp / self.time_precision)
+ )
+
+ result = self.model.generate(
+ encoder_output,
+ [prompt] * batch_size,
+ # length_penalty=options.length_penalty,
+ # max_length=self.max_length,
+ # return_scores=True,
+ # return_no_speech_prob=True,
+ # suppress_blank=options.suppress_blank,
+ # suppress_tokens=options.suppress_tokens,
+ # max_initial_timestamp_index=max_initial_timestamp_index,
+ )
+
+ tokens_batch = [x.sequences_ids[0] for x in result]
+
+ def decode_batch(tokens: List[List[int]]) -> str:
+ res = []
+ for tk in tokens:
+ res.append([token for token in tk if token < tokenizer.eot])
+ # text_tokens = [token for token in tokens if token < self.eot]
+ return tokenizer.tokenizer.decode_batch(res)
+
+ text = decode_batch(tokens_batch)
+
+ return text
+
+ def encode(self, features: np.ndarray) -> ctranslate2.StorageView:
+ # When the model is running on multiple GPUs, the encoder output should be moved
+ # to the CPU since we don't know which GPU will handle the next job.
+ to_cpu = self.model.device == "cuda" and len(self.model.device_index) > 1
+ # unsqueeze if batch size = 1
+ if len(features.shape) == 2:
+ features = np.expand_dims(features, 0)
+ features = faster_whisper.transcribe.get_ctranslate2_storage(features)
+
+ return self.model.encode(features, to_cpu=to_cpu)
+
+class FasterWhisperPipeline(Pipeline):
+ def __init__(
+ self,
+ model,
+ vad,
+ options,
+ tokenizer=None,
+ device: Union[int, str, "torch.device"] = -1,
+ framework = "pt",
+ **kwargs
):
- tokens = tokens.tolist()
- text_tokens = [token for token in tokens if token < tokenizer.eot]
- return {
- "seek": seek,
- "start": start,
- "end": end,
- "text": tokenizer.decode(text_tokens),
- "tokens": tokens,
- "temperature": result.temperature,
- "avg_logprob": result.avg_logprob,
- "compression_ratio": result.compression_ratio,
- "no_speech_prob": result.no_speech_prob,
- }
-
-
- # show the progress bar when verbose is False (if True, transcribed text will be printed)
- with tqdm.tqdm(
- total=content_frames, unit="frames", disable=verbose is not False
- ) as pbar:
- while seek < content_frames:
- time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
- mel_segment = mel[:, seek : seek + N_FRAMES]
- segment_size = min(N_FRAMES, content_frames - seek)
- segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
- mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)
-
- decode_options["prompt"] = all_tokens[prompt_reset_since:]
- result: DecodingResult = decode_with_fallback(mel_segment)
- tokens = torch.tensor(result.tokens)
- if no_speech_threshold is not None:
- # no voice activity check
- should_skip = result.no_speech_prob > no_speech_threshold
- if (
- logprob_threshold is not None
- and result.avg_logprob > logprob_threshold
- ):
- # don't skip if the logprob is high enough, despite the no_speech_prob
- should_skip = False
-
- if should_skip:
- seek += segment_size # fast-forward to the next segment boundary
- continue
-
- previous_seek = seek
- current_segments = []
-
- timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
- single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
-
- consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
- consecutive.add_(1)
- if len(consecutive) > 0:
- # if the output contains two consecutive timestamp tokens
- slices = consecutive.tolist()
- if single_timestamp_ending:
- slices.append(len(tokens))
-
- last_slice = 0
- for current_slice in slices:
- sliced_tokens = tokens[last_slice:current_slice]
- start_timestamp_pos = (
- sliced_tokens[0].item() - tokenizer.timestamp_begin
- )
- end_timestamp_pos = (
- sliced_tokens[-1].item() - tokenizer.timestamp_begin
- )
-
- # clamp end-time to at least be 1 frame after start-time
- end_timestamp_pos = max(end_timestamp_pos, start_timestamp_pos + time_precision)
-
- current_segments.append(
- new_segment(
- start=time_offset + start_timestamp_pos * time_precision,
- end=time_offset + end_timestamp_pos * time_precision,
- tokens=sliced_tokens,
- result=result,
- )
- )
- last_slice = current_slice
-
- if single_timestamp_ending:
- # single timestamp at the end means no speech after the last timestamp.
- seek += segment_size
- else:
- # otherwise, ignore the unfinished segment and seek to the last timestamp
- last_timestamp_pos = (
- tokens[last_slice - 1].item() - tokenizer.timestamp_begin
- )
- seek += last_timestamp_pos * input_stride
+ self.model = model
+ self.tokenizer = tokenizer
+ self.options = options
+ self._batch_size = kwargs.pop("batch_size", None)
+ self._num_workers = 1
+ self._preprocess_params, self._forward_params, self._postprocess_params = self._sanitize_parameters(**kwargs)
+ self.call_count = 0
+ self.framework = framework
+ if self.framework == "pt":
+ if isinstance(device, torch.device):
+ self.device = device
+ elif isinstance(device, str):
+ self.device = torch.device(device)
+ elif device < 0:
+ self.device = torch.device("cpu")
else:
- duration = segment_duration
- timestamps = tokens[timestamp_tokens.nonzero().flatten()]
- if (
- len(timestamps) > 0
- and timestamps[-1].item() != tokenizer.timestamp_begin
- ):
- # no consecutive timestamps but it has a timestamp; use the last one.
- last_timestamp_pos = (
- timestamps[-1].item() - tokenizer.timestamp_begin
- )
- duration = last_timestamp_pos * time_precision
+ self.device = torch.device(f"cuda:{device}")
+ else:
+ self.device = device
+
+ super(Pipeline, self).__init__()
+ self.vad_model = vad
- current_segments.append(
- new_segment(
- start=time_offset,
- end=time_offset + duration,
- tokens=tokens,
- result=result,
- )
- )
- seek += segment_size
+ def _sanitize_parameters(self, **kwargs):
+ preprocess_kwargs = {}
+ if "tokenizer" in kwargs:
+ preprocess_kwargs["maybe_arg"] = kwargs["maybe_arg"]
+ return preprocess_kwargs, {}, {}
- if not condition_on_previous_text or result.temperature > 0.5:
- # do not feed the prompt tokens if a high temperature was used
- prompt_reset_since = len(all_tokens)
+ def preprocess(self, audio):
+ audio = audio['inputs']
+ features = log_mel_spectrogram(audio, padding=N_SAMPLES - audio.shape[0])
+ return {'inputs': features}
- if word_timestamps:
- add_word_timestamps(
- segments=current_segments,
- model=model,
- tokenizer=tokenizer,
- mel=mel_segment,
- num_frames=segment_size,
- prepend_punctuations=prepend_punctuations,
- append_punctuations=append_punctuations,
- )
- word_end_timestamps = [
- w["end"] for s in current_segments for w in s["words"]
- ]
- if not single_timestamp_ending and len(word_end_timestamps) > 0:
- seek_shift = round(
- (word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND
- )
- if seek_shift > 0:
- seek = previous_seek + seek_shift
+ def _forward(self, model_inputs):
+ outputs = self.model.generate_segment_batched(model_inputs['inputs'], self.tokenizer, self.options)
+ return {'text': outputs}
+
+ def postprocess(self, model_outputs):
+ return model_outputs
- if verbose:
- for segment in current_segments:
- start, end, text = segment["start"], segment["end"], segment["text"]
- line = f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}"
- print(make_safe(line))
+ def get_iterator(
+ self, inputs, num_workers: int, batch_size: int, preprocess_params, forward_params, postprocess_params
+ ):
+ dataset = PipelineIterator(inputs, self.preprocess, preprocess_params)
+ if "TOKENIZERS_PARALLELISM" not in os.environ:
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
+ # TODO hack by collating feature_extractor and image_processor
- # if a segment is instantaneous or does not contain text, clear it
- for i, segment in enumerate(current_segments):
- if segment["start"] == segment["end"] or segment["text"].strip() == "":
- segment["text"] = ""
- segment["tokens"] = []
- segment["words"] = []
+ def stack(items):
+ return {'inputs': torch.stack([x['inputs'] for x in items])}
+ dataloader = torch.utils.data.DataLoader(dataset, num_workers=num_workers, batch_size=batch_size, collate_fn=stack)
+ model_iterator = PipelineIterator(dataloader, self.forward, forward_params, loader_batch_size=batch_size)
+ final_iterator = PipelineIterator(model_iterator, self.postprocess, postprocess_params)
+ return final_iterator
- all_segments.extend(
- [
- {"id": i, **segment}
- for i, segment in enumerate(
- current_segments, start=len(all_segments)
- )
- ]
- )
- all_tokens.extend(
- [token for segment in current_segments for token in segment["tokens"]]
- )
+ def transcribe(
+ self, audio: Union[str, np.ndarray], batch_size=None
+ ):
+ if isinstance(audio, str):
+ audio = load_audio(audio)
+
+ def data(audio, segments):
+ for seg in segments:
+ f1 = int(seg['start'] * SAMPLE_RATE)
+ f2 = int(seg['end'] * SAMPLE_RATE)
+ # print(f2-f1)
+ yield {'inputs': audio[f1:f2]}
- # update progress bar
- pbar.update(min(content_frames, seek) - previous_seek)
+ vad_segments = self.vad_model({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": SAMPLE_RATE})
+ vad_segments = merge_chunks(vad_segments, 30)
+ del_tokenizer = False
+ if self.tokenizer is None:
+ language = self.detect_language(audio)
+ self.tokenizer = faster_whisper.tokenizer.Tokenizer(self.model.hf_tokenizer, self.model.model.is_multilingual, task="transcribe", language=language)
+ del_tokenizer = True
+ else:
+ language = self.tokenizer.language_code
- return dict(
- text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]),
- segments=all_segments,
- language=language,
- )
-
-
-def transcribe_with_vad(
- model: "Whisper",
- audio: str,
- vad_pipeline,
- mel = None,
- verbose: Optional[bool] = None,
- **kwargs
-):
- """
- Transcribe per VAD segment
- """
-
- vad_segments = vad_pipeline(audio)
-
- # if not torch.is_tensor(audio):
- # if isinstance(audio, str):
- audio = load_audio(audio)
- audio = torch.from_numpy(audio)
-
- prev = 0
- output = {"segments": []}
-
- # merge segments to approx 30s inputs to make whisper most appropraite
- vad_segments = merge_chunks(vad_segments, chunk_size=CHUNK_LENGTH)
- if len(vad_segments) == 0:
- return output
-
- print(">>Performing transcription...")
- for sdx, seg_t in enumerate(vad_segments):
- if verbose:
- print(f"~~ Transcribing VAD chunk: ({format_timestamp(seg_t['start'])} --> {format_timestamp(seg_t['end'])}) ~~")
- seg_f_start, seg_f_end = int(seg_t["start"] * SAMPLE_RATE), int(seg_t["end"] * SAMPLE_RATE)
- local_f_start, local_f_end = seg_f_start - prev, seg_f_end - prev
- audio = audio[local_f_start:] # seek forward
- seg_audio = audio[:local_f_end-local_f_start] # seek forward
- prev = seg_f_start
- local_mel = log_mel_spectrogram(seg_audio, padding=N_SAMPLES)
- # need to pad
-
- result = transcribe(model, audio, mel=local_mel, verbose=verbose, **kwargs)
- seg_t["text"] = result["text"]
- output["segments"].append(
- {
- "start": seg_t["start"],
- "end": seg_t["end"],
- "language": result["language"],
- "text": result["text"],
- "seg-text": [x["text"] for x in result["segments"]],
- "seg-start": [x["start"] for x in result["segments"]],
- "seg-end": [x["end"] for x in result["segments"]],
+ segments = []
+ batch_size = batch_size or self._batch_size
+ for idx, out in enumerate(self.__call__(data(audio, vad_segments), batch_size=batch_size)):
+ text = out['text']
+ if batch_size in [0, 1, None]:
+ text = text[0]
+ segments.append(
+ {
+ "text": out['text'],
+ "start": round(vad_segments[idx]['start'], 3),
+ "end": round(vad_segments[idx]['end'], 3)
}
)
+
+ if del_tokenizer:
+ self.tokenizer = None
- output["language"] = output["segments"][0]["language"]
+ return {"segments": segments, "language": language}
- return output
+
+ def detect_language(self, audio: np.ndarray):
+ segment = log_mel_spectrogram(audio[: N_SAMPLES], padding=0)
+ encoder_output = self.model.encode(segment)
+ results = self.model.model.detect_language(encoder_output)
+ language_token, language_probability = results[0][0]
+ language = language_token[2:-2]
+ print(f"Detected language: {language} ({language_probability:.2f}) in first 30s of audio...")
+ return language
+
+if __name__ == "__main__":
+ main_type = "simple"
+ import time
+
+ import jiwer
+ from tqdm import tqdm
+ from whisper.normalizers import EnglishTextNormalizer
+
+ from benchmark.tedlium import parse_tedlium_annos
+
+ if main_type == "complex":
+ from faster_whisper.tokenizer import Tokenizer
+ from faster_whisper.transcribe import TranscriptionOptions
+ from faster_whisper.vad import (SpeechTimestampsMap,
+ get_speech_timestamps)
+
+ from whisperx.vad import load_vad_model, merge_chunks
+
+ from .audio import SAMPLE_RATE, load_audio, log_mel_spectrogram
+ faster_t_options = TranscriptionOptions(
+ beam_size=5,
+ best_of=5,
+ patience=1,
+ length_penalty=1,
+ temperatures=[0.0, 0.2, 0.4, 0.6, 0.8, 1.0],
+ compression_ratio_threshold=2.4,
+ log_prob_threshold=-1.0,
+ no_speech_threshold=0.6,
+ condition_on_previous_text=False,
+ initial_prompt=None,
+ prefix=None,
+ suppress_blank=True,
+ suppress_tokens=[-1],
+ without_timestamps=True,
+ max_initial_timestamp=0.0,
+ word_timestamps=False,
+ prepend_punctuations="\"'βΒΏ([{-",
+ append_punctuations="\"'.γ,οΌ!οΌ?οΌ:οΌβ)]}γ"
+ )
+ whisper_arch = "large-v2"
+ device = "cuda"
+ batch_size = 16
+ model = WhisperModel(whisper_arch, device="cuda", compute_type="float16",)
+ tokenizer = Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task="transcribe", language="en")
+ model = FasterWhisperPipeline(model, tokenizer, faster_t_options, device=-1)
+ fn = "DanielKahneman_2010.wav"
+ wav_dir = f"/tmp/test/wav/"
+ vad_model = load_vad_model("cuda", 0.6, 0.3)
+ audio = load_audio(os.path.join(wav_dir, fn))
+ vad_segments = vad_model({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": SAMPLE_RATE})
+ vad_segments = merge_chunks(vad_segments, 30)
+
+ def data(audio, segments):
+ for seg in segments:
+ f1 = int(seg['start'] * SAMPLE_RATE)
+ f2 = int(seg['end'] * SAMPLE_RATE)
+ # print(f2-f1)
+ yield {'inputs': audio[f1:f2]}
+ vad_method="pyannote"
+
+ wav_dir = f"/tmp/test/wav/"
+ wer_li = []
+ time_li = []
+ for fn in os.listdir(wav_dir):
+ if fn == "RobertGupta_2010U.wav":
+ continue
+ base_fn = fn.split('.')[0]
+ audio_fp = os.path.join(wav_dir, fn)
+
+ audio = load_audio(audio_fp)
+ t1 = time.time()
+ if vad_method == "pyannote":
+ vad_segments = vad_model({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": SAMPLE_RATE})
+ vad_segments = merge_chunks(vad_segments, 30)
+ elif vad_method == "silero":
+ vad_segments = get_speech_timestamps(audio, threshold=0.5, max_speech_duration_s=30)
+ vad_segments = [{"start": x["start"] / SAMPLE_RATE, "end": x["end"] / SAMPLE_RATE} for x in vad_segments]
+ new_segs = []
+ curr_start = vad_segments[0]['start']
+ curr_end = vad_segments[0]['end']
+ for seg in vad_segments[1:]:
+ if seg['end'] - curr_start > 30:
+ new_segs.append({"start": curr_start, "end": curr_end})
+ curr_start = seg['start']
+ curr_end = seg['end']
+ else:
+ curr_end = seg['end']
+ new_segs.append({"start": curr_start, "end": curr_end})
+ vad_segments = new_segs
+ text = []
+ # for idx, out in tqdm(enumerate(model(data(audio_fp, vad_segments), batch_size=batch_size)), total=len(vad_segments)):
+ for idx, out in enumerate(model(data(audio, vad_segments), batch_size=batch_size)):
+ text.append(out['text'])
+ t2 = time.time()
+ if batch_size == 1:
+ text = [x[0] for x in text]
+ text = " ".join(text)
+
+ normalizer = EnglishTextNormalizer()
+ text = normalizer(text)
+ gt_corpus = normalizer(parse_tedlium_annos(base_fn, "/tmp/test/"))
+
+ wer_result = jiwer.wer(gt_corpus, text)
+ print("WER: %.2f \t time: %.2f \t [%s]" % (wer_result * 100, t2-t1, fn))
+
+ wer_li.append(wer_result)
+ time_li.append(t2-t1)
+ print("# Avg Mean...")
+ print("WER: %.2f" % (sum(wer_li) * 100/len(wer_li)))
+ print("Time: %.2f" % (sum(time_li)/len(time_li)))
+ elif main_type == "simple":
+ model = load_model(
+ "large-v2",
+ device="cuda",
+ language="en",
+ )
+
+ wav_dir = f"/tmp/test/wav/"
+ wer_li = []
+ time_li = []
+ for fn in os.listdir(wav_dir):
+ if fn == "RobertGupta_2010U.wav":
+ continue
+ # fn = "DanielKahneman_2010.wav"
+ base_fn = fn.split('.')[0]
+ audio_fp = os.path.join(wav_dir, fn)
+
+ audio = load_audio(audio_fp)
+ t1 = time.time()
+ out = model.transcribe(audio_fp, batch_size=8)["segments"]
+ t2 = time.time()
+
+ text = " ".join([x['text'] for x in out])
+ normalizer = EnglishTextNormalizer()
+ text = normalizer(text)
+ gt_corpus = normalizer(parse_tedlium_annos(base_fn, "/tmp/test/"))
+
+ wer_result = jiwer.wer(gt_corpus, text)
+ print("WER: %.2f \t time: %.2f \t [%s]" % (wer_result * 100, t2-t1, fn))
+
+ wer_li.append(wer_result)
+ time_li.append(t2-t1)
+ print("# Avg Mean...")
+ print("WER: %.2f" % (sum(wer_li) * 100/len(wer_li)))
+ print("Time: %.2f" % (sum(time_li)/len(time_li)))
diff --git a/whisperx/assets/mel_filters.npz b/whisperx/assets/mel_filters.npz
new file mode 100644
index 0000000000000000000000000000000000000000..1a7839244dfb6b1cc02e4f3cfe12e4817a073bc7
GIT binary patch
literal 2048
zcmZ{ldpr|t8^`A`DO<+otT3lH(m_`ITt
zB|;Oynx)HJT`(1EF>38(a0=aik_wy4J{lY+B#Bb-2ffpu*yJ{j!qm{t
z^xMtm2Y$T#4*xq|STK_Bn9hLgL(WcH=%@WYxHDaPwEEIhuyv@h^Db}I$OqvWOB2NL
z9X<7CRj&D~xo5)OmJz)I%U
zkqrV5CK)ikMW)(OKUdr{dmUe=eJMloLF>s8*;&VMlJ*DNvtmAby$w&owv{^-#*jy_
zl48d&xIAmG7BO0JHJ6_M%wWjAKQ9ypKNO3a<)zkCRVOz1iaFi!K=wf)9!9o+0|$1c
zLNd9&GovRa-1K^bDWjo5Dn_kBA`9tBg_Y}gVeCAO8KzK_Cv_D`o}#u!WyDmupof?h
zX9Ts&DRcf=GgUf)0;da`PprFOaNuv`+HBl{W46HWXpIW|5ardLmBl!}jWJ&>pXk2Z
zEXcf7f+vz9U|i^HtnIZjOcAXeE^!H3#J%LlrB?=Y^XYydKfT_mIl5a(WmffC`T5Dw
z2k*V9tMN!{aOt1-szXI+_ZYsEm?>V5f@jBm-_NiHO=A^|3m0jtgCZ>bRK;-Hk;KlD
zcgJfnB`#hSs5EYbzD;Ev$fCW`*lv@ZqJ#+ERE@NX!x(kdZWmCLuI?E30ZkF~di6`A
z;Hkg*C%W^V#PDvPF)R0N9m0G`^WH(Cc3P0|y*{1^LEHP0V}cnm96Vhuh1EJQOAK+j
z%!rc*`!I$SJZzYR1D7>NHr+P`2)Fn!UE|X0NsT|r9J&&sDb73hQzG;pXV~xMh(KQJ
zB+KRgBh!VXat#;hT54}r`bteh@?*pr{v2nwp5U1z*>Sq^9WQ^&nJum?5a<<8cne%o
z>rA+=7j|rtqsY4}(3x8JwKeyL$l+L72-7-I_K%j~D<2>TEAB~F!u;T#<;YR_@rOE(
z!=j~8p0Sn;{|^Dwe2>Ujx7c593PrgxbKban*_`uC=DE9>wa`J32eh!h-F8fg!Ee2Y
z_>tB#ZhIL-4ASQ
zB_h=;c&TlyPo>3%-QpA=6JHBuwY4MRQz}uMhe=kv`AZHbQd+mO+2;ko0==-HW)(FL
z@;D^Xl!8Ya?@Urcsb{>=3^w9Scux7j?UxMKY?yWnqk(t}OKJ^>|JSv2%{3x@zj0bebHU1c1LTFejo6z~93MmrCb
z57vI#AdVaN)R)8+_`EMG_l~ufd4pAKSx)+|`*l54nZDrk#2(pz@d%Y$hb~5sLAXbL
zbwW~-x8cuD@8P!$C_v&+8s?}fcHTTq>6I$vqaqE>0&SeTIR_TnBdH<9%IOy_(GK)D
zw%VM%TO7@|>|oq-j=ToC5${4Ms5;y{yfjmJY}L1?i*jOqu=6WF5)6;%fKc)@WtU5L
z%FFxRDH}ak_2iLEazKX~Crl-DcaIiaTaI3kE>dl|a^eBrPWrRwOjEDXZlH!Qn>%|4
zF7h5!&)J`B_%p=8`33K}xkH)DU8u^WF;B|X(~LE6*jx7f7RFZj_LCrndK%b>g}+8=
zSTWHyV%k)2YLpQ9|9RGd`ZLtXZwnLZ837ZVN&qWIuN0>Y7zK
zkL-B_&B2iQ8d^Q!^@RmAMMEO^VwiKi6n{=kvWcU;PDe+!o(5+gL$%o31
z5HGnVUbofaprVI<-&15~ZJ5PmVr$CDzsqWss5W2vzM>uKBj3HQT1AD_GmDL7n{@kWa8>c3U|
EFD(_kqyPW_
literal 0
HcmV?d00001
diff --git a/whisperx/audio.py b/whisperx/audio.py
new file mode 100644
index 0000000..513ab7c
--- /dev/null
+++ b/whisperx/audio.py
@@ -0,0 +1,147 @@
+import os
+from functools import lru_cache
+from typing import Optional, Union
+
+import ffmpeg
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+from .utils import exact_div
+
+# hard-coded audio hyperparameters
+SAMPLE_RATE = 16000
+N_FFT = 400
+N_MELS = 80
+HOP_LENGTH = 160
+CHUNK_LENGTH = 30
+N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
+N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000 frames in a mel spectrogram input
+
+N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2
+FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame
+TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token
+
+
+def load_audio(file: str, sr: int = SAMPLE_RATE):
+ """
+ Open an audio file and read as mono waveform, resampling as necessary
+
+ Parameters
+ ----------
+ file: str
+ The audio file to open
+
+ sr: int
+ The sample rate to resample the audio if necessary
+
+ Returns
+ -------
+ A NumPy array containing the audio waveform, in float32 dtype.
+ """
+ try:
+ # This launches a subprocess to decode audio while down-mixing and resampling as necessary.
+ # Requires the ffmpeg CLI and `ffmpeg-python` package to be installed.
+ out, _ = (
+ ffmpeg.input(file, threads=0)
+ .output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=sr)
+ .run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True)
+ )
+ except ffmpeg.Error as e:
+ raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
+
+ return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
+
+
+def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
+ """
+ Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
+ """
+ if torch.is_tensor(array):
+ if array.shape[axis] > length:
+ array = array.index_select(
+ dim=axis, index=torch.arange(length, device=array.device)
+ )
+
+ if array.shape[axis] < length:
+ pad_widths = [(0, 0)] * array.ndim
+ pad_widths[axis] = (0, length - array.shape[axis])
+ array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes])
+ else:
+ if array.shape[axis] > length:
+ array = array.take(indices=range(length), axis=axis)
+
+ if array.shape[axis] < length:
+ pad_widths = [(0, 0)] * array.ndim
+ pad_widths[axis] = (0, length - array.shape[axis])
+ array = np.pad(array, pad_widths)
+
+ return array
+
+
+@lru_cache(maxsize=None)
+def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor:
+ """
+ load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
+ Allows decoupling librosa dependency; saved using:
+
+ np.savez_compressed(
+ "mel_filters.npz",
+ mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
+ )
+ """
+ assert n_mels == 80, f"Unsupported n_mels: {n_mels}"
+ with np.load(
+ os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
+ ) as f:
+ return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
+
+
+def log_mel_spectrogram(
+ audio: Union[str, np.ndarray, torch.Tensor],
+ n_mels: int = N_MELS,
+ padding: int = 0,
+ device: Optional[Union[str, torch.device]] = None,
+):
+ """
+ Compute the log-Mel spectrogram of
+
+ Parameters
+ ----------
+ audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
+ The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
+
+ n_mels: int
+ The number of Mel-frequency filters, only 80 is supported
+
+ padding: int
+ Number of zero samples to pad to the right
+
+ device: Optional[Union[str, torch.device]]
+ If given, the audio tensor is moved to this device before STFT
+
+ Returns
+ -------
+ torch.Tensor, shape = (80, n_frames)
+ A Tensor that contains the Mel spectrogram
+ """
+ if not torch.is_tensor(audio):
+ if isinstance(audio, str):
+ audio = load_audio(audio)
+ audio = torch.from_numpy(audio)
+
+ if device is not None:
+ audio = audio.to(device)
+ if padding > 0:
+ audio = F.pad(audio, (0, padding))
+ window = torch.hann_window(N_FFT).to(audio.device)
+ stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
+ magnitudes = stft[..., :-1].abs() ** 2
+
+ filters = mel_filters(audio.device, n_mels)
+ mel_spec = filters @ magnitudes
+
+ log_spec = torch.clamp(mel_spec, min=1e-10).log10()
+ log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
+ log_spec = (log_spec + 4.0) / 4.0
+ return log_spec
diff --git a/whisperx/transcribe.py b/whisperx/transcribe.py
index ed918e0..c3719ba 100644
--- a/whisperx/transcribe.py
+++ b/whisperx/transcribe.py
@@ -1,37 +1,30 @@
import argparse
-import os
import gc
+import os
import warnings
-from typing import TYPE_CHECKING, Optional, Tuple, Union
+
import numpy as np
import torch
-import tempfile
-import ffmpeg
-from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE
-from whisper.audio import SAMPLE_RATE
-from whisper.utils import (
- optional_float,
- optional_int,
- str2bool,
-)
-from .alignment import load_align_model, align
-from .asr import transcribe, transcribe_with_vad
+from .alignment import align, load_align_model
+from .asr import load_model
+from .audio import load_audio
from .diarize import DiarizationPipeline, assign_word_speakers
-from .utils import get_writer
-from .vad import load_vad_model
+from .utils import (LANGUAGES, TO_LANGUAGE_CODE, get_writer, optional_float,
+ optional_int, str2bool)
+
def cli():
- from whisper import available_models
-
# fmt: off
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
- parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use")
+ parser.add_argument("--model", default="small", help="name of the Whisper model to use")
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
+ parser.add_argument("--batch_size", default=8, type=int, help="device to use for PyTorch inference")
+
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
- parser.add_argument("--output_format", "-f", type=str, default="all", choices=["all", "srt", "srt-word", "vtt", "txt", "tsv", "ass", "ass-char", "pickle", "vad"], help="format of the output file; if not specified, all available formats will be produced")
+ parser.add_argument("--output_format", "-f", type=str, default="all", choices=["all", "srt", "vtt", "txt", "tsv", "json"], help="format of the output file; if not specified, all available formats will be produced")
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
@@ -39,13 +32,10 @@ def cli():
# alignment params
parser.add_argument("--align_model", default=None, help="Name of phoneme-level ASR model to do alignment")
- parser.add_argument("--align_extend", default=2, type=float, help="Seconds before and after to extend the whisper segments for alignment (if not using VAD).")
- parser.add_argument("--align_from_prev", default=True, type=bool, help="Whether to clip the alignment start time of current segment to the end time of the last aligned word of the previous segment (if not using VAD)")
parser.add_argument("--interpolate_method", default="nearest", choices=["nearest", "linear", "ignore"], help="For word .srt, method to assign timestamps to non-aligned words, or merge them into neighbouring.")
parser.add_argument("--no_align", action='store_true', help="Do not perform phoneme alignment")
# vad params
- parser.add_argument("--vad_filter", type=str2bool, default=True, help="Whether to pre-segment audio with VAD, highly recommended! Produces more accurate alignment + timestamp see WhisperX paper https://arxiv.org/abs/2303.00747")
parser.add_argument("--vad_onset", type=float, default=0.500, help="Onset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected")
parser.add_argument("--vad_offset", type=float, default=0.363, help="Offset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected.")
@@ -69,9 +59,14 @@ def cli():
parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
- parser.add_argument("--word_timestamps", type=str2bool, default=False, help="(experimental) extract word-level timestamps and refine the results based on them")
- parser.add_argument("--prepend_punctuations", type=str, default="\"\'βΒΏ([{-", help="if word_timestamps is True, merge these punctuation symbols with the next word")
- parser.add_argument("--append_punctuations", type=str, default="\"\'.γ,οΌ!οΌ?οΌ:οΌβ)]}γ", help="if word_timestamps is True, merge these punctuation symbols with the previous word")
+
+ parser.add_argument("--max_line_width", type=optional_int, default=None, help="(not possible with --no_align) the maximum number of characters in a line before breaking the line")
+ parser.add_argument("--max_line_count", type=optional_int, default=None, help="(requires --no_align) the maximum number of lines in a segment")
+ parser.add_argument("--highlight_words", type=str2bool, default=False, help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt")
+
+ # parser.add_argument("--word_timestamps", type=str2bool, default=False, help="(experimental) extract word-level timestamps and refine the results based on them")
+ # parser.add_argument("--prepend_punctuations", type=str, default="\"\'βΒΏ([{-", help="if word_timestamps is True, merge these punctuation symbols with the next word")
+ # parser.add_argument("--append_punctuations", type=str, default="\"\'.γ,οΌ!οΌ?οΌ:οΌβ)]}γ", help="if word_timestamps is True, merge these punctuation symbols with the previous word")
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face Access Token to access PyAnnote gated models")
@@ -81,7 +76,7 @@ def cli():
args = parser.parse_args().__dict__
model_name: str = args.pop("model")
- model_dir: str = args.pop("model_dir")
+ batch_size: int = args.pop("batch_size")
output_dir: str = args.pop("output_dir")
output_format: str = args.pop("output_format")
device: str = args.pop("device")
@@ -93,13 +88,10 @@ def cli():
os.makedirs(tmp_dir, exist_ok=True)
align_model: str = args.pop("align_model")
- align_extend: float = args.pop("align_extend")
- align_from_prev: bool = args.pop("align_from_prev")
interpolate_method: str = args.pop("interpolate_method")
no_align: bool = args.pop("no_align")
hf_token: str = args.pop("hf_token")
- vad_filter: bool = args.pop("vad_filter")
vad_onset: float = args.pop("vad_onset")
vad_offset: float = args.pop("vad_offset")
@@ -107,18 +99,7 @@ def cli():
min_speakers: int = args.pop("min_speakers")
max_speakers: int = args.pop("max_speakers")
- if vad_filter:
- from pyannote.audio import Pipeline
- from pyannote.audio import Model, Pipeline
- vad_model = load_vad_model(torch.device(device), vad_onset, vad_offset, use_auth_token=hf_token)
- else:
- vad_model = None
-
- # if model_flush:
- # print(">>Model flushing activated... Only loading model after ASR stage")
- # del align_model
- # align_model = ""
-
+ # TODO: check model loading works.
if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
if args["language"] is not None:
@@ -136,39 +117,43 @@ def cli():
if (threads := args.pop("threads")) > 0:
torch.set_num_threads(threads)
- from whisper import load_model
+ asr_options = {
+ "beam_size": args.pop("beam_size"),
+ "patience": args.pop("patience"),
+ "length_penalty": args.pop("length_penalty"),
+ "temperatures": temperature,
+ "compression_ratio_threshold": args.pop("compression_ratio_threshold"),
+ "log_prob_threshold": args.pop("logprob_threshold"),
+ "no_speech_threshold": args.pop("no_speech_threshold"),
+ "condition_on_previous_text": False,
+ "initial_prompt": args.pop("initial_prompt"),
+ }
writer = get_writer(output_format, output_dir)
-
+ word_options = ["highlight_words", "max_line_count", "max_line_width"]
+ if no_align:
+ for option in word_options:
+ if args[option]:
+ parser.error(f"--{option} requires --word_timestamps True")
+ if args["max_line_count"] and not args["max_line_width"]:
+ warnings.warn("--max_line_count has no effect without --max_line_width")
+ writer_args = {arg: args.pop(arg) for arg in word_options}
+
# Part 1: VAD & ASR Loop
results = []
tmp_results = []
- model = load_model(model_name, device=device, download_root=model_dir)
- for audio_path in args.pop("audio"):
- input_audio_path = audio_path
- tfile = None
+ # model = load_model(model_name, device=device, download_root=model_dir)
+ model = load_model(model_name, device=device, language=args['language'], asr_options=asr_options, vad_options={"vad_onset": vad_onset, "vad_offset": vad_offset},)
+ for audio_path in args.pop("audio"):
+ audio = load_audio(audio_path)
# >> VAD & ASR
- if vad_model is not None:
- if not audio_path.endswith(".wav"):
- print(">>VAD requires .wav format, converting to wav as a tempfile...")
- audio_basename = os.path.splitext(os.path.basename(audio_path))[0]
- if tmp_dir is not None:
- input_audio_path = os.path.join(tmp_dir, audio_basename + ".wav")
- else:
- input_audio_path = os.path.join(os.path.dirname(audio_path), audio_basename + ".wav")
- ffmpeg.input(audio_path, threads=0).output(input_audio_path, ac=1, ar=SAMPLE_RATE).run(cmd=["ffmpeg"])
- print(">>Performing VAD...")
- result = transcribe_with_vad(model, input_audio_path, vad_model, temperature=temperature, **args)
- else:
- print(">>Performing transcription...")
- result = transcribe(model, input_audio_path, temperature=temperature, **args)
-
- results.append((result, input_audio_path))
+ print(">>Performing transcription...")
+ result = model.transcribe(audio, batch_size=batch_size)
+ results.append((result, audio_path))
# Unload Whisper and VAD
del model
- del vad_model
gc.collect()
torch.cuda.empty_cache()
@@ -178,17 +163,22 @@ def cli():
results = []
align_language = args["language"] if args["language"] is not None else "en" # default to loading english if not specified
align_model, align_metadata = load_align_model(align_language, device, model_name=align_model)
- for result, input_audio_path in tmp_results:
+ for result, audio_path in tmp_results:
# >> Align
+ if len(tmp_results) > 1:
+ input_audio = audio_path
+ else:
+ # lazily load audio from part 1
+ input_audio = audio
+
if align_model is not None and len(result["segments"]) > 0:
if result.get("language", "en") != align_metadata["language"]:
# load new language
print(f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language...")
align_model, align_metadata = load_align_model(result["language"], device)
print(">>Performing alignment...")
- result = align(result["segments"], align_model, align_metadata, input_audio_path, device,
- extend_duration=align_extend, start_from_previous=align_from_prev, interpolate_method=interpolate_method)
- results.append((result, input_audio_path))
+ result = align(result["segments"], align_model, align_metadata, input_audio, device, interpolate_method=interpolate_method)
+ results.append((result, audio_path))
# Unload align model
del align_model
@@ -210,11 +200,7 @@ def cli():
# >> Write
for result, audio_path in results:
- writer(result, audio_path)
-
- # cleanup
- if input_audio_path != audio_path:
- os.remove(input_audio_path)
+ writer(result, audio_path, writer_args)
if __name__ == "__main__":
cli()
\ No newline at end of file
diff --git a/whisperx/utils.py b/whisperx/utils.py
index 14e298b..3401a84 100644
--- a/whisperx/utils.py
+++ b/whisperx/utils.py
@@ -1,280 +1,301 @@
+import json
import os
+import re
+import sys
import zlib
-from typing import Callable, TextIO, Iterator, Tuple
-import pandas as pd
-import numpy as np
+from typing import Callable, Optional, TextIO
-def interpolate_nans(x, method='nearest'):
- if x.notnull().sum() > 1:
- return x.interpolate(method=method).ffill().bfill()
+LANGUAGES = {
+ "en": "english",
+ "zh": "chinese",
+ "de": "german",
+ "es": "spanish",
+ "ru": "russian",
+ "ko": "korean",
+ "fr": "french",
+ "ja": "japanese",
+ "pt": "portuguese",
+ "tr": "turkish",
+ "pl": "polish",
+ "ca": "catalan",
+ "nl": "dutch",
+ "ar": "arabic",
+ "sv": "swedish",
+ "it": "italian",
+ "id": "indonesian",
+ "hi": "hindi",
+ "fi": "finnish",
+ "vi": "vietnamese",
+ "he": "hebrew",
+ "uk": "ukrainian",
+ "el": "greek",
+ "ms": "malay",
+ "cs": "czech",
+ "ro": "romanian",
+ "da": "danish",
+ "hu": "hungarian",
+ "ta": "tamil",
+ "no": "norwegian",
+ "th": "thai",
+ "ur": "urdu",
+ "hr": "croatian",
+ "bg": "bulgarian",
+ "lt": "lithuanian",
+ "la": "latin",
+ "mi": "maori",
+ "ml": "malayalam",
+ "cy": "welsh",
+ "sk": "slovak",
+ "te": "telugu",
+ "fa": "persian",
+ "lv": "latvian",
+ "bn": "bengali",
+ "sr": "serbian",
+ "az": "azerbaijani",
+ "sl": "slovenian",
+ "kn": "kannada",
+ "et": "estonian",
+ "mk": "macedonian",
+ "br": "breton",
+ "eu": "basque",
+ "is": "icelandic",
+ "hy": "armenian",
+ "ne": "nepali",
+ "mn": "mongolian",
+ "bs": "bosnian",
+ "kk": "kazakh",
+ "sq": "albanian",
+ "sw": "swahili",
+ "gl": "galician",
+ "mr": "marathi",
+ "pa": "punjabi",
+ "si": "sinhala",
+ "km": "khmer",
+ "sn": "shona",
+ "yo": "yoruba",
+ "so": "somali",
+ "af": "afrikaans",
+ "oc": "occitan",
+ "ka": "georgian",
+ "be": "belarusian",
+ "tg": "tajik",
+ "sd": "sindhi",
+ "gu": "gujarati",
+ "am": "amharic",
+ "yi": "yiddish",
+ "lo": "lao",
+ "uz": "uzbek",
+ "fo": "faroese",
+ "ht": "haitian creole",
+ "ps": "pashto",
+ "tk": "turkmen",
+ "nn": "nynorsk",
+ "mt": "maltese",
+ "sa": "sanskrit",
+ "lb": "luxembourgish",
+ "my": "myanmar",
+ "bo": "tibetan",
+ "tl": "tagalog",
+ "mg": "malagasy",
+ "as": "assamese",
+ "tt": "tatar",
+ "haw": "hawaiian",
+ "ln": "lingala",
+ "ha": "hausa",
+ "ba": "bashkir",
+ "jw": "javanese",
+ "su": "sundanese",
+}
+
+# language code lookup by name, with a few language aliases
+TO_LANGUAGE_CODE = {
+ **{language: code for code, language in LANGUAGES.items()},
+ "burmese": "my",
+ "valencian": "ca",
+ "flemish": "nl",
+ "haitian": "ht",
+ "letzeburgesch": "lb",
+ "pushto": "ps",
+ "panjabi": "pa",
+ "moldavian": "ro",
+ "moldovan": "ro",
+ "sinhalese": "si",
+ "castilian": "es",
+}
+
+
+system_encoding = sys.getdefaultencoding()
+
+if system_encoding != "utf-8":
+
+ def make_safe(string):
+ # replaces any character not representable using the system default encoding with an '?',
+ # avoiding UnicodeEncodeError (https://github.com/openai/whisper/discussions/729).
+ return string.encode(system_encoding, errors="replace").decode(system_encoding)
+
+else:
+
+ def make_safe(string):
+ # utf-8 can encode any Unicode code point, so no need to do the round-trip encoding
+ return string
+
+
+def exact_div(x, y):
+ assert x % y == 0
+ return x // y
+
+
+def str2bool(string):
+ str2val = {"True": True, "False": False}
+ if string in str2val:
+ return str2val[string]
else:
- return x.ffill().bfill()
-
-
-def write_txt(transcript: Iterator[dict], file: TextIO):
- for segment in transcript:
- print(segment['text'].strip(), file=file, flush=True)
+ raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}")
-def write_vtt(transcript: Iterator[dict], file: TextIO):
- print("WEBVTT\n", file=file)
- for segment in transcript:
- print(
- f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
- f"{segment['text'].strip().replace('-->', '->')}\n",
- file=file,
- flush=True,
+def optional_int(string):
+ return None if string == "None" else int(string)
+
+
+def optional_float(string):
+ return None if string == "None" else float(string)
+
+
+def compression_ratio(text) -> float:
+ text_bytes = text.encode("utf-8")
+ return len(text_bytes) / len(zlib.compress(text_bytes))
+
+
+def format_timestamp(
+ seconds: float, always_include_hours: bool = False, decimal_marker: str = "."
+):
+ assert seconds >= 0, "non-negative timestamp expected"
+ milliseconds = round(seconds * 1000.0)
+
+ hours = milliseconds // 3_600_000
+ milliseconds -= hours * 3_600_000
+
+ minutes = milliseconds // 60_000
+ milliseconds -= minutes * 60_000
+
+ seconds = milliseconds // 1_000
+ milliseconds -= seconds * 1_000
+
+ hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
+ return (
+ f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
+ )
+
+
+class ResultWriter:
+ extension: str
+
+ def __init__(self, output_dir: str):
+ self.output_dir = output_dir
+
+ def __call__(self, result: dict, audio_path: str, options: dict):
+ audio_basename = os.path.basename(audio_path)
+ audio_basename = os.path.splitext(audio_basename)[0]
+ output_path = os.path.join(
+ self.output_dir, audio_basename + "." + self.extension
)
-def write_tsv(transcript: Iterator[dict], file: TextIO):
- print("start", "end", "text", sep="\t", file=file)
- for segment in transcript:
- print(segment['start'], file=file, end="\t")
- print(segment['end'], file=file, end="\t")
- print(segment['text'].strip().replace("\t", " "), file=file, flush=True)
+ with open(output_path, "w", encoding="utf-8") as f:
+ self.write_result(result, file=f, options=options)
+
+ def write_result(self, result: dict, file: TextIO, options: dict):
+ raise NotImplementedError
-def write_srt(transcript: Iterator[dict], file: TextIO):
- """
- Write a transcript to a file in SRT format.
+class WriteTXT(ResultWriter):
+ extension: str = "txt"
- Example usage:
- from pathlib import Path
- from whisper.utils import write_srt
-
- result = transcribe(model, audio_path, temperature=temperature, **args)
-
- # save SRT
- audio_basename = Path(audio_path).stem
- with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt:
- write_srt(result["segments"], file=srt)
- """
- for i, segment in enumerate(transcript, start=1):
- # write srt lines
- print(
- f"{i}\n"
- f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> "
- f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n"
- f"{segment['text'].strip().replace('-->', '->')}\n",
- file=file,
- flush=True,
- )
+ def write_result(self, result: dict, file: TextIO, options: dict):
+ for segment in result["segments"]:
+ print(segment["text"].strip(), file=file, flush=True)
-def write_ass(transcript: Iterator[dict],
- file: TextIO,
- resolution: str = "word",
- color: str = None, underline=True,
- prefmt: str = None, suffmt: str = None,
- font: str = None, font_size: int = 24,
- strip=True, **kwargs):
- """
- Credit: https://github.com/jianfch/stable-ts/blob/ff79549bd01f764427879f07ecd626c46a9a430a/stable_whisper/text_output.py
- Generate Advanced SubStation Alpha (ass) file from results to
- display both phrase-level & word-level timestamp simultaneously by:
- -using segment-level timestamps display phrases as usual
- -using word-level timestamps change formats (e.g. color/underline) of the word in the displayed segment
- Note: ass file is used in the same way as srt, vtt, etc.
- Parameters
- ----------
- transcript: dict
- results from modified model
- file: TextIO
- file object to write to
- resolution: str
- "word" or "char", timestamp resolution to highlight.
- color: str
- color code for a word at its corresponding timestamp
- reverse order hexadecimal RGB value (e.g. FF0000 is full intensity blue. Default: 00FF00)
- underline: bool
- whether to underline a word at its corresponding timestamp
- prefmt: str
- used to specify format for word-level timestamps (must be use with 'suffmt' and overrides 'color'&'underline')
- appears as such in the .ass file:
- Hi, {}how{} are you?
- reference [Appendix A: Style override codes] in http://www.tcax.org/docs/ass-specs.htm
- suffmt: str
- used to specify format for word-level timestamps (must be use with 'prefmt' and overrides 'color'&'underline')
- appears as such in the .ass file:
- Hi, {}how{} are you?
- reference [Appendix A: Style override codes] in http://www.tcax.org/docs/ass-specs.htm
- font: str
- word font (default: Arial)
- font_size: int
- word font size (default: 48)
- kwargs:
- used for format styles:
- 'Name', 'Fontname', 'Fontsize', 'PrimaryColour', 'SecondaryColour', 'OutlineColour', 'BackColour', 'Bold',
- 'Italic', 'Underline', 'StrikeOut', 'ScaleX', 'ScaleY', 'Spacing', 'Angle', 'BorderStyle', 'Outline',
- 'Shadow', 'Alignment', 'MarginL', 'MarginR', 'MarginV', 'Encoding'
+class SubtitlesWriter(ResultWriter):
+ always_include_hours: bool
+ decimal_marker: str
- """
+ def iterate_result(self, result: dict, options: dict):
+ raw_max_line_width: Optional[int] = options["max_line_width"]
+ max_line_count: Optional[int] = options["max_line_count"]
+ highlight_words: bool = options["highlight_words"]
+ max_line_width = 1000 if raw_max_line_width is None else raw_max_line_width
+ preserve_segments = max_line_count is None or raw_max_line_width is None
- fmt_style_dict = {'Name': 'Default', 'Fontname': 'Arial', 'Fontsize': '48', 'PrimaryColour': '&Hffffff',
- 'SecondaryColour': '&Hffffff', 'OutlineColour': '&H0', 'BackColour': '&H0', 'Bold': '0',
- 'Italic': '0', 'Underline': '0', 'StrikeOut': '0', 'ScaleX': '100', 'ScaleY': '100',
- 'Spacing': '0', 'Angle': '0', 'BorderStyle': '1', 'Outline': '1', 'Shadow': '0',
- 'Alignment': '2', 'MarginL': '10', 'MarginR': '10', 'MarginV': '10', 'Encoding': '0'}
+ def iterate_subtitles():
+ line_len = 0
+ line_count = 1
+ # the next subtitle to yield (a list of word timings with whitespace)
+ subtitle: list[dict] = []
+ last = result["segments"][0]["words"][0]["start"]
+ for segment in result["segments"]:
+ for i, original_timing in enumerate(segment["words"]):
+ timing = original_timing.copy()
+ long_pause = not preserve_segments and timing["start"] - last > 3.0
+ has_room = line_len + len(timing["word"]) <= max_line_width
+ seg_break = i == 0 and len(subtitle) > 0 and preserve_segments
+ if line_len > 0 and has_room and not long_pause and not seg_break:
+ # line continuation
+ line_len += len(timing["word"])
+ else:
+ # new line
+ timing["word"] = timing["word"].strip()
+ if (
+ len(subtitle) > 0
+ and max_line_count is not None
+ and (long_pause or line_count >= max_line_count)
+ or seg_break
+ ):
+ # subtitle break
+ yield subtitle
+ subtitle = []
+ line_count = 1
+ elif line_len > 0:
+ # line break
+ line_count += 1
+ timing["word"] = "\n" + timing["word"]
+ line_len = len(timing["word"].strip())
+ subtitle.append(timing)
+ last = timing["start"]
+ if len(subtitle) > 0:
+ yield subtitle
- for k, v in filter(lambda x: 'colour' in x[0].lower() and not str(x[1]).startswith('&H'), kwargs.items()):
- kwargs[k] = f'&H{kwargs[k]}'
+ if "words" in result["segments"][0]:
+ for subtitle in iterate_subtitles():
+ subtitle_start = self.format_timestamp(subtitle[0]["start"])
+ subtitle_end = self.format_timestamp(subtitle[-1]["end"])
+ subtitle_text = "".join([word["word"] for word in subtitle])
+ if highlight_words:
+ last = subtitle_start
+ all_words = [timing["word"] for timing in subtitle]
+ for i, this_word in enumerate(subtitle):
+ start = self.format_timestamp(this_word["start"])
+ end = self.format_timestamp(this_word["end"])
+ if last != start:
+ yield last, start, subtitle_text
- fmt_style_dict.update((k, v) for k, v in kwargs.items() if k in fmt_style_dict)
-
- if font:
- fmt_style_dict.update(Fontname=font)
- if font_size:
- fmt_style_dict.update(Fontsize=font_size)
-
- fmts = f'Format: {", ".join(map(str, fmt_style_dict.keys()))}'
-
- styles = f'Style: {",".join(map(str, fmt_style_dict.values()))}'
-
- ass_str = f'[Script Info]\nScriptType: v4.00+\nPlayResX: 384\nPlayResY: 288\nScaledBorderAndShadow: yes\n\n' \
- f'[V4+ Styles]\n{fmts}\n{styles}\n\n' \
- f'[Events]\nFormat: Layer, Start, End, Style, Name, MarginL, MarginR, MarginV, Effect, Text\n\n'
-
- if prefmt or suffmt:
- if suffmt:
- assert prefmt, 'prefmt must be used along with suffmt'
+ yield start, end, "".join(
+ [
+ re.sub(r"^(\s*)(.*)$", r"\1\2", word)
+ if j == i
+ else word
+ for j, word in enumerate(all_words)
+ ]
+ )
+ last = end
+ else:
+ yield subtitle_start, subtitle_end, subtitle_text
else:
- suffmt = r'\r'
- else:
- if not color:
- color = 'HFF00'
- underline_code = r'\u1' if underline else ''
-
- prefmt = r'{\1c&' + f'{color.upper()}&{underline_code}' + '}'
- suffmt = r'{\r}'
-
- def secs_to_hhmmss(secs: Tuple[float, int]):
- mm, ss = divmod(secs, 60)
- hh, mm = divmod(mm, 60)
- return f'{hh:0>1.0f}:{mm:0>2.0f}:{ss:0>2.2f}'
-
-
- def dialogue(chars: str, start: float, end: float, idx_0: int, idx_1: int) -> str:
- if idx_0 == -1:
- text = chars
- else:
- text = f'{chars[:idx_0]}{prefmt}{chars[idx_0:idx_1]}{suffmt}{chars[idx_1:]}'
- return f"Dialogue: 0,{secs_to_hhmmss(start)},{secs_to_hhmmss(end)}," \
- f"Default,,0,0,0,,{text.strip() if strip else text}"
-
- if resolution == "word":
- resolution_key = "word-segments"
- elif resolution == "char":
- resolution_key = "char-segments"
- else:
- raise ValueError(".ass resolution should be 'word' or 'char', not ", resolution)
-
- ass_arr = []
-
- for segment in transcript:
- # if "12" in segment['text']:
- # import pdb; pdb.set_trace()
- if resolution_key in segment:
- res_segs = pd.DataFrame(segment[resolution_key])
- prev = segment['start']
- if "speaker" in segment:
- speaker_str = f"[{segment['speaker']}]: "
- else:
- speaker_str = ""
- for cdx, crow in res_segs.iterrows():
- if not np.isnan(crow['start']):
- if resolution == "char":
- idx_0 = cdx
- idx_1 = cdx + 1
- elif resolution == "word":
- idx_0 = int(crow["segment-text-start"])
- idx_1 = int(crow["segment-text-end"])
- # fill gap
- if crow['start'] > prev:
- filler_ts = {
- "chars": speaker_str + segment['text'],
- "start": prev,
- "end": crow['start'],
- "idx_0": -1,
- "idx_1": -1
- }
-
- ass_arr.append(filler_ts)
- # highlight current word
- f_word_ts = {
- "chars": speaker_str + segment['text'],
- "start": crow['start'],
- "end": crow['end'],
- "idx_0": idx_0 + len(speaker_str),
- "idx_1": idx_1 + len(speaker_str)
- }
- ass_arr.append(f_word_ts)
- prev = crow['end']
-
- ass_str += '\n'.join(map(lambda x: dialogue(**x), ass_arr))
-
- file.write(ass_str)
-
-
-from whisper.utils import SubtitlesWriter, ResultWriter, WriteTXT, WriteVTT, WriteSRT, WriteTSV, WriteJSON, format_timestamp
-
-class WriteASS(ResultWriter):
- extension: str = "ass"
-
- def write_result(self, result: dict, file: TextIO):
- write_ass(result["segments"], file, resolution="word")
-
-class WriteASSchar(ResultWriter):
- extension: str = "ass"
-
- def write_result(self, result: dict, file: TextIO):
- write_ass(result["segments"], file, resolution="char")
-
-class WritePickle(ResultWriter):
- extension: str = "ass"
-
- def write_result(self, result: dict, file: TextIO):
- pd.DataFrame(result["segments"]).to_pickle(file)
-
-class WriteSRTWord(ResultWriter):
- extension: str = "word.srt"
- always_include_hours: bool = True
- decimal_marker: str = ","
-
- def iterate_result(self, result: dict):
- for segment in result["word_segments"]:
- segment_start = self.format_timestamp(segment["start"])
- segment_end = self.format_timestamp(segment["end"])
- segment_text = segment["text"].strip().replace("-->", "->")
-
- if word_timings := segment.get("words", None):
- all_words = [timing["word"] for timing in word_timings]
- all_words[0] = all_words[0].strip() # remove the leading space, if any
- last = segment_start
- for i, this_word in enumerate(word_timings):
- start = self.format_timestamp(this_word["start"])
- end = self.format_timestamp(this_word["end"])
- if last != start:
- yield last, start, segment_text
-
- yield start, end, "".join(
- [
- f"{word}" if j == i else word
- for j, word in enumerate(all_words)
- ]
- )
- last = end
-
- if last != segment_end:
- yield last, segment_end, segment_text
- else:
+ for segment in result["segments"]:
+ segment_start = self.format_timestamp(segment["start"])
+ segment_end = self.format_timestamp(segment["end"])
+ segment_text = segment["text"].strip().replace("-->", "->")
yield segment_start, segment_end, segment_text
- def write_result(self, result: dict, file: TextIO):
- if "word_segments" not in result:
- return
- for i, (start, end, text) in enumerate(self.iterate_result(result), start=1):
- print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True)
-
def format_timestamp(self, seconds: float):
return format_timestamp(
seconds=seconds,
@@ -282,36 +303,81 @@ class WriteSRTWord(ResultWriter):
decimal_marker=self.decimal_marker,
)
-def get_writer(output_format: str, output_dir: str) -> Callable[[dict, TextIO], None]:
+
+class WriteVTT(SubtitlesWriter):
+ extension: str = "vtt"
+ always_include_hours: bool = False
+ decimal_marker: str = "."
+
+ def write_result(self, result: dict, file: TextIO, options: dict):
+ print("WEBVTT\n", file=file)
+ for start, end, text in self.iterate_result(result, options):
+ print(f"{start} --> {end}\n{text}\n", file=file, flush=True)
+
+
+class WriteSRT(SubtitlesWriter):
+ extension: str = "srt"
+ always_include_hours: bool = True
+ decimal_marker: str = ","
+
+ def write_result(self, result: dict, file: TextIO, options: dict):
+ for i, (start, end, text) in enumerate(
+ self.iterate_result(result, options), start=1
+ ):
+ print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True)
+
+
+class WriteTSV(ResultWriter):
+ """
+ Write a transcript to a file in TSV (tab-separated values) format containing lines like:
+ \t\t
+
+ Using integer milliseconds as start and end times means there's no chance of interference from
+ an environment setting a language encoding that causes the decimal in a floating point number
+ to appear as a comma; also is faster and more efficient to parse & store, e.g., in C++.
+ """
+
+ extension: str = "tsv"
+
+ def write_result(self, result: dict, file: TextIO, options: dict):
+ print("start", "end", "text", sep="\t", file=file)
+ for segment in result["segments"]:
+ print(round(1000 * segment["start"]), file=file, end="\t")
+ print(round(1000 * segment["end"]), file=file, end="\t")
+ print(segment["text"].strip().replace("\t", " "), file=file, flush=True)
+
+
+class WriteJSON(ResultWriter):
+ extension: str = "json"
+
+ def write_result(self, result: dict, file: TextIO, options: dict):
+ json.dump(result, file)
+
+
+def get_writer(
+ output_format: str, output_dir: str
+) -> Callable[[dict, TextIO, dict], None]:
writers = {
"txt": WriteTXT,
"vtt": WriteVTT,
"srt": WriteSRT,
"tsv": WriteTSV,
- "ass": WriteASS,
- "srt-word": WriteSRTWord,
- # "ass-char": WriteASSchar,
- # "pickle": WritePickle,
- # "json": WriteJSON,
- }
-
- writers_other = {
- "pkl": WritePickle,
- "ass-char": WriteASSchar
+ "json": WriteJSON,
}
if output_format == "all":
all_writers = [writer(output_dir) for writer in writers.values()]
- def write_all(result: dict, file: TextIO):
+ def write_all(result: dict, file: TextIO, options: dict):
for writer in all_writers:
- writer(result, file)
+ writer(result, file, options)
return write_all
- if output_format in writers:
- return writers[output_format](output_dir)
- elif output_format in writers_other:
- return writers_other[output_format](output_dir)
+ return writers[output_format](output_dir)
+
+def interpolate_nans(x, method='nearest'):
+ if x.notnull().sum() > 1:
+ return x.interpolate(method=method).ffill().bfill()
else:
- raise ValueError(f"Output format '{output_format}' not supported, choose from {writers.keys()} and {writers_other.keys()}")
+ return x.ffill().bfill()
\ No newline at end of file
diff --git a/whisperx/vad.py b/whisperx/vad.py
index 933d270..42b0bfb 100644
--- a/whisperx/vad.py
+++ b/whisperx/vad.py
@@ -1,22 +1,23 @@
+import hashlib
import os
import urllib
-import pandas as pd
+from typing import Callable, Optional, Text, Union
+
import numpy as np
+import pandas as pd
import torch
-import hashlib
-from tqdm import tqdm
-from typing import Optional, Callable, Union, Text
-from pyannote.audio.core.io import AudioFile
-from pyannote.core import Annotation, Segment, SlidingWindowFeature
-from pyannote.audio.pipelines.utils import PipelineModel
from pyannote.audio import Model
+from pyannote.audio.core.io import AudioFile
from pyannote.audio.pipelines import VoiceActivityDetection
+from pyannote.audio.pipelines.utils import PipelineModel
+from pyannote.core import Annotation, Segment, SlidingWindowFeature
+from tqdm import tqdm
+
from .diarize import Segment as SegmentX
-from typing import List, Tuple, Optional
VAD_SEGMENTATION_URL = "https://whisperx.s3.eu-west-2.amazonaws.com/model_weights/segmentation/0b5b3216d60a2d32fc086b47ea8c67589aaeb26b7e07fcbe620d6d0b83e209ea/pytorch_model.bin"
-def load_vad_model(device, vad_onset, vad_offset, use_auth_token=None, model_fp=None):
+def load_vad_model(device, vad_onset=0.500, vad_offset=0.363, use_auth_token=None, model_fp=None):
model_dir = torch.hub._get_torch_home()
os.makedirs(model_dir, exist_ok = True)
if model_fp is None:
From 2a29f0ec6a8df93152949053db2fe471ea6e16cc Mon Sep 17 00:00:00 2001
From: Max Bain
Date: Mon, 24 Apr 2023 21:24:22 +0100
Subject: [PATCH 02/20] add compute types
---
whisperx/transcribe.py | 2 ++
1 file changed, 2 insertions(+)
diff --git a/whisperx/transcribe.py b/whisperx/transcribe.py
index c3719ba..0c6b803 100644
--- a/whisperx/transcribe.py
+++ b/whisperx/transcribe.py
@@ -22,6 +22,8 @@ def cli():
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
parser.add_argument("--batch_size", default=8, type=int, help="device to use for PyTorch inference")
+ parser.add_argument("--compute_type", default="float16", type=str, choices=["float16", "float32", "int8"], help="compute type for computation")
+
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
parser.add_argument("--output_format", "-f", type=str, default="all", choices=["all", "srt", "vtt", "txt", "tsv", "json"], help="format of the output file; if not specified, all available formats will be produced")
From 0efad26066f2db7be1f23ebbcad990210027e0a2 Mon Sep 17 00:00:00 2001
From: Max Bain
Date: Mon, 24 Apr 2023 21:26:44 +0100
Subject: [PATCH 03/20] pass compute_type
---
whisperx/transcribe.py | 5 +++--
1 file changed, 3 insertions(+), 2 deletions(-)
diff --git a/whisperx/transcribe.py b/whisperx/transcribe.py
index 0c6b803..fd6cf52 100644
--- a/whisperx/transcribe.py
+++ b/whisperx/transcribe.py
@@ -24,7 +24,6 @@ def cli():
parser.add_argument("--batch_size", default=8, type=int, help="device to use for PyTorch inference")
parser.add_argument("--compute_type", default="float16", type=str, choices=["float16", "float32", "int8"], help="compute type for computation")
-
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
parser.add_argument("--output_format", "-f", type=str, default="all", choices=["all", "srt", "vtt", "txt", "tsv", "json"], help="format of the output file; if not specified, all available formats will be produced")
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
@@ -82,6 +81,8 @@ def cli():
output_dir: str = args.pop("output_dir")
output_format: str = args.pop("output_format")
device: str = args.pop("device")
+ compute_type: str = args.pop("compute_type")
+
# model_flush: bool = args.pop("model_flush")
os.makedirs(output_dir, exist_ok=True)
@@ -145,7 +146,7 @@ def cli():
results = []
tmp_results = []
# model = load_model(model_name, device=device, download_root=model_dir)
- model = load_model(model_name, device=device, language=args['language'], asr_options=asr_options, vad_options={"vad_onset": vad_onset, "vad_offset": vad_offset},)
+ model = load_model(model_name, device=device, compute_type=compute_type, language=args['language'], asr_options=asr_options, vad_options={"vad_onset": vad_onset, "vad_offset": vad_offset},)
for audio_path in args.pop("audio"):
audio = load_audio(audio_path)
From 25be8210e59e7fdceeedbda0b6c18c3b7c8cbe01 Mon Sep 17 00:00:00 2001
From: m-bain <36994049+m-bain@users.noreply.github.com>
Date: Tue, 25 Apr 2023 10:07:34 +0100
Subject: [PATCH 04/20] add v3 tag for install
---
README.md | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/README.md b/README.md
index b5b95c5..e34e533 100644
--- a/README.md
+++ b/README.md
@@ -81,7 +81,7 @@ See other methods [here.](https://pytorch.org/get-started/previous-versions/#whe
### 3. Install this repo
-`pip install git+https://github.com/m-bain/whisperx.git`
+`pip install git+https://github.com/m-bain/whisperx.git@v3`
If already installed, update package to most recent commit
@@ -242,4 +242,4 @@ If you use this in your research, please cite the paper:
journal={arXiv preprint, arXiv:2303.00747},
year={2023}
}
-```
\ No newline at end of file
+```
From db97f29678f4344496136283a08e75d5d8aba643 Mon Sep 17 00:00:00 2001
From: m-bain <36994049+m-bain@users.noreply.github.com>
Date: Tue, 25 Apr 2023 11:19:23 +0100
Subject: [PATCH 05/20] update pip install
---
README.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/README.md b/README.md
index e34e533..16f1e1c 100644
--- a/README.md
+++ b/README.md
@@ -75,7 +75,7 @@ GPU execution requires the NVIDIA libraries cuBLAS 11.x and cuDNN 8.x to be inst
### 2. Install PyTorch 0.11.0, e.g. for Linux and Windows:
-`conda install pytorch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 cudatoolkit=11.3 -c pytorch`
+`pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113`
See other methods [here.](https://pytorch.org/get-started/previous-versions/#wheel-4)
From cc7e168d2b520b4a5e46d50b7c64cafac739db03 Mon Sep 17 00:00:00 2001
From: m-bain <36994049+m-bain@users.noreply.github.com>
Date: Tue, 25 Apr 2023 12:14:23 +0100
Subject: [PATCH 06/20] add checkout command
---
README.md | 5 +++--
1 file changed, 3 insertions(+), 2 deletions(-)
diff --git a/README.md b/README.md
index 16f1e1c..c9951ce 100644
--- a/README.md
+++ b/README.md
@@ -85,12 +85,13 @@ See other methods [here.](https://pytorch.org/get-started/previous-versions/#whe
If already installed, update package to most recent commit
-`pip install git+https://github.com/m-bain/whisperx.git --upgrade`
+`pip install git+https://github.com/m-bain/whisperx.git@v3 --upgrade`
If wishing to modify this package, clone and install in editable mode:
```
-$ git clone https://github.com/m-bain/whisperX.git
+$ git clone https://github.com/m-bain/whisperX.git@v3
$ cd whisperX
+$ git checkout v3
$ pip install -e .
```
From cb176a186ee6dab55f55c55562664bef583e3fd6 Mon Sep 17 00:00:00 2001
From: Thomas Mol
Date: Sat, 29 Apr 2023 19:51:05 +0200
Subject: [PATCH 07/20] added num_workers to fix pickling error
---
whisperx/asr.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/whisperx/asr.py b/whisperx/asr.py
index d9cbff0..b23a54d 100644
--- a/whisperx/asr.py
+++ b/whisperx/asr.py
@@ -207,7 +207,7 @@ class FasterWhisperPipeline(Pipeline):
return final_iterator
def transcribe(
- self, audio: Union[str, np.ndarray], batch_size=None
+ self, audio: Union[str, np.ndarray], batch_size=None, num_workers=0
):
if isinstance(audio, str):
audio = load_audio(audio)
@@ -232,7 +232,7 @@ class FasterWhisperPipeline(Pipeline):
segments = []
batch_size = batch_size or self._batch_size
- for idx, out in enumerate(self.__call__(data(audio, vad_segments), batch_size=batch_size)):
+ for idx, out in enumerate(self.__call__(data(audio, vad_segments), batch_size=batch_size, num_workers=num_workers)):
text = out['text']
if batch_size in [0, 1, None]:
text = text[0]
From b9c8c5072b57b3ed9ec1c121a39f56930e4b4bb9 Mon Sep 17 00:00:00 2001
From: Simon
Date: Sun, 30 Apr 2023 18:34:18 +0200
Subject: [PATCH 08/20] Pad language detection if audio is too short
---
whisperx/asr.py | 5 ++++-
1 file changed, 4 insertions(+), 1 deletion(-)
diff --git a/whisperx/asr.py b/whisperx/asr.py
index b23a54d..ba6220b 100644
--- a/whisperx/asr.py
+++ b/whisperx/asr.py
@@ -251,7 +251,10 @@ class FasterWhisperPipeline(Pipeline):
def detect_language(self, audio: np.ndarray):
- segment = log_mel_spectrogram(audio[: N_SAMPLES], padding=0)
+ if audio.shape[0] < N_SAMPLES:
+ print("Warning: audio is shorter than 30s, language detection may be inaccurate.")
+ segment = log_mel_spectrogram(audio[: N_SAMPLES],
+ padding=0 if audio.shape[0] >= N_SAMPLES else N_SAMPLES - audio.shape[0])
encoder_output = self.model.encode(segment)
results = self.model.model.detect_language(encoder_output)
language_token, language_probability = results[0][0]
From 601c91140fe9056b8a40a10c923e0d09c51cd948 Mon Sep 17 00:00:00 2001
From: Prashanth Ellina
Date: Sun, 30 Apr 2023 17:33:24 +0000
Subject: [PATCH 09/20] references #202, attempt to fix speaker diarization
failing in v3
---
whisperx/alignment.py | 4 ++--
whisperx/transcribe.py | 6 ++++++
2 files changed, 8 insertions(+), 2 deletions(-)
diff --git a/whisperx/alignment.py b/whisperx/alignment.py
index 09f044f..e2c86f7 100644
--- a/whisperx/alignment.py
+++ b/whisperx/alignment.py
@@ -450,8 +450,8 @@ def align(
"end": srow["end"],
"text": text,
"words": word_list,
- # "word-segments": wseg,
- # "char-segments": cseg
+ "word-segments": wseg,
+ "char-segments": cseg
}
)
diff --git a/whisperx/transcribe.py b/whisperx/transcribe.py
index fd6cf52..dab9e12 100644
--- a/whisperx/transcribe.py
+++ b/whisperx/transcribe.py
@@ -203,6 +203,12 @@ def cli():
# >> Write
for result, audio_path in results:
+ # Remove pandas dataframes from result so that
+ # we can serialize the result with json
+ for seg in result["segments"]:
+ seg.pop("word-segments", None)
+ seg.pop("char-segments", None)
+
writer(result, audio_path, writer_args)
if __name__ == "__main__":
From 5becc99e56e184f2103df2aaf71798626ebdaf46 Mon Sep 17 00:00:00 2001
From: Simon
Date: Mon, 1 May 2023 13:47:41 +0200
Subject: [PATCH 10/20] Version bump pyannote, pytorch
---
requirements.txt | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/requirements.txt b/requirements.txt
index 2747e2d..d569c15 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,6 +1,6 @@
-torch==1.11.0
-torchaudio==0.11.0
-pyannote.audio
+torch==1.13.1
+torchaudio==0.13.1
+pyannote.audio==2.1.1
faster-whisper
transformers
ffmpeg-python==0.2.0
From 64ca208cc88d801d1c6295215b4f4f535dd806e3 Mon Sep 17 00:00:00 2001
From: Arnav Mehta <65492948+arnavmehta7@users.noreply.github.com>
Date: Tue, 2 May 2023 13:13:02 +0530
Subject: [PATCH 11/20] Fixed the word_start variable not initialized bug.
---
whisperx/alignment.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/whisperx/alignment.py b/whisperx/alignment.py
index e2c86f7..38c2f00 100644
--- a/whisperx/alignment.py
+++ b/whisperx/alignment.py
@@ -439,8 +439,8 @@ def align(
word_list.append(
{
"word": curr_text.rstrip(),
- "start": word_start,
- "end": word_end,
+ "start": wseg.iloc[wdx]['start'],
+ "end": wseg.iloc[wdx]['end'],
}
)
From 067189248f17b15d8af2cad3f38bcfe44911a8f2 Mon Sep 17 00:00:00 2001
From: Simon
Date: Tue, 2 May 2023 18:44:43 +0200
Subject: [PATCH 12/20] Use pyannote develop branch and torch version 2
---
requirements.txt | 5 ++---
setup.py | 2 +-
2 files changed, 3 insertions(+), 4 deletions(-)
diff --git a/requirements.txt b/requirements.txt
index d569c15..f4f9c21 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,6 +1,5 @@
-torch==1.13.1
-torchaudio==0.13.1
-pyannote.audio==2.1.1
+torch==2.0.0
+torchaudio==2.0.1
faster-whisper
transformers
ffmpeg-python==0.2.0
diff --git a/setup.py b/setup.py
index 2060d42..1eaaeb9 100644
--- a/setup.py
+++ b/setup.py
@@ -19,7 +19,7 @@ setup(
for r in pkg_resources.parse_requirements(
open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
)
- ],
+ ] + ["pyannote.audio @ https://github.com/pyannote/pyannote-audio/archive/refs/heads/develop.zip"],
entry_points = {
'console_scripts': ['whisperx=whisperx.transcribe:cli'],
},
From 2a6830492c7e9499a20e719ec5806b99b25af8ed Mon Sep 17 00:00:00 2001
From: Simon
Date: Tue, 2 May 2023 20:25:56 +0200
Subject: [PATCH 13/20] Fix pyannote to specific commit
---
setup.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/setup.py b/setup.py
index 1eaaeb9..859d171 100644
--- a/setup.py
+++ b/setup.py
@@ -19,7 +19,7 @@ setup(
for r in pkg_resources.parse_requirements(
open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
)
- ] + ["pyannote.audio @ https://github.com/pyannote/pyannote-audio/archive/refs/heads/develop.zip"],
+ ] + ["pyannote.audio @ git+https://github.com/pyannote/pyannote-audio@11b56a137a578db9335efc00298f6ec1932e6317"],
entry_points = {
'console_scripts': ['whisperx=whisperx.transcribe:cli'],
},
From cb53661070cab184ff2182e6498e715709e1a61e Mon Sep 17 00:00:00 2001
From: aramlang <100400031+aramlang@users.noreply.github.com>
Date: Wed, 3 May 2023 11:26:12 -0500
Subject: [PATCH 14/20] Enable Hebrew support
---
whisperx/alignment.py | 10 ++++++++--
1 file changed, 8 insertions(+), 2 deletions(-)
diff --git a/whisperx/alignment.py b/whisperx/alignment.py
index 38c2f00..2ae77f3 100644
--- a/whisperx/alignment.py
+++ b/whisperx/alignment.py
@@ -38,6 +38,7 @@ DEFAULT_ALIGN_MODELS_HF = {
"fa": "jonatasgrosman/wav2vec2-large-xlsr-53-persian",
"el": "jonatasgrosman/wav2vec2-large-xlsr-53-greek",
"tr": "mpoyraz/wav2vec2-xls-r-300m-cv7-turkish",
+ "he": "imvladikon/wav2vec2-xls-r-300m-hebrew",
}
@@ -231,8 +232,13 @@ def align(
emission = emissions[0].cpu().detach()
- trellis = get_trellis(emission, tokens)
- path = backtrack(trellis, emission, tokens)
+ blank_id = 0
+ for char, code in model_dictionary.items():
+ if char == '[pad]' or char == '':
+ blank_id = code
+
+ trellis = get_trellis(emission, tokens, blank_id)
+ path = backtrack(trellis, emission, tokens, blank_id)
if path is None:
print(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original...')
break
From 2d59eb97260037113d85edcfe2713117987bd218 Mon Sep 17 00:00:00 2001
From: Simon
Date: Wed, 3 May 2023 23:17:44 +0200
Subject: [PATCH 15/20] Add torch compile to log mel spectrogram
---
whisperx/asr.py | 5 ++++-
whisperx/audio.py | 44 +++++++++++++++-----------------------------
2 files changed, 19 insertions(+), 30 deletions(-)
diff --git a/whisperx/asr.py b/whisperx/asr.py
index ba6220b..1ca12ce 100644
--- a/whisperx/asr.py
+++ b/whisperx/asr.py
@@ -181,6 +181,9 @@ class FasterWhisperPipeline(Pipeline):
def preprocess(self, audio):
audio = audio['inputs']
+ if isinstance(audio, np.ndarray):
+ audio = torch.from_numpy(audio)
+
features = log_mel_spectrogram(audio, padding=N_SAMPLES - audio.shape[0])
return {'inputs': features}
@@ -253,7 +256,7 @@ class FasterWhisperPipeline(Pipeline):
def detect_language(self, audio: np.ndarray):
if audio.shape[0] < N_SAMPLES:
print("Warning: audio is shorter than 30s, language detection may be inaccurate.")
- segment = log_mel_spectrogram(audio[: N_SAMPLES],
+ segment = log_mel_spectrogram(torch.from_numpy(audio[:N_SAMPLES]),
padding=0 if audio.shape[0] >= N_SAMPLES else N_SAMPLES - audio.shape[0])
encoder_output = self.model.encode(segment)
results = self.model.model.detect_language(encoder_output)
diff --git a/whisperx/audio.py b/whisperx/audio.py
index 513ab7c..8ac0674 100644
--- a/whisperx/audio.py
+++ b/whisperx/audio.py
@@ -22,6 +22,12 @@ N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2
FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame
TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token
+with np.load(
+ os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
+) as f:
+ MEL_FILTERS = torch.from_numpy(f[f"mel_{80}"])
+
+
def load_audio(file: str, sr: int = SAMPLE_RATE):
"""
@@ -79,27 +85,9 @@ def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
return array
-@lru_cache(maxsize=None)
-def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor:
- """
- load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
- Allows decoupling librosa dependency; saved using:
-
- np.savez_compressed(
- "mel_filters.npz",
- mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
- )
- """
- assert n_mels == 80, f"Unsupported n_mels: {n_mels}"
- with np.load(
- os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
- ) as f:
- return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
-
-
+@torch.compile(fullgraph=True)
def log_mel_spectrogram(
- audio: Union[str, np.ndarray, torch.Tensor],
- n_mels: int = N_MELS,
+ audio: torch.Tensor,
padding: int = 0,
device: Optional[Union[str, torch.device]] = None,
):
@@ -108,7 +96,7 @@ def log_mel_spectrogram(
Parameters
----------
- audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
+ audio: torch.Tensor, shape = (*)
The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
n_mels: int
@@ -125,21 +113,19 @@ def log_mel_spectrogram(
torch.Tensor, shape = (80, n_frames)
A Tensor that contains the Mel spectrogram
"""
- if not torch.is_tensor(audio):
- if isinstance(audio, str):
- audio = load_audio(audio)
- audio = torch.from_numpy(audio)
+ global MEL_FILTERS
if device is not None:
audio = audio.to(device)
if padding > 0:
audio = F.pad(audio, (0, padding))
window = torch.hann_window(N_FFT).to(audio.device)
- stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
- magnitudes = stft[..., :-1].abs() ** 2
+ stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=False)
+ # Square the real and imaginary components and sum them together, similar to torch.abs() on complex tensors
+ magnitudes = (stft[:, :-1, :] ** 2).sum(dim=-1)
- filters = mel_filters(audio.device, n_mels)
- mel_spec = filters @ magnitudes
+ MEL_FILTERS = MEL_FILTERS.to(audio.device)
+ mel_spec = MEL_FILTERS @ magnitudes
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
From d8f0ef4a19a7cd06bc6afb8ce97f575e51655b68 Mon Sep 17 00:00:00 2001
From: Simon
Date: Thu, 4 May 2023 16:25:34 +0200
Subject: [PATCH 16/20] Set diarization device manually
---
whisperx/diarize.py | 7 ++++++-
whisperx/transcribe.py | 3 ++-
2 files changed, 8 insertions(+), 2 deletions(-)
diff --git a/whisperx/diarize.py b/whisperx/diarize.py
index 34dfc63..6f8c257 100644
--- a/whisperx/diarize.py
+++ b/whisperx/diarize.py
@@ -1,14 +1,19 @@
import numpy as np
import pandas as pd
from pyannote.audio import Pipeline
+from typing import Optional, Union
+import torch
class DiarizationPipeline:
def __init__(
self,
model_name="pyannote/speaker-diarization@2.1",
use_auth_token=None,
+ device: Optional[Union[str, torch.device]] = "cpu",
):
- self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token)
+ if isinstance(device, str):
+ device = torch.device(device)
+ self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token).to(device)
def __call__(self, audio, min_speakers=None, max_speakers=None):
segments = self.model(audio, min_speakers=min_speakers, max_speakers=max_speakers)
diff --git a/whisperx/transcribe.py b/whisperx/transcribe.py
index dab9e12..e284e83 100644
--- a/whisperx/transcribe.py
+++ b/whisperx/transcribe.py
@@ -193,8 +193,9 @@ def cli():
if hf_token is None:
print("Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model...")
tmp_results = results
+ print(">>Performing diarization...")
results = []
- diarize_model = DiarizationPipeline(use_auth_token=hf_token)
+ diarize_model = DiarizationPipeline(use_auth_token=hf_token, device=device)
for result, input_audio_path in tmp_results:
diarize_segments = diarize_model(input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers)
results_segments, word_segments = assign_word_speakers(diarize_segments, result["segments"])
From 4e2ac4e4e9da9b2b392c13f68f0718e86d293fb0 Mon Sep 17 00:00:00 2001
From: Max Bain
Date: Thu, 4 May 2023 20:38:13 +0100
Subject: [PATCH 17/20] torch2.0, remove compile for now, round to times to 3
decimal
---
README.md | 12 ++++++------
setup.py | 2 +-
whisperx/alignment.py | 4 ++--
whisperx/asr.py | 5 +----
whisperx/audio.py | 44 ++++++++++++++++++++++++++++--------------
whisperx/transcribe.py | 7 +------
6 files changed, 40 insertions(+), 34 deletions(-)
diff --git a/README.md b/README.md
index c9951ce..1f41bb9 100644
--- a/README.md
+++ b/README.md
@@ -61,23 +61,23 @@ This repository refines the timestamps of openAI's Whisper model via forced alig
Setup βοΈ
-Tested for PyTorch 0.11, Python 3.8 (use other versions at your own risk!)
+Tested for PyTorch 2.0, Python 3.10 (use other versions at your own risk!)
GPU execution requires the NVIDIA libraries cuBLAS 11.x and cuDNN 8.x to be installed on the system. Please refer to the [CTranslate2 documentation](https://opennmt.net/CTranslate2/installation.html).
-### 1. Create Python3.8 environment
+### 1. Create Python3.10 environment
-`conda create --name whisperx python=3.8`
+`conda create --name whisperx python=3.10`
`conda activate whisperx`
-### 2. Install PyTorch 0.11.0, e.g. for Linux and Windows:
+### 2. Install PyTorch2.0, e.g. for Linux and Windows CUDA11.7:
-`pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113`
+`pip3 install torch torchvision torchaudio`
-See other methods [here.](https://pytorch.org/get-started/previous-versions/#wheel-4)
+See other methods [here.](https://pytorch.org/get-started/locally/)
### 3. Install this repo
diff --git a/setup.py b/setup.py
index 859d171..66f22cd 100644
--- a/setup.py
+++ b/setup.py
@@ -6,7 +6,7 @@ from setuptools import setup, find_packages
setup(
name="whisperx",
py_modules=["whisperx"],
- version="3.0.0",
+ version="3.0.2",
description="Time-Accurate Automatic Speech Recognition using Whisper.",
readme="README.md",
python_requires=">=3.8",
diff --git a/whisperx/alignment.py b/whisperx/alignment.py
index 2ae77f3..e63e6e5 100644
--- a/whisperx/alignment.py
+++ b/whisperx/alignment.py
@@ -268,8 +268,8 @@ def align(
start, end, score = None, None, None
if cdx in clean_cdx:
char_seg = char_segments[clean_cdx.index(cdx)]
- start = char_seg.start * ratio + t1
- end = char_seg.end * ratio + t1
+ start = round(char_seg.start * ratio + t1, 3)
+ end = round(char_seg.end * ratio + t1, 3)
score = char_seg.score
char_segments_arr["char"].append(char)
diff --git a/whisperx/asr.py b/whisperx/asr.py
index 1ca12ce..ba6220b 100644
--- a/whisperx/asr.py
+++ b/whisperx/asr.py
@@ -181,9 +181,6 @@ class FasterWhisperPipeline(Pipeline):
def preprocess(self, audio):
audio = audio['inputs']
- if isinstance(audio, np.ndarray):
- audio = torch.from_numpy(audio)
-
features = log_mel_spectrogram(audio, padding=N_SAMPLES - audio.shape[0])
return {'inputs': features}
@@ -256,7 +253,7 @@ class FasterWhisperPipeline(Pipeline):
def detect_language(self, audio: np.ndarray):
if audio.shape[0] < N_SAMPLES:
print("Warning: audio is shorter than 30s, language detection may be inaccurate.")
- segment = log_mel_spectrogram(torch.from_numpy(audio[:N_SAMPLES]),
+ segment = log_mel_spectrogram(audio[: N_SAMPLES],
padding=0 if audio.shape[0] >= N_SAMPLES else N_SAMPLES - audio.shape[0])
encoder_output = self.model.encode(segment)
results = self.model.model.detect_language(encoder_output)
diff --git a/whisperx/audio.py b/whisperx/audio.py
index 8ac0674..513ab7c 100644
--- a/whisperx/audio.py
+++ b/whisperx/audio.py
@@ -22,12 +22,6 @@ N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2
FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame
TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token
-with np.load(
- os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
-) as f:
- MEL_FILTERS = torch.from_numpy(f[f"mel_{80}"])
-
-
def load_audio(file: str, sr: int = SAMPLE_RATE):
"""
@@ -85,9 +79,27 @@ def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
return array
-@torch.compile(fullgraph=True)
+@lru_cache(maxsize=None)
+def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor:
+ """
+ load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
+ Allows decoupling librosa dependency; saved using:
+
+ np.savez_compressed(
+ "mel_filters.npz",
+ mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
+ )
+ """
+ assert n_mels == 80, f"Unsupported n_mels: {n_mels}"
+ with np.load(
+ os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
+ ) as f:
+ return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
+
+
def log_mel_spectrogram(
- audio: torch.Tensor,
+ audio: Union[str, np.ndarray, torch.Tensor],
+ n_mels: int = N_MELS,
padding: int = 0,
device: Optional[Union[str, torch.device]] = None,
):
@@ -96,7 +108,7 @@ def log_mel_spectrogram(
Parameters
----------
- audio: torch.Tensor, shape = (*)
+ audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
n_mels: int
@@ -113,19 +125,21 @@ def log_mel_spectrogram(
torch.Tensor, shape = (80, n_frames)
A Tensor that contains the Mel spectrogram
"""
- global MEL_FILTERS
+ if not torch.is_tensor(audio):
+ if isinstance(audio, str):
+ audio = load_audio(audio)
+ audio = torch.from_numpy(audio)
if device is not None:
audio = audio.to(device)
if padding > 0:
audio = F.pad(audio, (0, padding))
window = torch.hann_window(N_FFT).to(audio.device)
- stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=False)
- # Square the real and imaginary components and sum them together, similar to torch.abs() on complex tensors
- magnitudes = (stft[:, :-1, :] ** 2).sum(dim=-1)
+ stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
+ magnitudes = stft[..., :-1].abs() ** 2
- MEL_FILTERS = MEL_FILTERS.to(audio.device)
- mel_spec = MEL_FILTERS @ magnitudes
+ filters = mel_filters(audio.device, n_mels)
+ mel_spec = filters @ magnitudes
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
diff --git a/whisperx/transcribe.py b/whisperx/transcribe.py
index e284e83..4b5a664 100644
--- a/whisperx/transcribe.py
+++ b/whisperx/transcribe.py
@@ -72,7 +72,6 @@ def cli():
parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face Access Token to access PyAnnote gated models")
# parser.add_argument("--model_flush", action="store_true", help="Flush memory from each model after use, reduces GPU requirement but slower processing >1 audio file.")
- parser.add_argument("--tmp_dir", default=None, help="Temporary directory to write audio file if input if not .wav format (only for VAD).")
# fmt: on
args = parser.parse_args().__dict__
@@ -86,10 +85,6 @@ def cli():
# model_flush: bool = args.pop("model_flush")
os.makedirs(output_dir, exist_ok=True)
- tmp_dir: str = args.pop("tmp_dir")
- if tmp_dir is not None:
- os.makedirs(tmp_dir, exist_ok=True)
-
align_model: str = args.pop("align_model")
interpolate_method: str = args.pop("interpolate_method")
no_align: bool = args.pop("no_align")
@@ -195,7 +190,7 @@ def cli():
tmp_results = results
print(">>Performing diarization...")
results = []
- diarize_model = DiarizationPipeline(use_auth_token=hf_token, device=device)
+ diarize_model = DiarizationPipeline(use_auth_token=hf_token)
for result, input_audio_path in tmp_results:
diarize_segments = diarize_model(input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers)
results_segments, word_segments = assign_word_speakers(diarize_segments, result["segments"])
From 07361ba1d7e10c218ef30dd465b92e89ddebb5c5 Mon Sep 17 00:00:00 2001
From: Max Bain
Date: Fri, 5 May 2023 11:53:51 +0100
Subject: [PATCH 18/20] add device to dia pipeline @sorgfresser
---
whisperx/diarize.py | 1 +
whisperx/transcribe.py | 2 +-
2 files changed, 2 insertions(+), 1 deletion(-)
diff --git a/whisperx/diarize.py b/whisperx/diarize.py
index 6f8c257..93ff41d 100644
--- a/whisperx/diarize.py
+++ b/whisperx/diarize.py
@@ -11,6 +11,7 @@ class DiarizationPipeline:
use_auth_token=None,
device: Optional[Union[str, torch.device]] = "cpu",
):
+ self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token)
if isinstance(device, str):
device = torch.device(device)
self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token).to(device)
diff --git a/whisperx/transcribe.py b/whisperx/transcribe.py
index 4b5a664..f3f63fe 100644
--- a/whisperx/transcribe.py
+++ b/whisperx/transcribe.py
@@ -190,7 +190,7 @@ def cli():
tmp_results = results
print(">>Performing diarization...")
results = []
- diarize_model = DiarizationPipeline(use_auth_token=hf_token)
+ diarize_model = DiarizationPipeline(use_auth_token=hf_token, device=device)
for result, input_audio_path in tmp_results:
diarize_segments = diarize_model(input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers)
results_segments, word_segments = assign_word_speakers(diarize_segments, result["segments"])
From 24008aa1ed67c4f75c90107b4937178a1452519d Mon Sep 17 00:00:00 2001
From: Max Bain
Date: Sun, 7 May 2023 15:32:58 +0100
Subject: [PATCH 19/20] fix long segments, break into sentences using nltk,
improve align logic, improve diarize (sentence-based)
---
requirements.txt | 3 +-
whisperx/alignment.py | 519 ++++++++++++++---------------------------
whisperx/asr.py | 155 +-----------
whisperx/diarize.py | 82 +++----
whisperx/transcribe.py | 17 +-
whisperx/utils.py | 67 ++++--
6 files changed, 269 insertions(+), 574 deletions(-)
diff --git a/requirements.txt b/requirements.txt
index f4f9c21..ec90a07 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -4,4 +4,5 @@ faster-whisper
transformers
ffmpeg-python==0.2.0
pandas
-setuptools==65.6.3
\ No newline at end of file
+setuptools==65.6.3
+nltk
\ No newline at end of file
diff --git a/whisperx/alignment.py b/whisperx/alignment.py
index e63e6e5..2812c10 100644
--- a/whisperx/alignment.py
+++ b/whisperx/alignment.py
@@ -13,6 +13,7 @@ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from .audio import SAMPLE_RATE, load_audio
from .utils import interpolate_nans
+import nltk
LANGUAGES_WITHOUT_SPACES = ["ja", "zh"]
@@ -84,386 +85,226 @@ def align(
align_model_metadata: dict,
audio: Union[str, np.ndarray, torch.Tensor],
device: str,
- extend_duration: float = 0.0,
- start_from_previous: bool = True,
interpolate_method: str = "nearest",
+ return_char_alignments: bool = False,
):
"""
- Force align phoneme recognition predictions to known transcription
-
- Parameters
- ----------
- transcript: Iterator[dict]
- The Whisper model instance
-
- model: torch.nn.Module
- Alignment model (wav2vec2)
-
- audio: Union[str, np.ndarray, torch.Tensor]
- The path to the audio file to open, or the audio waveform
-
- device: str
- cuda device
-
- diarization: pd.DataFrame {'start': List[float], 'end': List[float], 'speaker': List[float]}
- diarization segments with speaker labels.
-
- extend_duration: float
- Amount to pad input segments by. If not using vad--filter then recommended to use 2 seconds
-
- If the gzip compression ratio is above this value, treat as failed
-
- interpolate_method: str ["nearest", "linear", "ignore"]
- Method to assign timestamps to non-aligned words. Words are not able to be aligned when none of the characters occur in the align model dictionary.
- "nearest" copies timestamp of nearest word within the segment. "linear" is linear interpolation. "drop" removes that word from output.
-
- Returns
- -------
- A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
- the spoken language ("language"), which is detected when `decode_options["language"]` is None.
+ Align phoneme recognition predictions to known transcription.
"""
+
if not torch.is_tensor(audio):
if isinstance(audio, str):
audio = load_audio(audio)
audio = torch.from_numpy(audio)
if len(audio.shape) == 1:
audio = audio.unsqueeze(0)
-
+
MAX_DURATION = audio.shape[1] / SAMPLE_RATE
model_dictionary = align_model_metadata["dictionary"]
model_lang = align_model_metadata["language"]
model_type = align_model_metadata["type"]
- aligned_segments = []
-
- prev_t2 = 0
-
- char_segments_arr = {
- "segment-idx": [],
- "subsegment-idx": [],
- "word-idx": [],
- "char": [],
- "start": [],
- "end": [],
- "score": [],
- }
-
+ # 1. Preprocess to keep only characters in dictionary
for sdx, segment in enumerate(transcript):
- while True:
- segment_align_success = False
+ # strip spaces at beginning / end, but keep track of the amount.
+ num_leading = len(segment["text"]) - len(segment["text"].lstrip())
+ num_trailing = len(segment["text"]) - len(segment["text"].rstrip())
+ text = segment["text"]
- # strip spaces at beginning / end, but keep track of the amount.
- num_leading = len(segment["text"]) - len(segment["text"].lstrip())
- num_trailing = len(segment["text"]) - len(segment["text"].rstrip())
- transcription = segment["text"]
+ # split into words
+ if model_lang not in LANGUAGES_WITHOUT_SPACES:
+ per_word = text.split(" ")
+ else:
+ per_word = text
- # TODO: convert number tokenizer / symbols to phonetic words for alignment.
- # e.g. "$300" -> "three hundred dollars"
- # currently "$300" is ignored since no characters present in the phonetic dictionary
-
- # split into words
+ clean_char, clean_cdx = [], []
+ for cdx, char in enumerate(text):
+ char_ = char.lower()
+ # wav2vec2 models use "|" character to represent spaces
if model_lang not in LANGUAGES_WITHOUT_SPACES:
- per_word = transcription.split(" ")
- else:
- per_word = transcription
-
- # first check that characters in transcription can be aligned (they are contained in align model"s dictionary)
- clean_char, clean_cdx = [], []
- for cdx, char in enumerate(transcription):
- char_ = char.lower()
- # wav2vec2 models use "|" character to represent spaces
- if model_lang not in LANGUAGES_WITHOUT_SPACES:
- char_ = char_.replace(" ", "|")
-
- # ignore whitespace at beginning and end of transcript
- if cdx < num_leading:
- pass
- elif cdx > len(transcription) - num_trailing - 1:
- pass
- elif char_ in model_dictionary.keys():
- clean_char.append(char_)
- clean_cdx.append(cdx)
-
- clean_wdx = []
- for wdx, wrd in enumerate(per_word):
- if any([c in model_dictionary.keys() for c in wrd]):
- clean_wdx.append(wdx)
-
- # if no characters are in the dictionary, then we skip this segment...
- if len(clean_char) == 0:
- print(f'Failed to align segment ("{segment["text"]}"): no characters in this segment found in model dictionary, resorting to original...')
- break
-
- transcription_cleaned = "".join(clean_char)
- tokens = [model_dictionary[c] for c in transcription_cleaned]
-
- # we only pad if not using VAD filtering
- if "seg_text" not in segment:
- # pad according original timestamps
- t1 = max(segment["start"] - extend_duration, 0)
- t2 = min(segment["end"] + extend_duration, MAX_DURATION)
-
- # use prev_t2 as current t1 if it"s later
- if start_from_previous and t1 < prev_t2:
- t1 = prev_t2
-
- # check if timestamp range is still valid
- if t1 >= MAX_DURATION:
- print("Failed to align segment: original start time longer than audio duration, skipping...")
- break
- if t2 - t1 < 0.02:
- print("Failed to align segment: duration smaller than 0.02s time precision")
- break
-
- f1 = int(t1 * SAMPLE_RATE)
- f2 = int(t2 * SAMPLE_RATE)
-
- waveform_segment = audio[:, f1:f2]
-
- with torch.inference_mode():
- if model_type == "torchaudio":
- emissions, _ = model(waveform_segment.to(device))
- elif model_type == "huggingface":
- emissions = model(waveform_segment.to(device)).logits
- else:
- raise NotImplementedError(f"Align model of type {model_type} not supported.")
- emissions = torch.log_softmax(emissions, dim=-1)
-
- emission = emissions[0].cpu().detach()
-
- blank_id = 0
- for char, code in model_dictionary.items():
- if char == '[pad]' or char == '':
- blank_id = code
-
- trellis = get_trellis(emission, tokens, blank_id)
- path = backtrack(trellis, emission, tokens, blank_id)
- if path is None:
- print(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original...')
- break
- char_segments = merge_repeats(path, transcription_cleaned)
- # word_segments = merge_words(char_segments)
+ char_ = char_.replace(" ", "|")
+ # ignore whitespace at beginning and end of transcript
+ if cdx < num_leading:
+ pass
+ elif cdx > len(text) - num_trailing - 1:
+ pass
+ elif char_ in model_dictionary.keys():
+ clean_char.append(char_)
+ clean_cdx.append(cdx)
- # sub-segments
- if "seg-text" not in segment:
- segment["seg-text"] = [transcription]
-
- seg_lens = [0] + [len(x) for x in segment["seg-text"]]
- seg_lens_cumsum = list(np.cumsum(seg_lens))
- sub_seg_idx = 0
-
- wdx = 0
- duration = t2 - t1
- ratio = duration * waveform_segment.size(0) / (trellis.size(0) - 1)
- for cdx, char in enumerate(transcription + " "):
- is_last = False
- if cdx == len(transcription):
- break
- elif cdx+1 == len(transcription):
- is_last = True
-
-
- start, end, score = None, None, None
- if cdx in clean_cdx:
- char_seg = char_segments[clean_cdx.index(cdx)]
- start = round(char_seg.start * ratio + t1, 3)
- end = round(char_seg.end * ratio + t1, 3)
- score = char_seg.score
-
- char_segments_arr["char"].append(char)
- char_segments_arr["start"].append(start)
- char_segments_arr["end"].append(end)
- char_segments_arr["score"].append(score)
- char_segments_arr["word-idx"].append(wdx)
- char_segments_arr["segment-idx"].append(sdx)
- char_segments_arr["subsegment-idx"].append(sub_seg_idx)
-
- # word-level info
- if model_lang in LANGUAGES_WITHOUT_SPACES:
- # character == word
- wdx += 1
- elif is_last or transcription[cdx+1] == " " or cdx == seg_lens_cumsum[sub_seg_idx+1] - 1:
- wdx += 1
-
- if is_last or cdx == seg_lens_cumsum[sub_seg_idx+1] - 1:
- wdx = 0
- sub_seg_idx += 1
-
- prev_t2 = segment["end"]
-
- segment_align_success = True
- # end while True loop
- break
-
- # reset prev_t2 due to drifting issues
- if not segment_align_success:
- prev_t2 = 0
-
- char_segments_arr = pd.DataFrame(char_segments_arr)
- not_space = char_segments_arr["char"] != " "
-
- per_seg_grp = char_segments_arr.groupby(["segment-idx", "subsegment-idx"], as_index = False)
- char_segments_arr = per_seg_grp.apply(lambda x: x.reset_index(drop = True)).reset_index()
- per_word_grp = char_segments_arr[not_space].groupby(["segment-idx", "subsegment-idx", "word-idx"])
- per_subseg_grp = char_segments_arr[not_space].groupby(["segment-idx", "subsegment-idx"])
- per_seg_grp = char_segments_arr[not_space].groupby(["segment-idx"])
- char_segments_arr["local-char-idx"] = char_segments_arr.groupby(["segment-idx", "subsegment-idx"]).cumcount()
- per_word_grp = char_segments_arr[not_space].groupby(["segment-idx", "subsegment-idx", "word-idx"]) # regroup
-
- word_segments_arr = {}
-
- # start of word is first char with a timestamp
- word_segments_arr["start"] = per_word_grp["start"].min().values
- # end of word is last char with a timestamp
- word_segments_arr["end"] = per_word_grp["end"].max().values
- # score of word is mean (excluding nan)
- word_segments_arr["score"] = per_word_grp["score"].mean().values
-
- word_segments_arr["segment-text-start"] = per_word_grp["local-char-idx"].min().astype(int).values
- word_segments_arr["segment-text-end"] = per_word_grp["local-char-idx"].max().astype(int).values+1
- word_segments_arr = pd.DataFrame(word_segments_arr)
-
- word_segments_arr[["segment-idx", "subsegment-idx", "word-idx"]] = per_word_grp["local-char-idx"].min().reset_index()[["segment-idx", "subsegment-idx", "word-idx"]].astype(int)
- segments_arr = {}
- segments_arr["start"] = per_subseg_grp["start"].min().reset_index()["start"]
- segments_arr["end"] = per_subseg_grp["end"].max().reset_index()["end"]
- segments_arr = pd.DataFrame(segments_arr)
- segments_arr[["segment-idx", "subsegment-idx-start"]] = per_subseg_grp["start"].min().reset_index()[["segment-idx", "subsegment-idx"]]
- segments_arr["subsegment-idx-end"] = segments_arr["subsegment-idx-start"] + 1
-
- # interpolate missing words / sub-segments
- if interpolate_method != "ignore":
- wrd_subseg_grp = word_segments_arr.groupby(["segment-idx", "subsegment-idx"], group_keys=False)
- wrd_seg_grp = word_segments_arr.groupby(["segment-idx"], group_keys=False)
- # we still know which word timestamps are interpolated because their score == nan
- word_segments_arr["start"] = wrd_subseg_grp['start'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
- word_segments_arr["end"] = wrd_subseg_grp['end'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
-
- word_segments_arr["start"] = wrd_seg_grp['start'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
- word_segments_arr["end"] = wrd_seg_grp['end'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
-
- sub_seg_grp = segments_arr.groupby(["segment-idx"], group_keys=False)
- segments_arr['start'] = sub_seg_grp['start'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
- segments_arr['end'] = sub_seg_grp['end'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
-
- # merge words & subsegments which are missing times
- word_grp = word_segments_arr.groupby(["segment-idx", "subsegment-idx", "end"])
-
- word_segments_arr["segment-text-start"] = word_grp["segment-text-start"].transform(min)
- word_segments_arr["segment-text-end"] = word_grp["segment-text-end"].transform(max)
- word_segments_arr.drop_duplicates(subset=["segment-idx", "subsegment-idx", "end"], inplace=True)
-
- seg_grp_dup = segments_arr.groupby(["segment-idx", "start", "end"])
- segments_arr["subsegment-idx-start"] = seg_grp_dup["subsegment-idx-start"].transform(min)
- segments_arr["subsegment-idx-end"] = seg_grp_dup["subsegment-idx-end"].transform(max)
- segments_arr.drop_duplicates(subset=["segment-idx", "subsegment-idx-start", "subsegment-idx-end"], inplace=True)
- else:
- word_segments_arr.dropna(inplace=True)
- segments_arr.dropna(inplace=True)
-
- # if some segments still have missing timestamps (usually because all numerals / symbols), then use original timestamps...
- segments_arr['start'].fillna(pd.Series([x['start'] for x in transcript]), inplace=True)
- segments_arr['end'].fillna(pd.Series([x['end'] for x in transcript]), inplace=True)
- segments_arr['subsegment-idx-start'].fillna(0, inplace=True)
- segments_arr['subsegment-idx-end'].fillna(1, inplace=True)
+ clean_wdx = []
+ for wdx, wrd in enumerate(per_word):
+ if any([c in model_dictionary.keys() for c in wrd]):
+ clean_wdx.append(wdx)
+ sentence_spans = list(nltk.tokenize.punkt.PunktSentenceTokenizer().span_tokenize(text))
+ segment["clean_char"] = clean_char
+ segment["clean_cdx"] = clean_cdx
+ segment["clean_wdx"] = clean_wdx
+ segment["sentence_spans"] = sentence_spans
+
aligned_segments = []
- aligned_segments_word = []
- word_segments_arr.set_index(["segment-idx", "subsegment-idx"], inplace=True)
- char_segments_arr.set_index(["segment-idx", "subsegment-idx", "word-idx"], inplace=True)
+ # 2. Get prediction matrix from alignment model & align
+ for sdx, segment in enumerate(transcript):
+ t1 = segment["start"]
+ t2 = segment["end"]
+ text = segment["text"]
- for sdx, srow in segments_arr.iterrows():
+ aligned_seg = {
+ "start": t1,
+ "end": t2,
+ "text": text,
+ "words": [],
+ }
- seg_idx = int(srow["segment-idx"])
- sub_start = int(srow["subsegment-idx-start"])
- sub_end = int(srow["subsegment-idx-end"])
+ if return_char_alignments:
+ aligned_seg["chars"] = []
- seg = transcript[seg_idx]
- text = "".join(seg["seg-text"][sub_start:sub_end])
+ # check we can align
+ if len(segment["clean_char"]) == 0:
+ print(f'Failed to align segment ("{segment["text"]}"): no characters in this segment found in model dictionary, resorting to original...')
+ aligned_segments.append(aligned_seg)
+ continue
- wseg = word_segments_arr.loc[seg_idx].loc[sub_start:sub_end-1]
- wseg["start"].fillna(srow["start"], inplace=True)
- wseg["end"].fillna(srow["end"], inplace=True)
- wseg["segment-text-start"].fillna(0, inplace=True)
- wseg["segment-text-end"].fillna(len(text)-1, inplace=True)
+ if t1 >= MAX_DURATION or t2 - t1 < 0.02:
+ print("Failed to align segment: original start time longer than audio duration, skipping...")
+ aligned_segments.append(aligned_seg)
+ continue
- cseg = char_segments_arr.loc[seg_idx].loc[sub_start:sub_end-1]
- # fixes bug for single segment in transcript
- cseg['segment-text-start'] = cseg['level_1'] if 'level_1' in cseg else 0
- cseg['segment-text-end'] = cseg['level_1'] + 1 if 'level_1' in cseg else 1
- if 'level_1' in cseg: del cseg['level_1']
- if 'level_0' in cseg: del cseg['level_0']
- cseg.reset_index(inplace=True)
+ text_clean = "".join(segment["clean_char"])
+ tokens = [model_dictionary[c] for c in text_clean]
- def get_raw_text(word_row):
- return seg["seg-text"][word_row.name][int(word_row["segment-text-start"]):int(word_row["segment-text-end"])+1]
+ f1 = int(t1 * SAMPLE_RATE)
+ f2 = int(t2 * SAMPLE_RATE)
- word_list = []
- wdx = 0
- curr_text = get_raw_text(wseg.iloc[wdx])
- if not curr_text.startswith(" "):
- curr_text = " " + curr_text
+ # TODO: Probably can get some speedup gain with batched inference here
+ waveform_segment = audio[:, f1:f2]
+
+ with torch.inference_mode():
+ if model_type == "torchaudio":
+ emissions, _ = model(waveform_segment.to(device))
+ elif model_type == "huggingface":
+ emissions = model(waveform_segment.to(device)).logits
+ else:
+ raise NotImplementedError(f"Align model of type {model_type} not supported.")
+ emissions = torch.log_softmax(emissions, dim=-1)
+
+ emission = emissions[0].cpu().detach()
+
+ blank_id = 0
+ for char, code in model_dictionary.items():
+ if char == '[pad]' or char == '':
+ blank_id = code
+
+ trellis = get_trellis(emission, tokens, blank_id)
+ path = backtrack(trellis, emission, tokens, blank_id)
+
+ if path is None:
+ print(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original...')
+ aligned_segments.append(aligned_seg)
+ continue
+
+ char_segments = merge_repeats(path, text_clean)
+
+ duration = t2 -t1
+ ratio = duration * waveform_segment.size(0) / (trellis.size(0) - 1)
+
+ # assign timestamps to aligned characters
+ char_segments_arr = []
+ word_idx = 0
+ for cdx, char in enumerate(text):
+ start, end, score = None, None, None
+ if cdx in segment["clean_cdx"]:
+ char_seg = char_segments[segment["clean_cdx"].index(cdx)]
+ start = round(char_seg.start * ratio + t1, 3)
+ end = round(char_seg.end * ratio + t1, 3)
+ score = round(char_seg.score, 3)
+
+ char_segments_arr.append(
+ {
+ "char": char,
+ "start": start,
+ "end": end,
+ "score": score,
+ "word-idx": word_idx,
+ }
+ )
+
+ # increment word_idx, nltk word tokenization would probably be more robust here, but us space for now...
+ if model_lang in LANGUAGES_WITHOUT_SPACES:
+ word_idx += 1
+ elif cdx == len(text) - 1 or text[cdx+1] == " ":
+ word_idx += 1
- if len(wseg) > 1:
- for _, wrow in wseg.iloc[1:].iterrows():
- if wrow['start'] != wseg.iloc[wdx]['start']:
- word_start = wseg.iloc[wdx]['start']
- word_end = wseg.iloc[wdx]['end']
+ char_segments_arr = pd.DataFrame(char_segments_arr)
- aligned_segments_word.append(
- {
- "text": curr_text.strip(),
- "start": word_start,
- "end": word_end
- }
- )
+ aligned_subsegments = []
+ # assign sentence_idx to each character index
+ char_segments_arr["sentence-idx"] = None
+ for sdx, (sstart, send) in enumerate(segment["sentence_spans"]):
+ curr_chars = char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send)]
+ char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send), "sentence-idx"] = sdx
+
+ sentence_text = text[sstart:send]
+ sentence_start = curr_chars["start"].min()
+ sentence_end = curr_chars["end"].max()
+ sentence_words = []
- word_list.append(
- {
- "word": curr_text.rstrip(),
- "start": word_start,
- "end": word_end,
- }
- )
+ for word_idx in curr_chars["word-idx"].unique():
+ word_chars = curr_chars.loc[curr_chars["word-idx"] == word_idx]
+ word_text = "".join(word_chars["char"].tolist()).strip()
+ if len(word_text) == 0:
+ continue
+ word_start = word_chars["start"].min()
+ word_end = word_chars["end"].max()
+ word_score = round(word_chars["score"].mean(), 3)
- curr_text = " "
- curr_text += get_raw_text(wrow) + " "
- wdx += 1
+ # -1 indicates unalignable
+ word_segment = {"word": word_text}
- aligned_segments_word.append(
- {
- "text": curr_text.strip(),
- "start": wseg.iloc[wdx]["start"],
- "end": wseg.iloc[wdx]["end"]
- }
- )
+ if not np.isnan(word_start):
+ word_segment["start"] = word_start
+ if not np.isnan(word_end):
+ word_segment["end"] = word_end
+ if not np.isnan(word_score):
+ word_segment["score"] = word_score
- word_list.append(
- {
- "word": curr_text.rstrip(),
- "start": wseg.iloc[wdx]['start'],
- "end": wseg.iloc[wdx]['end'],
- }
- )
+ sentence_words.append(word_segment)
+
+ aligned_subsegments.append({
+ "text": sentence_text,
+ "start": sentence_start,
+ "end": sentence_end,
+ "words": sentence_words,
+ })
- aligned_segments.append(
- {
- "start": srow["start"],
- "end": srow["end"],
- "text": text,
- "words": word_list,
- "word-segments": wseg,
- "char-segments": cseg
- }
- )
-
-
- return {"segments": aligned_segments, "word_segments": aligned_segments_word}
+ if return_char_alignments:
+ curr_chars = curr_chars[["char", "start", "end", "score"]]
+ curr_chars.fillna(-1, inplace=True)
+ curr_chars = curr_chars.to_dict("records")
+ curr_chars = [{key: val for key, val in char.items() if val != -1} for char in curr_chars]
+ aligned_subsegments = pd.DataFrame(aligned_subsegments)
+ aligned_subsegments["start"] = interpolate_nans(aligned_subsegments["start"], method=interpolate_method)
+ aligned_subsegments["end"] = interpolate_nans(aligned_subsegments["end"], method=interpolate_method)
+ # concatenate sentences with same timestamps
+ agg_dict = {"text": " ".join, "words": "sum"}
+ if return_char_alignments:
+ agg_dict["chars"] = "sum"
+ aligned_subsegments= aligned_subsegments.groupby(["start", "end"], as_index=False).agg(agg_dict)
+ aligned_subsegments = aligned_subsegments.to_dict('records')
+ aligned_segments += aligned_subsegments
+
+ # create word_segments list
+ word_segments = []
+ for segment in aligned_segments:
+ word_segments += segment["words"]
+
+ return {"segments": aligned_segments, "word_segments": word_segments}
"""
source: https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html
diff --git a/whisperx/asr.py b/whisperx/asr.py
index ba6220b..f2c54f6 100644
--- a/whisperx/asr.py
+++ b/whisperx/asr.py
@@ -78,7 +78,7 @@ def load_model(whisper_arch, device, compute_type="float16", asr_options=None, l
class WhisperModel(faster_whisper.WhisperModel):
'''
FasterWhisperModel provides batched inference for faster-whisper.
- Currently only works in non-timestamp mode.
+ Currently only works in non-timestamp mode and fixed prompt for all samples in batch.
'''
def generate_segment_batched(self, features: np.ndarray, tokenizer: faster_whisper.tokenizer.Tokenizer, options: faster_whisper.transcribe.TranscriptionOptions, encoder_output = None):
@@ -140,6 +140,13 @@ class WhisperModel(faster_whisper.WhisperModel):
return self.model.encode(features, to_cpu=to_cpu)
class FasterWhisperPipeline(Pipeline):
+ """
+ Huggingface Pipeline wrapper for FasterWhisperModel.
+ """
+ # TODO:
+ # - add support for timestamp mode
+ # - add support for custom inference kwargs
+
def __init__(
self,
model,
@@ -261,149 +268,3 @@ class FasterWhisperPipeline(Pipeline):
language = language_token[2:-2]
print(f"Detected language: {language} ({language_probability:.2f}) in first 30s of audio...")
return language
-
-if __name__ == "__main__":
- main_type = "simple"
- import time
-
- import jiwer
- from tqdm import tqdm
- from whisper.normalizers import EnglishTextNormalizer
-
- from benchmark.tedlium import parse_tedlium_annos
-
- if main_type == "complex":
- from faster_whisper.tokenizer import Tokenizer
- from faster_whisper.transcribe import TranscriptionOptions
- from faster_whisper.vad import (SpeechTimestampsMap,
- get_speech_timestamps)
-
- from whisperx.vad import load_vad_model, merge_chunks
-
- from .audio import SAMPLE_RATE, load_audio, log_mel_spectrogram
- faster_t_options = TranscriptionOptions(
- beam_size=5,
- best_of=5,
- patience=1,
- length_penalty=1,
- temperatures=[0.0, 0.2, 0.4, 0.6, 0.8, 1.0],
- compression_ratio_threshold=2.4,
- log_prob_threshold=-1.0,
- no_speech_threshold=0.6,
- condition_on_previous_text=False,
- initial_prompt=None,
- prefix=None,
- suppress_blank=True,
- suppress_tokens=[-1],
- without_timestamps=True,
- max_initial_timestamp=0.0,
- word_timestamps=False,
- prepend_punctuations="\"'βΒΏ([{-",
- append_punctuations="\"'.γ,οΌ!οΌ?οΌ:οΌβ)]}γ"
- )
- whisper_arch = "large-v2"
- device = "cuda"
- batch_size = 16
- model = WhisperModel(whisper_arch, device="cuda", compute_type="float16",)
- tokenizer = Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task="transcribe", language="en")
- model = FasterWhisperPipeline(model, tokenizer, faster_t_options, device=-1)
- fn = "DanielKahneman_2010.wav"
- wav_dir = f"/tmp/test/wav/"
- vad_model = load_vad_model("cuda", 0.6, 0.3)
- audio = load_audio(os.path.join(wav_dir, fn))
- vad_segments = vad_model({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": SAMPLE_RATE})
- vad_segments = merge_chunks(vad_segments, 30)
-
- def data(audio, segments):
- for seg in segments:
- f1 = int(seg['start'] * SAMPLE_RATE)
- f2 = int(seg['end'] * SAMPLE_RATE)
- # print(f2-f1)
- yield {'inputs': audio[f1:f2]}
- vad_method="pyannote"
-
- wav_dir = f"/tmp/test/wav/"
- wer_li = []
- time_li = []
- for fn in os.listdir(wav_dir):
- if fn == "RobertGupta_2010U.wav":
- continue
- base_fn = fn.split('.')[0]
- audio_fp = os.path.join(wav_dir, fn)
-
- audio = load_audio(audio_fp)
- t1 = time.time()
- if vad_method == "pyannote":
- vad_segments = vad_model({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": SAMPLE_RATE})
- vad_segments = merge_chunks(vad_segments, 30)
- elif vad_method == "silero":
- vad_segments = get_speech_timestamps(audio, threshold=0.5, max_speech_duration_s=30)
- vad_segments = [{"start": x["start"] / SAMPLE_RATE, "end": x["end"] / SAMPLE_RATE} for x in vad_segments]
- new_segs = []
- curr_start = vad_segments[0]['start']
- curr_end = vad_segments[0]['end']
- for seg in vad_segments[1:]:
- if seg['end'] - curr_start > 30:
- new_segs.append({"start": curr_start, "end": curr_end})
- curr_start = seg['start']
- curr_end = seg['end']
- else:
- curr_end = seg['end']
- new_segs.append({"start": curr_start, "end": curr_end})
- vad_segments = new_segs
- text = []
- # for idx, out in tqdm(enumerate(model(data(audio_fp, vad_segments), batch_size=batch_size)), total=len(vad_segments)):
- for idx, out in enumerate(model(data(audio, vad_segments), batch_size=batch_size)):
- text.append(out['text'])
- t2 = time.time()
- if batch_size == 1:
- text = [x[0] for x in text]
- text = " ".join(text)
-
- normalizer = EnglishTextNormalizer()
- text = normalizer(text)
- gt_corpus = normalizer(parse_tedlium_annos(base_fn, "/tmp/test/"))
-
- wer_result = jiwer.wer(gt_corpus, text)
- print("WER: %.2f \t time: %.2f \t [%s]" % (wer_result * 100, t2-t1, fn))
-
- wer_li.append(wer_result)
- time_li.append(t2-t1)
- print("# Avg Mean...")
- print("WER: %.2f" % (sum(wer_li) * 100/len(wer_li)))
- print("Time: %.2f" % (sum(time_li)/len(time_li)))
- elif main_type == "simple":
- model = load_model(
- "large-v2",
- device="cuda",
- language="en",
- )
-
- wav_dir = f"/tmp/test/wav/"
- wer_li = []
- time_li = []
- for fn in os.listdir(wav_dir):
- if fn == "RobertGupta_2010U.wav":
- continue
- # fn = "DanielKahneman_2010.wav"
- base_fn = fn.split('.')[0]
- audio_fp = os.path.join(wav_dir, fn)
-
- audio = load_audio(audio_fp)
- t1 = time.time()
- out = model.transcribe(audio_fp, batch_size=8)["segments"]
- t2 = time.time()
-
- text = " ".join([x['text'] for x in out])
- normalizer = EnglishTextNormalizer()
- text = normalizer(text)
- gt_corpus = normalizer(parse_tedlium_annos(base_fn, "/tmp/test/"))
-
- wer_result = jiwer.wer(gt_corpus, text)
- print("WER: %.2f \t time: %.2f \t [%s]" % (wer_result * 100, t2-t1, fn))
-
- wer_li.append(wer_result)
- time_li.append(t2-t1)
- print("# Avg Mean...")
- print("WER: %.2f" % (sum(wer_li) * 100/len(wer_li)))
- print("Time: %.2f" % (sum(time_li)/len(time_li)))
diff --git a/whisperx/diarize.py b/whisperx/diarize.py
index 93ff41d..320d2a4 100644
--- a/whisperx/diarize.py
+++ b/whisperx/diarize.py
@@ -11,7 +11,6 @@ class DiarizationPipeline:
use_auth_token=None,
device: Optional[Union[str, torch.device]] = "cpu",
):
- self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token)
if isinstance(device, str):
device = torch.device(device)
self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token).to(device)
@@ -21,59 +20,44 @@ class DiarizationPipeline:
diarize_df = pd.DataFrame(segments.itertracks(yield_label=True))
diarize_df['start'] = diarize_df[0].apply(lambda x: x.start)
diarize_df['end'] = diarize_df[0].apply(lambda x: x.end)
+ diarize_df.rename(columns={2: "speaker"}, inplace=True)
return diarize_df
-def assign_word_speakers(diarize_df, result_segments, fill_nearest=False):
- for seg in result_segments:
- wdf = seg['word-segments']
- if len(wdf['start'].dropna()) == 0:
- wdf['start'] = seg['start']
- wdf['end'] = seg['end']
- speakers = []
- for wdx, wrow in wdf.iterrows():
- if not np.isnan(wrow['start']):
- diarize_df['intersection'] = np.minimum(diarize_df['end'], wrow['end']) - np.maximum(diarize_df['start'], wrow['start'])
- diarize_df['union'] = np.maximum(diarize_df['end'], wrow['end']) - np.minimum(diarize_df['start'], wrow['start'])
- # remove no hit
- if not fill_nearest:
- dia_tmp = diarize_df[diarize_df['intersection'] > 0]
- else:
- dia_tmp = diarize_df
- if len(dia_tmp) == 0:
- speaker = None
- else:
- speaker = dia_tmp.sort_values("intersection", ascending=False).iloc[0][2]
- else:
- speaker = None
- speakers.append(speaker)
- seg['word-segments']['speaker'] = speakers
- speaker_count = pd.Series(speakers).value_counts()
- if len(speaker_count) == 0:
- seg["speaker"]= "UNKNOWN"
+def assign_word_speakers(diarize_df, transcript_result, fill_nearest=False):
+ transcript_segments = transcript_result["segments"]
+ for seg in transcript_segments:
+ # assign speaker to segment (if any)
+ diarize_df['intersection'] = np.minimum(diarize_df['end'], seg['end']) - np.maximum(diarize_df['start'], seg['start'])
+ diarize_df['union'] = np.maximum(diarize_df['end'], seg['end']) - np.minimum(diarize_df['start'], seg['start'])
+ # remove no hit, otherwise we look for closest (even negative intersection...)
+ if not fill_nearest:
+ dia_tmp = diarize_df[diarize_df['intersection'] > 0]
else:
- seg["speaker"] = speaker_count.index[0]
+ dia_tmp = diarize_df
+ if len(dia_tmp) > 0:
+ # sum over speakers
+ speaker = dia_tmp.groupby("speaker")["intersection"].sum().sort_values(ascending=False).index[0]
+ seg["speaker"] = speaker
+
+ # assign speaker to words
+ if 'words' in seg:
+ for word in seg['words']:
+ if 'start' in word:
+ diarize_df['intersection'] = np.minimum(diarize_df['end'], word['end']) - np.maximum(diarize_df['start'], word['start'])
+ diarize_df['union'] = np.maximum(diarize_df['end'], word['end']) - np.minimum(diarize_df['start'], word['start'])
+ # remove no hit
+ if not fill_nearest:
+ dia_tmp = diarize_df[diarize_df['intersection'] > 0]
+ else:
+ dia_tmp = diarize_df
+ if len(dia_tmp) > 0:
+ # sum over speakers
+ speaker = dia_tmp.groupby("speaker")["intersection"].sum().sort_values(ascending=False).index[0]
+ word["speaker"] = speaker
+
+ return transcript_result
- # create word level segments for .srt
- word_seg = []
- for seg in result_segments:
- wseg = pd.DataFrame(seg["word-segments"])
- for wdx, wrow in wseg.iterrows():
- if wrow["start"] is not None:
- speaker = wrow['speaker']
- if speaker is None or speaker == np.nan:
- speaker = "UNKNOWN"
- word_seg.append(
- {
- "start": wrow["start"],
- "end": wrow["end"],
- "text": f"[{speaker}]: " + seg["text"][int(wrow["segment-text-start"]):int(wrow["segment-text-end"])]
- }
- )
-
- # TODO: create segments but split words on new speaker
-
- return result_segments, word_seg
class Segment:
def __init__(self, start, end, speaker=None):
diff --git a/whisperx/transcribe.py b/whisperx/transcribe.py
index f3f63fe..b89a545 100644
--- a/whisperx/transcribe.py
+++ b/whisperx/transcribe.py
@@ -64,14 +64,11 @@ def cli():
parser.add_argument("--max_line_width", type=optional_int, default=None, help="(not possible with --no_align) the maximum number of characters in a line before breaking the line")
parser.add_argument("--max_line_count", type=optional_int, default=None, help="(requires --no_align) the maximum number of lines in a segment")
parser.add_argument("--highlight_words", type=str2bool, default=False, help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt")
+ parser.add_argument("--segment_resolution", type=str, default="sentence", choices=["sentence", "chunk"], help="(not possible with --no_align) the maximum number of characters in a line before breaking the line")
- # parser.add_argument("--word_timestamps", type=str2bool, default=False, help="(experimental) extract word-level timestamps and refine the results based on them")
- # parser.add_argument("--prepend_punctuations", type=str, default="\"\'βΒΏ([{-", help="if word_timestamps is True, merge these punctuation symbols with the next word")
- # parser.add_argument("--append_punctuations", type=str, default="\"\'.γ,οΌ!οΌ?οΌ:οΌβ)]}γ", help="if word_timestamps is True, merge these punctuation symbols with the previous word")
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face Access Token to access PyAnnote gated models")
- # parser.add_argument("--model_flush", action="store_true", help="Flush memory from each model after use, reduces GPU requirement but slower processing >1 audio file.")
# fmt: on
args = parser.parse_args().__dict__
@@ -97,7 +94,6 @@ def cli():
min_speakers: int = args.pop("min_speakers")
max_speakers: int = args.pop("max_speakers")
- # TODO: check model loading works.
if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
if args["language"] is not None:
@@ -176,6 +172,7 @@ def cli():
align_model, align_metadata = load_align_model(result["language"], device)
print(">>Performing alignment...")
result = align(result["segments"], align_model, align_metadata, input_audio, device, interpolate_method=interpolate_method)
+
results.append((result, audio_path))
# Unload align model
@@ -193,18 +190,10 @@ def cli():
diarize_model = DiarizationPipeline(use_auth_token=hf_token, device=device)
for result, input_audio_path in tmp_results:
diarize_segments = diarize_model(input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers)
- results_segments, word_segments = assign_word_speakers(diarize_segments, result["segments"])
- result = {"segments": results_segments, "word_segments": word_segments}
+ result = assign_word_speakers(diarize_segments, result)
results.append((result, input_audio_path))
-
# >> Write
for result, audio_path in results:
- # Remove pandas dataframes from result so that
- # we can serialize the result with json
- for seg in result["segments"]:
- seg.pop("word-segments", None)
- seg.pop("char-segments", None)
-
writer(result, audio_path, writer_args)
if __name__ == "__main__":
diff --git a/whisperx/utils.py b/whisperx/utils.py
index 3401a84..d042bb7 100644
--- a/whisperx/utils.py
+++ b/whisperx/utils.py
@@ -231,11 +231,16 @@ class SubtitlesWriter(ResultWriter):
line_count = 1
# the next subtitle to yield (a list of word timings with whitespace)
subtitle: list[dict] = []
- last = result["segments"][0]["words"][0]["start"]
+ times = []
+ last = result["segments"][0]["start"]
for segment in result["segments"]:
for i, original_timing in enumerate(segment["words"]):
timing = original_timing.copy()
- long_pause = not preserve_segments and timing["start"] - last > 3.0
+ long_pause = not preserve_segments
+ if "start" in timing:
+ long_pause = long_pause and timing["start"] - last > 3.0
+ else:
+ long_pause = False
has_room = line_len + len(timing["word"]) <= max_line_width
seg_break = i == 0 and len(subtitle) > 0 and preserve_segments
if line_len > 0 and has_room and not long_pause and not seg_break:
@@ -251,8 +256,9 @@ class SubtitlesWriter(ResultWriter):
or seg_break
):
# subtitle break
- yield subtitle
+ yield subtitle, times
subtitle = []
+ times = []
line_count = 1
elif line_len > 0:
# line break
@@ -260,40 +266,53 @@ class SubtitlesWriter(ResultWriter):
timing["word"] = "\n" + timing["word"]
line_len = len(timing["word"].strip())
subtitle.append(timing)
- last = timing["start"]
+ times.append((segment["start"], segment["end"], segment.get("speaker")))
+ if "start" in timing:
+ last = timing["start"]
if len(subtitle) > 0:
- yield subtitle
+ yield subtitle, times
if "words" in result["segments"][0]:
- for subtitle in iterate_subtitles():
- subtitle_start = self.format_timestamp(subtitle[0]["start"])
- subtitle_end = self.format_timestamp(subtitle[-1]["end"])
- subtitle_text = "".join([word["word"] for word in subtitle])
- if highlight_words:
+ for subtitle, _ in iterate_subtitles():
+ sstart, ssend, speaker = _[0]
+ subtitle_start = self.format_timestamp(sstart)
+ subtitle_end = self.format_timestamp(ssend)
+ subtitle_text = " ".join([word["word"] for word in subtitle])
+ has_timing = any(["start" in word for word in subtitle])
+
+ # add [$SPEAKER_ID]: to each subtitle if speaker is available
+ prefix = ""
+ if speaker is not None:
+ prefix = f"[{speaker}]: "
+
+ if highlight_words and has_timing:
last = subtitle_start
all_words = [timing["word"] for timing in subtitle]
for i, this_word in enumerate(subtitle):
- start = self.format_timestamp(this_word["start"])
- end = self.format_timestamp(this_word["end"])
- if last != start:
- yield last, start, subtitle_text
+ if "start" in this_word:
+ start = self.format_timestamp(this_word["start"])
+ end = self.format_timestamp(this_word["end"])
+ if last != start:
+ yield last, start, subtitle_text
- yield start, end, "".join(
- [
- re.sub(r"^(\s*)(.*)$", r"\1\2", word)
- if j == i
- else word
- for j, word in enumerate(all_words)
- ]
- )
- last = end
+ yield start, end, prefix + " ".join(
+ [
+ re.sub(r"^(\s*)(.*)$", r"\1\2", word)
+ if j == i
+ else word
+ for j, word in enumerate(all_words)
+ ]
+ )
+ last = end
else:
- yield subtitle_start, subtitle_end, subtitle_text
+ yield subtitle_start, subtitle_end, prefix + subtitle_text
else:
for segment in result["segments"]:
segment_start = self.format_timestamp(segment["start"])
segment_end = self.format_timestamp(segment["end"])
segment_text = segment["text"].strip().replace("-->", "->")
+ if "speaker" in segment:
+ segment_text = f"[{segment['speaker']}]: {segment_text}"
yield segment_start, segment_end, segment_text
def format_timestamp(self, seconds: float):
From 4603f010a5cdb93717e747ec092e0b7fa38877d2 Mon Sep 17 00:00:00 2001
From: Max Bain
Date: Sun, 7 May 2023 20:28:33 +0100
Subject: [PATCH 20/20] update readme, setup, add option to return
char_timestamps
---
README.md | 154 +++++++++++++++++++++++++----------------
setup.py | 2 +-
whisperx/__init__.py | 3 +-
whisperx/alignment.py | 1 +
whisperx/transcribe.py | 8 ++-
5 files changed, 103 insertions(+), 65 deletions(-)
diff --git a/README.md b/README.md
index 1f41bb9..bccbca8 100644
--- a/README.md
+++ b/README.md
@@ -13,36 +13,36 @@
+
+
+
-
- What is it β’
- Setup β’
- Usage β’
- Multilingual β’
- Contribute β’
- More examples β’
- Paper
-
-
-Whisper-Based Automatic Speech Recognition (ASR) with improved timestamp accuracy + quality via forced phoneme alignment and speech-activity batching.
-
-
+
-What is it π
-
-This repository refines the timestamps of openAI's Whisper model via forced aligment with phoneme-based ASR models (e.g. wav2vec2.0) and VAD preprocesssing, multilingual use-case.
+
-**Whisper** is an ASR model [developed by OpenAI](https://github.com/openai/whisper), trained on a large dataset of diverse audio. Whilst it does produces highly accurate transcriptions, the corresponding timestamps are at the utterance-level, not per word, and can be inaccurate by several seconds.
+This repository provides fast automatic speaker recognition (70x realtime with large-v2) with word-level timestamps and speaker diarization.
+
+- β‘οΈ Batched inference for 70x realtime transcription using whisper large-v2
+- πͺΆ [faster-whisper](https://github.com/guillaumekln/faster-whisper) backend, requires <8GB gpu memory for large-v2 with beam_size=5
+- π― Accurate word-level timestamps using wav2vec2 alignment
+- π―ββοΈ Multispeaker ASR using speaker diarization from [pyannote-audio](https://github.com/pyannote/pyannote-audio) (labels each segment/word with speaker ID)
+- π£οΈ VAD preprocessing, reduces hallucination & batching with no WER degradation
+
+
+
+**Whisper** is an ASR model [developed by OpenAI](https://github.com/openai/whisper), trained on a large dataset of diverse audio. Whilst it does produces highly accurate transcriptions, the corresponding timestamps are at the utterance-level, not per word, and can be inaccurate by several seconds. OpenAI's whisper does not natively support batching.
**Phoneme-Based ASR** A suite of models finetuned to recognise the smallest unit of speech distinguishing one word from another, e.g. the element p in "tap". A popular example model is [wav2vec2.0](https://huggingface.co/facebook/wav2vec2-large-960h-lv60-self).
@@ -50,15 +50,15 @@ This repository refines the timestamps of openAI's Whisper model via forced alig
**Voice Activity Detection (VAD)** is the detection of the presence or absence of human speech.
+**Speaker Diarization** is the process of partitioning an audio stream containing human speech into homogeneous segments according to the identity of each speaker.
+
+
Newπ¨
+- v3 transcript segment-per-sentence: using nltk sent_tokenize for better subtitlting & better diarization
- v3 released, 70x speed-up open-sourced. Using batched whisper with [faster-whisper](https://github.com/guillaumekln/faster-whisper) backend!
-- v2 released, code cleanup, imports whisper library, batched inference from paper not included (contact for licensing / batched model API). VAD filtering is now turned on by default, as in the paper.
-- Paper dropππ¨βπ«! Please see our [ArxiV preprint](https://arxiv.org/abs/2303.00747) for benchmarking and details of WhisperX. We also introduce more efficient batch inference resulting in large-v2 with *60-70x REAL TIME speed (not provided in this repo).
-- VAD filtering: Voice Activity Detection (VAD) from [Pyannote.audio](https://huggingface.co/pyannote/voice-activity-detection) is used as a preprocessing step to remove reliance on whisper timestamps and only transcribe audio segments containing speech. add `--vad_filter True` flag, increases timestamp accuracy and robustness (requires more GPU mem due to 30s inputs in wav2vec2)
-- Character level timestamps (see `*.char.ass` file output)
-- Diarization (still in beta, add `--diarize`)
-
+- v2 released, code cleanup, imports whisper library VAD filtering is now turned on by default, as in the paper.
+- Paper dropππ¨βπ«! Please see our [ArxiV preprint](https://arxiv.org/abs/2303.00747) for benchmarking and details of WhisperX. We also introduce more efficient batch inference resulting in large-v2 with *60-70x REAL TIME speed.
Setup βοΈ
Tested for PyTorch 2.0, Python 3.10 (use other versions at your own risk!)
@@ -89,15 +89,13 @@ If already installed, update package to most recent commit
If wishing to modify this package, clone and install in editable mode:
```
-$ git clone https://github.com/m-bain/whisperX.git@v3
+$ git clone https://github.com/m-bain/whisperX.git
$ cd whisperX
-$ git checkout v3
$ pip install -e .
```
You may also need to install ffmpeg, rust etc. Follow openAI instructions here https://github.com/openai/whisper#setup.
-
### Speaker Diarization
To **enable Speaker. Diarization**, include your Hugging Face access token that you can generate from [Here](https://huggingface.co/settings/tokens) after the `--hf_token` argument and accept the user agreement for the following models: [Segmentation](https://huggingface.co/pyannote/segmentation) , [Voice Activity Detection (VAD)](https://huggingface.co/pyannote/voice-activity-detection) , and [Speaker Diarization](https://huggingface.co/pyannote/speaker-diarization)
@@ -106,15 +104,11 @@ To **enable Speaker. Diarization**, include your Hugging Face access token that
### English
-Run whisper on example segment (using default params)
+Run whisper on example segment (using default params, whisper small) add `--highlight_words True` to visualise word timings in the .srt file.
whisperx examples/sample01.wav
-For increased timestamp accuracy, at the cost of higher gpu mem, use bigger models (bigger alignment model not found to be that helpful, see paper) e.g.
-
- whisperx examples/sample01.wav --model large-v2 --align_model WAV2VEC2_ASR_LARGE_LV60K_960H
-
Result using *WhisperX* with forced alignment to wav2vec2.0 large:
https://user-images.githubusercontent.com/36994049/208253969-7e35fe2a-7541-434a-ae91-8e919540555d.mp4
@@ -123,6 +117,16 @@ Compare this to original whisper out the box, where many transcriptions are out
https://user-images.githubusercontent.com/36994049/207743923-b4f0d537-29ae-4be2-b404-bb941db73652.mov
+
+For increased timestamp accuracy, at the cost of higher gpu mem, use bigger models (bigger alignment model not found to be that helpful, see paper) e.g.
+
+ whisperx examples/sample01.wav --model large-v2 --align_model WAV2VEC2_ASR_LARGE_LV60K_960H --batch_size 4
+
+
+To label the transcript with speaker ID's (set number of speakers if known e.g. `--min_speakers 2` `--max_speakers 2`):
+
+ whisperx examples/sample01.wav --model large-v2 --diarize --highlight_words True
+
### Other languages
The phoneme ASR alignment model is *language-specific*, for tested languages these models are [automatically picked from torchaudio pipelines or huggingface](https://github.com/m-bain/whisperX/blob/e909f2f766b23b2000f2d95df41f9b844ac53e49/whisperx/transcribe.py#L22).
@@ -132,7 +136,7 @@ Currently default models provided for `{en, fr, de, es, it, ja, zh, nl, uk, pt}`
#### E.g. German
- whisperx --model large --language de examples/sample_de_01.wav
+ whisperx --model large-v2 --language de examples/sample_de_01.wav
https://user-images.githubusercontent.com/36994049/208298811-e36002ba-3698-4731-97d4-0aebd07e0eb3.mov
@@ -143,79 +147,107 @@ See more examples in other languages [here](EXAMPLES.md).
```python
import whisperx
+import gc
device = "cuda"
audio_file = "audio.mp3"
+batch_size = 16 # reduce if low on GPU mem
+compute_type = "float16" # change to "int8" if low on GPU mem (may reduce accuracy)
-# transcribe with original whisper
-model = whisperx.load_model("large-v2", device)
+# 1. Transcribe with original whisper (batched)
+model = whisperx.load_model("large-v2", device, compute_type=compute_type)
audio = whisperx.load_audio(audio_file)
-result = model.transcribe(audio, batch_size=8)
-
+result = model.transcribe(audio, batch_size=batch_size)
print(result["segments"]) # before alignment
-# load alignment model and metadata
+# delete model if low on GPU resources
+# import gc; gc.collect(); torch.cuda.empty_cache(); del model
+
+# 2. Align whisper output
model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
+result = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False)
-# align whisper output
-result_aligned = whisperx.align(result["segments"], model_a, metadata, audio, device)
+print(result["segments"]) # after alignment
-print(result_aligned["segments"]) # after alignment
-print(result_aligned["word_segments"]) # after alignment
+# delete model if low on GPU resources
+# import gc; gc.collect(); torch.cuda.empty_cache(); del model_a
+
+# 3. Assign speaker labels
+diarize_model = whisperx.DiarizationPipeline(use_auth_token=YOUR_HF_TOKEN, device=device)
+
+# add min/max number of speakers if known
+diarize_segments = diarize_model(input_audio_path)
+# diarize_model(input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers)
+
+result = assign_word_speakers(diarize_segments, result)
+print(diarize_segments)
+print(result["segments"]) # segments are now assigned speaker IDs
```
-Whisper Modifications
+Technical Details π·ββοΈ
-In addition to forced alignment, the following two modifications have been made to the whisper transcription method:
+For specific details on the batching and alignment, the effect of VAD, as well as the chosen alignment model, see the preprint [paper](https://www.robots.ox.ac.uk/~vgg/publications/2023/Bain23/bain23.pdf).
-1. `--condition_on_prev_text` is set to `False` by default (reduces hallucination)
+To reduce GPU memory requirements, try any of the following (2. & 3. can affect quality):
+1. reduce batch size, e.g. `--batch_size 4`
+2. use a smaller ASR model `--model base`
+3. Use lighter compute type `--compute_type int8`
+
+Transcription differences from openai's whisper:
+1. Transcription without timestamps. To enable single pass batching, whisper inference is performed `--without_timestamps True`, this ensures 1 forward pass per sample in the batch. However, this can cause discrepancies the default whisper output.
+2. VAD-based segment transcription, unlike the buffered transcription of openai's. In Wthe WhisperX paper we show this reduces WER, and enables accurate batched inference
+3. `--condition_on_prev_text` is set to `False` by default (reduces hallucination)
Limitations β οΈ
-- Whisper normalises spoken numbers e.g. "fifty seven" to arabic numerals "57". Need to perform this normalization after alignment, so the phonemes can be aligned. Currently just ignores numbers.
-- If setting `--vad_filter False`, then whisperx assumes the initial whisper timestamps are accurate to some degree (within margin of 2 seconds, adjust if needed -- bigger margins more prone to alignment errors)
+- Transcript words which do not contain characters in the alignment models dictionary e.g. "2014." or "Β£13.60" cannot be aligned and therefore are not given a timing.
- Overlapping speech is not handled particularly well by whisper nor whisperx
-- Diariazation is far from perfect.
+- Diarization is far from perfect (working on this with custom model v4 -- see contact me).
+- Language specific wav2vec2 model is needed
Contribute π§βπ«
-If you are multilingual, a major way you can contribute to this project is to find phoneme models on huggingface (or train your own) and test them on speech for the target language. If the results look good send a merge request and some examples showing its success.
+If you are multilingual, a major way you can contribute to this project is to find phoneme models on huggingface (or train your own) and test them on speech for the target language. If the results look good send a pull request and some examples showing its success.
-The next major upgrade we are working on is whisper with speaker diarization, so if you have any experience on this please share.
+Bug finding and pull requests are also highly appreciated to keep this project going, since it's already diverging from the original research scope.
-Coming Soon π
+TODO π
* [x] Multilingual init
-* [x] Subtitle .ass output
-
* [x] Automatic align model selection based on language detection
* [x] Python usage
-* [x] Character level timestamps
-
* [x] Incorporating speaker diarization
* [x] Model flush, for low gpu mem resources
* [x] Faster-whisper backend
+* [x] Add max-line etc. see (openai's whisper utils.py)
+
+* [x] Sentence-level segments (nltk toolbox)
+
+* [x] Improve alignment logic
+
+* [ ] update examples with diarization and word highlighting
+
+* [ ] Subtitle .ass output <- bring this back (removed in v3)
+
* [ ] Add benchmarking code (TEDLIUM for spd/WER & word segmentation)
* [ ] Allow silero-vad as alternative VAD option
-* [ ] Add max-line etc. see (openai's whisper utils.py)
-
* [ ] Improve diarization (word level). *Harder than first thought...*
-Contact maxhbain@gmail.com for queries and licensing / early access to a model API with batched inference (transcribe 1hr audio in under 1min).
+Contact maxhbain@gmail.com for queries. WhisperX v4 development is underway with with siginificantly improved diarization. To support v4 and get early access, get in touch.
@@ -224,14 +256,16 @@ Contact maxhbain@gmail.com for queries and licensing / early access to a model A
This work, and my PhD, is supported by the [VGG (Visual Geometry Group)](https://www.robots.ox.ac.uk/~vgg/) and the University of Oxford.
-
Of course, this is builds on [openAI's whisper](https://github.com/openai/whisper).
And borrows important alignment code from [PyTorch tutorial on forced alignment](https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html)
-Valuable VAD & Diarization Models from (pyannote.audio)[https://github.com/pyannote/pyannote-audio]
+Valuable VAD & Diarization Models from [pyannote audio][https://github.com/pyannote/pyannote-audio]
-Great backend from (faster-whisper)[https://github.com/guillaumekln/faster-whisper] and (CTranslate2)[https://github.com/OpenNMT/CTranslate2]
+Great backend from [faster-whisper](https://github.com/guillaumekln/faster-whisper) and [CTranslate2](https://github.com/OpenNMT/CTranslate2)
+Those who have [supported this work financially](https://www.buymeacoffee.com/maxhbain) π
+
+Finally, thanks to the OS [contributors](https://github.com/m-bain/whisperX/graphs/contributors) of this project, keeping it going and identifying bugs.
Citation
If you use this in your research, please cite the paper:
diff --git a/setup.py b/setup.py
index 66f22cd..eea26ad 100644
--- a/setup.py
+++ b/setup.py
@@ -6,7 +6,7 @@ from setuptools import setup, find_packages
setup(
name="whisperx",
py_modules=["whisperx"],
- version="3.0.2",
+ version="3.1.0",
description="Time-Accurate Automatic Speech Recognition using Whisper.",
readme="README.md",
python_requires=">=3.8",
diff --git a/whisperx/__init__.py b/whisperx/__init__.py
index d0294b9..20abaae 100644
--- a/whisperx/__init__.py
+++ b/whisperx/__init__.py
@@ -1,3 +1,4 @@
from .transcribe import load_model
from .alignment import load_align_model, align
-from .audio import load_audio
\ No newline at end of file
+from .audio import load_audio
+from .diarize import assign_word_speakers, DiarizationPipeline
\ No newline at end of file
diff --git a/whisperx/alignment.py b/whisperx/alignment.py
index 2812c10..b873475 100644
--- a/whisperx/alignment.py
+++ b/whisperx/alignment.py
@@ -287,6 +287,7 @@ def align(
curr_chars.fillna(-1, inplace=True)
curr_chars = curr_chars.to_dict("records")
curr_chars = [{key: val for key, val in char.items() if val != -1} for char in curr_chars]
+ aligned_subsegments[-1]["chars"] = curr_chars
aligned_subsegments = pd.DataFrame(aligned_subsegments)
aligned_subsegments["start"] = interpolate_nans(aligned_subsegments["start"], method=interpolate_method)
diff --git a/whisperx/transcribe.py b/whisperx/transcribe.py
index b89a545..d09c5f6 100644
--- a/whisperx/transcribe.py
+++ b/whisperx/transcribe.py
@@ -35,6 +35,7 @@ def cli():
parser.add_argument("--align_model", default=None, help="Name of phoneme-level ASR model to do alignment")
parser.add_argument("--interpolate_method", default="nearest", choices=["nearest", "linear", "ignore"], help="For word .srt, method to assign timestamps to non-aligned words, or merge them into neighbouring.")
parser.add_argument("--no_align", action='store_true', help="Do not perform phoneme alignment")
+ parser.add_argument("--return_char_alignments", action='store_true', help="Return character-level alignments in the output json file")
# vad params
parser.add_argument("--vad_onset", type=float, default=0.500, help="Onset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected")
@@ -42,8 +43,8 @@ def cli():
# diarization params
parser.add_argument("--diarize", action="store_true", help="Apply diarization to assign speaker labels to each segment/word")
- parser.add_argument("--min_speakers", default=None, type=int)
- parser.add_argument("--max_speakers", default=None, type=int)
+ parser.add_argument("--min_speakers", default=None, type=int, help="Minimum number of speakers to in audio file")
+ parser.add_argument("--max_speakers", default=None, type=int, help="Maximum number of speakers to in audio file")
parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
@@ -85,6 +86,7 @@ def cli():
align_model: str = args.pop("align_model")
interpolate_method: str = args.pop("interpolate_method")
no_align: bool = args.pop("no_align")
+ return_char_alignments: bool = args.pop("return_char_alignments")
hf_token: str = args.pop("hf_token")
vad_onset: float = args.pop("vad_onset")
@@ -171,7 +173,7 @@ def cli():
print(f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language...")
align_model, align_metadata = load_align_model(result["language"], device)
print(">>Performing alignment...")
- result = align(result["segments"], align_model, align_metadata, input_audio, device, interpolate_method=interpolate_method)
+ result = align(result["segments"], align_model, align_metadata, input_audio, device, interpolate_method=interpolate_method, return_char_alignments=return_char_alignments)
results.append((result, audio_path))