24 Commits

Author SHA1 Message Date
d8a2b4ffc9 Merge pull request #246 from m-bain/v3
V3
2023-05-13 12:18:09 +01:00
9ffb7e7a23 Merge branch 'v3' of https://github.com/m-bain/whisperX into v3
Conflicts:
	setup.py
2023-05-13 12:16:33 +01:00
fd8f1003cf add translate, fix word_timestamp error 2023-05-13 12:14:06 +01:00
46b416296f Merge pull request #123 from koldbrandt/danish_alignment
Danish alignment model
2023-05-09 23:10:24 +01:00
7642390d0a Merge branch 'main' into danish_alignment 2023-05-09 23:10:13 +01:00
8b05ad4dae Merge pull request #235 from sorgfresser/main
Add custom typing for results
2023-05-09 23:05:02 +01:00
5421f1d7ca remove v3 tag on pip install 2023-05-09 13:42:50 +01:00
91e959ec4f Merge branch 'm-bain:main' into main 2023-05-08 20:46:25 +02:00
eabf35dff0 Custom result types 2023-05-08 20:45:34 +02:00
4919ad21fc Merge pull request #233 from sorgfresser/main
Fix tuple unpacking
2023-05-08 19:05:47 +01:00
b50aafb17b Fix tuple unpacking 2023-05-08 20:03:42 +02:00
2efa136114 update python usage example 2023-05-08 17:20:38 +01:00
0b839f3f01 Update README.md 2023-05-07 20:36:08 +01:00
1caddfb564 Merge pull request #225 from m-bain/v3
V3
2023-05-07 20:31:16 +01:00
7ad554c64f Merge branch 'main' into v3 2023-05-07 20:30:57 +01:00
4603f010a5 update readme, setup, add option to return char_timestamps 2023-05-07 20:28:33 +01:00
24008aa1ed fix long segments, break into sentences using nltk, improve align logic, improve diarize (sentence-based) 2023-05-07 15:32:58 +01:00
07361ba1d7 add device to dia pipeline @sorgfresser 2023-05-05 11:53:51 +01:00
b666523004 add v3 pre-release comment, and v4 progress update 2023-05-02 15:10:40 +01:00
69e038cbc4 Merge pull request #209 from SohaibAnwaar/feat-dockerfile
feat: adding the docker file
2023-05-02 14:55:30 +01:00
a693a779fa feat: adding the docker file 2023-05-02 13:28:20 +05:00
5b85c5433f Update setup.py 2023-04-28 16:47:04 +01:00
d31f6e0b8a Merge branch 'm-bain:main' into danish_alignment 2023-03-06 10:52:47 +01:00
c8404d9805 added a danish alignment model 2023-03-04 13:20:40 +01:00
11 changed files with 458 additions and 652 deletions

3
.gitignore vendored
View File

@ -1,2 +1,3 @@
whisperx.egg-info/
**/__pycache__/
**/__pycache__/
.ipynb_checkpoints

166
README.md
View File

@ -13,36 +13,36 @@
<img src="https://img.shields.io/github/license/m-bain/whisperX.svg"
alt="GitHub license">
</a>
<a href="https://arxiv.org/abs/2303.00747">
<img src="http://img.shields.io/badge/Arxiv-2303.00747-B31B1B.svg"
alt="ArXiv paper">
</a>
<a href="https://twitter.com/intent/tweet?text=&url=https%3A%2F%2Fgithub.com%2Fm-bain%2FwhisperX">
<img src="https://img.shields.io/twitter/url/https/github.com/m-bain/whisperX.svg?style=social" alt="Twitter">
</a>
</p>
<p align="center">
<a href="#what-is-it">What is it</a>
<a href="#setup">Setup</a>
<a href="#example">Usage</a>
<a href="#other-languages">Multilingual</a>
<a href="#contribute">Contribute</a>
<a href="EXAMPLES.md">More examples</a>
<a href="https://arxiv.org/abs/2303.00747">Paper</a>
</p>
<img width="1216" align="center" alt="whisperx-arch" src="figures/pipeline.png">
<p align="left">Whisper-Based Automatic Speech Recognition (ASR) with improved timestamp accuracy + quality via forced phoneme alignment and speech-activity batching.
</p>
<!-- <p align="left">Whisper-Based Automatic Speech Recognition (ASR) with improved timestamp accuracy + quality via forced phoneme alignment and voice-activity based batching for fast inference.</p> -->
<h2 align="left", id="what-is-it">What is it 🔎</h2>
This repository refines the timestamps of openAI's Whisper model via forced aligment with phoneme-based ASR models (e.g. wav2vec2.0) and VAD preprocesssing, multilingual use-case.
<!-- <h2 align="left", id="what-is-it">What is it 🔎</h2> -->
**Whisper** is an ASR model [developed by OpenAI](https://github.com/openai/whisper), trained on a large dataset of diverse audio. Whilst it does produces highly accurate transcriptions, the corresponding timestamps are at the utterance-level, not per word, and can be inaccurate by several seconds.
This repository provides fast automatic speech recognition (70x realtime with large-v2) with word-level timestamps and speaker diarization.
- ⚡️ Batched inference for 70x realtime transcription using whisper large-v2
- 🪶 [faster-whisper](https://github.com/guillaumekln/faster-whisper) backend, requires <8GB gpu memory for large-v2 with beam_size=5
- 🎯 Accurate word-level timestamps using wav2vec2 alignment
- 👯 Multispeaker ASR using speaker diarization from [pyannote-audio](https://github.com/pyannote/pyannote-audio) (speaker ID labels)
- 🗣 VAD preprocessing, reduces hallucination & batching with no WER degradation
**Whisper** is an ASR model [developed by OpenAI](https://github.com/openai/whisper), trained on a large dataset of diverse audio. Whilst it does produces highly accurate transcriptions, the corresponding timestamps are at the utterance-level, not per word, and can be inaccurate by several seconds. OpenAI's whisper does not natively support batching.
**Phoneme-Based ASR** A suite of models finetuned to recognise the smallest unit of speech distinguishing one word from another, e.g. the element p in "tap". A popular example model is [wav2vec2.0](https://huggingface.co/facebook/wav2vec2-large-960h-lv60-self).
@ -50,15 +50,14 @@ This repository refines the timestamps of openAI's Whisper model via forced alig
**Voice Activity Detection (VAD)** is the detection of the presence or absence of human speech.
**Speaker Diarization** is the process of partitioning an audio stream containing human speech into homogeneous segments according to the identity of each speaker.
<h2 align="left", id="highlights">New🚨</h2>
- v3 transcript segment-per-sentence: using nltk sent_tokenize for better subtitlting & better diarization
- v3 released, 70x speed-up open-sourced. Using batched whisper with [faster-whisper](https://github.com/guillaumekln/faster-whisper) backend!
- v2 released, code cleanup, imports whisper library, batched inference from paper not included (contact for licensing / batched model API). VAD filtering is now turned on by default, as in the paper.
- Paper drop🎓👨🏫! Please see our [ArxiV preprint](https://arxiv.org/abs/2303.00747) for benchmarking and details of WhisperX. We also introduce more efficient batch inference resulting in large-v2 with *60-70x REAL TIME speed (not provided in this repo).
- VAD filtering: Voice Activity Detection (VAD) from [Pyannote.audio](https://huggingface.co/pyannote/voice-activity-detection) is used as a preprocessing step to remove reliance on whisper timestamps and only transcribe audio segments containing speech. add `--vad_filter True` flag, increases timestamp accuracy and robustness (requires more GPU mem due to 30s inputs in wav2vec2)
- Character level timestamps (see `*.char.ass` file output)
- Diarization (still in beta, add `--diarize`)
- v2 released, code cleanup, imports whisper library VAD filtering is now turned on by default, as in the paper.
- Paper drop🎓👨🏫! Please see our [ArxiV preprint](https://arxiv.org/abs/2303.00747) for benchmarking and details of WhisperX. We also introduce more efficient batch inference resulting in large-v2 with *60-70x REAL TIME speed.
<h2 align="left" id="setup">Setup ⚙️</h2>
Tested for PyTorch 2.0, Python 3.10 (use other versions at your own risk!)
@ -75,29 +74,27 @@ GPU execution requires the NVIDIA libraries cuBLAS 11.x and cuDNN 8.x to be inst
### 2. Install PyTorch2.0, e.g. for Linux and Windows CUDA11.7:
`pip3 install torch torchvision torchaudio`
`conda install pytorch==2.0.0 torchvision==0.15.0 torchaudio==2.0.0 pytorch-cuda=11.7 -c pytorch -c nvidia`
See other methods [here.](https://pytorch.org/get-started/locally/)
See other methods [here.](https://pytorch.org/get-started/previous-versions/#v200)
### 3. Install this repo
`pip install git+https://github.com/m-bain/whisperx.git@v3`
`pip install git+https://github.com/m-bain/whisperx.git`
If already installed, update package to most recent commit
`pip install git+https://github.com/m-bain/whisperx.git@v3 --upgrade`
`pip install git+https://github.com/m-bain/whisperx.git --upgrade`
If wishing to modify this package, clone and install in editable mode:
```
$ git clone https://github.com/m-bain/whisperX.git@v3
$ git clone https://github.com/m-bain/whisperX.git
$ cd whisperX
$ git checkout v3
$ pip install -e .
```
You may also need to install ffmpeg, rust etc. Follow openAI instructions here https://github.com/openai/whisper#setup.
### Speaker Diarization
To **enable Speaker. Diarization**, include your Hugging Face access token that you can generate from [Here](https://huggingface.co/settings/tokens) after the `--hf_token` argument and accept the user agreement for the following models: [Segmentation](https://huggingface.co/pyannote/segmentation) , [Voice Activity Detection (VAD)](https://huggingface.co/pyannote/voice-activity-detection) , and [Speaker Diarization](https://huggingface.co/pyannote/speaker-diarization)
@ -106,15 +103,11 @@ To **enable Speaker. Diarization**, include your Hugging Face access token that
### English
Run whisper on example segment (using default params)
Run whisper on example segment (using default params, whisper small) add `--highlight_words True` to visualise word timings in the .srt file.
whisperx examples/sample01.wav
For increased timestamp accuracy, at the cost of higher gpu mem, use bigger models (bigger alignment model not found to be that helpful, see paper) e.g.
whisperx examples/sample01.wav --model large-v2 --align_model WAV2VEC2_ASR_LARGE_LV60K_960H
Result using *WhisperX* with forced alignment to wav2vec2.0 large:
https://user-images.githubusercontent.com/36994049/208253969-7e35fe2a-7541-434a-ae91-8e919540555d.mp4
@ -123,6 +116,16 @@ Compare this to original whisper out the box, where many transcriptions are out
https://user-images.githubusercontent.com/36994049/207743923-b4f0d537-29ae-4be2-b404-bb941db73652.mov
For increased timestamp accuracy, at the cost of higher gpu mem, use bigger models (bigger alignment model not found to be that helpful, see paper) e.g.
whisperx examples/sample01.wav --model large-v2 --align_model WAV2VEC2_ASR_LARGE_LV60K_960H --batch_size 4
To label the transcript with speaker ID's (set number of speakers if known e.g. `--min_speakers 2` `--max_speakers 2`):
whisperx examples/sample01.wav --model large-v2 --diarize --highlight_words True
### Other languages
The phoneme ASR alignment model is *language-specific*, for tested languages these models are [automatically picked from torchaudio pipelines or huggingface](https://github.com/m-bain/whisperX/blob/e909f2f766b23b2000f2d95df41f9b844ac53e49/whisperx/transcribe.py#L22).
@ -132,7 +135,7 @@ Currently default models provided for `{en, fr, de, es, it, ja, zh, nl, uk, pt}`
#### E.g. German
whisperx --model large --language de examples/sample_de_01.wav
whisperx --model large-v2 --language de examples/sample_de_01.wav
https://user-images.githubusercontent.com/36994049/208298811-e36002ba-3698-4731-97d4-0aebd07e0eb3.mov
@ -143,79 +146,108 @@ See more examples in other languages [here](EXAMPLES.md).
```python
import whisperx
import gc
device = "cuda"
audio_file = "audio.mp3"
batch_size = 16 # reduce if low on GPU mem
compute_type = "float16" # change to "int8" if low on GPU mem (may reduce accuracy)
# transcribe with original whisper
model = whisperx.load_model("large-v2", device)
# 1. Transcribe with original whisper (batched)
model = whisperx.load_model("large-v2", device, compute_type=compute_type)
audio = whisperx.load_audio(audio_file)
result = model.transcribe(audio, batch_size=8)
result = model.transcribe(audio, batch_size=batch_size)
print(result["segments"]) # before alignment
# load alignment model and metadata
# delete model if low on GPU resources
# import gc; gc.collect(); torch.cuda.empty_cache(); del model
# 2. Align whisper output
model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
result = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False)
# align whisper output
result_aligned = whisperx.align(result["segments"], model_a, metadata, audio, device)
print(result["segments"]) # after alignment
print(result_aligned["segments"]) # after alignment
print(result_aligned["word_segments"]) # after alignment
# delete model if low on GPU resources
# import gc; gc.collect(); torch.cuda.empty_cache(); del model_a
# 3. Assign speaker labels
diarize_model = whisperx.DiarizationPipeline(use_auth_token=YOUR_HF_TOKEN, device=device)
# add min/max number of speakers if known
diarize_segments = diarize_model(audio_file)
# diarize_model(audio_file, min_speakers=min_speakers, max_speakers=max_speakers)
result = whisperx.assign_word_speakers(diarize_segments, result)
print(diarize_segments)
print(result["segments"]) # segments are now assigned speaker IDs
```
<h2 align="left" id="whisper-mod">Whisper Modifications</h2>
<h2 align="left" id="whisper-mod">Technical Details 👷‍♂️</h2>
In addition to forced alignment, the following two modifications have been made to the whisper transcription method:
For specific details on the batching and alignment, the effect of VAD, as well as the chosen alignment model, see the preprint [paper](https://www.robots.ox.ac.uk/~vgg/publications/2023/Bain23/bain23.pdf).
1. `--condition_on_prev_text` is set to `False` by default (reduces hallucination)
To reduce GPU memory requirements, try any of the following (2. & 3. can affect quality):
1. reduce batch size, e.g. `--batch_size 4`
2. use a smaller ASR model `--model base`
3. Use lighter compute type `--compute_type int8`
Transcription differences from openai's whisper:
1. Transcription without timestamps. To enable single pass batching, whisper inference is performed `--without_timestamps True`, this ensures 1 forward pass per sample in the batch. However, this can cause discrepancies the default whisper output.
2. VAD-based segment transcription, unlike the buffered transcription of openai's. In Wthe WhisperX paper we show this reduces WER, and enables accurate batched inference
3. `--condition_on_prev_text` is set to `False` by default (reduces hallucination)
<h2 align="left" id="limitations">Limitations ⚠️</h2>
- Whisper normalises spoken numbers e.g. "fifty seven" to arabic numerals "57". Need to perform this normalization after alignment, so the phonemes can be aligned. Currently just ignores numbers.
- If setting `--vad_filter False`, then whisperx assumes the initial whisper timestamps are accurate to some degree (within margin of 2 seconds, adjust if needed -- bigger margins more prone to alignment errors)
- Transcript words which do not contain characters in the alignment models dictionary e.g. "2014." or "£13.60" cannot be aligned and therefore are not given a timing.
- Overlapping speech is not handled particularly well by whisper nor whisperx
- Diariazation is far from perfect.
- Diarization is far from perfect (working on this with custom model v4 -- see contact me).
- Language specific wav2vec2 model is needed
<h2 align="left" id="contribute">Contribute 🧑‍🏫</h2>
If you are multilingual, a major way you can contribute to this project is to find phoneme models on huggingface (or train your own) and test them on speech for the target language. If the results look good send a merge request and some examples showing its success.
If you are multilingual, a major way you can contribute to this project is to find phoneme models on huggingface (or train your own) and test them on speech for the target language. If the results look good send a pull request and some examples showing its success.
The next major upgrade we are working on is whisper with speaker diarization, so if you have any experience on this please share.
Bug finding and pull requests are also highly appreciated to keep this project going, since it's already diverging from the original research scope.
<h2 align="left" id="coming-soon">Coming Soon 🗓</h2>
<h2 align="left" id="coming-soon">TODO 🗓</h2>
* [x] Multilingual init
* [x] Subtitle .ass output
* [x] Automatic align model selection based on language detection
* [x] Python usage
* [x] Character level timestamps
* [x] Incorporating speaker diarization
* [x] Model flush, for low gpu mem resources
* [x] Faster-whisper backend
* [x] Add max-line etc. see (openai's whisper utils.py)
* [x] Sentence-level segments (nltk toolbox)
* [x] Improve alignment logic
* [ ] update examples with diarization and word highlighting
* [ ] Subtitle .ass output <- bring this back (removed in v3)
* [ ] Add benchmarking code (TEDLIUM for spd/WER & word segmentation)
* [ ] Allow silero-vad as alternative VAD option
* [ ] Add max-line etc. see (openai's whisper utils.py)
* [ ] Improve diarization (word level). *Harder than first thought...*
<h2 align="left" id="contact">Contact/Support 📇</h2>
Contact maxhbain@gmail.com for queries and licensing / early access to a model API with batched inference (transcribe 1hr audio in under 1min).
Contact maxhbain@gmail.com for queries. WhisperX v4 development is underway with with siginificantly improved diarization. To support v4 and get early access, get in touch.
<a href="https://www.buymeacoffee.com/maxhbain" target="_blank"><img src="https://cdn.buymeacoffee.com/buttons/default-orange.png" alt="Buy Me A Coffee" height="41" width="174"></a>
@ -224,14 +256,18 @@ Contact maxhbain@gmail.com for queries and licensing / early access to a model A
This work, and my PhD, is supported by the [VGG (Visual Geometry Group)](https://www.robots.ox.ac.uk/~vgg/) and the University of Oxford.
Of course, this is builds on [openAI's whisper](https://github.com/openai/whisper).
And borrows important alignment code from [PyTorch tutorial on forced alignment](https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html)
Borrows important alignment code from [PyTorch tutorial on forced alignment](https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html)
And uses the wonderful pyannote VAD / Diarization https://github.com/pyannote/pyannote-audio
Valuable VAD & Diarization Models from (pyannote.audio)[https://github.com/pyannote/pyannote-audio]
Great backend from (faster-whisper)[https://github.com/guillaumekln/faster-whisper] and (CTranslate2)[https://github.com/OpenNMT/CTranslate2]
Valuable VAD & Diarization Models from [pyannote audio][https://github.com/pyannote/pyannote-audio]
Great backend from [faster-whisper](https://github.com/guillaumekln/faster-whisper) and [CTranslate2](https://github.com/OpenNMT/CTranslate2)
Those who have [supported this work financially](https://www.buymeacoffee.com/maxhbain) 🙏
Finally, thanks to the OS [contributors](https://github.com/m-bain/whisperX/graphs/contributors) of this project, keeping it going and identifying bugs.
<h2 align="left" id="cite">Citation</h2>
If you use this in your research, please cite the paper:

View File

@ -4,4 +4,5 @@ faster-whisper
transformers
ffmpeg-python==0.2.0
pandas
setuptools==65.6.3
setuptools==65.6.3
nltk

View File

@ -6,7 +6,7 @@ from setuptools import setup, find_packages
setup(
name="whisperx",
py_modules=["whisperx"],
version="3.0.2",
version="3.1.1",
description="Time-Accurate Automatic Speech Recognition using Whisper.",
readme="README.md",
python_requires=">=3.8",

View File

@ -1,3 +1,4 @@
from .transcribe import load_model
from .alignment import load_align_model, align
from .audio import load_audio
from .audio import load_audio
from .diarize import assign_word_speakers, DiarizationPipeline

View File

@ -3,7 +3,7 @@ Forced Alignment with Whisper
C. Max Bain
"""
from dataclasses import dataclass
from typing import Iterator, Union
from typing import Iterator, Union, List
import numpy as np
import pandas as pd
@ -13,6 +13,8 @@ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from .audio import SAMPLE_RATE, load_audio
from .utils import interpolate_nans
from .types import AlignedTranscriptionResult, SingleSegment, SingleAlignedSegment, SingleWordSegment
import nltk
LANGUAGES_WITHOUT_SPACES = ["ja", "zh"]
@ -38,6 +40,7 @@ DEFAULT_ALIGN_MODELS_HF = {
"fa": "jonatasgrosman/wav2vec2-large-xlsr-53-persian",
"el": "jonatasgrosman/wav2vec2-large-xlsr-53-greek",
"tr": "mpoyraz/wav2vec2-xls-r-300m-cv7-turkish",
"da": "saattrupdan/wav2vec2-xls-r-300m-ftspeech",
"he": "imvladikon/wav2vec2-xls-r-300m-hebrew",
}
@ -79,391 +82,236 @@ def load_align_model(language_code, device, model_name=None, model_dir=None):
def align(
transcript: Iterator[dict],
transcript: Iterator[SingleSegment],
model: torch.nn.Module,
align_model_metadata: dict,
audio: Union[str, np.ndarray, torch.Tensor],
device: str,
extend_duration: float = 0.0,
start_from_previous: bool = True,
interpolate_method: str = "nearest",
):
return_char_alignments: bool = False,
) -> AlignedTranscriptionResult:
"""
Align phoneme recognition predictions to known transcription.
"""
Force align phoneme recognition predictions to known transcription
Parameters
----------
transcript: Iterator[dict]
The Whisper model instance
model: torch.nn.Module
Alignment model (wav2vec2)
audio: Union[str, np.ndarray, torch.Tensor]
The path to the audio file to open, or the audio waveform
device: str
cuda device
diarization: pd.DataFrame {'start': List[float], 'end': List[float], 'speaker': List[float]}
diarization segments with speaker labels.
extend_duration: float
Amount to pad input segments by. If not using vad--filter then recommended to use 2 seconds
If the gzip compression ratio is above this value, treat as failed
interpolate_method: str ["nearest", "linear", "ignore"]
Method to assign timestamps to non-aligned words. Words are not able to be aligned when none of the characters occur in the align model dictionary.
"nearest" copies timestamp of nearest word within the segment. "linear" is linear interpolation. "drop" removes that word from output.
Returns
-------
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
the spoken language ("language"), which is detected when `decode_options["language"]` is None.
"""
if not torch.is_tensor(audio):
if isinstance(audio, str):
audio = load_audio(audio)
audio = torch.from_numpy(audio)
if len(audio.shape) == 1:
audio = audio.unsqueeze(0)
MAX_DURATION = audio.shape[1] / SAMPLE_RATE
model_dictionary = align_model_metadata["dictionary"]
model_lang = align_model_metadata["language"]
model_type = align_model_metadata["type"]
aligned_segments = []
prev_t2 = 0
char_segments_arr = {
"segment-idx": [],
"subsegment-idx": [],
"word-idx": [],
"char": [],
"start": [],
"end": [],
"score": [],
}
# 1. Preprocess to keep only characters in dictionary
for sdx, segment in enumerate(transcript):
while True:
segment_align_success = False
# strip spaces at beginning / end, but keep track of the amount.
num_leading = len(segment["text"]) - len(segment["text"].lstrip())
num_trailing = len(segment["text"]) - len(segment["text"].rstrip())
text = segment["text"]
# strip spaces at beginning / end, but keep track of the amount.
num_leading = len(segment["text"]) - len(segment["text"].lstrip())
num_trailing = len(segment["text"]) - len(segment["text"].rstrip())
transcription = segment["text"]
# split into words
if model_lang not in LANGUAGES_WITHOUT_SPACES:
per_word = text.split(" ")
else:
per_word = text
# TODO: convert number tokenizer / symbols to phonetic words for alignment.
# e.g. "$300" -> "three hundred dollars"
# currently "$300" is ignored since no characters present in the phonetic dictionary
# split into words
clean_char, clean_cdx = [], []
for cdx, char in enumerate(text):
char_ = char.lower()
# wav2vec2 models use "|" character to represent spaces
if model_lang not in LANGUAGES_WITHOUT_SPACES:
per_word = transcription.split(" ")
char_ = char_.replace(" ", "|")
# ignore whitespace at beginning and end of transcript
if cdx < num_leading:
pass
elif cdx > len(text) - num_trailing - 1:
pass
elif char_ in model_dictionary.keys():
clean_char.append(char_)
clean_cdx.append(cdx)
clean_wdx = []
for wdx, wrd in enumerate(per_word):
if any([c in model_dictionary.keys() for c in wrd]):
clean_wdx.append(wdx)
sentence_spans = list(nltk.tokenize.punkt.PunktSentenceTokenizer().span_tokenize(text))
segment["clean_char"] = clean_char
segment["clean_cdx"] = clean_cdx
segment["clean_wdx"] = clean_wdx
segment["sentence_spans"] = sentence_spans
aligned_segments: List[SingleAlignedSegment] = []
# 2. Get prediction matrix from alignment model & align
for sdx, segment in enumerate(transcript):
t1 = segment["start"]
t2 = segment["end"]
text = segment["text"]
aligned_seg: SingleAlignedSegment = {
"start": t1,
"end": t2,
"text": text,
"words": [],
}
if return_char_alignments:
aligned_seg["chars"] = []
# check we can align
if len(segment["clean_char"]) == 0:
print(f'Failed to align segment ("{segment["text"]}"): no characters in this segment found in model dictionary, resorting to original...')
aligned_segments.append(aligned_seg)
continue
if t1 >= MAX_DURATION or t2 - t1 < 0.02:
print("Failed to align segment: original start time longer than audio duration, skipping...")
aligned_segments.append(aligned_seg)
continue
text_clean = "".join(segment["clean_char"])
tokens = [model_dictionary[c] for c in text_clean]
f1 = int(t1 * SAMPLE_RATE)
f2 = int(t2 * SAMPLE_RATE)
# TODO: Probably can get some speedup gain with batched inference here
waveform_segment = audio[:, f1:f2]
with torch.inference_mode():
if model_type == "torchaudio":
emissions, _ = model(waveform_segment.to(device))
elif model_type == "huggingface":
emissions = model(waveform_segment.to(device)).logits
else:
per_word = transcription
raise NotImplementedError(f"Align model of type {model_type} not supported.")
emissions = torch.log_softmax(emissions, dim=-1)
# first check that characters in transcription can be aligned (they are contained in align model"s dictionary)
clean_char, clean_cdx = [], []
for cdx, char in enumerate(transcription):
char_ = char.lower()
# wav2vec2 models use "|" character to represent spaces
if model_lang not in LANGUAGES_WITHOUT_SPACES:
char_ = char_.replace(" ", "|")
# ignore whitespace at beginning and end of transcript
if cdx < num_leading:
pass
elif cdx > len(transcription) - num_trailing - 1:
pass
elif char_ in model_dictionary.keys():
clean_char.append(char_)
clean_cdx.append(cdx)
emission = emissions[0].cpu().detach()
clean_wdx = []
for wdx, wrd in enumerate(per_word):
if any([c in model_dictionary.keys() for c in wrd]):
clean_wdx.append(wdx)
blank_id = 0
for char, code in model_dictionary.items():
if char == '[pad]' or char == '<pad>':
blank_id = code
# if no characters are in the dictionary, then we skip this segment...
if len(clean_char) == 0:
print(f'Failed to align segment ("{segment["text"]}"): no characters in this segment found in model dictionary, resorting to original...')
break
transcription_cleaned = "".join(clean_char)
tokens = [model_dictionary[c] for c in transcription_cleaned]
trellis = get_trellis(emission, tokens, blank_id)
path = backtrack(trellis, emission, tokens, blank_id)
# we only pad if not using VAD filtering
if "seg_text" not in segment:
# pad according original timestamps
t1 = max(segment["start"] - extend_duration, 0)
t2 = min(segment["end"] + extend_duration, MAX_DURATION)
if path is None:
print(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original...')
aligned_segments.append(aligned_seg)
continue
# use prev_t2 as current t1 if it"s later
if start_from_previous and t1 < prev_t2:
t1 = prev_t2
char_segments = merge_repeats(path, text_clean)
# check if timestamp range is still valid
if t1 >= MAX_DURATION:
print("Failed to align segment: original start time longer than audio duration, skipping...")
break
if t2 - t1 < 0.02:
print("Failed to align segment: duration smaller than 0.02s time precision")
break
duration = t2 -t1
ratio = duration * waveform_segment.size(0) / (trellis.size(0) - 1)
f1 = int(t1 * SAMPLE_RATE)
f2 = int(t2 * SAMPLE_RATE)
# assign timestamps to aligned characters
char_segments_arr = []
word_idx = 0
for cdx, char in enumerate(text):
start, end, score = None, None, None
if cdx in segment["clean_cdx"]:
char_seg = char_segments[segment["clean_cdx"].index(cdx)]
start = round(char_seg.start * ratio + t1, 3)
end = round(char_seg.end * ratio + t1, 3)
score = round(char_seg.score, 3)
waveform_segment = audio[:, f1:f2]
char_segments_arr.append(
{
"char": char,
"start": start,
"end": end,
"score": score,
"word-idx": word_idx,
}
)
with torch.inference_mode():
if model_type == "torchaudio":
emissions, _ = model(waveform_segment.to(device))
elif model_type == "huggingface":
emissions = model(waveform_segment.to(device)).logits
else:
raise NotImplementedError(f"Align model of type {model_type} not supported.")
emissions = torch.log_softmax(emissions, dim=-1)
emission = emissions[0].cpu().detach()
blank_id = 0
for char, code in model_dictionary.items():
if char == '[pad]' or char == '<pad>':
blank_id = code
trellis = get_trellis(emission, tokens, blank_id)
path = backtrack(trellis, emission, tokens, blank_id)
if path is None:
print(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original...')
break
char_segments = merge_repeats(path, transcription_cleaned)
# word_segments = merge_words(char_segments)
# increment word_idx, nltk word tokenization would probably be more robust here, but us space for now...
if model_lang in LANGUAGES_WITHOUT_SPACES:
word_idx += 1
elif cdx == len(text) - 1 or text[cdx+1] == " ":
word_idx += 1
char_segments_arr = pd.DataFrame(char_segments_arr)
# sub-segments
if "seg-text" not in segment:
segment["seg-text"] = [transcription]
seg_lens = [0] + [len(x) for x in segment["seg-text"]]
seg_lens_cumsum = list(np.cumsum(seg_lens))
sub_seg_idx = 0
wdx = 0
duration = t2 - t1
ratio = duration * waveform_segment.size(0) / (trellis.size(0) - 1)
for cdx, char in enumerate(transcription + " "):
is_last = False
if cdx == len(transcription):
break
elif cdx+1 == len(transcription):
is_last = True
start, end, score = None, None, None
if cdx in clean_cdx:
char_seg = char_segments[clean_cdx.index(cdx)]
start = round(char_seg.start * ratio + t1, 3)
end = round(char_seg.end * ratio + t1, 3)
score = char_seg.score
char_segments_arr["char"].append(char)
char_segments_arr["start"].append(start)
char_segments_arr["end"].append(end)
char_segments_arr["score"].append(score)
char_segments_arr["word-idx"].append(wdx)
char_segments_arr["segment-idx"].append(sdx)
char_segments_arr["subsegment-idx"].append(sub_seg_idx)
# word-level info
if model_lang in LANGUAGES_WITHOUT_SPACES:
# character == word
wdx += 1
elif is_last or transcription[cdx+1] == " " or cdx == seg_lens_cumsum[sub_seg_idx+1] - 1:
wdx += 1
if is_last or cdx == seg_lens_cumsum[sub_seg_idx+1] - 1:
wdx = 0
sub_seg_idx += 1
prev_t2 = segment["end"]
segment_align_success = True
# end while True loop
break
# reset prev_t2 due to drifting issues
if not segment_align_success:
prev_t2 = 0
aligned_subsegments = []
# assign sentence_idx to each character index
char_segments_arr["sentence-idx"] = None
for sdx, (sstart, send) in enumerate(segment["sentence_spans"]):
curr_chars = char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send)]
char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send), "sentence-idx"] = sdx
char_segments_arr = pd.DataFrame(char_segments_arr)
not_space = char_segments_arr["char"] != " "
sentence_text = text[sstart:send]
sentence_start = curr_chars["start"].min()
sentence_end = curr_chars["end"].max()
sentence_words = []
per_seg_grp = char_segments_arr.groupby(["segment-idx", "subsegment-idx"], as_index = False)
char_segments_arr = per_seg_grp.apply(lambda x: x.reset_index(drop = True)).reset_index()
per_word_grp = char_segments_arr[not_space].groupby(["segment-idx", "subsegment-idx", "word-idx"])
per_subseg_grp = char_segments_arr[not_space].groupby(["segment-idx", "subsegment-idx"])
per_seg_grp = char_segments_arr[not_space].groupby(["segment-idx"])
char_segments_arr["local-char-idx"] = char_segments_arr.groupby(["segment-idx", "subsegment-idx"]).cumcount()
per_word_grp = char_segments_arr[not_space].groupby(["segment-idx", "subsegment-idx", "word-idx"]) # regroup
for word_idx in curr_chars["word-idx"].unique():
word_chars = curr_chars.loc[curr_chars["word-idx"] == word_idx]
word_text = "".join(word_chars["char"].tolist()).strip()
if len(word_text) == 0:
continue
word_segments_arr = {}
# dont use space character for alignment
word_chars = word_chars[word_chars["char"] != " "]
# start of word is first char with a timestamp
word_segments_arr["start"] = per_word_grp["start"].min().values
# end of word is last char with a timestamp
word_segments_arr["end"] = per_word_grp["end"].max().values
# score of word is mean (excluding nan)
word_segments_arr["score"] = per_word_grp["score"].mean().values
word_start = word_chars["start"].min()
word_end = word_chars["end"].max()
word_score = round(word_chars["score"].mean(), 3)
word_segments_arr["segment-text-start"] = per_word_grp["local-char-idx"].min().astype(int).values
word_segments_arr["segment-text-end"] = per_word_grp["local-char-idx"].max().astype(int).values+1
word_segments_arr = pd.DataFrame(word_segments_arr)
# -1 indicates unalignable
word_segment = {"word": word_text}
word_segments_arr[["segment-idx", "subsegment-idx", "word-idx"]] = per_word_grp["local-char-idx"].min().reset_index()[["segment-idx", "subsegment-idx", "word-idx"]].astype(int)
segments_arr = {}
segments_arr["start"] = per_subseg_grp["start"].min().reset_index()["start"]
segments_arr["end"] = per_subseg_grp["end"].max().reset_index()["end"]
segments_arr = pd.DataFrame(segments_arr)
segments_arr[["segment-idx", "subsegment-idx-start"]] = per_subseg_grp["start"].min().reset_index()[["segment-idx", "subsegment-idx"]]
segments_arr["subsegment-idx-end"] = segments_arr["subsegment-idx-start"] + 1
if not np.isnan(word_start):
word_segment["start"] = word_start
if not np.isnan(word_end):
word_segment["end"] = word_end
if not np.isnan(word_score):
word_segment["score"] = word_score
# interpolate missing words / sub-segments
if interpolate_method != "ignore":
wrd_subseg_grp = word_segments_arr.groupby(["segment-idx", "subsegment-idx"], group_keys=False)
wrd_seg_grp = word_segments_arr.groupby(["segment-idx"], group_keys=False)
# we still know which word timestamps are interpolated because their score == nan
word_segments_arr["start"] = wrd_subseg_grp['start'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
word_segments_arr["end"] = wrd_subseg_grp['end'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
word_segments_arr["start"] = wrd_seg_grp['start'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
word_segments_arr["end"] = wrd_seg_grp['end'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
sub_seg_grp = segments_arr.groupby(["segment-idx"], group_keys=False)
segments_arr['start'] = sub_seg_grp['start'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
segments_arr['end'] = sub_seg_grp['end'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
# merge words & subsegments which are missing times
word_grp = word_segments_arr.groupby(["segment-idx", "subsegment-idx", "end"])
word_segments_arr["segment-text-start"] = word_grp["segment-text-start"].transform(min)
word_segments_arr["segment-text-end"] = word_grp["segment-text-end"].transform(max)
word_segments_arr.drop_duplicates(subset=["segment-idx", "subsegment-idx", "end"], inplace=True)
seg_grp_dup = segments_arr.groupby(["segment-idx", "start", "end"])
segments_arr["subsegment-idx-start"] = seg_grp_dup["subsegment-idx-start"].transform(min)
segments_arr["subsegment-idx-end"] = seg_grp_dup["subsegment-idx-end"].transform(max)
segments_arr.drop_duplicates(subset=["segment-idx", "subsegment-idx-start", "subsegment-idx-end"], inplace=True)
else:
word_segments_arr.dropna(inplace=True)
segments_arr.dropna(inplace=True)
# if some segments still have missing timestamps (usually because all numerals / symbols), then use original timestamps...
segments_arr['start'].fillna(pd.Series([x['start'] for x in transcript]), inplace=True)
segments_arr['end'].fillna(pd.Series([x['end'] for x in transcript]), inplace=True)
segments_arr['subsegment-idx-start'].fillna(0, inplace=True)
segments_arr['subsegment-idx-end'].fillna(1, inplace=True)
aligned_segments = []
aligned_segments_word = []
word_segments_arr.set_index(["segment-idx", "subsegment-idx"], inplace=True)
char_segments_arr.set_index(["segment-idx", "subsegment-idx", "word-idx"], inplace=True)
for sdx, srow in segments_arr.iterrows():
seg_idx = int(srow["segment-idx"])
sub_start = int(srow["subsegment-idx-start"])
sub_end = int(srow["subsegment-idx-end"])
seg = transcript[seg_idx]
text = "".join(seg["seg-text"][sub_start:sub_end])
wseg = word_segments_arr.loc[seg_idx].loc[sub_start:sub_end-1]
wseg["start"].fillna(srow["start"], inplace=True)
wseg["end"].fillna(srow["end"], inplace=True)
wseg["segment-text-start"].fillna(0, inplace=True)
wseg["segment-text-end"].fillna(len(text)-1, inplace=True)
cseg = char_segments_arr.loc[seg_idx].loc[sub_start:sub_end-1]
# fixes bug for single segment in transcript
cseg['segment-text-start'] = cseg['level_1'] if 'level_1' in cseg else 0
cseg['segment-text-end'] = cseg['level_1'] + 1 if 'level_1' in cseg else 1
if 'level_1' in cseg: del cseg['level_1']
if 'level_0' in cseg: del cseg['level_0']
cseg.reset_index(inplace=True)
def get_raw_text(word_row):
return seg["seg-text"][word_row.name][int(word_row["segment-text-start"]):int(word_row["segment-text-end"])+1]
word_list = []
wdx = 0
curr_text = get_raw_text(wseg.iloc[wdx])
if not curr_text.startswith(" "):
curr_text = " " + curr_text
sentence_words.append(word_segment)
if len(wseg) > 1:
for _, wrow in wseg.iloc[1:].iterrows():
if wrow['start'] != wseg.iloc[wdx]['start']:
word_start = wseg.iloc[wdx]['start']
word_end = wseg.iloc[wdx]['end']
aligned_subsegments.append({
"text": sentence_text,
"start": sentence_start,
"end": sentence_end,
"words": sentence_words,
})
aligned_segments_word.append(
{
"text": curr_text.strip(),
"start": word_start,
"end": word_end
}
)
if return_char_alignments:
curr_chars = curr_chars[["char", "start", "end", "score"]]
curr_chars.fillna(-1, inplace=True)
curr_chars = curr_chars.to_dict("records")
curr_chars = [{key: val for key, val in char.items() if val != -1} for char in curr_chars]
aligned_subsegments[-1]["chars"] = curr_chars
word_list.append(
{
"word": curr_text.rstrip(),
"start": word_start,
"end": word_end,
}
)
aligned_subsegments = pd.DataFrame(aligned_subsegments)
aligned_subsegments["start"] = interpolate_nans(aligned_subsegments["start"], method=interpolate_method)
aligned_subsegments["end"] = interpolate_nans(aligned_subsegments["end"], method=interpolate_method)
# concatenate sentences with same timestamps
agg_dict = {"text": " ".join, "words": "sum"}
if return_char_alignments:
agg_dict["chars"] = "sum"
aligned_subsegments= aligned_subsegments.groupby(["start", "end"], as_index=False).agg(agg_dict)
aligned_subsegments = aligned_subsegments.to_dict('records')
aligned_segments += aligned_subsegments
curr_text = " "
curr_text += get_raw_text(wrow) + " "
wdx += 1
aligned_segments_word.append(
{
"text": curr_text.strip(),
"start": wseg.iloc[wdx]["start"],
"end": wseg.iloc[wdx]["end"]
}
)
word_list.append(
{
"word": curr_text.rstrip(),
"start": wseg.iloc[wdx]['start'],
"end": wseg.iloc[wdx]['end'],
}
)
aligned_segments.append(
{
"start": srow["start"],
"end": srow["end"],
"text": text,
"words": word_list,
"word-segments": wseg,
"char-segments": cseg
}
)
return {"segments": aligned_segments, "word_segments": aligned_segments_word}
# create word_segments list
word_segments: List[SingleWordSegment] = []
for segment in aligned_segments:
word_segments += segment["words"]
return {"segments": aligned_segments, "word_segments": word_segments}
"""
source: https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html

View File

@ -11,10 +11,10 @@ from transformers.pipelines.pt_utils import PipelineIterator
from .audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram
from .vad import load_vad_model, merge_chunks
from .types import TranscriptionResult, SingleSegment
def load_model(whisper_arch, device, compute_type="float16", asr_options=None, language=None,
vad_options=None, model=None):
vad_options=None, model=None, task="transcribe"):
'''Load a Whisper model for inference.
Args:
whisper_arch: str - The name of the Whisper model to load.
@ -31,7 +31,7 @@ def load_model(whisper_arch, device, compute_type="float16", asr_options=None, l
model = WhisperModel(whisper_arch, device=device, compute_type=compute_type)
if language is not None:
tokenizer = faster_whisper.tokenizer.Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task="transcribe", language=language)
tokenizer = faster_whisper.tokenizer.Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task=task, language=language)
else:
print("No language specified, language will be first be detected for each audio file (increases inference time).")
tokenizer = None
@ -78,7 +78,7 @@ def load_model(whisper_arch, device, compute_type="float16", asr_options=None, l
class WhisperModel(faster_whisper.WhisperModel):
'''
FasterWhisperModel provides batched inference for faster-whisper.
Currently only works in non-timestamp mode.
Currently only works in non-timestamp mode and fixed prompt for all samples in batch.
'''
def generate_segment_batched(self, features: np.ndarray, tokenizer: faster_whisper.tokenizer.Tokenizer, options: faster_whisper.transcribe.TranscriptionOptions, encoder_output = None):
@ -140,6 +140,13 @@ class WhisperModel(faster_whisper.WhisperModel):
return self.model.encode(features, to_cpu=to_cpu)
class FasterWhisperPipeline(Pipeline):
"""
Huggingface Pipeline wrapper for FasterWhisperModel.
"""
# TODO:
# - add support for timestamp mode
# - add support for custom inference kwargs
def __init__(
self,
model,
@ -208,7 +215,7 @@ class FasterWhisperPipeline(Pipeline):
def transcribe(
self, audio: Union[str, np.ndarray], batch_size=None, num_workers=0
):
) -> TranscriptionResult:
if isinstance(audio, str):
audio = load_audio(audio)
@ -230,7 +237,7 @@ class FasterWhisperPipeline(Pipeline):
else:
language = self.tokenizer.language_code
segments = []
segments: List[SingleSegment] = []
batch_size = batch_size or self._batch_size
for idx, out in enumerate(self.__call__(data(audio, vad_segments), batch_size=batch_size, num_workers=num_workers)):
text = out['text']
@ -238,7 +245,7 @@ class FasterWhisperPipeline(Pipeline):
text = text[0]
segments.append(
{
"text": out['text'],
"text": text,
"start": round(vad_segments[idx]['start'], 3),
"end": round(vad_segments[idx]['end'], 3)
}
@ -261,149 +268,3 @@ class FasterWhisperPipeline(Pipeline):
language = language_token[2:-2]
print(f"Detected language: {language} ({language_probability:.2f}) in first 30s of audio...")
return language
if __name__ == "__main__":
main_type = "simple"
import time
import jiwer
from tqdm import tqdm
from whisper.normalizers import EnglishTextNormalizer
from benchmark.tedlium import parse_tedlium_annos
if main_type == "complex":
from faster_whisper.tokenizer import Tokenizer
from faster_whisper.transcribe import TranscriptionOptions
from faster_whisper.vad import (SpeechTimestampsMap,
get_speech_timestamps)
from whisperx.vad import load_vad_model, merge_chunks
from .audio import SAMPLE_RATE, load_audio, log_mel_spectrogram
faster_t_options = TranscriptionOptions(
beam_size=5,
best_of=5,
patience=1,
length_penalty=1,
temperatures=[0.0, 0.2, 0.4, 0.6, 0.8, 1.0],
compression_ratio_threshold=2.4,
log_prob_threshold=-1.0,
no_speech_threshold=0.6,
condition_on_previous_text=False,
initial_prompt=None,
prefix=None,
suppress_blank=True,
suppress_tokens=[-1],
without_timestamps=True,
max_initial_timestamp=0.0,
word_timestamps=False,
prepend_punctuations="\"'“¿([{-",
append_punctuations="\"'.。,!?::”)]}、"
)
whisper_arch = "large-v2"
device = "cuda"
batch_size = 16
model = WhisperModel(whisper_arch, device="cuda", compute_type="float16",)
tokenizer = Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task="transcribe", language="en")
model = FasterWhisperPipeline(model, tokenizer, faster_t_options, device=-1)
fn = "DanielKahneman_2010.wav"
wav_dir = f"/tmp/test/wav/"
vad_model = load_vad_model("cuda", 0.6, 0.3)
audio = load_audio(os.path.join(wav_dir, fn))
vad_segments = vad_model({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": SAMPLE_RATE})
vad_segments = merge_chunks(vad_segments, 30)
def data(audio, segments):
for seg in segments:
f1 = int(seg['start'] * SAMPLE_RATE)
f2 = int(seg['end'] * SAMPLE_RATE)
# print(f2-f1)
yield {'inputs': audio[f1:f2]}
vad_method="pyannote"
wav_dir = f"/tmp/test/wav/"
wer_li = []
time_li = []
for fn in os.listdir(wav_dir):
if fn == "RobertGupta_2010U.wav":
continue
base_fn = fn.split('.')[0]
audio_fp = os.path.join(wav_dir, fn)
audio = load_audio(audio_fp)
t1 = time.time()
if vad_method == "pyannote":
vad_segments = vad_model({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": SAMPLE_RATE})
vad_segments = merge_chunks(vad_segments, 30)
elif vad_method == "silero":
vad_segments = get_speech_timestamps(audio, threshold=0.5, max_speech_duration_s=30)
vad_segments = [{"start": x["start"] / SAMPLE_RATE, "end": x["end"] / SAMPLE_RATE} for x in vad_segments]
new_segs = []
curr_start = vad_segments[0]['start']
curr_end = vad_segments[0]['end']
for seg in vad_segments[1:]:
if seg['end'] - curr_start > 30:
new_segs.append({"start": curr_start, "end": curr_end})
curr_start = seg['start']
curr_end = seg['end']
else:
curr_end = seg['end']
new_segs.append({"start": curr_start, "end": curr_end})
vad_segments = new_segs
text = []
# for idx, out in tqdm(enumerate(model(data(audio_fp, vad_segments), batch_size=batch_size)), total=len(vad_segments)):
for idx, out in enumerate(model(data(audio, vad_segments), batch_size=batch_size)):
text.append(out['text'])
t2 = time.time()
if batch_size == 1:
text = [x[0] for x in text]
text = " ".join(text)
normalizer = EnglishTextNormalizer()
text = normalizer(text)
gt_corpus = normalizer(parse_tedlium_annos(base_fn, "/tmp/test/"))
wer_result = jiwer.wer(gt_corpus, text)
print("WER: %.2f \t time: %.2f \t [%s]" % (wer_result * 100, t2-t1, fn))
wer_li.append(wer_result)
time_li.append(t2-t1)
print("# Avg Mean...")
print("WER: %.2f" % (sum(wer_li) * 100/len(wer_li)))
print("Time: %.2f" % (sum(time_li)/len(time_li)))
elif main_type == "simple":
model = load_model(
"large-v2",
device="cuda",
language="en",
)
wav_dir = f"/tmp/test/wav/"
wer_li = []
time_li = []
for fn in os.listdir(wav_dir):
if fn == "RobertGupta_2010U.wav":
continue
# fn = "DanielKahneman_2010.wav"
base_fn = fn.split('.')[0]
audio_fp = os.path.join(wav_dir, fn)
audio = load_audio(audio_fp)
t1 = time.time()
out = model.transcribe(audio_fp, batch_size=8)["segments"]
t2 = time.time()
text = " ".join([x['text'] for x in out])
normalizer = EnglishTextNormalizer()
text = normalizer(text)
gt_corpus = normalizer(parse_tedlium_annos(base_fn, "/tmp/test/"))
wer_result = jiwer.wer(gt_corpus, text)
print("WER: %.2f \t time: %.2f \t [%s]" % (wer_result * 100, t2-t1, fn))
wer_li.append(wer_result)
time_li.append(t2-t1)
print("# Avg Mean...")
print("WER: %.2f" % (sum(wer_li) * 100/len(wer_li)))
print("Time: %.2f" % (sum(time_li)/len(time_li)))

View File

@ -20,59 +20,44 @@ class DiarizationPipeline:
diarize_df = pd.DataFrame(segments.itertracks(yield_label=True))
diarize_df['start'] = diarize_df[0].apply(lambda x: x.start)
diarize_df['end'] = diarize_df[0].apply(lambda x: x.end)
diarize_df.rename(columns={2: "speaker"}, inplace=True)
return diarize_df
def assign_word_speakers(diarize_df, result_segments, fill_nearest=False):
for seg in result_segments:
wdf = seg['word-segments']
if len(wdf['start'].dropna()) == 0:
wdf['start'] = seg['start']
wdf['end'] = seg['end']
speakers = []
for wdx, wrow in wdf.iterrows():
if not np.isnan(wrow['start']):
diarize_df['intersection'] = np.minimum(diarize_df['end'], wrow['end']) - np.maximum(diarize_df['start'], wrow['start'])
diarize_df['union'] = np.maximum(diarize_df['end'], wrow['end']) - np.minimum(diarize_df['start'], wrow['start'])
# remove no hit
if not fill_nearest:
dia_tmp = diarize_df[diarize_df['intersection'] > 0]
else:
dia_tmp = diarize_df
if len(dia_tmp) == 0:
speaker = None
else:
speaker = dia_tmp.sort_values("intersection", ascending=False).iloc[0][2]
else:
speaker = None
speakers.append(speaker)
seg['word-segments']['speaker'] = speakers
speaker_count = pd.Series(speakers).value_counts()
if len(speaker_count) == 0:
seg["speaker"]= "UNKNOWN"
def assign_word_speakers(diarize_df, transcript_result, fill_nearest=False):
transcript_segments = transcript_result["segments"]
for seg in transcript_segments:
# assign speaker to segment (if any)
diarize_df['intersection'] = np.minimum(diarize_df['end'], seg['end']) - np.maximum(diarize_df['start'], seg['start'])
diarize_df['union'] = np.maximum(diarize_df['end'], seg['end']) - np.minimum(diarize_df['start'], seg['start'])
# remove no hit, otherwise we look for closest (even negative intersection...)
if not fill_nearest:
dia_tmp = diarize_df[diarize_df['intersection'] > 0]
else:
seg["speaker"] = speaker_count.index[0]
dia_tmp = diarize_df
if len(dia_tmp) > 0:
# sum over speakers
speaker = dia_tmp.groupby("speaker")["intersection"].sum().sort_values(ascending=False).index[0]
seg["speaker"] = speaker
# assign speaker to words
if 'words' in seg:
for word in seg['words']:
if 'start' in word:
diarize_df['intersection'] = np.minimum(diarize_df['end'], word['end']) - np.maximum(diarize_df['start'], word['start'])
diarize_df['union'] = np.maximum(diarize_df['end'], word['end']) - np.minimum(diarize_df['start'], word['start'])
# remove no hit
if not fill_nearest:
dia_tmp = diarize_df[diarize_df['intersection'] > 0]
else:
dia_tmp = diarize_df
if len(dia_tmp) > 0:
# sum over speakers
speaker = dia_tmp.groupby("speaker")["intersection"].sum().sort_values(ascending=False).index[0]
word["speaker"] = speaker
return transcript_result
# create word level segments for .srt
word_seg = []
for seg in result_segments:
wseg = pd.DataFrame(seg["word-segments"])
for wdx, wrow in wseg.iterrows():
if wrow["start"] is not None:
speaker = wrow['speaker']
if speaker is None or speaker == np.nan:
speaker = "UNKNOWN"
word_seg.append(
{
"start": wrow["start"],
"end": wrow["end"],
"text": f"[{speaker}]: " + seg["text"][int(wrow["segment-text-start"]):int(wrow["segment-text-end"])]
}
)
# TODO: create segments but split words on new speaker
return result_segments, word_seg
class Segment:
def __init__(self, start, end, speaker=None):

View File

@ -35,6 +35,7 @@ def cli():
parser.add_argument("--align_model", default=None, help="Name of phoneme-level ASR model to do alignment")
parser.add_argument("--interpolate_method", default="nearest", choices=["nearest", "linear", "ignore"], help="For word .srt, method to assign timestamps to non-aligned words, or merge them into neighbouring.")
parser.add_argument("--no_align", action='store_true', help="Do not perform phoneme alignment")
parser.add_argument("--return_char_alignments", action='store_true', help="Return character-level alignments in the output json file")
# vad params
parser.add_argument("--vad_onset", type=float, default=0.500, help="Onset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected")
@ -42,8 +43,8 @@ def cli():
# diarization params
parser.add_argument("--diarize", action="store_true", help="Apply diarization to assign speaker labels to each segment/word")
parser.add_argument("--min_speakers", default=None, type=int)
parser.add_argument("--max_speakers", default=None, type=int)
parser.add_argument("--min_speakers", default=None, type=int, help="Minimum number of speakers to in audio file")
parser.add_argument("--max_speakers", default=None, type=int, help="Maximum number of speakers to in audio file")
parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
@ -64,14 +65,11 @@ def cli():
parser.add_argument("--max_line_width", type=optional_int, default=None, help="(not possible with --no_align) the maximum number of characters in a line before breaking the line")
parser.add_argument("--max_line_count", type=optional_int, default=None, help="(requires --no_align) the maximum number of lines in a segment")
parser.add_argument("--highlight_words", type=str2bool, default=False, help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt")
parser.add_argument("--segment_resolution", type=str, default="sentence", choices=["sentence", "chunk"], help="(not possible with --no_align) the maximum number of characters in a line before breaking the line")
# parser.add_argument("--word_timestamps", type=str2bool, default=False, help="(experimental) extract word-level timestamps and refine the results based on them")
# parser.add_argument("--prepend_punctuations", type=str, default="\"\'“¿([{-", help="if word_timestamps is True, merge these punctuation symbols with the next word")
# parser.add_argument("--append_punctuations", type=str, default="\"\'.。,!?::”)]}、", help="if word_timestamps is True, merge these punctuation symbols with the previous word")
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face Access Token to access PyAnnote gated models")
# parser.add_argument("--model_flush", action="store_true", help="Flush memory from each model after use, reduces GPU requirement but slower processing >1 audio file.")
# fmt: on
args = parser.parse_args().__dict__
@ -88,6 +86,12 @@ def cli():
align_model: str = args.pop("align_model")
interpolate_method: str = args.pop("interpolate_method")
no_align: bool = args.pop("no_align")
task : str = args.pop("task")
if task == "translate":
# translation cannot be aligned
no_align = True
return_char_alignments: bool = args.pop("return_char_alignments")
hf_token: str = args.pop("hf_token")
vad_onset: float = args.pop("vad_onset")
@ -97,7 +101,6 @@ def cli():
min_speakers: int = args.pop("min_speakers")
max_speakers: int = args.pop("max_speakers")
# TODO: check model loading works.
if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
if args["language"] is not None:
@ -141,7 +144,7 @@ def cli():
results = []
tmp_results = []
# model = load_model(model_name, device=device, download_root=model_dir)
model = load_model(model_name, device=device, compute_type=compute_type, language=args['language'], asr_options=asr_options, vad_options={"vad_onset": vad_onset, "vad_offset": vad_offset},)
model = load_model(model_name, device=device, compute_type=compute_type, language=args['language'], asr_options=asr_options, vad_options={"vad_onset": vad_onset, "vad_offset": vad_offset}, task=task)
for audio_path in args.pop("audio"):
audio = load_audio(audio_path)
@ -175,7 +178,8 @@ def cli():
print(f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language...")
align_model, align_metadata = load_align_model(result["language"], device)
print(">>Performing alignment...")
result = align(result["segments"], align_model, align_metadata, input_audio, device, interpolate_method=interpolate_method)
result = align(result["segments"], align_model, align_metadata, input_audio, device, interpolate_method=interpolate_method, return_char_alignments=return_char_alignments)
results.append((result, audio_path))
# Unload align model
@ -190,21 +194,13 @@ def cli():
tmp_results = results
print(">>Performing diarization...")
results = []
diarize_model = DiarizationPipeline(use_auth_token=hf_token)
diarize_model = DiarizationPipeline(use_auth_token=hf_token, device=device)
for result, input_audio_path in tmp_results:
diarize_segments = diarize_model(input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers)
results_segments, word_segments = assign_word_speakers(diarize_segments, result["segments"])
result = {"segments": results_segments, "word_segments": word_segments}
result = assign_word_speakers(diarize_segments, result)
results.append((result, input_audio_path))
# >> Write
for result, audio_path in results:
# Remove pandas dataframes from result so that
# we can serialize the result with json
for seg in result["segments"]:
seg.pop("word-segments", None)
seg.pop("char-segments", None)
writer(result, audio_path, writer_args)
if __name__ == "__main__":

58
whisperx/types.py Normal file
View File

@ -0,0 +1,58 @@
from typing import TypedDict, Optional
class SingleWordSegment(TypedDict):
"""
A single word of a speech.
"""
word: str
start: float
end: float
score: float
class SingleCharSegment(TypedDict):
"""
A single char of a speech.
"""
char: str
start: float
end: float
score: float
class SingleSegment(TypedDict):
"""
A single segment (up to multiple sentences) of a speech.
"""
start: float
end: float
text: str
class SingleAlignedSegment(TypedDict):
"""
A single segment (up to multiple sentences) of a speech with word alignment.
"""
start: float
end: float
text: str
words: list[SingleWordSegment]
chars: Optional[list[SingleCharSegment]]
class TranscriptionResult(TypedDict):
"""
A list of segments and word segments of a speech.
"""
segments: list[SingleSegment]
language: str
class AlignedTranscriptionResult(TypedDict):
"""
A list of segments and word segments of a speech.
"""
segments: list[SingleAlignedSegment]
word_segments: list[SingleWordSegment]

View File

@ -231,11 +231,16 @@ class SubtitlesWriter(ResultWriter):
line_count = 1
# the next subtitle to yield (a list of word timings with whitespace)
subtitle: list[dict] = []
last = result["segments"][0]["words"][0]["start"]
times = []
last = result["segments"][0]["start"]
for segment in result["segments"]:
for i, original_timing in enumerate(segment["words"]):
timing = original_timing.copy()
long_pause = not preserve_segments and timing["start"] - last > 3.0
long_pause = not preserve_segments
if "start" in timing:
long_pause = long_pause and timing["start"] - last > 3.0
else:
long_pause = False
has_room = line_len + len(timing["word"]) <= max_line_width
seg_break = i == 0 and len(subtitle) > 0 and preserve_segments
if line_len > 0 and has_room and not long_pause and not seg_break:
@ -251,8 +256,9 @@ class SubtitlesWriter(ResultWriter):
or seg_break
):
# subtitle break
yield subtitle
yield subtitle, times
subtitle = []
times = []
line_count = 1
elif line_len > 0:
# line break
@ -260,40 +266,53 @@ class SubtitlesWriter(ResultWriter):
timing["word"] = "\n" + timing["word"]
line_len = len(timing["word"].strip())
subtitle.append(timing)
last = timing["start"]
times.append((segment["start"], segment["end"], segment.get("speaker")))
if "start" in timing:
last = timing["start"]
if len(subtitle) > 0:
yield subtitle
yield subtitle, times
if "words" in result["segments"][0]:
for subtitle in iterate_subtitles():
subtitle_start = self.format_timestamp(subtitle[0]["start"])
subtitle_end = self.format_timestamp(subtitle[-1]["end"])
subtitle_text = "".join([word["word"] for word in subtitle])
if highlight_words:
for subtitle, _ in iterate_subtitles():
sstart, ssend, speaker = _[0]
subtitle_start = self.format_timestamp(sstart)
subtitle_end = self.format_timestamp(ssend)
subtitle_text = " ".join([word["word"] for word in subtitle])
has_timing = any(["start" in word for word in subtitle])
# add [$SPEAKER_ID]: to each subtitle if speaker is available
prefix = ""
if speaker is not None:
prefix = f"[{speaker}]: "
if highlight_words and has_timing:
last = subtitle_start
all_words = [timing["word"] for timing in subtitle]
for i, this_word in enumerate(subtitle):
start = self.format_timestamp(this_word["start"])
end = self.format_timestamp(this_word["end"])
if last != start:
yield last, start, subtitle_text
if "start" in this_word:
start = self.format_timestamp(this_word["start"])
end = self.format_timestamp(this_word["end"])
if last != start:
yield last, start, subtitle_text
yield start, end, "".join(
[
re.sub(r"^(\s*)(.*)$", r"\1<u>\2</u>", word)
if j == i
else word
for j, word in enumerate(all_words)
]
)
last = end
yield start, end, prefix + " ".join(
[
re.sub(r"^(\s*)(.*)$", r"\1<u>\2</u>", word)
if j == i
else word
for j, word in enumerate(all_words)
]
)
last = end
else:
yield subtitle_start, subtitle_end, subtitle_text
yield subtitle_start, subtitle_end, prefix + subtitle_text
else:
for segment in result["segments"]:
segment_start = self.format_timestamp(segment["start"])
segment_end = self.format_timestamp(segment["end"])
segment_text = segment["text"].strip().replace("-->", "->")
if "speaker" in segment:
segment_text = f"[{segment['speaker']}]: {segment_text}"
yield segment_start, segment_end, segment_text
def format_timestamp(self, seconds: float):