This commit is contained in:
Max Bain
2023-04-24 21:08:43 +01:00
parent da458863d7
commit 558d980535
11 changed files with 1034 additions and 846 deletions

View File

@ -32,7 +32,7 @@
<img width="1216" align="center" alt="whisperx-arch" src="figures/pipeline.png"> <img width="1216" align="center" alt="whisperx-arch" src="figures/pipeline.png">
<p align="left">Whisper-Based Automatic Speech Recognition (ASR) with improved timestamp accuracy using forced alignment. <p align="left">Whisper-Based Automatic Speech Recognition (ASR) with improved timestamp accuracy + quality via forced phoneme alignment and speech-activity batching.
</p> </p>
@ -52,6 +52,7 @@ This repository refines the timestamps of openAI's Whisper model via forced alig
<h2 align="left", id="highlights">New🚨</h2> <h2 align="left", id="highlights">New🚨</h2>
- v3 released, 70x speed-up open-sourced. Using batched whisper with [faster-whisper](https://github.com/guillaumekln/faster-whisper) backend!
- v2 released, code cleanup, imports whisper library, batched inference from paper not included (contact for licensing / batched model API). VAD filtering is now turned on by default, as in the paper. - v2 released, code cleanup, imports whisper library, batched inference from paper not included (contact for licensing / batched model API). VAD filtering is now turned on by default, as in the paper.
- Paper drop🎓👨🏫! Please see our [ArxiV preprint](https://arxiv.org/abs/2303.00747) for benchmarking and details of WhisperX. We also introduce more efficient batch inference resulting in large-v2 with *60-70x REAL TIME speed (not provided in this repo). - Paper drop🎓👨🏫! Please see our [ArxiV preprint](https://arxiv.org/abs/2303.00747) for benchmarking and details of WhisperX. We also introduce more efficient batch inference resulting in large-v2 with *60-70x REAL TIME speed (not provided in this repo).
- VAD filtering: Voice Activity Detection (VAD) from [Pyannote.audio](https://huggingface.co/pyannote/voice-activity-detection) is used as a preprocessing step to remove reliance on whisper timestamps and only transcribe audio segments containing speech. add `--vad_filter True` flag, increases timestamp accuracy and robustness (requires more GPU mem due to 30s inputs in wav2vec2) - VAD filtering: Voice Activity Detection (VAD) from [Pyannote.audio](https://huggingface.co/pyannote/voice-activity-detection) is used as a preprocessing step to remove reliance on whisper timestamps and only transcribe audio segments containing speech. add `--vad_filter True` flag, increases timestamp accuracy and robustness (requires more GPU mem due to 30s inputs in wav2vec2)
@ -60,7 +61,25 @@ This repository refines the timestamps of openAI's Whisper model via forced alig
<h2 align="left" id="setup">Setup ⚙️</h2> <h2 align="left" id="setup">Setup ⚙️</h2>
Install this package using Tested for PyTorch 0.11, Python 3.8 (use other versions at your own risk!)
GPU execution requires the NVIDIA libraries cuBLAS 11.x and cuDNN 8.x to be installed on the system. Please refer to the [CTranslate2 documentation](https://opennmt.net/CTranslate2/installation.html).
### 1. Create Python3.8 environment
`conda create --name whisperx python=3.8`
`conda activate whisperx`
### 2. Install PyTorch 0.11.0, e.g. for Linux and Windows:
`conda install pytorch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 cudatoolkit=11.3 -c pytorch`
See other methods [here.](https://pytorch.org/get-started/previous-versions/#wheel-4)
### 3. Install this repo
`pip install git+https://github.com/m-bain/whisperx.git` `pip install git+https://github.com/m-bain/whisperx.git`
@ -78,13 +97,6 @@ $ pip install -e .
You may also need to install ffmpeg, rust etc. Follow openAI instructions here https://github.com/openai/whisper#setup. You may also need to install ffmpeg, rust etc. Follow openAI instructions here https://github.com/openai/whisper#setup.
### Setup not working???
Safest to use install pytorch as follows (for gpu)
`conda install pytorch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 -c pytorch
`
### Speaker Diarization ### Speaker Diarization
To **enable Speaker. Diarization**, include your Hugging Face access token that you can generate from [Here](https://huggingface.co/settings/tokens) after the `--hf_token` argument and accept the user agreement for the following models: [Segmentation](https://huggingface.co/pyannote/segmentation) , [Voice Activity Detection (VAD)](https://huggingface.co/pyannote/voice-activity-detection) , and [Speaker Diarization](https://huggingface.co/pyannote/speaker-diarization) To **enable Speaker. Diarization**, include your Hugging Face access token that you can generate from [Here](https://huggingface.co/settings/tokens) after the `--hf_token` argument and accept the user agreement for the following models: [Segmentation](https://huggingface.co/pyannote/segmentation) , [Voice Activity Detection (VAD)](https://huggingface.co/pyannote/voice-activity-detection) , and [Speaker Diarization](https://huggingface.co/pyannote/speaker-diarization)
@ -130,14 +142,15 @@ See more examples in other languages [here](EXAMPLES.md).
```python ```python
import whisperx import whisperx
import whisper
device = "cuda" device = "cuda"
audio_file = "audio.mp3" audio_file = "audio.mp3"
# transcribe with original whisper # transcribe with original whisper
model = whisper.load_model("large", device) model = whisperx.load_model("large-v2", device)
result = model.transcribe(audio_file)
audio = whisperx.load_audio(audio_file)
result = model.transcribe(audio, batch_size=8)
print(result["segments"]) # before alignment print(result["segments"]) # before alignment
@ -145,7 +158,7 @@ print(result["segments"]) # before alignment
model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device) model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
# align whisper output # align whisper output
result_aligned = whisperx.align(result["segments"], model_a, metadata, audio_file, device) result_aligned = whisperx.align(result["segments"], model_a, metadata, audio, device)
print(result_aligned["segments"]) # after alignment print(result_aligned["segments"]) # after alignment
print(result_aligned["word_segments"]) # after alignment print(result_aligned["word_segments"]) # after alignment
@ -186,9 +199,15 @@ The next major upgrade we are working on is whisper with speaker diarization, so
* [x] Incorporating speaker diarization * [x] Incorporating speaker diarization
* [ ] Automatic .wav conversion to make VAD compatible * [x] Model flush, for low gpu mem resources
* [ ] Model flush, for low gpu mem resources * [x] Faster-whisper backend
* [ ] Add benchmarking code (TEDLIUM for spd/WER & word segmentation)
* [ ] Allow silero-vad as alternative VAD option
* [ ] Add max-line etc. see (openai's whisper utils.py)
* [ ] Improve diarization (word level). *Harder than first thought...* * [ ] Improve diarization (word level). *Harder than first thought...*
@ -205,10 +224,13 @@ Contact maxhbain@gmail.com for queries and licensing / early access to a model A
This work, and my PhD, is supported by the [VGG (Visual Geometry Group)](https://www.robots.ox.ac.uk/~vgg/) and the University of Oxford. This work, and my PhD, is supported by the [VGG (Visual Geometry Group)](https://www.robots.ox.ac.uk/~vgg/) and the University of Oxford.
Of course, this is builds on [openAI's whisper](https://github.com/openai/whisper). Of course, this is builds on [openAI's whisper](https://github.com/openai/whisper).
And borrows important alignment code from [PyTorch tutorial on forced alignment](https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html) And borrows important alignment code from [PyTorch tutorial on forced alignment](https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html)
Valuable VAD & Diarization Models from (pyannote.audio)[https://github.com/pyannote/pyannote-audio]
Great backend from (faster-whisper)[https://github.com/guillaumekln/faster-whisper] and (CTranslate2)[https://github.com/OpenNMT/CTranslate2]
<h2 align="left" id="cite">Citation</h2> <h2 align="left" id="cite">Citation</h2>
If you use this in your research, please cite the paper: If you use this in your research, please cite the paper:
@ -220,37 +242,4 @@ If you use this in your research, please cite the paper:
journal={arXiv preprint, arXiv:2303.00747}, journal={arXiv preprint, arXiv:2303.00747},
year={2023} year={2023}
} }
``` ```
as well the following works, used in each stage of the pipeline:
```bibtex
@article{radford2022robust,
title={Robust speech recognition via large-scale weak supervision},
author={Radford, Alec and Kim, Jong Wook and Xu, Tao and Brockman, Greg and McLeavey, Christine and Sutskever, Ilya},
journal={arXiv preprint arXiv:2212.04356},
year={2022}
}
```
```bibtex
@article{baevski2020wav2vec,
title={wav2vec 2.0: A framework for self-supervised learning of speech representations},
author={Baevski, Alexei and Zhou, Yuhao and Mohamed, Abdelrahman and Auli, Michael},
journal={Advances in neural information processing systems},
volume={33},
pages={12449--12460},
year={2020}
}
```
```bibtex
@inproceedings{bredin2020pyannote,
title={Pyannote. audio: neural building blocks for speaker diarization},
author={Bredin, Herv{\'e} and Yin, Ruiqing and Coria, Juan Manuel and Gelly, Gregory and Korshunov, Pavel and Lavechin, Marvin and Fustes, Diego and Titeux, Hadrien and Bouaziz, Wassim and Gill, Marie-Philippe},
booktitle={ICASSP 2020-2020 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
pages={7124--7128},
year={2020},
organization={IEEE}
}
```

View File

@ -1,10 +1,8 @@
numpy torch==1.11.0
pandas torchaudio==0.11.0
torch >=1.9
torchaudio >=0.10,<1.0
tqdm
more-itertools
transformers>=4.19.0
ffmpeg-python==0.2.0
pyannote.audio pyannote.audio
openai-whisper==20230314 faster-whisper
transformers
ffmpeg-python==0.2.0
pandas
setuptools==65.6.3

View File

@ -6,7 +6,7 @@ from setuptools import setup, find_packages
setup( setup(
name="whisperx", name="whisperx",
py_modules=["whisperx"], py_modules=["whisperx"],
version="2.0", version="3.0.0",
description="Time-Accurate Automatic Speech Recognition using Whisper.", description="Time-Accurate Automatic Speech Recognition using Whisper.",
readme="README.md", readme="README.md",
python_requires=">=3.8", python_requires=">=3.8",

View File

@ -1,3 +1,3 @@
from .transcribe import transcribe, transcribe_with_vad from .transcribe import load_model
from .alignment import load_align_model, align from .alignment import load_align_model, align
from .vad import load_vad_model from .audio import load_audio

View File

@ -2,16 +2,17 @@
Forced Alignment with Whisper Forced Alignment with Whisper
C. Max Bain C. Max Bain
""" """
from dataclasses import dataclass
from typing import Iterator, Union
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from typing import List, Union, Iterator, TYPE_CHECKING
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
import torchaudio
import torch import torch
from dataclasses import dataclass import torchaudio
from whisper.audio import SAMPLE_RATE, load_audio from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from .utils import interpolate_nans
from .audio import SAMPLE_RATE, load_audio
from .utils import interpolate_nans
LANGUAGES_WITHOUT_SPACES = ["ja", "zh"] LANGUAGES_WITHOUT_SPACES = ["ja", "zh"]
@ -391,34 +392,42 @@ def align(
if 'level_1' in cseg: del cseg['level_1'] if 'level_1' in cseg: del cseg['level_1']
if 'level_0' in cseg: del cseg['level_0'] if 'level_0' in cseg: del cseg['level_0']
cseg.reset_index(inplace=True) cseg.reset_index(inplace=True)
aligned_segments.append(
{
"start": srow["start"],
"end": srow["end"],
"text": text,
"word-segments": wseg,
"char-segments": cseg
}
)
def get_raw_text(word_row): def get_raw_text(word_row):
return seg["seg-text"][word_row.name][int(word_row["segment-text-start"]):int(word_row["segment-text-end"])+1] return seg["seg-text"][word_row.name][int(word_row["segment-text-start"]):int(word_row["segment-text-end"])+1]
word_list = []
wdx = 0 wdx = 0
curr_text = get_raw_text(wseg.iloc[wdx]) curr_text = get_raw_text(wseg.iloc[wdx])
if not curr_text.startswith(" "):
curr_text = " " + curr_text
if len(wseg) > 1: if len(wseg) > 1:
for _, wrow in wseg.iloc[1:].iterrows(): for _, wrow in wseg.iloc[1:].iterrows():
if wrow['start'] != wseg.iloc[wdx]['start']: if wrow['start'] != wseg.iloc[wdx]['start']:
word_start = wseg.iloc[wdx]['start']
word_end = wseg.iloc[wdx]['end']
aligned_segments_word.append( aligned_segments_word.append(
{ {
"text": curr_text.strip(), "text": curr_text.strip(),
"start": wseg.iloc[wdx]["start"], "start": word_start,
"end": wseg.iloc[wdx]["end"], "end": word_end
} }
) )
curr_text = ""
curr_text += " " + get_raw_text(wrow) word_list.append(
{
"word": curr_text.rstrip(),
"start": word_start,
"end": word_end,
}
)
curr_text = " "
curr_text += get_raw_text(wrow) + " "
wdx += 1 wdx += 1
aligned_segments_word.append( aligned_segments_word.append(
{ {
"text": curr_text.strip(), "text": curr_text.strip(),
@ -427,6 +436,25 @@ def align(
} }
) )
word_list.append(
{
"word": curr_text.rstrip(),
"start": word_start,
"end": word_end,
}
)
aligned_segments.append(
{
"start": srow["start"],
"end": srow["end"],
"text": text,
"words": word_list,
# "word-segments": wseg,
# "char-segments": cseg
}
)
return {"segments": aligned_segments, "word_segments": aligned_segments_word} return {"segments": aligned_segments, "word_segments": aligned_segments_word}

View File

@ -1,433 +1,406 @@
import os
import warnings import warnings
from typing import TYPE_CHECKING, Optional, Tuple, Union from typing import List, Union
import ctranslate2
import faster_whisper
import numpy as np import numpy as np
import torch import torch
import tqdm from transformers import Pipeline
import ffmpeg from transformers.pipelines.pt_utils import PipelineIterator
from whisper.audio import (
FRAMES_PER_SECOND,
HOP_LENGTH,
N_FRAMES,
N_SAMPLES,
SAMPLE_RATE,
CHUNK_LENGTH,
log_mel_spectrogram,
pad_or_trim,
load_audio
)
from whisper.decoding import DecodingOptions, DecodingResult
from whisper.timing import add_word_timestamps
from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
from whisper.utils import (
exact_div,
format_timestamp,
make_safe,
)
if TYPE_CHECKING: from .audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram
from whisper.model import Whisper from .vad import load_vad_model, merge_chunks
from .vad import merge_chunks
def transcribe( def load_model(whisper_arch, device, compute_type="float16", asr_options=None, language=None,
model: "Whisper", vad_options=None, model=None):
audio: Union[str, np.ndarray, torch.Tensor] = None, '''Load a Whisper model for inference.
mel: np.ndarray = None, Args:
verbose: Optional[bool] = None, whisper_arch: str - The name of the Whisper model to load.
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0), device: str - The device to load the model on.
compression_ratio_threshold: Optional[float] = 2.4, compute_type: str - The compute type to use for the model.
logprob_threshold: Optional[float] = -1.0, options: dict - A dictionary of options to use for the model.
no_speech_threshold: Optional[float] = 0.6, language: str - The language of the model. (use English for now)
condition_on_previous_text: bool = True, Returns:
initial_prompt: Optional[str] = None, A Whisper pipeline.
word_timestamps: bool = False, '''
prepend_punctuations: str = "\"'“¿([{-",
append_punctuations: str = "\"'.。,!?::”)]}、",
**decode_options,
):
"""
Transcribe an audio file using Whisper.
We redefine the Whisper transcribe function to allow mel input (for sequential slicing of audio)
Parameters if whisper_arch.endswith(".en"):
---------- language = "en"
model: Whisper
The Whisper model instance
audio: Union[str, np.ndarray, torch.Tensor] model = WhisperModel(whisper_arch, device=device, compute_type=compute_type)
The path to the audio file to open, or the audio waveform if language is not None:
tokenizer = faster_whisper.tokenizer.Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task="transcribe", language=language)
mel: np.ndarray
Mel spectrogram of audio segment.
verbose: bool
Whether to display the text being decoded to the console. If True, displays all the details,
If False, displays minimal details. If None, does not display anything
temperature: Union[float, Tuple[float, ...]]
Temperature for sampling. It can be a tuple of temperatures, which will be successively used
upon failures according to either `compression_ratio_threshold` or `logprob_threshold`.
compression_ratio_threshold: float
If the gzip compression ratio is above this value, treat as failed
logprob_threshold: float
If the average log probability over sampled tokens is below this value, treat as failed
no_speech_threshold: float
If the no_speech probability is higher than this value AND the average log probability
over sampled tokens is below `logprob_threshold`, consider the segment as silent
condition_on_previous_text: bool
if True, the previous output of the model is provided as a prompt for the next window;
disabling may make the text inconsistent across windows, but the model becomes less prone to
getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
word_timestamps: bool
Extract word-level timestamps using the cross-attention pattern and dynamic time warping,
and include the timestamps for each word in each segment.
prepend_punctuations: str
If word_timestamps is True, merge these punctuation symbols with the next word
append_punctuations: str
If word_timestamps is True, merge these punctuation symbols with the previous word
initial_prompt: Optional[str]
Optional text to provide as a prompt for the first window. This can be used to provide, or
"prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
to make it more likely to predict those word correctly.
decode_options: dict
Keyword arguments to construct `DecodingOptions` instances
Returns
-------
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
the spoken language ("language"), which is detected when `decode_options["language"]` is None.
"""
dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32
if model.device == torch.device("cpu"):
if torch.cuda.is_available():
warnings.warn("Performing inference on CPU when CUDA is available")
if dtype == torch.float16:
warnings.warn("FP16 is not supported on CPU; using FP32 instead")
dtype = torch.float32
if dtype == torch.float32:
decode_options["fp16"] = False
# Pad 30-seconds of silence to the input audio, for slicing
if mel is None:
if audio is None:
raise ValueError("Transcribe needs either audio or mel as input, currently both are none.")
mel = log_mel_spectrogram(audio, padding=N_SAMPLES)
content_frames = mel.shape[-1] - N_FRAMES
if decode_options.get("language", None) is None:
if not model.is_multilingual:
decode_options["language"] = "en"
else:
if verbose:
print(
"Detecting language using up to the first 30 seconds. Use `--language` to specify the language"
)
mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
_, probs = model.detect_language(mel_segment)
decode_options["language"] = max(probs, key=probs.get)
if verbose is not None:
print(
f"Detected language: {LANGUAGES[decode_options['language']].title()}"
)
language: str = decode_options["language"]
task: str = decode_options.get("task", "transcribe")
tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task)
if word_timestamps and task == "translate":
warnings.warn("Word-level timestamps on translations may not be reliable.")
def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
temperatures = (
[temperature] if isinstance(temperature, (int, float)) else temperature
)
decode_result = None
for t in temperatures:
kwargs = {**decode_options}
if t > 0:
# disable beam_size and patience when t > 0
kwargs.pop("beam_size", None)
kwargs.pop("patience", None)
else:
# disable best_of when t == 0
kwargs.pop("best_of", None)
options = DecodingOptions(**kwargs, temperature=t)
decode_result = model.decode(segment, options)
needs_fallback = False
if (
compression_ratio_threshold is not None
and decode_result.compression_ratio > compression_ratio_threshold
):
needs_fallback = True # too repetitive
if (
logprob_threshold is not None
and decode_result.avg_logprob < logprob_threshold
):
needs_fallback = True # average log probability is too low
if not needs_fallback:
break
return decode_result
seek = 0
input_stride = exact_div(
N_FRAMES, model.dims.n_audio_ctx
) # mel frames per output token: 2
time_precision = (
input_stride * HOP_LENGTH / SAMPLE_RATE
) # time per output token: 0.02 (seconds)
all_tokens = []
all_segments = []
prompt_reset_since = 0
if initial_prompt is not None:
initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip())
all_tokens.extend(initial_prompt_tokens)
else: else:
initial_prompt_tokens = [] print("No language specified, language will be first be detected for each audio file (increases inference time).")
tokenizer = None
def new_segment( default_asr_options = {
*, start: float, end: float, tokens: torch.Tensor, result: DecodingResult "beam_size": 5,
"best_of": 5,
"patience": 1,
"length_penalty": 1,
"temperatures": [0.0, 0.2, 0.4, 0.6, 0.8, 1.0],
"compression_ratio_threshold": 2.4,
"log_prob_threshold": -1.0,
"no_speech_threshold": 0.6,
"condition_on_previous_text": False,
"initial_prompt": None,
"prefix": None,
"suppress_blank": True,
"suppress_tokens": [-1],
"without_timestamps": True,
"max_initial_timestamp": 0.0,
"word_timestamps": False,
"prepend_punctuations": "\"'“¿([{-",
"append_punctuations": "\"'.。,!?::”)]}、"
}
if asr_options is not None:
default_asr_options.update(asr_options)
default_asr_options = faster_whisper.transcribe.TranscriptionOptions(**default_asr_options)
default_vad_options = {
"vad_onset": 0.500,
"vad_offset": 0.363
}
if vad_options is not None:
default_vad_options.update(vad_options)
vad_model = load_vad_model(torch.device(device), use_auth_token=None, **default_vad_options)
return FasterWhisperPipeline(model, vad_model, default_asr_options, tokenizer)
class WhisperModel(faster_whisper.WhisperModel):
'''
FasterWhisperModel provides batched inference for faster-whisper.
Currently only works in non-timestamp mode.
'''
def generate_segment_batched(self, features: np.ndarray, tokenizer: faster_whisper.tokenizer.Tokenizer, options: faster_whisper.transcribe.TranscriptionOptions, encoder_output = None):
batch_size = features.shape[0]
all_tokens = []
prompt_reset_since = 0
if options.initial_prompt is not None:
initial_prompt = " " + options.initial_prompt.strip()
initial_prompt_tokens = tokenizer.encode(initial_prompt)
all_tokens.extend(initial_prompt_tokens)
previous_tokens = all_tokens[prompt_reset_since:]
prompt = self.get_prompt(
tokenizer,
previous_tokens,
without_timestamps=options.without_timestamps,
prefix=options.prefix,
)
encoder_output = self.encode(features)
max_initial_timestamp_index = int(
round(options.max_initial_timestamp / self.time_precision)
)
result = self.model.generate(
encoder_output,
[prompt] * batch_size,
# length_penalty=options.length_penalty,
# max_length=self.max_length,
# return_scores=True,
# return_no_speech_prob=True,
# suppress_blank=options.suppress_blank,
# suppress_tokens=options.suppress_tokens,
# max_initial_timestamp_index=max_initial_timestamp_index,
)
tokens_batch = [x.sequences_ids[0] for x in result]
def decode_batch(tokens: List[List[int]]) -> str:
res = []
for tk in tokens:
res.append([token for token in tk if token < tokenizer.eot])
# text_tokens = [token for token in tokens if token < self.eot]
return tokenizer.tokenizer.decode_batch(res)
text = decode_batch(tokens_batch)
return text
def encode(self, features: np.ndarray) -> ctranslate2.StorageView:
# When the model is running on multiple GPUs, the encoder output should be moved
# to the CPU since we don't know which GPU will handle the next job.
to_cpu = self.model.device == "cuda" and len(self.model.device_index) > 1
# unsqueeze if batch size = 1
if len(features.shape) == 2:
features = np.expand_dims(features, 0)
features = faster_whisper.transcribe.get_ctranslate2_storage(features)
return self.model.encode(features, to_cpu=to_cpu)
class FasterWhisperPipeline(Pipeline):
def __init__(
self,
model,
vad,
options,
tokenizer=None,
device: Union[int, str, "torch.device"] = -1,
framework = "pt",
**kwargs
): ):
tokens = tokens.tolist() self.model = model
text_tokens = [token for token in tokens if token < tokenizer.eot] self.tokenizer = tokenizer
return { self.options = options
"seek": seek, self._batch_size = kwargs.pop("batch_size", None)
"start": start, self._num_workers = 1
"end": end, self._preprocess_params, self._forward_params, self._postprocess_params = self._sanitize_parameters(**kwargs)
"text": tokenizer.decode(text_tokens), self.call_count = 0
"tokens": tokens, self.framework = framework
"temperature": result.temperature, if self.framework == "pt":
"avg_logprob": result.avg_logprob, if isinstance(device, torch.device):
"compression_ratio": result.compression_ratio, self.device = device
"no_speech_prob": result.no_speech_prob, elif isinstance(device, str):
} self.device = torch.device(device)
elif device < 0:
self.device = torch.device("cpu")
# show the progress bar when verbose is False (if True, transcribed text will be printed)
with tqdm.tqdm(
total=content_frames, unit="frames", disable=verbose is not False
) as pbar:
while seek < content_frames:
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
mel_segment = mel[:, seek : seek + N_FRAMES]
segment_size = min(N_FRAMES, content_frames - seek)
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)
decode_options["prompt"] = all_tokens[prompt_reset_since:]
result: DecodingResult = decode_with_fallback(mel_segment)
tokens = torch.tensor(result.tokens)
if no_speech_threshold is not None:
# no voice activity check
should_skip = result.no_speech_prob > no_speech_threshold
if (
logprob_threshold is not None
and result.avg_logprob > logprob_threshold
):
# don't skip if the logprob is high enough, despite the no_speech_prob
should_skip = False
if should_skip:
seek += segment_size # fast-forward to the next segment boundary
continue
previous_seek = seek
current_segments = []
timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
consecutive.add_(1)
if len(consecutive) > 0:
# if the output contains two consecutive timestamp tokens
slices = consecutive.tolist()
if single_timestamp_ending:
slices.append(len(tokens))
last_slice = 0
for current_slice in slices:
sliced_tokens = tokens[last_slice:current_slice]
start_timestamp_pos = (
sliced_tokens[0].item() - tokenizer.timestamp_begin
)
end_timestamp_pos = (
sliced_tokens[-1].item() - tokenizer.timestamp_begin
)
# clamp end-time to at least be 1 frame after start-time
end_timestamp_pos = max(end_timestamp_pos, start_timestamp_pos + time_precision)
current_segments.append(
new_segment(
start=time_offset + start_timestamp_pos * time_precision,
end=time_offset + end_timestamp_pos * time_precision,
tokens=sliced_tokens,
result=result,
)
)
last_slice = current_slice
if single_timestamp_ending:
# single timestamp at the end means no speech after the last timestamp.
seek += segment_size
else:
# otherwise, ignore the unfinished segment and seek to the last timestamp
last_timestamp_pos = (
tokens[last_slice - 1].item() - tokenizer.timestamp_begin
)
seek += last_timestamp_pos * input_stride
else: else:
duration = segment_duration self.device = torch.device(f"cuda:{device}")
timestamps = tokens[timestamp_tokens.nonzero().flatten()] else:
if ( self.device = device
len(timestamps) > 0
and timestamps[-1].item() != tokenizer.timestamp_begin super(Pipeline, self).__init__()
): self.vad_model = vad
# no consecutive timestamps but it has a timestamp; use the last one.
last_timestamp_pos = (
timestamps[-1].item() - tokenizer.timestamp_begin
)
duration = last_timestamp_pos * time_precision
current_segments.append( def _sanitize_parameters(self, **kwargs):
new_segment( preprocess_kwargs = {}
start=time_offset, if "tokenizer" in kwargs:
end=time_offset + duration, preprocess_kwargs["maybe_arg"] = kwargs["maybe_arg"]
tokens=tokens, return preprocess_kwargs, {}, {}
result=result,
)
)
seek += segment_size
if not condition_on_previous_text or result.temperature > 0.5: def preprocess(self, audio):
# do not feed the prompt tokens if a high temperature was used audio = audio['inputs']
prompt_reset_since = len(all_tokens) features = log_mel_spectrogram(audio, padding=N_SAMPLES - audio.shape[0])
return {'inputs': features}
if word_timestamps: def _forward(self, model_inputs):
add_word_timestamps( outputs = self.model.generate_segment_batched(model_inputs['inputs'], self.tokenizer, self.options)
segments=current_segments, return {'text': outputs}
model=model,
tokenizer=tokenizer, def postprocess(self, model_outputs):
mel=mel_segment, return model_outputs
num_frames=segment_size,
prepend_punctuations=prepend_punctuations,
append_punctuations=append_punctuations,
)
word_end_timestamps = [
w["end"] for s in current_segments for w in s["words"]
]
if not single_timestamp_ending and len(word_end_timestamps) > 0:
seek_shift = round(
(word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND
)
if seek_shift > 0:
seek = previous_seek + seek_shift
if verbose: def get_iterator(
for segment in current_segments: self, inputs, num_workers: int, batch_size: int, preprocess_params, forward_params, postprocess_params
start, end, text = segment["start"], segment["end"], segment["text"] ):
line = f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}" dataset = PipelineIterator(inputs, self.preprocess, preprocess_params)
print(make_safe(line)) if "TOKENIZERS_PARALLELISM" not in os.environ:
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# TODO hack by collating feature_extractor and image_processor
# if a segment is instantaneous or does not contain text, clear it def stack(items):
for i, segment in enumerate(current_segments): return {'inputs': torch.stack([x['inputs'] for x in items])}
if segment["start"] == segment["end"] or segment["text"].strip() == "": dataloader = torch.utils.data.DataLoader(dataset, num_workers=num_workers, batch_size=batch_size, collate_fn=stack)
segment["text"] = "" model_iterator = PipelineIterator(dataloader, self.forward, forward_params, loader_batch_size=batch_size)
segment["tokens"] = [] final_iterator = PipelineIterator(model_iterator, self.postprocess, postprocess_params)
segment["words"] = [] return final_iterator
all_segments.extend( def transcribe(
[ self, audio: Union[str, np.ndarray], batch_size=None
{"id": i, **segment} ):
for i, segment in enumerate( if isinstance(audio, str):
current_segments, start=len(all_segments) audio = load_audio(audio)
)
] def data(audio, segments):
) for seg in segments:
all_tokens.extend( f1 = int(seg['start'] * SAMPLE_RATE)
[token for segment in current_segments for token in segment["tokens"]] f2 = int(seg['end'] * SAMPLE_RATE)
) # print(f2-f1)
yield {'inputs': audio[f1:f2]}
# update progress bar vad_segments = self.vad_model({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": SAMPLE_RATE})
pbar.update(min(content_frames, seek) - previous_seek) vad_segments = merge_chunks(vad_segments, 30)
del_tokenizer = False
if self.tokenizer is None:
language = self.detect_language(audio)
self.tokenizer = faster_whisper.tokenizer.Tokenizer(self.model.hf_tokenizer, self.model.model.is_multilingual, task="transcribe", language=language)
del_tokenizer = True
else:
language = self.tokenizer.language_code
return dict( segments = []
text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]), batch_size = batch_size or self._batch_size
segments=all_segments, for idx, out in enumerate(self.__call__(data(audio, vad_segments), batch_size=batch_size)):
language=language, text = out['text']
) if batch_size in [0, 1, None]:
text = text[0]
segments.append(
def transcribe_with_vad( {
model: "Whisper", "text": out['text'],
audio: str, "start": round(vad_segments[idx]['start'], 3),
vad_pipeline, "end": round(vad_segments[idx]['end'], 3)
mel = None,
verbose: Optional[bool] = None,
**kwargs
):
"""
Transcribe per VAD segment
"""
vad_segments = vad_pipeline(audio)
# if not torch.is_tensor(audio):
# if isinstance(audio, str):
audio = load_audio(audio)
audio = torch.from_numpy(audio)
prev = 0
output = {"segments": []}
# merge segments to approx 30s inputs to make whisper most appropraite
vad_segments = merge_chunks(vad_segments, chunk_size=CHUNK_LENGTH)
if len(vad_segments) == 0:
return output
print(">>Performing transcription...")
for sdx, seg_t in enumerate(vad_segments):
if verbose:
print(f"~~ Transcribing VAD chunk: ({format_timestamp(seg_t['start'])} --> {format_timestamp(seg_t['end'])}) ~~")
seg_f_start, seg_f_end = int(seg_t["start"] * SAMPLE_RATE), int(seg_t["end"] * SAMPLE_RATE)
local_f_start, local_f_end = seg_f_start - prev, seg_f_end - prev
audio = audio[local_f_start:] # seek forward
seg_audio = audio[:local_f_end-local_f_start] # seek forward
prev = seg_f_start
local_mel = log_mel_spectrogram(seg_audio, padding=N_SAMPLES)
# need to pad
result = transcribe(model, audio, mel=local_mel, verbose=verbose, **kwargs)
seg_t["text"] = result["text"]
output["segments"].append(
{
"start": seg_t["start"],
"end": seg_t["end"],
"language": result["language"],
"text": result["text"],
"seg-text": [x["text"] for x in result["segments"]],
"seg-start": [x["start"] for x in result["segments"]],
"seg-end": [x["end"] for x in result["segments"]],
} }
) )
if del_tokenizer:
self.tokenizer = None
output["language"] = output["segments"][0]["language"] return {"segments": segments, "language": language}
return output
def detect_language(self, audio: np.ndarray):
segment = log_mel_spectrogram(audio[: N_SAMPLES], padding=0)
encoder_output = self.model.encode(segment)
results = self.model.model.detect_language(encoder_output)
language_token, language_probability = results[0][0]
language = language_token[2:-2]
print(f"Detected language: {language} ({language_probability:.2f}) in first 30s of audio...")
return language
if __name__ == "__main__":
main_type = "simple"
import time
import jiwer
from tqdm import tqdm
from whisper.normalizers import EnglishTextNormalizer
from benchmark.tedlium import parse_tedlium_annos
if main_type == "complex":
from faster_whisper.tokenizer import Tokenizer
from faster_whisper.transcribe import TranscriptionOptions
from faster_whisper.vad import (SpeechTimestampsMap,
get_speech_timestamps)
from whisperx.vad import load_vad_model, merge_chunks
from .audio import SAMPLE_RATE, load_audio, log_mel_spectrogram
faster_t_options = TranscriptionOptions(
beam_size=5,
best_of=5,
patience=1,
length_penalty=1,
temperatures=[0.0, 0.2, 0.4, 0.6, 0.8, 1.0],
compression_ratio_threshold=2.4,
log_prob_threshold=-1.0,
no_speech_threshold=0.6,
condition_on_previous_text=False,
initial_prompt=None,
prefix=None,
suppress_blank=True,
suppress_tokens=[-1],
without_timestamps=True,
max_initial_timestamp=0.0,
word_timestamps=False,
prepend_punctuations="\"'“¿([{-",
append_punctuations="\"'.。,!?::”)]}、"
)
whisper_arch = "large-v2"
device = "cuda"
batch_size = 16
model = WhisperModel(whisper_arch, device="cuda", compute_type="float16",)
tokenizer = Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task="transcribe", language="en")
model = FasterWhisperPipeline(model, tokenizer, faster_t_options, device=-1)
fn = "DanielKahneman_2010.wav"
wav_dir = f"/tmp/test/wav/"
vad_model = load_vad_model("cuda", 0.6, 0.3)
audio = load_audio(os.path.join(wav_dir, fn))
vad_segments = vad_model({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": SAMPLE_RATE})
vad_segments = merge_chunks(vad_segments, 30)
def data(audio, segments):
for seg in segments:
f1 = int(seg['start'] * SAMPLE_RATE)
f2 = int(seg['end'] * SAMPLE_RATE)
# print(f2-f1)
yield {'inputs': audio[f1:f2]}
vad_method="pyannote"
wav_dir = f"/tmp/test/wav/"
wer_li = []
time_li = []
for fn in os.listdir(wav_dir):
if fn == "RobertGupta_2010U.wav":
continue
base_fn = fn.split('.')[0]
audio_fp = os.path.join(wav_dir, fn)
audio = load_audio(audio_fp)
t1 = time.time()
if vad_method == "pyannote":
vad_segments = vad_model({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": SAMPLE_RATE})
vad_segments = merge_chunks(vad_segments, 30)
elif vad_method == "silero":
vad_segments = get_speech_timestamps(audio, threshold=0.5, max_speech_duration_s=30)
vad_segments = [{"start": x["start"] / SAMPLE_RATE, "end": x["end"] / SAMPLE_RATE} for x in vad_segments]
new_segs = []
curr_start = vad_segments[0]['start']
curr_end = vad_segments[0]['end']
for seg in vad_segments[1:]:
if seg['end'] - curr_start > 30:
new_segs.append({"start": curr_start, "end": curr_end})
curr_start = seg['start']
curr_end = seg['end']
else:
curr_end = seg['end']
new_segs.append({"start": curr_start, "end": curr_end})
vad_segments = new_segs
text = []
# for idx, out in tqdm(enumerate(model(data(audio_fp, vad_segments), batch_size=batch_size)), total=len(vad_segments)):
for idx, out in enumerate(model(data(audio, vad_segments), batch_size=batch_size)):
text.append(out['text'])
t2 = time.time()
if batch_size == 1:
text = [x[0] for x in text]
text = " ".join(text)
normalizer = EnglishTextNormalizer()
text = normalizer(text)
gt_corpus = normalizer(parse_tedlium_annos(base_fn, "/tmp/test/"))
wer_result = jiwer.wer(gt_corpus, text)
print("WER: %.2f \t time: %.2f \t [%s]" % (wer_result * 100, t2-t1, fn))
wer_li.append(wer_result)
time_li.append(t2-t1)
print("# Avg Mean...")
print("WER: %.2f" % (sum(wer_li) * 100/len(wer_li)))
print("Time: %.2f" % (sum(time_li)/len(time_li)))
elif main_type == "simple":
model = load_model(
"large-v2",
device="cuda",
language="en",
)
wav_dir = f"/tmp/test/wav/"
wer_li = []
time_li = []
for fn in os.listdir(wav_dir):
if fn == "RobertGupta_2010U.wav":
continue
# fn = "DanielKahneman_2010.wav"
base_fn = fn.split('.')[0]
audio_fp = os.path.join(wav_dir, fn)
audio = load_audio(audio_fp)
t1 = time.time()
out = model.transcribe(audio_fp, batch_size=8)["segments"]
t2 = time.time()
text = " ".join([x['text'] for x in out])
normalizer = EnglishTextNormalizer()
text = normalizer(text)
gt_corpus = normalizer(parse_tedlium_annos(base_fn, "/tmp/test/"))
wer_result = jiwer.wer(gt_corpus, text)
print("WER: %.2f \t time: %.2f \t [%s]" % (wer_result * 100, t2-t1, fn))
wer_li.append(wer_result)
time_li.append(t2-t1)
print("# Avg Mean...")
print("WER: %.2f" % (sum(wer_li) * 100/len(wer_li)))
print("Time: %.2f" % (sum(time_li)/len(time_li)))

Binary file not shown.

147
whisperx/audio.py Normal file
View File

@ -0,0 +1,147 @@
import os
from functools import lru_cache
from typing import Optional, Union
import ffmpeg
import numpy as np
import torch
import torch.nn.functional as F
from .utils import exact_div
# hard-coded audio hyperparameters
SAMPLE_RATE = 16000
N_FFT = 400
N_MELS = 80
HOP_LENGTH = 160
CHUNK_LENGTH = 30
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000 frames in a mel spectrogram input
N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2
FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame
TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token
def load_audio(file: str, sr: int = SAMPLE_RATE):
"""
Open an audio file and read as mono waveform, resampling as necessary
Parameters
----------
file: str
The audio file to open
sr: int
The sample rate to resample the audio if necessary
Returns
-------
A NumPy array containing the audio waveform, in float32 dtype.
"""
try:
# This launches a subprocess to decode audio while down-mixing and resampling as necessary.
# Requires the ffmpeg CLI and `ffmpeg-python` package to be installed.
out, _ = (
ffmpeg.input(file, threads=0)
.output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=sr)
.run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True)
)
except ffmpeg.Error as e:
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
"""
Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
"""
if torch.is_tensor(array):
if array.shape[axis] > length:
array = array.index_select(
dim=axis, index=torch.arange(length, device=array.device)
)
if array.shape[axis] < length:
pad_widths = [(0, 0)] * array.ndim
pad_widths[axis] = (0, length - array.shape[axis])
array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes])
else:
if array.shape[axis] > length:
array = array.take(indices=range(length), axis=axis)
if array.shape[axis] < length:
pad_widths = [(0, 0)] * array.ndim
pad_widths[axis] = (0, length - array.shape[axis])
array = np.pad(array, pad_widths)
return array
@lru_cache(maxsize=None)
def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor:
"""
load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
Allows decoupling librosa dependency; saved using:
np.savez_compressed(
"mel_filters.npz",
mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
)
"""
assert n_mels == 80, f"Unsupported n_mels: {n_mels}"
with np.load(
os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
) as f:
return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
def log_mel_spectrogram(
audio: Union[str, np.ndarray, torch.Tensor],
n_mels: int = N_MELS,
padding: int = 0,
device: Optional[Union[str, torch.device]] = None,
):
"""
Compute the log-Mel spectrogram of
Parameters
----------
audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
n_mels: int
The number of Mel-frequency filters, only 80 is supported
padding: int
Number of zero samples to pad to the right
device: Optional[Union[str, torch.device]]
If given, the audio tensor is moved to this device before STFT
Returns
-------
torch.Tensor, shape = (80, n_frames)
A Tensor that contains the Mel spectrogram
"""
if not torch.is_tensor(audio):
if isinstance(audio, str):
audio = load_audio(audio)
audio = torch.from_numpy(audio)
if device is not None:
audio = audio.to(device)
if padding > 0:
audio = F.pad(audio, (0, padding))
window = torch.hann_window(N_FFT).to(audio.device)
stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
magnitudes = stft[..., :-1].abs() ** 2
filters = mel_filters(audio.device, n_mels)
mel_spec = filters @ magnitudes
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0
return log_spec

View File

@ -1,37 +1,30 @@
import argparse import argparse
import os
import gc import gc
import os
import warnings import warnings
from typing import TYPE_CHECKING, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
import tempfile
import ffmpeg
from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE
from whisper.audio import SAMPLE_RATE
from whisper.utils import (
optional_float,
optional_int,
str2bool,
)
from .alignment import load_align_model, align from .alignment import align, load_align_model
from .asr import transcribe, transcribe_with_vad from .asr import load_model
from .audio import load_audio
from .diarize import DiarizationPipeline, assign_word_speakers from .diarize import DiarizationPipeline, assign_word_speakers
from .utils import get_writer from .utils import (LANGUAGES, TO_LANGUAGE_CODE, get_writer, optional_float,
from .vad import load_vad_model optional_int, str2bool)
def cli(): def cli():
from whisper import available_models
# fmt: off # fmt: off
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe") parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use") parser.add_argument("--model", default="small", help="name of the Whisper model to use")
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default") parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference") parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
parser.add_argument("--batch_size", default=8, type=int, help="device to use for PyTorch inference")
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs") parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
parser.add_argument("--output_format", "-f", type=str, default="all", choices=["all", "srt", "srt-word", "vtt", "txt", "tsv", "ass", "ass-char", "pickle", "vad"], help="format of the output file; if not specified, all available formats will be produced") parser.add_argument("--output_format", "-f", type=str, default="all", choices=["all", "srt", "vtt", "txt", "tsv", "json"], help="format of the output file; if not specified, all available formats will be produced")
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages") parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')") parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
@ -39,13 +32,10 @@ def cli():
# alignment params # alignment params
parser.add_argument("--align_model", default=None, help="Name of phoneme-level ASR model to do alignment") parser.add_argument("--align_model", default=None, help="Name of phoneme-level ASR model to do alignment")
parser.add_argument("--align_extend", default=2, type=float, help="Seconds before and after to extend the whisper segments for alignment (if not using VAD).")
parser.add_argument("--align_from_prev", default=True, type=bool, help="Whether to clip the alignment start time of current segment to the end time of the last aligned word of the previous segment (if not using VAD)")
parser.add_argument("--interpolate_method", default="nearest", choices=["nearest", "linear", "ignore"], help="For word .srt, method to assign timestamps to non-aligned words, or merge them into neighbouring.") parser.add_argument("--interpolate_method", default="nearest", choices=["nearest", "linear", "ignore"], help="For word .srt, method to assign timestamps to non-aligned words, or merge them into neighbouring.")
parser.add_argument("--no_align", action='store_true', help="Do not perform phoneme alignment") parser.add_argument("--no_align", action='store_true', help="Do not perform phoneme alignment")
# vad params # vad params
parser.add_argument("--vad_filter", type=str2bool, default=True, help="Whether to pre-segment audio with VAD, highly recommended! Produces more accurate alignment + timestamp see WhisperX paper https://arxiv.org/abs/2303.00747")
parser.add_argument("--vad_onset", type=float, default=0.500, help="Onset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected") parser.add_argument("--vad_onset", type=float, default=0.500, help="Onset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected")
parser.add_argument("--vad_offset", type=float, default=0.363, help="Offset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected.") parser.add_argument("--vad_offset", type=float, default=0.363, help="Offset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected.")
@ -69,9 +59,14 @@ def cli():
parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed") parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed") parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence") parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
parser.add_argument("--word_timestamps", type=str2bool, default=False, help="(experimental) extract word-level timestamps and refine the results based on them")
parser.add_argument("--prepend_punctuations", type=str, default="\"\'“¿([{-", help="if word_timestamps is True, merge these punctuation symbols with the next word") parser.add_argument("--max_line_width", type=optional_int, default=None, help="(not possible with --no_align) the maximum number of characters in a line before breaking the line")
parser.add_argument("--append_punctuations", type=str, default="\"\'.。,!?::”)]}、", help="if word_timestamps is True, merge these punctuation symbols with the previous word") parser.add_argument("--max_line_count", type=optional_int, default=None, help="(requires --no_align) the maximum number of lines in a segment")
parser.add_argument("--highlight_words", type=str2bool, default=False, help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt")
# parser.add_argument("--word_timestamps", type=str2bool, default=False, help="(experimental) extract word-level timestamps and refine the results based on them")
# parser.add_argument("--prepend_punctuations", type=str, default="\"\'“¿([{-", help="if word_timestamps is True, merge these punctuation symbols with the next word")
# parser.add_argument("--append_punctuations", type=str, default="\"\'.。,!?::”)]}、", help="if word_timestamps is True, merge these punctuation symbols with the previous word")
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS") parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face Access Token to access PyAnnote gated models") parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face Access Token to access PyAnnote gated models")
@ -81,7 +76,7 @@ def cli():
args = parser.parse_args().__dict__ args = parser.parse_args().__dict__
model_name: str = args.pop("model") model_name: str = args.pop("model")
model_dir: str = args.pop("model_dir") batch_size: int = args.pop("batch_size")
output_dir: str = args.pop("output_dir") output_dir: str = args.pop("output_dir")
output_format: str = args.pop("output_format") output_format: str = args.pop("output_format")
device: str = args.pop("device") device: str = args.pop("device")
@ -93,13 +88,10 @@ def cli():
os.makedirs(tmp_dir, exist_ok=True) os.makedirs(tmp_dir, exist_ok=True)
align_model: str = args.pop("align_model") align_model: str = args.pop("align_model")
align_extend: float = args.pop("align_extend")
align_from_prev: bool = args.pop("align_from_prev")
interpolate_method: str = args.pop("interpolate_method") interpolate_method: str = args.pop("interpolate_method")
no_align: bool = args.pop("no_align") no_align: bool = args.pop("no_align")
hf_token: str = args.pop("hf_token") hf_token: str = args.pop("hf_token")
vad_filter: bool = args.pop("vad_filter")
vad_onset: float = args.pop("vad_onset") vad_onset: float = args.pop("vad_onset")
vad_offset: float = args.pop("vad_offset") vad_offset: float = args.pop("vad_offset")
@ -107,18 +99,7 @@ def cli():
min_speakers: int = args.pop("min_speakers") min_speakers: int = args.pop("min_speakers")
max_speakers: int = args.pop("max_speakers") max_speakers: int = args.pop("max_speakers")
if vad_filter: # TODO: check model loading works.
from pyannote.audio import Pipeline
from pyannote.audio import Model, Pipeline
vad_model = load_vad_model(torch.device(device), vad_onset, vad_offset, use_auth_token=hf_token)
else:
vad_model = None
# if model_flush:
# print(">>Model flushing activated... Only loading model after ASR stage")
# del align_model
# align_model = ""
if model_name.endswith(".en") and args["language"] not in {"en", "English"}: if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
if args["language"] is not None: if args["language"] is not None:
@ -136,39 +117,43 @@ def cli():
if (threads := args.pop("threads")) > 0: if (threads := args.pop("threads")) > 0:
torch.set_num_threads(threads) torch.set_num_threads(threads)
from whisper import load_model asr_options = {
"beam_size": args.pop("beam_size"),
"patience": args.pop("patience"),
"length_penalty": args.pop("length_penalty"),
"temperatures": temperature,
"compression_ratio_threshold": args.pop("compression_ratio_threshold"),
"log_prob_threshold": args.pop("logprob_threshold"),
"no_speech_threshold": args.pop("no_speech_threshold"),
"condition_on_previous_text": False,
"initial_prompt": args.pop("initial_prompt"),
}
writer = get_writer(output_format, output_dir) writer = get_writer(output_format, output_dir)
word_options = ["highlight_words", "max_line_count", "max_line_width"]
if no_align:
for option in word_options:
if args[option]:
parser.error(f"--{option} requires --word_timestamps True")
if args["max_line_count"] and not args["max_line_width"]:
warnings.warn("--max_line_count has no effect without --max_line_width")
writer_args = {arg: args.pop(arg) for arg in word_options}
# Part 1: VAD & ASR Loop # Part 1: VAD & ASR Loop
results = [] results = []
tmp_results = [] tmp_results = []
model = load_model(model_name, device=device, download_root=model_dir) # model = load_model(model_name, device=device, download_root=model_dir)
for audio_path in args.pop("audio"): model = load_model(model_name, device=device, language=args['language'], asr_options=asr_options, vad_options={"vad_onset": vad_onset, "vad_offset": vad_offset},)
input_audio_path = audio_path
tfile = None
for audio_path in args.pop("audio"):
audio = load_audio(audio_path)
# >> VAD & ASR # >> VAD & ASR
if vad_model is not None: print(">>Performing transcription...")
if not audio_path.endswith(".wav"): result = model.transcribe(audio, batch_size=batch_size)
print(">>VAD requires .wav format, converting to wav as a tempfile...") results.append((result, audio_path))
audio_basename = os.path.splitext(os.path.basename(audio_path))[0]
if tmp_dir is not None:
input_audio_path = os.path.join(tmp_dir, audio_basename + ".wav")
else:
input_audio_path = os.path.join(os.path.dirname(audio_path), audio_basename + ".wav")
ffmpeg.input(audio_path, threads=0).output(input_audio_path, ac=1, ar=SAMPLE_RATE).run(cmd=["ffmpeg"])
print(">>Performing VAD...")
result = transcribe_with_vad(model, input_audio_path, vad_model, temperature=temperature, **args)
else:
print(">>Performing transcription...")
result = transcribe(model, input_audio_path, temperature=temperature, **args)
results.append((result, input_audio_path))
# Unload Whisper and VAD # Unload Whisper and VAD
del model del model
del vad_model
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ -178,17 +163,22 @@ def cli():
results = [] results = []
align_language = args["language"] if args["language"] is not None else "en" # default to loading english if not specified align_language = args["language"] if args["language"] is not None else "en" # default to loading english if not specified
align_model, align_metadata = load_align_model(align_language, device, model_name=align_model) align_model, align_metadata = load_align_model(align_language, device, model_name=align_model)
for result, input_audio_path in tmp_results: for result, audio_path in tmp_results:
# >> Align # >> Align
if len(tmp_results) > 1:
input_audio = audio_path
else:
# lazily load audio from part 1
input_audio = audio
if align_model is not None and len(result["segments"]) > 0: if align_model is not None and len(result["segments"]) > 0:
if result.get("language", "en") != align_metadata["language"]: if result.get("language", "en") != align_metadata["language"]:
# load new language # load new language
print(f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language...") print(f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language...")
align_model, align_metadata = load_align_model(result["language"], device) align_model, align_metadata = load_align_model(result["language"], device)
print(">>Performing alignment...") print(">>Performing alignment...")
result = align(result["segments"], align_model, align_metadata, input_audio_path, device, result = align(result["segments"], align_model, align_metadata, input_audio, device, interpolate_method=interpolate_method)
extend_duration=align_extend, start_from_previous=align_from_prev, interpolate_method=interpolate_method) results.append((result, audio_path))
results.append((result, input_audio_path))
# Unload align model # Unload align model
del align_model del align_model
@ -210,11 +200,7 @@ def cli():
# >> Write # >> Write
for result, audio_path in results: for result, audio_path in results:
writer(result, audio_path) writer(result, audio_path, writer_args)
# cleanup
if input_audio_path != audio_path:
os.remove(input_audio_path)
if __name__ == "__main__": if __name__ == "__main__":
cli() cli()

View File

@ -1,280 +1,301 @@
import json
import os import os
import re
import sys
import zlib import zlib
from typing import Callable, TextIO, Iterator, Tuple from typing import Callable, Optional, TextIO
import pandas as pd
import numpy as np
def interpolate_nans(x, method='nearest'): LANGUAGES = {
if x.notnull().sum() > 1: "en": "english",
return x.interpolate(method=method).ffill().bfill() "zh": "chinese",
"de": "german",
"es": "spanish",
"ru": "russian",
"ko": "korean",
"fr": "french",
"ja": "japanese",
"pt": "portuguese",
"tr": "turkish",
"pl": "polish",
"ca": "catalan",
"nl": "dutch",
"ar": "arabic",
"sv": "swedish",
"it": "italian",
"id": "indonesian",
"hi": "hindi",
"fi": "finnish",
"vi": "vietnamese",
"he": "hebrew",
"uk": "ukrainian",
"el": "greek",
"ms": "malay",
"cs": "czech",
"ro": "romanian",
"da": "danish",
"hu": "hungarian",
"ta": "tamil",
"no": "norwegian",
"th": "thai",
"ur": "urdu",
"hr": "croatian",
"bg": "bulgarian",
"lt": "lithuanian",
"la": "latin",
"mi": "maori",
"ml": "malayalam",
"cy": "welsh",
"sk": "slovak",
"te": "telugu",
"fa": "persian",
"lv": "latvian",
"bn": "bengali",
"sr": "serbian",
"az": "azerbaijani",
"sl": "slovenian",
"kn": "kannada",
"et": "estonian",
"mk": "macedonian",
"br": "breton",
"eu": "basque",
"is": "icelandic",
"hy": "armenian",
"ne": "nepali",
"mn": "mongolian",
"bs": "bosnian",
"kk": "kazakh",
"sq": "albanian",
"sw": "swahili",
"gl": "galician",
"mr": "marathi",
"pa": "punjabi",
"si": "sinhala",
"km": "khmer",
"sn": "shona",
"yo": "yoruba",
"so": "somali",
"af": "afrikaans",
"oc": "occitan",
"ka": "georgian",
"be": "belarusian",
"tg": "tajik",
"sd": "sindhi",
"gu": "gujarati",
"am": "amharic",
"yi": "yiddish",
"lo": "lao",
"uz": "uzbek",
"fo": "faroese",
"ht": "haitian creole",
"ps": "pashto",
"tk": "turkmen",
"nn": "nynorsk",
"mt": "maltese",
"sa": "sanskrit",
"lb": "luxembourgish",
"my": "myanmar",
"bo": "tibetan",
"tl": "tagalog",
"mg": "malagasy",
"as": "assamese",
"tt": "tatar",
"haw": "hawaiian",
"ln": "lingala",
"ha": "hausa",
"ba": "bashkir",
"jw": "javanese",
"su": "sundanese",
}
# language code lookup by name, with a few language aliases
TO_LANGUAGE_CODE = {
**{language: code for code, language in LANGUAGES.items()},
"burmese": "my",
"valencian": "ca",
"flemish": "nl",
"haitian": "ht",
"letzeburgesch": "lb",
"pushto": "ps",
"panjabi": "pa",
"moldavian": "ro",
"moldovan": "ro",
"sinhalese": "si",
"castilian": "es",
}
system_encoding = sys.getdefaultencoding()
if system_encoding != "utf-8":
def make_safe(string):
# replaces any character not representable using the system default encoding with an '?',
# avoiding UnicodeEncodeError (https://github.com/openai/whisper/discussions/729).
return string.encode(system_encoding, errors="replace").decode(system_encoding)
else:
def make_safe(string):
# utf-8 can encode any Unicode code point, so no need to do the round-trip encoding
return string
def exact_div(x, y):
assert x % y == 0
return x // y
def str2bool(string):
str2val = {"True": True, "False": False}
if string in str2val:
return str2val[string]
else: else:
return x.ffill().bfill() raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}")
def write_txt(transcript: Iterator[dict], file: TextIO):
for segment in transcript:
print(segment['text'].strip(), file=file, flush=True)
def write_vtt(transcript: Iterator[dict], file: TextIO): def optional_int(string):
print("WEBVTT\n", file=file) return None if string == "None" else int(string)
for segment in transcript:
print(
f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n" def optional_float(string):
f"{segment['text'].strip().replace('-->', '->')}\n", return None if string == "None" else float(string)
file=file,
flush=True,
def compression_ratio(text) -> float:
text_bytes = text.encode("utf-8")
return len(text_bytes) / len(zlib.compress(text_bytes))
def format_timestamp(
seconds: float, always_include_hours: bool = False, decimal_marker: str = "."
):
assert seconds >= 0, "non-negative timestamp expected"
milliseconds = round(seconds * 1000.0)
hours = milliseconds // 3_600_000
milliseconds -= hours * 3_600_000
minutes = milliseconds // 60_000
milliseconds -= minutes * 60_000
seconds = milliseconds // 1_000
milliseconds -= seconds * 1_000
hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
return (
f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
)
class ResultWriter:
extension: str
def __init__(self, output_dir: str):
self.output_dir = output_dir
def __call__(self, result: dict, audio_path: str, options: dict):
audio_basename = os.path.basename(audio_path)
audio_basename = os.path.splitext(audio_basename)[0]
output_path = os.path.join(
self.output_dir, audio_basename + "." + self.extension
) )
def write_tsv(transcript: Iterator[dict], file: TextIO): with open(output_path, "w", encoding="utf-8") as f:
print("start", "end", "text", sep="\t", file=file) self.write_result(result, file=f, options=options)
for segment in transcript:
print(segment['start'], file=file, end="\t") def write_result(self, result: dict, file: TextIO, options: dict):
print(segment['end'], file=file, end="\t") raise NotImplementedError
print(segment['text'].strip().replace("\t", " "), file=file, flush=True)
def write_srt(transcript: Iterator[dict], file: TextIO): class WriteTXT(ResultWriter):
""" extension: str = "txt"
Write a transcript to a file in SRT format.
Example usage: def write_result(self, result: dict, file: TextIO, options: dict):
from pathlib import Path for segment in result["segments"]:
from whisper.utils import write_srt print(segment["text"].strip(), file=file, flush=True)
result = transcribe(model, audio_path, temperature=temperature, **args)
# save SRT
audio_basename = Path(audio_path).stem
with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt:
write_srt(result["segments"], file=srt)
"""
for i, segment in enumerate(transcript, start=1):
# write srt lines
print(
f"{i}\n"
f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> "
f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n"
f"{segment['text'].strip().replace('-->', '->')}\n",
file=file,
flush=True,
)
def write_ass(transcript: Iterator[dict], class SubtitlesWriter(ResultWriter):
file: TextIO, always_include_hours: bool
resolution: str = "word", decimal_marker: str
color: str = None, underline=True,
prefmt: str = None, suffmt: str = None,
font: str = None, font_size: int = 24,
strip=True, **kwargs):
"""
Credit: https://github.com/jianfch/stable-ts/blob/ff79549bd01f764427879f07ecd626c46a9a430a/stable_whisper/text_output.py
Generate Advanced SubStation Alpha (ass) file from results to
display both phrase-level & word-level timestamp simultaneously by:
-using segment-level timestamps display phrases as usual
-using word-level timestamps change formats (e.g. color/underline) of the word in the displayed segment
Note: ass file is used in the same way as srt, vtt, etc.
Parameters
----------
transcript: dict
results from modified model
file: TextIO
file object to write to
resolution: str
"word" or "char", timestamp resolution to highlight.
color: str
color code for a word at its corresponding timestamp
<bbggrr> reverse order hexadecimal RGB value (e.g. FF0000 is full intensity blue. Default: 00FF00)
underline: bool
whether to underline a word at its corresponding timestamp
prefmt: str
used to specify format for word-level timestamps (must be use with 'suffmt' and overrides 'color'&'underline')
appears as such in the .ass file:
Hi, {<prefmt>}how{<suffmt>} are you?
reference [Appendix A: Style override codes] in http://www.tcax.org/docs/ass-specs.htm
suffmt: str
used to specify format for word-level timestamps (must be use with 'prefmt' and overrides 'color'&'underline')
appears as such in the .ass file:
Hi, {<prefmt>}how{<suffmt>} are you?
reference [Appendix A: Style override codes] in http://www.tcax.org/docs/ass-specs.htm
font: str
word font (default: Arial)
font_size: int
word font size (default: 48)
kwargs:
used for format styles:
'Name', 'Fontname', 'Fontsize', 'PrimaryColour', 'SecondaryColour', 'OutlineColour', 'BackColour', 'Bold',
'Italic', 'Underline', 'StrikeOut', 'ScaleX', 'ScaleY', 'Spacing', 'Angle', 'BorderStyle', 'Outline',
'Shadow', 'Alignment', 'MarginL', 'MarginR', 'MarginV', 'Encoding'
""" def iterate_result(self, result: dict, options: dict):
raw_max_line_width: Optional[int] = options["max_line_width"]
max_line_count: Optional[int] = options["max_line_count"]
highlight_words: bool = options["highlight_words"]
max_line_width = 1000 if raw_max_line_width is None else raw_max_line_width
preserve_segments = max_line_count is None or raw_max_line_width is None
fmt_style_dict = {'Name': 'Default', 'Fontname': 'Arial', 'Fontsize': '48', 'PrimaryColour': '&Hffffff', def iterate_subtitles():
'SecondaryColour': '&Hffffff', 'OutlineColour': '&H0', 'BackColour': '&H0', 'Bold': '0', line_len = 0
'Italic': '0', 'Underline': '0', 'StrikeOut': '0', 'ScaleX': '100', 'ScaleY': '100', line_count = 1
'Spacing': '0', 'Angle': '0', 'BorderStyle': '1', 'Outline': '1', 'Shadow': '0', # the next subtitle to yield (a list of word timings with whitespace)
'Alignment': '2', 'MarginL': '10', 'MarginR': '10', 'MarginV': '10', 'Encoding': '0'} subtitle: list[dict] = []
last = result["segments"][0]["words"][0]["start"]
for segment in result["segments"]:
for i, original_timing in enumerate(segment["words"]):
timing = original_timing.copy()
long_pause = not preserve_segments and timing["start"] - last > 3.0
has_room = line_len + len(timing["word"]) <= max_line_width
seg_break = i == 0 and len(subtitle) > 0 and preserve_segments
if line_len > 0 and has_room and not long_pause and not seg_break:
# line continuation
line_len += len(timing["word"])
else:
# new line
timing["word"] = timing["word"].strip()
if (
len(subtitle) > 0
and max_line_count is not None
and (long_pause or line_count >= max_line_count)
or seg_break
):
# subtitle break
yield subtitle
subtitle = []
line_count = 1
elif line_len > 0:
# line break
line_count += 1
timing["word"] = "\n" + timing["word"]
line_len = len(timing["word"].strip())
subtitle.append(timing)
last = timing["start"]
if len(subtitle) > 0:
yield subtitle
for k, v in filter(lambda x: 'colour' in x[0].lower() and not str(x[1]).startswith('&H'), kwargs.items()): if "words" in result["segments"][0]:
kwargs[k] = f'&H{kwargs[k]}' for subtitle in iterate_subtitles():
subtitle_start = self.format_timestamp(subtitle[0]["start"])
subtitle_end = self.format_timestamp(subtitle[-1]["end"])
subtitle_text = "".join([word["word"] for word in subtitle])
if highlight_words:
last = subtitle_start
all_words = [timing["word"] for timing in subtitle]
for i, this_word in enumerate(subtitle):
start = self.format_timestamp(this_word["start"])
end = self.format_timestamp(this_word["end"])
if last != start:
yield last, start, subtitle_text
fmt_style_dict.update((k, v) for k, v in kwargs.items() if k in fmt_style_dict) yield start, end, "".join(
[
if font: re.sub(r"^(\s*)(.*)$", r"\1<u>\2</u>", word)
fmt_style_dict.update(Fontname=font) if j == i
if font_size: else word
fmt_style_dict.update(Fontsize=font_size) for j, word in enumerate(all_words)
]
fmts = f'Format: {", ".join(map(str, fmt_style_dict.keys()))}' )
last = end
styles = f'Style: {",".join(map(str, fmt_style_dict.values()))}' else:
yield subtitle_start, subtitle_end, subtitle_text
ass_str = f'[Script Info]\nScriptType: v4.00+\nPlayResX: 384\nPlayResY: 288\nScaledBorderAndShadow: yes\n\n' \
f'[V4+ Styles]\n{fmts}\n{styles}\n\n' \
f'[Events]\nFormat: Layer, Start, End, Style, Name, MarginL, MarginR, MarginV, Effect, Text\n\n'
if prefmt or suffmt:
if suffmt:
assert prefmt, 'prefmt must be used along with suffmt'
else: else:
suffmt = r'\r' for segment in result["segments"]:
else: segment_start = self.format_timestamp(segment["start"])
if not color: segment_end = self.format_timestamp(segment["end"])
color = 'HFF00' segment_text = segment["text"].strip().replace("-->", "->")
underline_code = r'\u1' if underline else ''
prefmt = r'{\1c&' + f'{color.upper()}&{underline_code}' + '}'
suffmt = r'{\r}'
def secs_to_hhmmss(secs: Tuple[float, int]):
mm, ss = divmod(secs, 60)
hh, mm = divmod(mm, 60)
return f'{hh:0>1.0f}:{mm:0>2.0f}:{ss:0>2.2f}'
def dialogue(chars: str, start: float, end: float, idx_0: int, idx_1: int) -> str:
if idx_0 == -1:
text = chars
else:
text = f'{chars[:idx_0]}{prefmt}{chars[idx_0:idx_1]}{suffmt}{chars[idx_1:]}'
return f"Dialogue: 0,{secs_to_hhmmss(start)},{secs_to_hhmmss(end)}," \
f"Default,,0,0,0,,{text.strip() if strip else text}"
if resolution == "word":
resolution_key = "word-segments"
elif resolution == "char":
resolution_key = "char-segments"
else:
raise ValueError(".ass resolution should be 'word' or 'char', not ", resolution)
ass_arr = []
for segment in transcript:
# if "12" in segment['text']:
# import pdb; pdb.set_trace()
if resolution_key in segment:
res_segs = pd.DataFrame(segment[resolution_key])
prev = segment['start']
if "speaker" in segment:
speaker_str = f"[{segment['speaker']}]: "
else:
speaker_str = ""
for cdx, crow in res_segs.iterrows():
if not np.isnan(crow['start']):
if resolution == "char":
idx_0 = cdx
idx_1 = cdx + 1
elif resolution == "word":
idx_0 = int(crow["segment-text-start"])
idx_1 = int(crow["segment-text-end"])
# fill gap
if crow['start'] > prev:
filler_ts = {
"chars": speaker_str + segment['text'],
"start": prev,
"end": crow['start'],
"idx_0": -1,
"idx_1": -1
}
ass_arr.append(filler_ts)
# highlight current word
f_word_ts = {
"chars": speaker_str + segment['text'],
"start": crow['start'],
"end": crow['end'],
"idx_0": idx_0 + len(speaker_str),
"idx_1": idx_1 + len(speaker_str)
}
ass_arr.append(f_word_ts)
prev = crow['end']
ass_str += '\n'.join(map(lambda x: dialogue(**x), ass_arr))
file.write(ass_str)
from whisper.utils import SubtitlesWriter, ResultWriter, WriteTXT, WriteVTT, WriteSRT, WriteTSV, WriteJSON, format_timestamp
class WriteASS(ResultWriter):
extension: str = "ass"
def write_result(self, result: dict, file: TextIO):
write_ass(result["segments"], file, resolution="word")
class WriteASSchar(ResultWriter):
extension: str = "ass"
def write_result(self, result: dict, file: TextIO):
write_ass(result["segments"], file, resolution="char")
class WritePickle(ResultWriter):
extension: str = "ass"
def write_result(self, result: dict, file: TextIO):
pd.DataFrame(result["segments"]).to_pickle(file)
class WriteSRTWord(ResultWriter):
extension: str = "word.srt"
always_include_hours: bool = True
decimal_marker: str = ","
def iterate_result(self, result: dict):
for segment in result["word_segments"]:
segment_start = self.format_timestamp(segment["start"])
segment_end = self.format_timestamp(segment["end"])
segment_text = segment["text"].strip().replace("-->", "->")
if word_timings := segment.get("words", None):
all_words = [timing["word"] for timing in word_timings]
all_words[0] = all_words[0].strip() # remove the leading space, if any
last = segment_start
for i, this_word in enumerate(word_timings):
start = self.format_timestamp(this_word["start"])
end = self.format_timestamp(this_word["end"])
if last != start:
yield last, start, segment_text
yield start, end, "".join(
[
f"<u>{word}</u>" if j == i else word
for j, word in enumerate(all_words)
]
)
last = end
if last != segment_end:
yield last, segment_end, segment_text
else:
yield segment_start, segment_end, segment_text yield segment_start, segment_end, segment_text
def write_result(self, result: dict, file: TextIO):
if "word_segments" not in result:
return
for i, (start, end, text) in enumerate(self.iterate_result(result), start=1):
print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True)
def format_timestamp(self, seconds: float): def format_timestamp(self, seconds: float):
return format_timestamp( return format_timestamp(
seconds=seconds, seconds=seconds,
@ -282,36 +303,81 @@ class WriteSRTWord(ResultWriter):
decimal_marker=self.decimal_marker, decimal_marker=self.decimal_marker,
) )
def get_writer(output_format: str, output_dir: str) -> Callable[[dict, TextIO], None]:
class WriteVTT(SubtitlesWriter):
extension: str = "vtt"
always_include_hours: bool = False
decimal_marker: str = "."
def write_result(self, result: dict, file: TextIO, options: dict):
print("WEBVTT\n", file=file)
for start, end, text in self.iterate_result(result, options):
print(f"{start} --> {end}\n{text}\n", file=file, flush=True)
class WriteSRT(SubtitlesWriter):
extension: str = "srt"
always_include_hours: bool = True
decimal_marker: str = ","
def write_result(self, result: dict, file: TextIO, options: dict):
for i, (start, end, text) in enumerate(
self.iterate_result(result, options), start=1
):
print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True)
class WriteTSV(ResultWriter):
"""
Write a transcript to a file in TSV (tab-separated values) format containing lines like:
<start time in integer milliseconds>\t<end time in integer milliseconds>\t<transcript text>
Using integer milliseconds as start and end times means there's no chance of interference from
an environment setting a language encoding that causes the decimal in a floating point number
to appear as a comma; also is faster and more efficient to parse & store, e.g., in C++.
"""
extension: str = "tsv"
def write_result(self, result: dict, file: TextIO, options: dict):
print("start", "end", "text", sep="\t", file=file)
for segment in result["segments"]:
print(round(1000 * segment["start"]), file=file, end="\t")
print(round(1000 * segment["end"]), file=file, end="\t")
print(segment["text"].strip().replace("\t", " "), file=file, flush=True)
class WriteJSON(ResultWriter):
extension: str = "json"
def write_result(self, result: dict, file: TextIO, options: dict):
json.dump(result, file)
def get_writer(
output_format: str, output_dir: str
) -> Callable[[dict, TextIO, dict], None]:
writers = { writers = {
"txt": WriteTXT, "txt": WriteTXT,
"vtt": WriteVTT, "vtt": WriteVTT,
"srt": WriteSRT, "srt": WriteSRT,
"tsv": WriteTSV, "tsv": WriteTSV,
"ass": WriteASS, "json": WriteJSON,
"srt-word": WriteSRTWord,
# "ass-char": WriteASSchar,
# "pickle": WritePickle,
# "json": WriteJSON,
}
writers_other = {
"pkl": WritePickle,
"ass-char": WriteASSchar
} }
if output_format == "all": if output_format == "all":
all_writers = [writer(output_dir) for writer in writers.values()] all_writers = [writer(output_dir) for writer in writers.values()]
def write_all(result: dict, file: TextIO): def write_all(result: dict, file: TextIO, options: dict):
for writer in all_writers: for writer in all_writers:
writer(result, file) writer(result, file, options)
return write_all return write_all
if output_format in writers: return writers[output_format](output_dir)
return writers[output_format](output_dir)
elif output_format in writers_other: def interpolate_nans(x, method='nearest'):
return writers_other[output_format](output_dir) if x.notnull().sum() > 1:
return x.interpolate(method=method).ffill().bfill()
else: else:
raise ValueError(f"Output format '{output_format}' not supported, choose from {writers.keys()} and {writers_other.keys()}") return x.ffill().bfill()

View File

@ -1,22 +1,23 @@
import hashlib
import os import os
import urllib import urllib
import pandas as pd from typing import Callable, Optional, Text, Union
import numpy as np import numpy as np
import pandas as pd
import torch import torch
import hashlib
from tqdm import tqdm
from typing import Optional, Callable, Union, Text
from pyannote.audio.core.io import AudioFile
from pyannote.core import Annotation, Segment, SlidingWindowFeature
from pyannote.audio.pipelines.utils import PipelineModel
from pyannote.audio import Model from pyannote.audio import Model
from pyannote.audio.core.io import AudioFile
from pyannote.audio.pipelines import VoiceActivityDetection from pyannote.audio.pipelines import VoiceActivityDetection
from pyannote.audio.pipelines.utils import PipelineModel
from pyannote.core import Annotation, Segment, SlidingWindowFeature
from tqdm import tqdm
from .diarize import Segment as SegmentX from .diarize import Segment as SegmentX
from typing import List, Tuple, Optional
VAD_SEGMENTATION_URL = "https://whisperx.s3.eu-west-2.amazonaws.com/model_weights/segmentation/0b5b3216d60a2d32fc086b47ea8c67589aaeb26b7e07fcbe620d6d0b83e209ea/pytorch_model.bin" VAD_SEGMENTATION_URL = "https://whisperx.s3.eu-west-2.amazonaws.com/model_weights/segmentation/0b5b3216d60a2d32fc086b47ea8c67589aaeb26b7e07fcbe620d6d0b83e209ea/pytorch_model.bin"
def load_vad_model(device, vad_onset, vad_offset, use_auth_token=None, model_fp=None): def load_vad_model(device, vad_onset=0.500, vad_offset=0.363, use_auth_token=None, model_fp=None):
model_dir = torch.hub._get_torch_home() model_dir = torch.hub._get_torch_home()
os.makedirs(model_dir, exist_ok = True) os.makedirs(model_dir, exist_ok = True)
if model_fp is None: if model_fp is None: