mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
Compare commits
21 Commits
Author | SHA1 | Date | |
---|---|---|---|
1caddfb564 | |||
7ad554c64f | |||
4603f010a5 | |||
24008aa1ed | |||
07361ba1d7 | |||
4e2ac4e4e9 | |||
d2116b98ca | |||
d8f0ef4a19 | |||
1b62c61c71 | |||
2d59eb9726 | |||
cb53661070 | |||
2a6830492c | |||
da3aabe181 | |||
067189248f | |||
b666523004 | |||
69e038cbc4 | |||
9fb51412c0 | |||
a693a779fa | |||
64ca208cc8 | |||
5becc99e56 | |||
5b85c5433f |
1
.gitignore
vendored
1
.gitignore
vendored
@ -1,2 +1,3 @@
|
|||||||
whisperx.egg-info/
|
whisperx.egg-info/
|
||||||
**/__pycache__/
|
**/__pycache__/
|
||||||
|
.ipynb_checkpoints
|
||||||
|
19
Dockerfile
Normal file
19
Dockerfile
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
FROM pytorch/pytorch:1.11.0-cuda11.3-cudnn8-devel
|
||||||
|
RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub && \
|
||||||
|
apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64/7fa2af80.pub
|
||||||
|
RUN apt-get update && \
|
||||||
|
apt-get install -y wget && \
|
||||||
|
wget -qO - https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/7fa2af80.pub | apt-key add - && \
|
||||||
|
apt-get update && \
|
||||||
|
apt-get install -y git && \
|
||||||
|
apt-get install libsndfile1 -y && \
|
||||||
|
apt-get clean
|
||||||
|
|
||||||
|
RUN pip install --upgrade pip
|
||||||
|
RUN pip install --upgrade setuptools
|
||||||
|
RUN pip install git+https://github.com/m-bain/whisperx.git
|
||||||
|
RUN pip install jupyter ipykernel
|
||||||
|
EXPOSE 8888
|
||||||
|
# Use external volume for data
|
||||||
|
ENV NVIDIA_VISIBLE_DEVICES 1
|
||||||
|
CMD ["jupyter", "notebook", "--ip=0.0.0.0", "--port=8888", "--NotebookApp.token=''","--NotebookApp.password=''", "--allow-root"]
|
171
README.md
171
README.md
@ -13,36 +13,36 @@
|
|||||||
<img src="https://img.shields.io/github/license/m-bain/whisperX.svg"
|
<img src="https://img.shields.io/github/license/m-bain/whisperX.svg"
|
||||||
alt="GitHub license">
|
alt="GitHub license">
|
||||||
</a>
|
</a>
|
||||||
|
<a href="https://arxiv.org/abs/2303.00747">
|
||||||
|
<img src="http://img.shields.io/badge/Arxiv-2303.00747-B31B1B.svg"
|
||||||
|
alt="ArXiv paper">
|
||||||
|
</a>
|
||||||
<a href="https://twitter.com/intent/tweet?text=&url=https%3A%2F%2Fgithub.com%2Fm-bain%2FwhisperX">
|
<a href="https://twitter.com/intent/tweet?text=&url=https%3A%2F%2Fgithub.com%2Fm-bain%2FwhisperX">
|
||||||
<img src="https://img.shields.io/twitter/url/https/github.com/m-bain/whisperX.svg?style=social" alt="Twitter">
|
<img src="https://img.shields.io/twitter/url/https/github.com/m-bain/whisperX.svg?style=social" alt="Twitter">
|
||||||
</a>
|
</a>
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
<p align="center">
|
|
||||||
<a href="#what-is-it">What is it</a> •
|
|
||||||
<a href="#setup">Setup</a> •
|
|
||||||
<a href="#example">Usage</a> •
|
|
||||||
<a href="#other-languages">Multilingual</a> •
|
|
||||||
<a href="#contribute">Contribute</a> •
|
|
||||||
<a href="EXAMPLES.md">More examples</a> •
|
|
||||||
<a href="https://arxiv.org/abs/2303.00747">Paper</a>
|
|
||||||
</p>
|
|
||||||
|
|
||||||
|
|
||||||
<img width="1216" align="center" alt="whisperx-arch" src="figures/pipeline.png">
|
<img width="1216" align="center" alt="whisperx-arch" src="figures/pipeline.png">
|
||||||
|
|
||||||
|
|
||||||
<p align="left">Whisper-Based Automatic Speech Recognition (ASR) with improved timestamp accuracy + quality via forced phoneme alignment and speech-activity batching.
|
<!-- <p align="left">Whisper-Based Automatic Speech Recognition (ASR) with improved timestamp accuracy + quality via forced phoneme alignment and voice-activity based batching for fast inference.</p> -->
|
||||||
|
|
||||||
</p>
|
|
||||||
|
|
||||||
|
|
||||||
<h2 align="left", id="what-is-it">What is it 🔎</h2>
|
<!-- <h2 align="left", id="what-is-it">What is it 🔎</h2> -->
|
||||||
|
|
||||||
This repository refines the timestamps of openAI's Whisper model via forced aligment with phoneme-based ASR models (e.g. wav2vec2.0) and VAD preprocesssing, multilingual use-case.
|
|
||||||
|
|
||||||
|
|
||||||
**Whisper** is an ASR model [developed by OpenAI](https://github.com/openai/whisper), trained on a large dataset of diverse audio. Whilst it does produces highly accurate transcriptions, the corresponding timestamps are at the utterance-level, not per word, and can be inaccurate by several seconds.
|
This repository provides fast automatic speaker recognition (70x realtime with large-v2) with word-level timestamps and speaker diarization.
|
||||||
|
|
||||||
|
- ⚡️ Batched inference for 70x realtime transcription using whisper large-v2
|
||||||
|
- 🪶 [faster-whisper](https://github.com/guillaumekln/faster-whisper) backend, requires <8GB gpu memory for large-v2 with beam_size=5
|
||||||
|
- 🎯 Accurate word-level timestamps using wav2vec2 alignment
|
||||||
|
- 👯♂️ Multispeaker ASR using speaker diarization from [pyannote-audio](https://github.com/pyannote/pyannote-audio) (labels each segment/word with speaker ID)
|
||||||
|
- 🗣️ VAD preprocessing, reduces hallucination & batching with no WER degradation
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
**Whisper** is an ASR model [developed by OpenAI](https://github.com/openai/whisper), trained on a large dataset of diverse audio. Whilst it does produces highly accurate transcriptions, the corresponding timestamps are at the utterance-level, not per word, and can be inaccurate by several seconds. OpenAI's whisper does not natively support batching.
|
||||||
|
|
||||||
**Phoneme-Based ASR** A suite of models finetuned to recognise the smallest unit of speech distinguishing one word from another, e.g. the element p in "tap". A popular example model is [wav2vec2.0](https://huggingface.co/facebook/wav2vec2-large-960h-lv60-self).
|
**Phoneme-Based ASR** A suite of models finetuned to recognise the smallest unit of speech distinguishing one word from another, e.g. the element p in "tap". A popular example model is [wav2vec2.0](https://huggingface.co/facebook/wav2vec2-large-960h-lv60-self).
|
||||||
|
|
||||||
@ -50,34 +50,40 @@ This repository refines the timestamps of openAI's Whisper model via forced alig
|
|||||||
|
|
||||||
**Voice Activity Detection (VAD)** is the detection of the presence or absence of human speech.
|
**Voice Activity Detection (VAD)** is the detection of the presence or absence of human speech.
|
||||||
|
|
||||||
<h2 align="left", id="highlights">New🚨</h2>
|
**Speaker Diarization** is the process of partitioning an audio stream containing human speech into homogeneous segments according to the identity of each speaker.
|
||||||
|
|
||||||
- v3 released, 70x speed-up open-sourced. Using batched whisper with [faster-whisper](https://github.com/guillaumekln/faster-whisper) backend!
|
- v3 pre-release [this branch](https://github.com/m-bain/whisperX/tree/v3) *70x speed-up open-sourced. Using batched whisper with faster-whisper backend*!
|
||||||
- v2 released, code cleanup, imports whisper library, batched inference from paper not included (contact for licensing / batched model API). VAD filtering is now turned on by default, as in the paper.
|
- v2 released, code cleanup, imports whisper library. VAD filtering is now turned on by default, as in the paper.
|
||||||
- Paper drop🎓👨🏫! Please see our [ArxiV preprint](https://arxiv.org/abs/2303.00747) for benchmarking and details of WhisperX. We also introduce more efficient batch inference resulting in large-v2 with *60-70x REAL TIME speed (not provided in this repo).
|
- Paper drop🎓👨🏫! Please see our [ArxiV preprint](https://arxiv.org/abs/2303.00747) for benchmarking and details of WhisperX. We also introduce more efficient batch inference resulting in large-v2 with *60-70x REAL TIME speed (not provided in this repo).
|
||||||
- VAD filtering: Voice Activity Detection (VAD) from [Pyannote.audio](https://huggingface.co/pyannote/voice-activity-detection) is used as a preprocessing step to remove reliance on whisper timestamps and only transcribe audio segments containing speech. add `--vad_filter True` flag, increases timestamp accuracy and robustness (requires more GPU mem due to 30s inputs in wav2vec2)
|
- VAD filtering: Voice Activity Detection (VAD) from [Pyannote.audio](https://huggingface.co/pyannote/voice-activity-detection) is used as a preprocessing step to remove reliance on whisper timestamps and only transcribe audio segments containing speech. add `--vad_filter True` flag, increases timestamp accuracy and robustness (requires more GPU mem due to 30s inputs in wav2vec2)
|
||||||
- Character level timestamps (see `*.char.ass` file output)
|
- Character level timestamps (see `*.char.ass` file output)
|
||||||
- Diarization (still in beta, add `--diarize`)
|
- Diarization (still in beta, add `--diarize`)
|
||||||
|
|
||||||
|
<h2 align="left", id="highlights">New🚨</h2>
|
||||||
|
|
||||||
|
- v3 transcript segment-per-sentence: using nltk sent_tokenize for better subtitlting & better diarization
|
||||||
|
- v3 released, 70x speed-up open-sourced. Using batched whisper with [faster-whisper](https://github.com/guillaumekln/faster-whisper) backend!
|
||||||
|
- v2 released, code cleanup, imports whisper library VAD filtering is now turned on by default, as in the paper.
|
||||||
|
- Paper drop🎓👨🏫! Please see our [ArxiV preprint](https://arxiv.org/abs/2303.00747) for benchmarking and details of WhisperX. We also introduce more efficient batch inference resulting in large-v2 with *60-70x REAL TIME speed.
|
||||||
|
|
||||||
<h2 align="left" id="setup">Setup ⚙️</h2>
|
<h2 align="left" id="setup">Setup ⚙️</h2>
|
||||||
Tested for PyTorch 0.11, Python 3.8 (use other versions at your own risk!)
|
Tested for PyTorch 2.0, Python 3.10 (use other versions at your own risk!)
|
||||||
|
|
||||||
GPU execution requires the NVIDIA libraries cuBLAS 11.x and cuDNN 8.x to be installed on the system. Please refer to the [CTranslate2 documentation](https://opennmt.net/CTranslate2/installation.html).
|
GPU execution requires the NVIDIA libraries cuBLAS 11.x and cuDNN 8.x to be installed on the system. Please refer to the [CTranslate2 documentation](https://opennmt.net/CTranslate2/installation.html).
|
||||||
|
|
||||||
|
|
||||||
### 1. Create Python3.8 environment
|
### 1. Create Python3.10 environment
|
||||||
|
|
||||||
`conda create --name whisperx python=3.8`
|
`conda create --name whisperx python=3.10`
|
||||||
|
|
||||||
`conda activate whisperx`
|
`conda activate whisperx`
|
||||||
|
|
||||||
|
|
||||||
### 2. Install PyTorch 0.11.0, e.g. for Linux and Windows:
|
### 2. Install PyTorch2.0, e.g. for Linux and Windows CUDA11.7:
|
||||||
|
|
||||||
`pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113`
|
`pip3 install torch torchvision torchaudio`
|
||||||
|
|
||||||
See other methods [here.](https://pytorch.org/get-started/previous-versions/#wheel-4)
|
See other methods [here.](https://pytorch.org/get-started/locally/)
|
||||||
|
|
||||||
### 3. Install this repo
|
### 3. Install this repo
|
||||||
|
|
||||||
@ -89,15 +95,13 @@ If already installed, update package to most recent commit
|
|||||||
|
|
||||||
If wishing to modify this package, clone and install in editable mode:
|
If wishing to modify this package, clone and install in editable mode:
|
||||||
```
|
```
|
||||||
$ git clone https://github.com/m-bain/whisperX.git@v3
|
$ git clone https://github.com/m-bain/whisperX.git
|
||||||
$ cd whisperX
|
$ cd whisperX
|
||||||
$ git checkout v3
|
|
||||||
$ pip install -e .
|
$ pip install -e .
|
||||||
```
|
```
|
||||||
|
|
||||||
You may also need to install ffmpeg, rust etc. Follow openAI instructions here https://github.com/openai/whisper#setup.
|
You may also need to install ffmpeg, rust etc. Follow openAI instructions here https://github.com/openai/whisper#setup.
|
||||||
|
|
||||||
|
|
||||||
### Speaker Diarization
|
### Speaker Diarization
|
||||||
To **enable Speaker. Diarization**, include your Hugging Face access token that you can generate from [Here](https://huggingface.co/settings/tokens) after the `--hf_token` argument and accept the user agreement for the following models: [Segmentation](https://huggingface.co/pyannote/segmentation) , [Voice Activity Detection (VAD)](https://huggingface.co/pyannote/voice-activity-detection) , and [Speaker Diarization](https://huggingface.co/pyannote/speaker-diarization)
|
To **enable Speaker. Diarization**, include your Hugging Face access token that you can generate from [Here](https://huggingface.co/settings/tokens) after the `--hf_token` argument and accept the user agreement for the following models: [Segmentation](https://huggingface.co/pyannote/segmentation) , [Voice Activity Detection (VAD)](https://huggingface.co/pyannote/voice-activity-detection) , and [Speaker Diarization](https://huggingface.co/pyannote/speaker-diarization)
|
||||||
|
|
||||||
@ -106,15 +110,11 @@ To **enable Speaker. Diarization**, include your Hugging Face access token that
|
|||||||
|
|
||||||
### English
|
### English
|
||||||
|
|
||||||
Run whisper on example segment (using default params)
|
Run whisper on example segment (using default params, whisper small) add `--highlight_words True` to visualise word timings in the .srt file.
|
||||||
|
|
||||||
whisperx examples/sample01.wav
|
whisperx examples/sample01.wav
|
||||||
|
|
||||||
|
|
||||||
For increased timestamp accuracy, at the cost of higher gpu mem, use bigger models (bigger alignment model not found to be that helpful, see paper) e.g.
|
|
||||||
|
|
||||||
whisperx examples/sample01.wav --model large-v2 --align_model WAV2VEC2_ASR_LARGE_LV60K_960H
|
|
||||||
|
|
||||||
Result using *WhisperX* with forced alignment to wav2vec2.0 large:
|
Result using *WhisperX* with forced alignment to wav2vec2.0 large:
|
||||||
|
|
||||||
https://user-images.githubusercontent.com/36994049/208253969-7e35fe2a-7541-434a-ae91-8e919540555d.mp4
|
https://user-images.githubusercontent.com/36994049/208253969-7e35fe2a-7541-434a-ae91-8e919540555d.mp4
|
||||||
@ -123,6 +123,16 @@ Compare this to original whisper out the box, where many transcriptions are out
|
|||||||
|
|
||||||
https://user-images.githubusercontent.com/36994049/207743923-b4f0d537-29ae-4be2-b404-bb941db73652.mov
|
https://user-images.githubusercontent.com/36994049/207743923-b4f0d537-29ae-4be2-b404-bb941db73652.mov
|
||||||
|
|
||||||
|
|
||||||
|
For increased timestamp accuracy, at the cost of higher gpu mem, use bigger models (bigger alignment model not found to be that helpful, see paper) e.g.
|
||||||
|
|
||||||
|
whisperx examples/sample01.wav --model large-v2 --align_model WAV2VEC2_ASR_LARGE_LV60K_960H --batch_size 4
|
||||||
|
|
||||||
|
|
||||||
|
To label the transcript with speaker ID's (set number of speakers if known e.g. `--min_speakers 2` `--max_speakers 2`):
|
||||||
|
|
||||||
|
whisperx examples/sample01.wav --model large-v2 --diarize --highlight_words True
|
||||||
|
|
||||||
### Other languages
|
### Other languages
|
||||||
|
|
||||||
The phoneme ASR alignment model is *language-specific*, for tested languages these models are [automatically picked from torchaudio pipelines or huggingface](https://github.com/m-bain/whisperX/blob/e909f2f766b23b2000f2d95df41f9b844ac53e49/whisperx/transcribe.py#L22).
|
The phoneme ASR alignment model is *language-specific*, for tested languages these models are [automatically picked from torchaudio pipelines or huggingface](https://github.com/m-bain/whisperX/blob/e909f2f766b23b2000f2d95df41f9b844ac53e49/whisperx/transcribe.py#L22).
|
||||||
@ -132,7 +142,7 @@ Currently default models provided for `{en, fr, de, es, it, ja, zh, nl, uk, pt}`
|
|||||||
|
|
||||||
|
|
||||||
#### E.g. German
|
#### E.g. German
|
||||||
whisperx --model large --language de examples/sample_de_01.wav
|
whisperx --model large-v2 --language de examples/sample_de_01.wav
|
||||||
|
|
||||||
https://user-images.githubusercontent.com/36994049/208298811-e36002ba-3698-4731-97d4-0aebd07e0eb3.mov
|
https://user-images.githubusercontent.com/36994049/208298811-e36002ba-3698-4731-97d4-0aebd07e0eb3.mov
|
||||||
|
|
||||||
@ -143,79 +153,108 @@ See more examples in other languages [here](EXAMPLES.md).
|
|||||||
|
|
||||||
```python
|
```python
|
||||||
import whisperx
|
import whisperx
|
||||||
|
import gc
|
||||||
|
|
||||||
device = "cuda"
|
device = "cuda"
|
||||||
audio_file = "audio.mp3"
|
audio_file = "audio.mp3"
|
||||||
|
batch_size = 16 # reduce if low on GPU mem
|
||||||
|
compute_type = "float16" # change to "int8" if low on GPU mem (may reduce accuracy)
|
||||||
|
|
||||||
# transcribe with original whisper
|
# 1. Transcribe with original whisper (batched)
|
||||||
model = whisperx.load_model("large-v2", device)
|
model = whisperx.load_model("large-v2", device, compute_type=compute_type)
|
||||||
|
|
||||||
audio = whisperx.load_audio(audio_file)
|
audio = whisperx.load_audio(audio_file)
|
||||||
result = model.transcribe(audio, batch_size=8)
|
result = model.transcribe(audio, batch_size=batch_size)
|
||||||
|
|
||||||
print(result["segments"]) # before alignment
|
print(result["segments"]) # before alignment
|
||||||
|
|
||||||
# load alignment model and metadata
|
# delete model if low on GPU resources
|
||||||
|
# import gc; gc.collect(); torch.cuda.empty_cache(); del model
|
||||||
|
|
||||||
|
# 2. Align whisper output
|
||||||
model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
|
model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
|
||||||
|
result = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False)
|
||||||
|
|
||||||
# align whisper output
|
print(result["segments"]) # after alignment
|
||||||
result_aligned = whisperx.align(result["segments"], model_a, metadata, audio, device)
|
|
||||||
|
|
||||||
print(result_aligned["segments"]) # after alignment
|
# delete model if low on GPU resources
|
||||||
print(result_aligned["word_segments"]) # after alignment
|
# import gc; gc.collect(); torch.cuda.empty_cache(); del model_a
|
||||||
|
|
||||||
|
# 3. Assign speaker labels
|
||||||
|
diarize_model = whisperx.DiarizationPipeline(use_auth_token=YOUR_HF_TOKEN, device=device)
|
||||||
|
|
||||||
|
# add min/max number of speakers if known
|
||||||
|
diarize_segments = diarize_model(input_audio_path)
|
||||||
|
# diarize_model(input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers)
|
||||||
|
|
||||||
|
result = assign_word_speakers(diarize_segments, result)
|
||||||
|
print(diarize_segments)
|
||||||
|
print(result["segments"]) # segments are now assigned speaker IDs
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
<h2 align="left" id="whisper-mod">Whisper Modifications</h2>
|
<h2 align="left" id="whisper-mod">Technical Details 👷♂️</h2>
|
||||||
|
|
||||||
In addition to forced alignment, the following two modifications have been made to the whisper transcription method:
|
For specific details on the batching and alignment, the effect of VAD, as well as the chosen alignment model, see the preprint [paper](https://www.robots.ox.ac.uk/~vgg/publications/2023/Bain23/bain23.pdf).
|
||||||
|
|
||||||
1. `--condition_on_prev_text` is set to `False` by default (reduces hallucination)
|
To reduce GPU memory requirements, try any of the following (2. & 3. can affect quality):
|
||||||
|
1. reduce batch size, e.g. `--batch_size 4`
|
||||||
|
2. use a smaller ASR model `--model base`
|
||||||
|
3. Use lighter compute type `--compute_type int8`
|
||||||
|
|
||||||
|
Transcription differences from openai's whisper:
|
||||||
|
1. Transcription without timestamps. To enable single pass batching, whisper inference is performed `--without_timestamps True`, this ensures 1 forward pass per sample in the batch. However, this can cause discrepancies the default whisper output.
|
||||||
|
2. VAD-based segment transcription, unlike the buffered transcription of openai's. In Wthe WhisperX paper we show this reduces WER, and enables accurate batched inference
|
||||||
|
3. `--condition_on_prev_text` is set to `False` by default (reduces hallucination)
|
||||||
|
|
||||||
<h2 align="left" id="limitations">Limitations ⚠️</h2>
|
<h2 align="left" id="limitations">Limitations ⚠️</h2>
|
||||||
|
|
||||||
- Whisper normalises spoken numbers e.g. "fifty seven" to arabic numerals "57". Need to perform this normalization after alignment, so the phonemes can be aligned. Currently just ignores numbers.
|
- Transcript words which do not contain characters in the alignment models dictionary e.g. "2014." or "£13.60" cannot be aligned and therefore are not given a timing.
|
||||||
- If setting `--vad_filter False`, then whisperx assumes the initial whisper timestamps are accurate to some degree (within margin of 2 seconds, adjust if needed -- bigger margins more prone to alignment errors)
|
|
||||||
- Overlapping speech is not handled particularly well by whisper nor whisperx
|
- Overlapping speech is not handled particularly well by whisper nor whisperx
|
||||||
- Diariazation is far from perfect.
|
- Diarization is far from perfect (working on this with custom model v4 -- see contact me).
|
||||||
|
- Language specific wav2vec2 model is needed
|
||||||
|
|
||||||
|
|
||||||
<h2 align="left" id="contribute">Contribute 🧑🏫</h2>
|
<h2 align="left" id="contribute">Contribute 🧑🏫</h2>
|
||||||
|
|
||||||
If you are multilingual, a major way you can contribute to this project is to find phoneme models on huggingface (or train your own) and test them on speech for the target language. If the results look good send a merge request and some examples showing its success.
|
If you are multilingual, a major way you can contribute to this project is to find phoneme models on huggingface (or train your own) and test them on speech for the target language. If the results look good send a pull request and some examples showing its success.
|
||||||
|
|
||||||
The next major upgrade we are working on is whisper with speaker diarization, so if you have any experience on this please share.
|
Bug finding and pull requests are also highly appreciated to keep this project going, since it's already diverging from the original research scope.
|
||||||
|
|
||||||
<h2 align="left" id="coming-soon">Coming Soon 🗓</h2>
|
<h2 align="left" id="coming-soon">TODO 🗓</h2>
|
||||||
|
|
||||||
* [x] Multilingual init
|
* [x] Multilingual init
|
||||||
|
|
||||||
* [x] Subtitle .ass output
|
|
||||||
|
|
||||||
* [x] Automatic align model selection based on language detection
|
* [x] Automatic align model selection based on language detection
|
||||||
|
|
||||||
* [x] Python usage
|
* [x] Python usage
|
||||||
|
|
||||||
* [x] Character level timestamps
|
|
||||||
|
|
||||||
* [x] Incorporating speaker diarization
|
* [x] Incorporating speaker diarization
|
||||||
|
|
||||||
* [x] Model flush, for low gpu mem resources
|
* [x] Model flush, for low gpu mem resources
|
||||||
|
|
||||||
* [x] Faster-whisper backend
|
* [x] Faster-whisper backend
|
||||||
|
|
||||||
|
* [x] Add max-line etc. see (openai's whisper utils.py)
|
||||||
|
|
||||||
|
* [x] Sentence-level segments (nltk toolbox)
|
||||||
|
|
||||||
|
* [x] Improve alignment logic
|
||||||
|
|
||||||
|
* [ ] update examples with diarization and word highlighting
|
||||||
|
|
||||||
|
* [ ] Subtitle .ass output <- bring this back (removed in v3)
|
||||||
|
|
||||||
* [ ] Add benchmarking code (TEDLIUM for spd/WER & word segmentation)
|
* [ ] Add benchmarking code (TEDLIUM for spd/WER & word segmentation)
|
||||||
|
|
||||||
* [ ] Allow silero-vad as alternative VAD option
|
* [ ] Allow silero-vad as alternative VAD option
|
||||||
|
|
||||||
* [ ] Add max-line etc. see (openai's whisper utils.py)
|
|
||||||
|
|
||||||
* [ ] Improve diarization (word level). *Harder than first thought...*
|
* [ ] Improve diarization (word level). *Harder than first thought...*
|
||||||
|
|
||||||
|
|
||||||
<h2 align="left" id="contact">Contact/Support 📇</h2>
|
<h2 align="left" id="contact">Contact/Support 📇</h2>
|
||||||
|
|
||||||
Contact maxhbain@gmail.com for queries and licensing / early access to a model API with batched inference (transcribe 1hr audio in under 1min).
|
|
||||||
|
Contact maxhbain@gmail.com for queries. WhisperX v4 development is underway with with siginificantly improved diarization. To support v4 and get early access, get in touch.
|
||||||
|
|
||||||
<a href="https://www.buymeacoffee.com/maxhbain" target="_blank"><img src="https://cdn.buymeacoffee.com/buttons/default-orange.png" alt="Buy Me A Coffee" height="41" width="174"></a>
|
<a href="https://www.buymeacoffee.com/maxhbain" target="_blank"><img src="https://cdn.buymeacoffee.com/buttons/default-orange.png" alt="Buy Me A Coffee" height="41" width="174"></a>
|
||||||
|
|
||||||
@ -224,14 +263,18 @@ Contact maxhbain@gmail.com for queries and licensing / early access to a model A
|
|||||||
|
|
||||||
This work, and my PhD, is supported by the [VGG (Visual Geometry Group)](https://www.robots.ox.ac.uk/~vgg/) and the University of Oxford.
|
This work, and my PhD, is supported by the [VGG (Visual Geometry Group)](https://www.robots.ox.ac.uk/~vgg/) and the University of Oxford.
|
||||||
|
|
||||||
|
|
||||||
Of course, this is builds on [openAI's whisper](https://github.com/openai/whisper).
|
Of course, this is builds on [openAI's whisper](https://github.com/openai/whisper).
|
||||||
And borrows important alignment code from [PyTorch tutorial on forced alignment](https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html)
|
Borrows important alignment code from [PyTorch tutorial on forced alignment](https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html)
|
||||||
|
And uses the wonderful pyannote VAD / Diarization https://github.com/pyannote/pyannote-audio
|
||||||
|
|
||||||
Valuable VAD & Diarization Models from (pyannote.audio)[https://github.com/pyannote/pyannote-audio]
|
|
||||||
|
|
||||||
Great backend from (faster-whisper)[https://github.com/guillaumekln/faster-whisper] and (CTranslate2)[https://github.com/OpenNMT/CTranslate2]
|
Valuable VAD & Diarization Models from [pyannote audio][https://github.com/pyannote/pyannote-audio]
|
||||||
|
|
||||||
|
Great backend from [faster-whisper](https://github.com/guillaumekln/faster-whisper) and [CTranslate2](https://github.com/OpenNMT/CTranslate2)
|
||||||
|
|
||||||
|
Those who have [supported this work financially](https://www.buymeacoffee.com/maxhbain) 🙏
|
||||||
|
|
||||||
|
Finally, thanks to the OS [contributors](https://github.com/m-bain/whisperX/graphs/contributors) of this project, keeping it going and identifying bugs.
|
||||||
|
|
||||||
<h2 align="left" id="cite">Citation</h2>
|
<h2 align="left" id="cite">Citation</h2>
|
||||||
If you use this in your research, please cite the paper:
|
If you use this in your research, please cite the paper:
|
||||||
|
91
notebooks/whisperx.ipynb
Normal file
91
notebooks/whisperx.ipynb
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"id": "11fc5246",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"/opt/conda/lib/python3.8/site-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: /opt/conda/lib/python3.8/site-packages/torchvision/image.so: undefined symbol: _ZNK3c1010TensorImpl36is_contiguous_nondefault_policy_implENS_12MemoryFormatE\n",
|
||||||
|
" warn(f\"Failed to load image Python extension: {e}\")\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"ename": "OutOfMemoryError",
|
||||||
|
"evalue": "CUDA out of memory. Tried to allocate 20.00 MiB (GPU 0; 7.79 GiB total capacity; 5.76 GiB already allocated; 59.19 MiB free; 6.06 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF",
|
||||||
|
"output_type": "error",
|
||||||
|
"traceback": [
|
||||||
|
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||||
|
"\u001b[0;31mOutOfMemoryError\u001b[0m Traceback (most recent call last)",
|
||||||
|
"\u001b[0;32m/tmp/ipykernel_66/1447832577.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0;31m# transcribe with original whisper\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 8\u001b[0;31m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mwhisper\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"large\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 9\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtranscribe\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0maudio_file\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||||
|
"\u001b[0;32m/opt/conda/lib/python3.8/site-packages/whisper/__init__.py\u001b[0m in \u001b[0;36mload_model\u001b[0;34m(name, device, download_root, in_memory)\u001b[0m\n\u001b[1;32m 152\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_alignment_heads\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0malignment_heads\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 153\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 154\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
|
||||||
|
"\u001b[0;32m/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36mto\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 987\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_floating_point\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_complex\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_blocking\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 988\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 989\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconvert\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 990\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 991\u001b[0m def register_backward_hook(\n",
|
||||||
|
"\u001b[0;32m/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_apply\u001b[0;34m(self, fn)\u001b[0m\n\u001b[1;32m 639\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 640\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchildren\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 641\u001b[0;31m \u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 642\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 643\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mcompute_should_use_set_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtensor_applied\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||||
|
"\u001b[0;32m/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_apply\u001b[0;34m(self, fn)\u001b[0m\n\u001b[1;32m 639\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 640\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchildren\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 641\u001b[0;31m \u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 642\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 643\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mcompute_should_use_set_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtensor_applied\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||||
|
"\u001b[0;32m/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_apply\u001b[0;34m(self, fn)\u001b[0m\n\u001b[1;32m 639\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 640\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchildren\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 641\u001b[0;31m \u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 642\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 643\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mcompute_should_use_set_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtensor_applied\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||||
|
"\u001b[0;32m/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_apply\u001b[0;34m(self, fn)\u001b[0m\n\u001b[1;32m 639\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 640\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchildren\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 641\u001b[0;31m \u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 642\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 643\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mcompute_should_use_set_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtensor_applied\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||||
|
"\u001b[0;32m/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_apply\u001b[0;34m(self, fn)\u001b[0m\n\u001b[1;32m 639\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 640\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchildren\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 641\u001b[0;31m \u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 642\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 643\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mcompute_should_use_set_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtensor_applied\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||||
|
"\u001b[0;32m/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_apply\u001b[0;34m(self, fn)\u001b[0m\n\u001b[1;32m 662\u001b[0m \u001b[0;31m# `with torch.no_grad():`\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 663\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mno_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 664\u001b[0;31m \u001b[0mparam_applied\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparam\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 665\u001b[0m \u001b[0mshould_use_set_data\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcompute_should_use_set_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparam\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparam_applied\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 666\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mshould_use_set_data\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||||
|
"\u001b[0;32m/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36mconvert\u001b[0;34m(t)\u001b[0m\n\u001b[1;32m 985\u001b[0m return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None,\n\u001b[1;32m 986\u001b[0m non_blocking, memory_format=convert_to_format)\n\u001b[0;32m--> 987\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_floating_point\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_complex\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_blocking\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 988\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 989\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconvert\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||||
|
"\u001b[0;31mOutOfMemoryError\u001b[0m: CUDA out of memory. Tried to allocate 20.00 MiB (GPU 0; 7.79 GiB total capacity; 5.76 GiB already allocated; 59.19 MiB free; 6.06 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"import whisperx\n",
|
||||||
|
"import whisper\n",
|
||||||
|
"\n",
|
||||||
|
"device = \"cuda\" \n",
|
||||||
|
"audio_file = \"audio.mp3\"\n",
|
||||||
|
"\n",
|
||||||
|
"# transcribe with original whisper\n",
|
||||||
|
"model = whisper.load_model(\"large\", device)\n",
|
||||||
|
"result = model.transcribe(audio_file)\n",
|
||||||
|
"\n",
|
||||||
|
"print(result[\"segments\"]) # before alignment\n",
|
||||||
|
"\n",
|
||||||
|
"# load alignment model and metadata\n",
|
||||||
|
"model_a, metadata = whisperx.load_align_model(language_code=result[\"language\"], device=device)\n",
|
||||||
|
"\n",
|
||||||
|
"# align whisper output\n",
|
||||||
|
"result_aligned = whisperx.align(result[\"segments\"], model_a, metadata, audio_file, device)\n",
|
||||||
|
"\n",
|
||||||
|
"print(result_aligned[\"segments\"]) # after alignment\n",
|
||||||
|
"print(result_aligned[\"word_segments\"]) # after alignment"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "b63e6170",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.8.12"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
@ -1,8 +1,8 @@
|
|||||||
torch==1.11.0
|
torch==2.0.0
|
||||||
torchaudio==0.11.0
|
torchaudio==2.0.1
|
||||||
pyannote.audio
|
|
||||||
faster-whisper
|
faster-whisper
|
||||||
transformers
|
transformers
|
||||||
ffmpeg-python==0.2.0
|
ffmpeg-python==0.2.0
|
||||||
pandas
|
pandas
|
||||||
setuptools==65.6.3
|
setuptools==65.6.3
|
||||||
|
nltk
|
6
setup.py
6
setup.py
@ -6,8 +6,8 @@ from setuptools import setup, find_packages
|
|||||||
setup(
|
setup(
|
||||||
name="whisperx",
|
name="whisperx",
|
||||||
py_modules=["whisperx"],
|
py_modules=["whisperx"],
|
||||||
version="3.0.0",
|
version="3.1.0",
|
||||||
description="Time-Accurate Automatic Speech Recognition using Whisper.",
|
description="Time-Accurate Automatic Speech Recognition.",
|
||||||
readme="README.md",
|
readme="README.md",
|
||||||
python_requires=">=3.8",
|
python_requires=">=3.8",
|
||||||
author="Max Bain",
|
author="Max Bain",
|
||||||
@ -19,7 +19,7 @@ setup(
|
|||||||
for r in pkg_resources.parse_requirements(
|
for r in pkg_resources.parse_requirements(
|
||||||
open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
|
open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
|
||||||
)
|
)
|
||||||
],
|
] + ["pyannote.audio @ git+https://github.com/pyannote/pyannote-audio@11b56a137a578db9335efc00298f6ec1932e6317"],
|
||||||
entry_points = {
|
entry_points = {
|
||||||
'console_scripts': ['whisperx=whisperx.transcribe:cli'],
|
'console_scripts': ['whisperx=whisperx.transcribe:cli'],
|
||||||
},
|
},
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
from .transcribe import load_model
|
from .transcribe import load_model
|
||||||
from .alignment import load_align_model, align
|
from .alignment import load_align_model, align
|
||||||
from .audio import load_audio
|
from .audio import load_audio
|
||||||
|
from .diarize import assign_word_speakers, DiarizationPipeline
|
@ -13,6 +13,7 @@ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
|
|||||||
|
|
||||||
from .audio import SAMPLE_RATE, load_audio
|
from .audio import SAMPLE_RATE, load_audio
|
||||||
from .utils import interpolate_nans
|
from .utils import interpolate_nans
|
||||||
|
import nltk
|
||||||
|
|
||||||
LANGUAGES_WITHOUT_SPACES = ["ja", "zh"]
|
LANGUAGES_WITHOUT_SPACES = ["ja", "zh"]
|
||||||
|
|
||||||
@ -38,6 +39,7 @@ DEFAULT_ALIGN_MODELS_HF = {
|
|||||||
"fa": "jonatasgrosman/wav2vec2-large-xlsr-53-persian",
|
"fa": "jonatasgrosman/wav2vec2-large-xlsr-53-persian",
|
||||||
"el": "jonatasgrosman/wav2vec2-large-xlsr-53-greek",
|
"el": "jonatasgrosman/wav2vec2-large-xlsr-53-greek",
|
||||||
"tr": "mpoyraz/wav2vec2-xls-r-300m-cv7-turkish",
|
"tr": "mpoyraz/wav2vec2-xls-r-300m-cv7-turkish",
|
||||||
|
"he": "imvladikon/wav2vec2-xls-r-300m-hebrew",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -83,44 +85,13 @@ def align(
|
|||||||
align_model_metadata: dict,
|
align_model_metadata: dict,
|
||||||
audio: Union[str, np.ndarray, torch.Tensor],
|
audio: Union[str, np.ndarray, torch.Tensor],
|
||||||
device: str,
|
device: str,
|
||||||
extend_duration: float = 0.0,
|
|
||||||
start_from_previous: bool = True,
|
|
||||||
interpolate_method: str = "nearest",
|
interpolate_method: str = "nearest",
|
||||||
|
return_char_alignments: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Force align phoneme recognition predictions to known transcription
|
Align phoneme recognition predictions to known transcription.
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
transcript: Iterator[dict]
|
|
||||||
The Whisper model instance
|
|
||||||
|
|
||||||
model: torch.nn.Module
|
|
||||||
Alignment model (wav2vec2)
|
|
||||||
|
|
||||||
audio: Union[str, np.ndarray, torch.Tensor]
|
|
||||||
The path to the audio file to open, or the audio waveform
|
|
||||||
|
|
||||||
device: str
|
|
||||||
cuda device
|
|
||||||
|
|
||||||
diarization: pd.DataFrame {'start': List[float], 'end': List[float], 'speaker': List[float]}
|
|
||||||
diarization segments with speaker labels.
|
|
||||||
|
|
||||||
extend_duration: float
|
|
||||||
Amount to pad input segments by. If not using vad--filter then recommended to use 2 seconds
|
|
||||||
|
|
||||||
If the gzip compression ratio is above this value, treat as failed
|
|
||||||
|
|
||||||
interpolate_method: str ["nearest", "linear", "ignore"]
|
|
||||||
Method to assign timestamps to non-aligned words. Words are not able to be aligned when none of the characters occur in the align model dictionary.
|
|
||||||
"nearest" copies timestamp of nearest word within the segment. "linear" is linear interpolation. "drop" removes that word from output.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
|
|
||||||
the spoken language ("language"), which is detected when `decode_options["language"]` is None.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not torch.is_tensor(audio):
|
if not torch.is_tensor(audio):
|
||||||
if isinstance(audio, str):
|
if isinstance(audio, str):
|
||||||
audio = load_audio(audio)
|
audio = load_audio(audio)
|
||||||
@ -134,42 +105,21 @@ def align(
|
|||||||
model_lang = align_model_metadata["language"]
|
model_lang = align_model_metadata["language"]
|
||||||
model_type = align_model_metadata["type"]
|
model_type = align_model_metadata["type"]
|
||||||
|
|
||||||
aligned_segments = []
|
# 1. Preprocess to keep only characters in dictionary
|
||||||
|
|
||||||
prev_t2 = 0
|
|
||||||
|
|
||||||
char_segments_arr = {
|
|
||||||
"segment-idx": [],
|
|
||||||
"subsegment-idx": [],
|
|
||||||
"word-idx": [],
|
|
||||||
"char": [],
|
|
||||||
"start": [],
|
|
||||||
"end": [],
|
|
||||||
"score": [],
|
|
||||||
}
|
|
||||||
|
|
||||||
for sdx, segment in enumerate(transcript):
|
for sdx, segment in enumerate(transcript):
|
||||||
while True:
|
|
||||||
segment_align_success = False
|
|
||||||
|
|
||||||
# strip spaces at beginning / end, but keep track of the amount.
|
# strip spaces at beginning / end, but keep track of the amount.
|
||||||
num_leading = len(segment["text"]) - len(segment["text"].lstrip())
|
num_leading = len(segment["text"]) - len(segment["text"].lstrip())
|
||||||
num_trailing = len(segment["text"]) - len(segment["text"].rstrip())
|
num_trailing = len(segment["text"]) - len(segment["text"].rstrip())
|
||||||
transcription = segment["text"]
|
text = segment["text"]
|
||||||
|
|
||||||
# TODO: convert number tokenizer / symbols to phonetic words for alignment.
|
|
||||||
# e.g. "$300" -> "three hundred dollars"
|
|
||||||
# currently "$300" is ignored since no characters present in the phonetic dictionary
|
|
||||||
|
|
||||||
# split into words
|
# split into words
|
||||||
if model_lang not in LANGUAGES_WITHOUT_SPACES:
|
if model_lang not in LANGUAGES_WITHOUT_SPACES:
|
||||||
per_word = transcription.split(" ")
|
per_word = text.split(" ")
|
||||||
else:
|
else:
|
||||||
per_word = transcription
|
per_word = text
|
||||||
|
|
||||||
# first check that characters in transcription can be aligned (they are contained in align model"s dictionary)
|
|
||||||
clean_char, clean_cdx = [], []
|
clean_char, clean_cdx = [], []
|
||||||
for cdx, char in enumerate(transcription):
|
for cdx, char in enumerate(text):
|
||||||
char_ = char.lower()
|
char_ = char.lower()
|
||||||
# wav2vec2 models use "|" character to represent spaces
|
# wav2vec2 models use "|" character to represent spaces
|
||||||
if model_lang not in LANGUAGES_WITHOUT_SPACES:
|
if model_lang not in LANGUAGES_WITHOUT_SPACES:
|
||||||
@ -178,7 +128,7 @@ def align(
|
|||||||
# ignore whitespace at beginning and end of transcript
|
# ignore whitespace at beginning and end of transcript
|
||||||
if cdx < num_leading:
|
if cdx < num_leading:
|
||||||
pass
|
pass
|
||||||
elif cdx > len(transcription) - num_trailing - 1:
|
elif cdx > len(text) - num_trailing - 1:
|
||||||
pass
|
pass
|
||||||
elif char_ in model_dictionary.keys():
|
elif char_ in model_dictionary.keys():
|
||||||
clean_char.append(char_)
|
clean_char.append(char_)
|
||||||
@ -189,35 +139,49 @@ def align(
|
|||||||
if any([c in model_dictionary.keys() for c in wrd]):
|
if any([c in model_dictionary.keys() for c in wrd]):
|
||||||
clean_wdx.append(wdx)
|
clean_wdx.append(wdx)
|
||||||
|
|
||||||
# if no characters are in the dictionary, then we skip this segment...
|
sentence_spans = list(nltk.tokenize.punkt.PunktSentenceTokenizer().span_tokenize(text))
|
||||||
if len(clean_char) == 0:
|
|
||||||
|
segment["clean_char"] = clean_char
|
||||||
|
segment["clean_cdx"] = clean_cdx
|
||||||
|
segment["clean_wdx"] = clean_wdx
|
||||||
|
segment["sentence_spans"] = sentence_spans
|
||||||
|
|
||||||
|
aligned_segments = []
|
||||||
|
|
||||||
|
# 2. Get prediction matrix from alignment model & align
|
||||||
|
for sdx, segment in enumerate(transcript):
|
||||||
|
t1 = segment["start"]
|
||||||
|
t2 = segment["end"]
|
||||||
|
text = segment["text"]
|
||||||
|
|
||||||
|
aligned_seg = {
|
||||||
|
"start": t1,
|
||||||
|
"end": t2,
|
||||||
|
"text": text,
|
||||||
|
"words": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
if return_char_alignments:
|
||||||
|
aligned_seg["chars"] = []
|
||||||
|
|
||||||
|
# check we can align
|
||||||
|
if len(segment["clean_char"]) == 0:
|
||||||
print(f'Failed to align segment ("{segment["text"]}"): no characters in this segment found in model dictionary, resorting to original...')
|
print(f'Failed to align segment ("{segment["text"]}"): no characters in this segment found in model dictionary, resorting to original...')
|
||||||
break
|
aligned_segments.append(aligned_seg)
|
||||||
|
continue
|
||||||
|
|
||||||
transcription_cleaned = "".join(clean_char)
|
if t1 >= MAX_DURATION or t2 - t1 < 0.02:
|
||||||
tokens = [model_dictionary[c] for c in transcription_cleaned]
|
|
||||||
|
|
||||||
# we only pad if not using VAD filtering
|
|
||||||
if "seg_text" not in segment:
|
|
||||||
# pad according original timestamps
|
|
||||||
t1 = max(segment["start"] - extend_duration, 0)
|
|
||||||
t2 = min(segment["end"] + extend_duration, MAX_DURATION)
|
|
||||||
|
|
||||||
# use prev_t2 as current t1 if it"s later
|
|
||||||
if start_from_previous and t1 < prev_t2:
|
|
||||||
t1 = prev_t2
|
|
||||||
|
|
||||||
# check if timestamp range is still valid
|
|
||||||
if t1 >= MAX_DURATION:
|
|
||||||
print("Failed to align segment: original start time longer than audio duration, skipping...")
|
print("Failed to align segment: original start time longer than audio duration, skipping...")
|
||||||
break
|
aligned_segments.append(aligned_seg)
|
||||||
if t2 - t1 < 0.02:
|
continue
|
||||||
print("Failed to align segment: duration smaller than 0.02s time precision")
|
|
||||||
break
|
text_clean = "".join(segment["clean_char"])
|
||||||
|
tokens = [model_dictionary[c] for c in text_clean]
|
||||||
|
|
||||||
f1 = int(t1 * SAMPLE_RATE)
|
f1 = int(t1 * SAMPLE_RATE)
|
||||||
f2 = int(t2 * SAMPLE_RATE)
|
f2 = int(t2 * SAMPLE_RATE)
|
||||||
|
|
||||||
|
# TODO: Probably can get some speedup gain with batched inference here
|
||||||
waveform_segment = audio[:, f1:f2]
|
waveform_segment = audio[:, f1:f2]
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
@ -231,233 +195,117 @@ def align(
|
|||||||
|
|
||||||
emission = emissions[0].cpu().detach()
|
emission = emissions[0].cpu().detach()
|
||||||
|
|
||||||
trellis = get_trellis(emission, tokens)
|
blank_id = 0
|
||||||
path = backtrack(trellis, emission, tokens)
|
for char, code in model_dictionary.items():
|
||||||
|
if char == '[pad]' or char == '<pad>':
|
||||||
|
blank_id = code
|
||||||
|
|
||||||
|
trellis = get_trellis(emission, tokens, blank_id)
|
||||||
|
path = backtrack(trellis, emission, tokens, blank_id)
|
||||||
|
|
||||||
if path is None:
|
if path is None:
|
||||||
print(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original...')
|
print(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original...')
|
||||||
break
|
aligned_segments.append(aligned_seg)
|
||||||
char_segments = merge_repeats(path, transcription_cleaned)
|
continue
|
||||||
# word_segments = merge_words(char_segments)
|
|
||||||
|
|
||||||
|
char_segments = merge_repeats(path, text_clean)
|
||||||
|
|
||||||
# sub-segments
|
duration = t2 -t1
|
||||||
if "seg-text" not in segment:
|
|
||||||
segment["seg-text"] = [transcription]
|
|
||||||
|
|
||||||
seg_lens = [0] + [len(x) for x in segment["seg-text"]]
|
|
||||||
seg_lens_cumsum = list(np.cumsum(seg_lens))
|
|
||||||
sub_seg_idx = 0
|
|
||||||
|
|
||||||
wdx = 0
|
|
||||||
duration = t2 - t1
|
|
||||||
ratio = duration * waveform_segment.size(0) / (trellis.size(0) - 1)
|
ratio = duration * waveform_segment.size(0) / (trellis.size(0) - 1)
|
||||||
for cdx, char in enumerate(transcription + " "):
|
|
||||||
is_last = False
|
|
||||||
if cdx == len(transcription):
|
|
||||||
break
|
|
||||||
elif cdx+1 == len(transcription):
|
|
||||||
is_last = True
|
|
||||||
|
|
||||||
|
|
||||||
|
# assign timestamps to aligned characters
|
||||||
|
char_segments_arr = []
|
||||||
|
word_idx = 0
|
||||||
|
for cdx, char in enumerate(text):
|
||||||
start, end, score = None, None, None
|
start, end, score = None, None, None
|
||||||
if cdx in clean_cdx:
|
if cdx in segment["clean_cdx"]:
|
||||||
char_seg = char_segments[clean_cdx.index(cdx)]
|
char_seg = char_segments[segment["clean_cdx"].index(cdx)]
|
||||||
start = char_seg.start * ratio + t1
|
start = round(char_seg.start * ratio + t1, 3)
|
||||||
end = char_seg.end * ratio + t1
|
end = round(char_seg.end * ratio + t1, 3)
|
||||||
score = char_seg.score
|
score = round(char_seg.score, 3)
|
||||||
|
|
||||||
char_segments_arr["char"].append(char)
|
char_segments_arr.append(
|
||||||
char_segments_arr["start"].append(start)
|
{
|
||||||
char_segments_arr["end"].append(end)
|
"char": char,
|
||||||
char_segments_arr["score"].append(score)
|
"start": start,
|
||||||
char_segments_arr["word-idx"].append(wdx)
|
"end": end,
|
||||||
char_segments_arr["segment-idx"].append(sdx)
|
"score": score,
|
||||||
char_segments_arr["subsegment-idx"].append(sub_seg_idx)
|
"word-idx": word_idx,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# word-level info
|
# increment word_idx, nltk word tokenization would probably be more robust here, but us space for now...
|
||||||
if model_lang in LANGUAGES_WITHOUT_SPACES:
|
if model_lang in LANGUAGES_WITHOUT_SPACES:
|
||||||
# character == word
|
word_idx += 1
|
||||||
wdx += 1
|
elif cdx == len(text) - 1 or text[cdx+1] == " ":
|
||||||
elif is_last or transcription[cdx+1] == " " or cdx == seg_lens_cumsum[sub_seg_idx+1] - 1:
|
word_idx += 1
|
||||||
wdx += 1
|
|
||||||
|
|
||||||
if is_last or cdx == seg_lens_cumsum[sub_seg_idx+1] - 1:
|
|
||||||
wdx = 0
|
|
||||||
sub_seg_idx += 1
|
|
||||||
|
|
||||||
prev_t2 = segment["end"]
|
|
||||||
|
|
||||||
segment_align_success = True
|
|
||||||
# end while True loop
|
|
||||||
break
|
|
||||||
|
|
||||||
# reset prev_t2 due to drifting issues
|
|
||||||
if not segment_align_success:
|
|
||||||
prev_t2 = 0
|
|
||||||
|
|
||||||
char_segments_arr = pd.DataFrame(char_segments_arr)
|
char_segments_arr = pd.DataFrame(char_segments_arr)
|
||||||
not_space = char_segments_arr["char"] != " "
|
|
||||||
|
|
||||||
per_seg_grp = char_segments_arr.groupby(["segment-idx", "subsegment-idx"], as_index = False)
|
aligned_subsegments = []
|
||||||
char_segments_arr = per_seg_grp.apply(lambda x: x.reset_index(drop = True)).reset_index()
|
# assign sentence_idx to each character index
|
||||||
per_word_grp = char_segments_arr[not_space].groupby(["segment-idx", "subsegment-idx", "word-idx"])
|
char_segments_arr["sentence-idx"] = None
|
||||||
per_subseg_grp = char_segments_arr[not_space].groupby(["segment-idx", "subsegment-idx"])
|
for sdx, (sstart, send) in enumerate(segment["sentence_spans"]):
|
||||||
per_seg_grp = char_segments_arr[not_space].groupby(["segment-idx"])
|
curr_chars = char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send)]
|
||||||
char_segments_arr["local-char-idx"] = char_segments_arr.groupby(["segment-idx", "subsegment-idx"]).cumcount()
|
char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send), "sentence-idx"] = sdx
|
||||||
per_word_grp = char_segments_arr[not_space].groupby(["segment-idx", "subsegment-idx", "word-idx"]) # regroup
|
|
||||||
|
|
||||||
word_segments_arr = {}
|
sentence_text = text[sstart:send]
|
||||||
|
sentence_start = curr_chars["start"].min()
|
||||||
|
sentence_end = curr_chars["end"].max()
|
||||||
|
sentence_words = []
|
||||||
|
|
||||||
# start of word is first char with a timestamp
|
for word_idx in curr_chars["word-idx"].unique():
|
||||||
word_segments_arr["start"] = per_word_grp["start"].min().values
|
word_chars = curr_chars.loc[curr_chars["word-idx"] == word_idx]
|
||||||
# end of word is last char with a timestamp
|
word_text = "".join(word_chars["char"].tolist()).strip()
|
||||||
word_segments_arr["end"] = per_word_grp["end"].max().values
|
if len(word_text) == 0:
|
||||||
# score of word is mean (excluding nan)
|
continue
|
||||||
word_segments_arr["score"] = per_word_grp["score"].mean().values
|
word_start = word_chars["start"].min()
|
||||||
|
word_end = word_chars["end"].max()
|
||||||
|
word_score = round(word_chars["score"].mean(), 3)
|
||||||
|
|
||||||
word_segments_arr["segment-text-start"] = per_word_grp["local-char-idx"].min().astype(int).values
|
# -1 indicates unalignable
|
||||||
word_segments_arr["segment-text-end"] = per_word_grp["local-char-idx"].max().astype(int).values+1
|
word_segment = {"word": word_text}
|
||||||
word_segments_arr = pd.DataFrame(word_segments_arr)
|
|
||||||
|
|
||||||
word_segments_arr[["segment-idx", "subsegment-idx", "word-idx"]] = per_word_grp["local-char-idx"].min().reset_index()[["segment-idx", "subsegment-idx", "word-idx"]].astype(int)
|
if not np.isnan(word_start):
|
||||||
segments_arr = {}
|
word_segment["start"] = word_start
|
||||||
segments_arr["start"] = per_subseg_grp["start"].min().reset_index()["start"]
|
if not np.isnan(word_end):
|
||||||
segments_arr["end"] = per_subseg_grp["end"].max().reset_index()["end"]
|
word_segment["end"] = word_end
|
||||||
segments_arr = pd.DataFrame(segments_arr)
|
if not np.isnan(word_score):
|
||||||
segments_arr[["segment-idx", "subsegment-idx-start"]] = per_subseg_grp["start"].min().reset_index()[["segment-idx", "subsegment-idx"]]
|
word_segment["score"] = word_score
|
||||||
segments_arr["subsegment-idx-end"] = segments_arr["subsegment-idx-start"] + 1
|
|
||||||
|
|
||||||
# interpolate missing words / sub-segments
|
sentence_words.append(word_segment)
|
||||||
if interpolate_method != "ignore":
|
|
||||||
wrd_subseg_grp = word_segments_arr.groupby(["segment-idx", "subsegment-idx"], group_keys=False)
|
|
||||||
wrd_seg_grp = word_segments_arr.groupby(["segment-idx"], group_keys=False)
|
|
||||||
# we still know which word timestamps are interpolated because their score == nan
|
|
||||||
word_segments_arr["start"] = wrd_subseg_grp['start'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
|
|
||||||
word_segments_arr["end"] = wrd_subseg_grp['end'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
|
|
||||||
|
|
||||||
word_segments_arr["start"] = wrd_seg_grp['start'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
|
aligned_subsegments.append({
|
||||||
word_segments_arr["end"] = wrd_seg_grp['end'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
|
"text": sentence_text,
|
||||||
|
"start": sentence_start,
|
||||||
|
"end": sentence_end,
|
||||||
|
"words": sentence_words,
|
||||||
|
})
|
||||||
|
|
||||||
sub_seg_grp = segments_arr.groupby(["segment-idx"], group_keys=False)
|
if return_char_alignments:
|
||||||
segments_arr['start'] = sub_seg_grp['start'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
|
curr_chars = curr_chars[["char", "start", "end", "score"]]
|
||||||
segments_arr['end'] = sub_seg_grp['end'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
|
curr_chars.fillna(-1, inplace=True)
|
||||||
|
curr_chars = curr_chars.to_dict("records")
|
||||||
|
curr_chars = [{key: val for key, val in char.items() if val != -1} for char in curr_chars]
|
||||||
|
aligned_subsegments[-1]["chars"] = curr_chars
|
||||||
|
|
||||||
# merge words & subsegments which are missing times
|
aligned_subsegments = pd.DataFrame(aligned_subsegments)
|
||||||
word_grp = word_segments_arr.groupby(["segment-idx", "subsegment-idx", "end"])
|
aligned_subsegments["start"] = interpolate_nans(aligned_subsegments["start"], method=interpolate_method)
|
||||||
|
aligned_subsegments["end"] = interpolate_nans(aligned_subsegments["end"], method=interpolate_method)
|
||||||
|
# concatenate sentences with same timestamps
|
||||||
|
agg_dict = {"text": " ".join, "words": "sum"}
|
||||||
|
if return_char_alignments:
|
||||||
|
agg_dict["chars"] = "sum"
|
||||||
|
aligned_subsegments= aligned_subsegments.groupby(["start", "end"], as_index=False).agg(agg_dict)
|
||||||
|
aligned_subsegments = aligned_subsegments.to_dict('records')
|
||||||
|
aligned_segments += aligned_subsegments
|
||||||
|
|
||||||
word_segments_arr["segment-text-start"] = word_grp["segment-text-start"].transform(min)
|
# create word_segments list
|
||||||
word_segments_arr["segment-text-end"] = word_grp["segment-text-end"].transform(max)
|
word_segments = []
|
||||||
word_segments_arr.drop_duplicates(subset=["segment-idx", "subsegment-idx", "end"], inplace=True)
|
for segment in aligned_segments:
|
||||||
|
word_segments += segment["words"]
|
||||||
seg_grp_dup = segments_arr.groupby(["segment-idx", "start", "end"])
|
|
||||||
segments_arr["subsegment-idx-start"] = seg_grp_dup["subsegment-idx-start"].transform(min)
|
|
||||||
segments_arr["subsegment-idx-end"] = seg_grp_dup["subsegment-idx-end"].transform(max)
|
|
||||||
segments_arr.drop_duplicates(subset=["segment-idx", "subsegment-idx-start", "subsegment-idx-end"], inplace=True)
|
|
||||||
else:
|
|
||||||
word_segments_arr.dropna(inplace=True)
|
|
||||||
segments_arr.dropna(inplace=True)
|
|
||||||
|
|
||||||
# if some segments still have missing timestamps (usually because all numerals / symbols), then use original timestamps...
|
|
||||||
segments_arr['start'].fillna(pd.Series([x['start'] for x in transcript]), inplace=True)
|
|
||||||
segments_arr['end'].fillna(pd.Series([x['end'] for x in transcript]), inplace=True)
|
|
||||||
segments_arr['subsegment-idx-start'].fillna(0, inplace=True)
|
|
||||||
segments_arr['subsegment-idx-end'].fillna(1, inplace=True)
|
|
||||||
|
|
||||||
|
|
||||||
aligned_segments = []
|
|
||||||
aligned_segments_word = []
|
|
||||||
|
|
||||||
word_segments_arr.set_index(["segment-idx", "subsegment-idx"], inplace=True)
|
|
||||||
char_segments_arr.set_index(["segment-idx", "subsegment-idx", "word-idx"], inplace=True)
|
|
||||||
|
|
||||||
for sdx, srow in segments_arr.iterrows():
|
|
||||||
|
|
||||||
seg_idx = int(srow["segment-idx"])
|
|
||||||
sub_start = int(srow["subsegment-idx-start"])
|
|
||||||
sub_end = int(srow["subsegment-idx-end"])
|
|
||||||
|
|
||||||
seg = transcript[seg_idx]
|
|
||||||
text = "".join(seg["seg-text"][sub_start:sub_end])
|
|
||||||
|
|
||||||
wseg = word_segments_arr.loc[seg_idx].loc[sub_start:sub_end-1]
|
|
||||||
wseg["start"].fillna(srow["start"], inplace=True)
|
|
||||||
wseg["end"].fillna(srow["end"], inplace=True)
|
|
||||||
wseg["segment-text-start"].fillna(0, inplace=True)
|
|
||||||
wseg["segment-text-end"].fillna(len(text)-1, inplace=True)
|
|
||||||
|
|
||||||
cseg = char_segments_arr.loc[seg_idx].loc[sub_start:sub_end-1]
|
|
||||||
# fixes bug for single segment in transcript
|
|
||||||
cseg['segment-text-start'] = cseg['level_1'] if 'level_1' in cseg else 0
|
|
||||||
cseg['segment-text-end'] = cseg['level_1'] + 1 if 'level_1' in cseg else 1
|
|
||||||
if 'level_1' in cseg: del cseg['level_1']
|
|
||||||
if 'level_0' in cseg: del cseg['level_0']
|
|
||||||
cseg.reset_index(inplace=True)
|
|
||||||
|
|
||||||
def get_raw_text(word_row):
|
|
||||||
return seg["seg-text"][word_row.name][int(word_row["segment-text-start"]):int(word_row["segment-text-end"])+1]
|
|
||||||
|
|
||||||
word_list = []
|
|
||||||
wdx = 0
|
|
||||||
curr_text = get_raw_text(wseg.iloc[wdx])
|
|
||||||
if not curr_text.startswith(" "):
|
|
||||||
curr_text = " " + curr_text
|
|
||||||
|
|
||||||
if len(wseg) > 1:
|
|
||||||
for _, wrow in wseg.iloc[1:].iterrows():
|
|
||||||
if wrow['start'] != wseg.iloc[wdx]['start']:
|
|
||||||
word_start = wseg.iloc[wdx]['start']
|
|
||||||
word_end = wseg.iloc[wdx]['end']
|
|
||||||
|
|
||||||
aligned_segments_word.append(
|
|
||||||
{
|
|
||||||
"text": curr_text.strip(),
|
|
||||||
"start": word_start,
|
|
||||||
"end": word_end
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
word_list.append(
|
|
||||||
{
|
|
||||||
"word": curr_text.rstrip(),
|
|
||||||
"start": word_start,
|
|
||||||
"end": word_end,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
curr_text = " "
|
|
||||||
curr_text += get_raw_text(wrow) + " "
|
|
||||||
wdx += 1
|
|
||||||
|
|
||||||
aligned_segments_word.append(
|
|
||||||
{
|
|
||||||
"text": curr_text.strip(),
|
|
||||||
"start": wseg.iloc[wdx]["start"],
|
|
||||||
"end": wseg.iloc[wdx]["end"]
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
word_list.append(
|
|
||||||
{
|
|
||||||
"word": curr_text.rstrip(),
|
|
||||||
"start": word_start,
|
|
||||||
"end": word_end,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
aligned_segments.append(
|
|
||||||
{
|
|
||||||
"start": srow["start"],
|
|
||||||
"end": srow["end"],
|
|
||||||
"text": text,
|
|
||||||
"words": word_list,
|
|
||||||
"word-segments": wseg,
|
|
||||||
"char-segments": cseg
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
return {"segments": aligned_segments, "word_segments": aligned_segments_word}
|
|
||||||
|
|
||||||
|
return {"segments": aligned_segments, "word_segments": word_segments}
|
||||||
|
|
||||||
"""
|
"""
|
||||||
source: https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html
|
source: https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html
|
||||||
|
155
whisperx/asr.py
155
whisperx/asr.py
@ -78,7 +78,7 @@ def load_model(whisper_arch, device, compute_type="float16", asr_options=None, l
|
|||||||
class WhisperModel(faster_whisper.WhisperModel):
|
class WhisperModel(faster_whisper.WhisperModel):
|
||||||
'''
|
'''
|
||||||
FasterWhisperModel provides batched inference for faster-whisper.
|
FasterWhisperModel provides batched inference for faster-whisper.
|
||||||
Currently only works in non-timestamp mode.
|
Currently only works in non-timestamp mode and fixed prompt for all samples in batch.
|
||||||
'''
|
'''
|
||||||
|
|
||||||
def generate_segment_batched(self, features: np.ndarray, tokenizer: faster_whisper.tokenizer.Tokenizer, options: faster_whisper.transcribe.TranscriptionOptions, encoder_output = None):
|
def generate_segment_batched(self, features: np.ndarray, tokenizer: faster_whisper.tokenizer.Tokenizer, options: faster_whisper.transcribe.TranscriptionOptions, encoder_output = None):
|
||||||
@ -140,6 +140,13 @@ class WhisperModel(faster_whisper.WhisperModel):
|
|||||||
return self.model.encode(features, to_cpu=to_cpu)
|
return self.model.encode(features, to_cpu=to_cpu)
|
||||||
|
|
||||||
class FasterWhisperPipeline(Pipeline):
|
class FasterWhisperPipeline(Pipeline):
|
||||||
|
"""
|
||||||
|
Huggingface Pipeline wrapper for FasterWhisperModel.
|
||||||
|
"""
|
||||||
|
# TODO:
|
||||||
|
# - add support for timestamp mode
|
||||||
|
# - add support for custom inference kwargs
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model,
|
model,
|
||||||
@ -261,149 +268,3 @@ class FasterWhisperPipeline(Pipeline):
|
|||||||
language = language_token[2:-2]
|
language = language_token[2:-2]
|
||||||
print(f"Detected language: {language} ({language_probability:.2f}) in first 30s of audio...")
|
print(f"Detected language: {language} ({language_probability:.2f}) in first 30s of audio...")
|
||||||
return language
|
return language
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main_type = "simple"
|
|
||||||
import time
|
|
||||||
|
|
||||||
import jiwer
|
|
||||||
from tqdm import tqdm
|
|
||||||
from whisper.normalizers import EnglishTextNormalizer
|
|
||||||
|
|
||||||
from benchmark.tedlium import parse_tedlium_annos
|
|
||||||
|
|
||||||
if main_type == "complex":
|
|
||||||
from faster_whisper.tokenizer import Tokenizer
|
|
||||||
from faster_whisper.transcribe import TranscriptionOptions
|
|
||||||
from faster_whisper.vad import (SpeechTimestampsMap,
|
|
||||||
get_speech_timestamps)
|
|
||||||
|
|
||||||
from whisperx.vad import load_vad_model, merge_chunks
|
|
||||||
|
|
||||||
from .audio import SAMPLE_RATE, load_audio, log_mel_spectrogram
|
|
||||||
faster_t_options = TranscriptionOptions(
|
|
||||||
beam_size=5,
|
|
||||||
best_of=5,
|
|
||||||
patience=1,
|
|
||||||
length_penalty=1,
|
|
||||||
temperatures=[0.0, 0.2, 0.4, 0.6, 0.8, 1.0],
|
|
||||||
compression_ratio_threshold=2.4,
|
|
||||||
log_prob_threshold=-1.0,
|
|
||||||
no_speech_threshold=0.6,
|
|
||||||
condition_on_previous_text=False,
|
|
||||||
initial_prompt=None,
|
|
||||||
prefix=None,
|
|
||||||
suppress_blank=True,
|
|
||||||
suppress_tokens=[-1],
|
|
||||||
without_timestamps=True,
|
|
||||||
max_initial_timestamp=0.0,
|
|
||||||
word_timestamps=False,
|
|
||||||
prepend_punctuations="\"'“¿([{-",
|
|
||||||
append_punctuations="\"'.。,,!!??::”)]}、"
|
|
||||||
)
|
|
||||||
whisper_arch = "large-v2"
|
|
||||||
device = "cuda"
|
|
||||||
batch_size = 16
|
|
||||||
model = WhisperModel(whisper_arch, device="cuda", compute_type="float16",)
|
|
||||||
tokenizer = Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task="transcribe", language="en")
|
|
||||||
model = FasterWhisperPipeline(model, tokenizer, faster_t_options, device=-1)
|
|
||||||
fn = "DanielKahneman_2010.wav"
|
|
||||||
wav_dir = f"/tmp/test/wav/"
|
|
||||||
vad_model = load_vad_model("cuda", 0.6, 0.3)
|
|
||||||
audio = load_audio(os.path.join(wav_dir, fn))
|
|
||||||
vad_segments = vad_model({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": SAMPLE_RATE})
|
|
||||||
vad_segments = merge_chunks(vad_segments, 30)
|
|
||||||
|
|
||||||
def data(audio, segments):
|
|
||||||
for seg in segments:
|
|
||||||
f1 = int(seg['start'] * SAMPLE_RATE)
|
|
||||||
f2 = int(seg['end'] * SAMPLE_RATE)
|
|
||||||
# print(f2-f1)
|
|
||||||
yield {'inputs': audio[f1:f2]}
|
|
||||||
vad_method="pyannote"
|
|
||||||
|
|
||||||
wav_dir = f"/tmp/test/wav/"
|
|
||||||
wer_li = []
|
|
||||||
time_li = []
|
|
||||||
for fn in os.listdir(wav_dir):
|
|
||||||
if fn == "RobertGupta_2010U.wav":
|
|
||||||
continue
|
|
||||||
base_fn = fn.split('.')[0]
|
|
||||||
audio_fp = os.path.join(wav_dir, fn)
|
|
||||||
|
|
||||||
audio = load_audio(audio_fp)
|
|
||||||
t1 = time.time()
|
|
||||||
if vad_method == "pyannote":
|
|
||||||
vad_segments = vad_model({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": SAMPLE_RATE})
|
|
||||||
vad_segments = merge_chunks(vad_segments, 30)
|
|
||||||
elif vad_method == "silero":
|
|
||||||
vad_segments = get_speech_timestamps(audio, threshold=0.5, max_speech_duration_s=30)
|
|
||||||
vad_segments = [{"start": x["start"] / SAMPLE_RATE, "end": x["end"] / SAMPLE_RATE} for x in vad_segments]
|
|
||||||
new_segs = []
|
|
||||||
curr_start = vad_segments[0]['start']
|
|
||||||
curr_end = vad_segments[0]['end']
|
|
||||||
for seg in vad_segments[1:]:
|
|
||||||
if seg['end'] - curr_start > 30:
|
|
||||||
new_segs.append({"start": curr_start, "end": curr_end})
|
|
||||||
curr_start = seg['start']
|
|
||||||
curr_end = seg['end']
|
|
||||||
else:
|
|
||||||
curr_end = seg['end']
|
|
||||||
new_segs.append({"start": curr_start, "end": curr_end})
|
|
||||||
vad_segments = new_segs
|
|
||||||
text = []
|
|
||||||
# for idx, out in tqdm(enumerate(model(data(audio_fp, vad_segments), batch_size=batch_size)), total=len(vad_segments)):
|
|
||||||
for idx, out in enumerate(model(data(audio, vad_segments), batch_size=batch_size)):
|
|
||||||
text.append(out['text'])
|
|
||||||
t2 = time.time()
|
|
||||||
if batch_size == 1:
|
|
||||||
text = [x[0] for x in text]
|
|
||||||
text = " ".join(text)
|
|
||||||
|
|
||||||
normalizer = EnglishTextNormalizer()
|
|
||||||
text = normalizer(text)
|
|
||||||
gt_corpus = normalizer(parse_tedlium_annos(base_fn, "/tmp/test/"))
|
|
||||||
|
|
||||||
wer_result = jiwer.wer(gt_corpus, text)
|
|
||||||
print("WER: %.2f \t time: %.2f \t [%s]" % (wer_result * 100, t2-t1, fn))
|
|
||||||
|
|
||||||
wer_li.append(wer_result)
|
|
||||||
time_li.append(t2-t1)
|
|
||||||
print("# Avg Mean...")
|
|
||||||
print("WER: %.2f" % (sum(wer_li) * 100/len(wer_li)))
|
|
||||||
print("Time: %.2f" % (sum(time_li)/len(time_li)))
|
|
||||||
elif main_type == "simple":
|
|
||||||
model = load_model(
|
|
||||||
"large-v2",
|
|
||||||
device="cuda",
|
|
||||||
language="en",
|
|
||||||
)
|
|
||||||
|
|
||||||
wav_dir = f"/tmp/test/wav/"
|
|
||||||
wer_li = []
|
|
||||||
time_li = []
|
|
||||||
for fn in os.listdir(wav_dir):
|
|
||||||
if fn == "RobertGupta_2010U.wav":
|
|
||||||
continue
|
|
||||||
# fn = "DanielKahneman_2010.wav"
|
|
||||||
base_fn = fn.split('.')[0]
|
|
||||||
audio_fp = os.path.join(wav_dir, fn)
|
|
||||||
|
|
||||||
audio = load_audio(audio_fp)
|
|
||||||
t1 = time.time()
|
|
||||||
out = model.transcribe(audio_fp, batch_size=8)["segments"]
|
|
||||||
t2 = time.time()
|
|
||||||
|
|
||||||
text = " ".join([x['text'] for x in out])
|
|
||||||
normalizer = EnglishTextNormalizer()
|
|
||||||
text = normalizer(text)
|
|
||||||
gt_corpus = normalizer(parse_tedlium_annos(base_fn, "/tmp/test/"))
|
|
||||||
|
|
||||||
wer_result = jiwer.wer(gt_corpus, text)
|
|
||||||
print("WER: %.2f \t time: %.2f \t [%s]" % (wer_result * 100, t2-t1, fn))
|
|
||||||
|
|
||||||
wer_li.append(wer_result)
|
|
||||||
time_li.append(t2-t1)
|
|
||||||
print("# Avg Mean...")
|
|
||||||
print("WER: %.2f" % (sum(wer_li) * 100/len(wer_li)))
|
|
||||||
print("Time: %.2f" % (sum(time_li)/len(time_li)))
|
|
||||||
|
@ -1,73 +1,63 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from pyannote.audio import Pipeline
|
from pyannote.audio import Pipeline
|
||||||
|
from typing import Optional, Union
|
||||||
|
import torch
|
||||||
|
|
||||||
class DiarizationPipeline:
|
class DiarizationPipeline:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_name="pyannote/speaker-diarization@2.1",
|
model_name="pyannote/speaker-diarization@2.1",
|
||||||
use_auth_token=None,
|
use_auth_token=None,
|
||||||
|
device: Optional[Union[str, torch.device]] = "cpu",
|
||||||
):
|
):
|
||||||
self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token)
|
if isinstance(device, str):
|
||||||
|
device = torch.device(device)
|
||||||
|
self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token).to(device)
|
||||||
|
|
||||||
def __call__(self, audio, min_speakers=None, max_speakers=None):
|
def __call__(self, audio, min_speakers=None, max_speakers=None):
|
||||||
segments = self.model(audio, min_speakers=min_speakers, max_speakers=max_speakers)
|
segments = self.model(audio, min_speakers=min_speakers, max_speakers=max_speakers)
|
||||||
diarize_df = pd.DataFrame(segments.itertracks(yield_label=True))
|
diarize_df = pd.DataFrame(segments.itertracks(yield_label=True))
|
||||||
diarize_df['start'] = diarize_df[0].apply(lambda x: x.start)
|
diarize_df['start'] = diarize_df[0].apply(lambda x: x.start)
|
||||||
diarize_df['end'] = diarize_df[0].apply(lambda x: x.end)
|
diarize_df['end'] = diarize_df[0].apply(lambda x: x.end)
|
||||||
|
diarize_df.rename(columns={2: "speaker"}, inplace=True)
|
||||||
return diarize_df
|
return diarize_df
|
||||||
|
|
||||||
def assign_word_speakers(diarize_df, result_segments, fill_nearest=False):
|
|
||||||
for seg in result_segments:
|
def assign_word_speakers(diarize_df, transcript_result, fill_nearest=False):
|
||||||
wdf = seg['word-segments']
|
transcript_segments = transcript_result["segments"]
|
||||||
if len(wdf['start'].dropna()) == 0:
|
for seg in transcript_segments:
|
||||||
wdf['start'] = seg['start']
|
# assign speaker to segment (if any)
|
||||||
wdf['end'] = seg['end']
|
diarize_df['intersection'] = np.minimum(diarize_df['end'], seg['end']) - np.maximum(diarize_df['start'], seg['start'])
|
||||||
speakers = []
|
diarize_df['union'] = np.maximum(diarize_df['end'], seg['end']) - np.minimum(diarize_df['start'], seg['start'])
|
||||||
for wdx, wrow in wdf.iterrows():
|
# remove no hit, otherwise we look for closest (even negative intersection...)
|
||||||
if not np.isnan(wrow['start']):
|
if not fill_nearest:
|
||||||
diarize_df['intersection'] = np.minimum(diarize_df['end'], wrow['end']) - np.maximum(diarize_df['start'], wrow['start'])
|
dia_tmp = diarize_df[diarize_df['intersection'] > 0]
|
||||||
diarize_df['union'] = np.maximum(diarize_df['end'], wrow['end']) - np.minimum(diarize_df['start'], wrow['start'])
|
else:
|
||||||
|
dia_tmp = diarize_df
|
||||||
|
if len(dia_tmp) > 0:
|
||||||
|
# sum over speakers
|
||||||
|
speaker = dia_tmp.groupby("speaker")["intersection"].sum().sort_values(ascending=False).index[0]
|
||||||
|
seg["speaker"] = speaker
|
||||||
|
|
||||||
|
# assign speaker to words
|
||||||
|
if 'words' in seg:
|
||||||
|
for word in seg['words']:
|
||||||
|
if 'start' in word:
|
||||||
|
diarize_df['intersection'] = np.minimum(diarize_df['end'], word['end']) - np.maximum(diarize_df['start'], word['start'])
|
||||||
|
diarize_df['union'] = np.maximum(diarize_df['end'], word['end']) - np.minimum(diarize_df['start'], word['start'])
|
||||||
# remove no hit
|
# remove no hit
|
||||||
if not fill_nearest:
|
if not fill_nearest:
|
||||||
dia_tmp = diarize_df[diarize_df['intersection'] > 0]
|
dia_tmp = diarize_df[diarize_df['intersection'] > 0]
|
||||||
else:
|
else:
|
||||||
dia_tmp = diarize_df
|
dia_tmp = diarize_df
|
||||||
if len(dia_tmp) == 0:
|
if len(dia_tmp) > 0:
|
||||||
speaker = None
|
# sum over speakers
|
||||||
else:
|
speaker = dia_tmp.groupby("speaker")["intersection"].sum().sort_values(ascending=False).index[0]
|
||||||
speaker = dia_tmp.sort_values("intersection", ascending=False).iloc[0][2]
|
word["speaker"] = speaker
|
||||||
else:
|
|
||||||
speaker = None
|
|
||||||
speakers.append(speaker)
|
|
||||||
seg['word-segments']['speaker'] = speakers
|
|
||||||
|
|
||||||
speaker_count = pd.Series(speakers).value_counts()
|
return transcript_result
|
||||||
if len(speaker_count) == 0:
|
|
||||||
seg["speaker"]= "UNKNOWN"
|
|
||||||
else:
|
|
||||||
seg["speaker"] = speaker_count.index[0]
|
|
||||||
|
|
||||||
# create word level segments for .srt
|
|
||||||
word_seg = []
|
|
||||||
for seg in result_segments:
|
|
||||||
wseg = pd.DataFrame(seg["word-segments"])
|
|
||||||
for wdx, wrow in wseg.iterrows():
|
|
||||||
if wrow["start"] is not None:
|
|
||||||
speaker = wrow['speaker']
|
|
||||||
if speaker is None or speaker == np.nan:
|
|
||||||
speaker = "UNKNOWN"
|
|
||||||
word_seg.append(
|
|
||||||
{
|
|
||||||
"start": wrow["start"],
|
|
||||||
"end": wrow["end"],
|
|
||||||
"text": f"[{speaker}]: " + seg["text"][int(wrow["segment-text-start"]):int(wrow["segment-text-end"])]
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO: create segments but split words on new speaker
|
|
||||||
|
|
||||||
return result_segments, word_seg
|
|
||||||
|
|
||||||
class Segment:
|
class Segment:
|
||||||
def __init__(self, start, end, speaker=None):
|
def __init__(self, start, end, speaker=None):
|
||||||
|
@ -35,6 +35,7 @@ def cli():
|
|||||||
parser.add_argument("--align_model", default=None, help="Name of phoneme-level ASR model to do alignment")
|
parser.add_argument("--align_model", default=None, help="Name of phoneme-level ASR model to do alignment")
|
||||||
parser.add_argument("--interpolate_method", default="nearest", choices=["nearest", "linear", "ignore"], help="For word .srt, method to assign timestamps to non-aligned words, or merge them into neighbouring.")
|
parser.add_argument("--interpolate_method", default="nearest", choices=["nearest", "linear", "ignore"], help="For word .srt, method to assign timestamps to non-aligned words, or merge them into neighbouring.")
|
||||||
parser.add_argument("--no_align", action='store_true', help="Do not perform phoneme alignment")
|
parser.add_argument("--no_align", action='store_true', help="Do not perform phoneme alignment")
|
||||||
|
parser.add_argument("--return_char_alignments", action='store_true', help="Return character-level alignments in the output json file")
|
||||||
|
|
||||||
# vad params
|
# vad params
|
||||||
parser.add_argument("--vad_onset", type=float, default=0.500, help="Onset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected")
|
parser.add_argument("--vad_onset", type=float, default=0.500, help="Onset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected")
|
||||||
@ -42,8 +43,8 @@ def cli():
|
|||||||
|
|
||||||
# diarization params
|
# diarization params
|
||||||
parser.add_argument("--diarize", action="store_true", help="Apply diarization to assign speaker labels to each segment/word")
|
parser.add_argument("--diarize", action="store_true", help="Apply diarization to assign speaker labels to each segment/word")
|
||||||
parser.add_argument("--min_speakers", default=None, type=int)
|
parser.add_argument("--min_speakers", default=None, type=int, help="Minimum number of speakers to in audio file")
|
||||||
parser.add_argument("--max_speakers", default=None, type=int)
|
parser.add_argument("--max_speakers", default=None, type=int, help="Maximum number of speakers to in audio file")
|
||||||
|
|
||||||
parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
|
parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
|
||||||
parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
|
parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
|
||||||
@ -64,15 +65,11 @@ def cli():
|
|||||||
parser.add_argument("--max_line_width", type=optional_int, default=None, help="(not possible with --no_align) the maximum number of characters in a line before breaking the line")
|
parser.add_argument("--max_line_width", type=optional_int, default=None, help="(not possible with --no_align) the maximum number of characters in a line before breaking the line")
|
||||||
parser.add_argument("--max_line_count", type=optional_int, default=None, help="(requires --no_align) the maximum number of lines in a segment")
|
parser.add_argument("--max_line_count", type=optional_int, default=None, help="(requires --no_align) the maximum number of lines in a segment")
|
||||||
parser.add_argument("--highlight_words", type=str2bool, default=False, help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt")
|
parser.add_argument("--highlight_words", type=str2bool, default=False, help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt")
|
||||||
|
parser.add_argument("--segment_resolution", type=str, default="sentence", choices=["sentence", "chunk"], help="(not possible with --no_align) the maximum number of characters in a line before breaking the line")
|
||||||
|
|
||||||
# parser.add_argument("--word_timestamps", type=str2bool, default=False, help="(experimental) extract word-level timestamps and refine the results based on them")
|
|
||||||
# parser.add_argument("--prepend_punctuations", type=str, default="\"\'“¿([{-", help="if word_timestamps is True, merge these punctuation symbols with the next word")
|
|
||||||
# parser.add_argument("--append_punctuations", type=str, default="\"\'.。,,!!??::”)]}、", help="if word_timestamps is True, merge these punctuation symbols with the previous word")
|
|
||||||
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
|
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
|
||||||
|
|
||||||
parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face Access Token to access PyAnnote gated models")
|
parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face Access Token to access PyAnnote gated models")
|
||||||
# parser.add_argument("--model_flush", action="store_true", help="Flush memory from each model after use, reduces GPU requirement but slower processing >1 audio file.")
|
|
||||||
parser.add_argument("--tmp_dir", default=None, help="Temporary directory to write audio file if input if not .wav format (only for VAD).")
|
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
args = parser.parse_args().__dict__
|
args = parser.parse_args().__dict__
|
||||||
@ -86,13 +83,10 @@ def cli():
|
|||||||
# model_flush: bool = args.pop("model_flush")
|
# model_flush: bool = args.pop("model_flush")
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
tmp_dir: str = args.pop("tmp_dir")
|
|
||||||
if tmp_dir is not None:
|
|
||||||
os.makedirs(tmp_dir, exist_ok=True)
|
|
||||||
|
|
||||||
align_model: str = args.pop("align_model")
|
align_model: str = args.pop("align_model")
|
||||||
interpolate_method: str = args.pop("interpolate_method")
|
interpolate_method: str = args.pop("interpolate_method")
|
||||||
no_align: bool = args.pop("no_align")
|
no_align: bool = args.pop("no_align")
|
||||||
|
return_char_alignments: bool = args.pop("return_char_alignments")
|
||||||
|
|
||||||
hf_token: str = args.pop("hf_token")
|
hf_token: str = args.pop("hf_token")
|
||||||
vad_onset: float = args.pop("vad_onset")
|
vad_onset: float = args.pop("vad_onset")
|
||||||
@ -102,7 +96,6 @@ def cli():
|
|||||||
min_speakers: int = args.pop("min_speakers")
|
min_speakers: int = args.pop("min_speakers")
|
||||||
max_speakers: int = args.pop("max_speakers")
|
max_speakers: int = args.pop("max_speakers")
|
||||||
|
|
||||||
# TODO: check model loading works.
|
|
||||||
|
|
||||||
if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
|
if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
|
||||||
if args["language"] is not None:
|
if args["language"] is not None:
|
||||||
@ -180,7 +173,8 @@ def cli():
|
|||||||
print(f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language...")
|
print(f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language...")
|
||||||
align_model, align_metadata = load_align_model(result["language"], device)
|
align_model, align_metadata = load_align_model(result["language"], device)
|
||||||
print(">>Performing alignment...")
|
print(">>Performing alignment...")
|
||||||
result = align(result["segments"], align_model, align_metadata, input_audio, device, interpolate_method=interpolate_method)
|
result = align(result["segments"], align_model, align_metadata, input_audio, device, interpolate_method=interpolate_method, return_char_alignments=return_char_alignments)
|
||||||
|
|
||||||
results.append((result, audio_path))
|
results.append((result, audio_path))
|
||||||
|
|
||||||
# Unload align model
|
# Unload align model
|
||||||
@ -193,22 +187,15 @@ def cli():
|
|||||||
if hf_token is None:
|
if hf_token is None:
|
||||||
print("Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model...")
|
print("Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model...")
|
||||||
tmp_results = results
|
tmp_results = results
|
||||||
|
print(">>Performing diarization...")
|
||||||
results = []
|
results = []
|
||||||
diarize_model = DiarizationPipeline(use_auth_token=hf_token)
|
diarize_model = DiarizationPipeline(use_auth_token=hf_token, device=device)
|
||||||
for result, input_audio_path in tmp_results:
|
for result, input_audio_path in tmp_results:
|
||||||
diarize_segments = diarize_model(input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers)
|
diarize_segments = diarize_model(input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers)
|
||||||
results_segments, word_segments = assign_word_speakers(diarize_segments, result["segments"])
|
result = assign_word_speakers(diarize_segments, result)
|
||||||
result = {"segments": results_segments, "word_segments": word_segments}
|
|
||||||
results.append((result, input_audio_path))
|
results.append((result, input_audio_path))
|
||||||
|
|
||||||
# >> Write
|
# >> Write
|
||||||
for result, audio_path in results:
|
for result, audio_path in results:
|
||||||
# Remove pandas dataframes from result so that
|
|
||||||
# we can serialize the result with json
|
|
||||||
for seg in result["segments"]:
|
|
||||||
seg.pop("word-segments", None)
|
|
||||||
seg.pop("char-segments", None)
|
|
||||||
|
|
||||||
writer(result, audio_path, writer_args)
|
writer(result, audio_path, writer_args)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -231,11 +231,16 @@ class SubtitlesWriter(ResultWriter):
|
|||||||
line_count = 1
|
line_count = 1
|
||||||
# the next subtitle to yield (a list of word timings with whitespace)
|
# the next subtitle to yield (a list of word timings with whitespace)
|
||||||
subtitle: list[dict] = []
|
subtitle: list[dict] = []
|
||||||
last = result["segments"][0]["words"][0]["start"]
|
times = []
|
||||||
|
last = result["segments"][0]["start"]
|
||||||
for segment in result["segments"]:
|
for segment in result["segments"]:
|
||||||
for i, original_timing in enumerate(segment["words"]):
|
for i, original_timing in enumerate(segment["words"]):
|
||||||
timing = original_timing.copy()
|
timing = original_timing.copy()
|
||||||
long_pause = not preserve_segments and timing["start"] - last > 3.0
|
long_pause = not preserve_segments
|
||||||
|
if "start" in timing:
|
||||||
|
long_pause = long_pause and timing["start"] - last > 3.0
|
||||||
|
else:
|
||||||
|
long_pause = False
|
||||||
has_room = line_len + len(timing["word"]) <= max_line_width
|
has_room = line_len + len(timing["word"]) <= max_line_width
|
||||||
seg_break = i == 0 and len(subtitle) > 0 and preserve_segments
|
seg_break = i == 0 and len(subtitle) > 0 and preserve_segments
|
||||||
if line_len > 0 and has_room and not long_pause and not seg_break:
|
if line_len > 0 and has_room and not long_pause and not seg_break:
|
||||||
@ -251,8 +256,9 @@ class SubtitlesWriter(ResultWriter):
|
|||||||
or seg_break
|
or seg_break
|
||||||
):
|
):
|
||||||
# subtitle break
|
# subtitle break
|
||||||
yield subtitle
|
yield subtitle, times
|
||||||
subtitle = []
|
subtitle = []
|
||||||
|
times = []
|
||||||
line_count = 1
|
line_count = 1
|
||||||
elif line_len > 0:
|
elif line_len > 0:
|
||||||
# line break
|
# line break
|
||||||
@ -260,25 +266,36 @@ class SubtitlesWriter(ResultWriter):
|
|||||||
timing["word"] = "\n" + timing["word"]
|
timing["word"] = "\n" + timing["word"]
|
||||||
line_len = len(timing["word"].strip())
|
line_len = len(timing["word"].strip())
|
||||||
subtitle.append(timing)
|
subtitle.append(timing)
|
||||||
|
times.append((segment["start"], segment["end"], segment.get("speaker")))
|
||||||
|
if "start" in timing:
|
||||||
last = timing["start"]
|
last = timing["start"]
|
||||||
if len(subtitle) > 0:
|
if len(subtitle) > 0:
|
||||||
yield subtitle
|
yield subtitle, times
|
||||||
|
|
||||||
if "words" in result["segments"][0]:
|
if "words" in result["segments"][0]:
|
||||||
for subtitle in iterate_subtitles():
|
for subtitle, _ in iterate_subtitles():
|
||||||
subtitle_start = self.format_timestamp(subtitle[0]["start"])
|
sstart, ssend, speaker = _[0]
|
||||||
subtitle_end = self.format_timestamp(subtitle[-1]["end"])
|
subtitle_start = self.format_timestamp(sstart)
|
||||||
subtitle_text = "".join([word["word"] for word in subtitle])
|
subtitle_end = self.format_timestamp(ssend)
|
||||||
if highlight_words:
|
subtitle_text = " ".join([word["word"] for word in subtitle])
|
||||||
|
has_timing = any(["start" in word for word in subtitle])
|
||||||
|
|
||||||
|
# add [$SPEAKER_ID]: to each subtitle if speaker is available
|
||||||
|
prefix = ""
|
||||||
|
if speaker is not None:
|
||||||
|
prefix = f"[{speaker}]: "
|
||||||
|
|
||||||
|
if highlight_words and has_timing:
|
||||||
last = subtitle_start
|
last = subtitle_start
|
||||||
all_words = [timing["word"] for timing in subtitle]
|
all_words = [timing["word"] for timing in subtitle]
|
||||||
for i, this_word in enumerate(subtitle):
|
for i, this_word in enumerate(subtitle):
|
||||||
|
if "start" in this_word:
|
||||||
start = self.format_timestamp(this_word["start"])
|
start = self.format_timestamp(this_word["start"])
|
||||||
end = self.format_timestamp(this_word["end"])
|
end = self.format_timestamp(this_word["end"])
|
||||||
if last != start:
|
if last != start:
|
||||||
yield last, start, subtitle_text
|
yield last, start, subtitle_text
|
||||||
|
|
||||||
yield start, end, "".join(
|
yield start, end, prefix + " ".join(
|
||||||
[
|
[
|
||||||
re.sub(r"^(\s*)(.*)$", r"\1<u>\2</u>", word)
|
re.sub(r"^(\s*)(.*)$", r"\1<u>\2</u>", word)
|
||||||
if j == i
|
if j == i
|
||||||
@ -288,12 +305,14 @@ class SubtitlesWriter(ResultWriter):
|
|||||||
)
|
)
|
||||||
last = end
|
last = end
|
||||||
else:
|
else:
|
||||||
yield subtitle_start, subtitle_end, subtitle_text
|
yield subtitle_start, subtitle_end, prefix + subtitle_text
|
||||||
else:
|
else:
|
||||||
for segment in result["segments"]:
|
for segment in result["segments"]:
|
||||||
segment_start = self.format_timestamp(segment["start"])
|
segment_start = self.format_timestamp(segment["start"])
|
||||||
segment_end = self.format_timestamp(segment["end"])
|
segment_end = self.format_timestamp(segment["end"])
|
||||||
segment_text = segment["text"].strip().replace("-->", "->")
|
segment_text = segment["text"].strip().replace("-->", "->")
|
||||||
|
if "speaker" in segment:
|
||||||
|
segment_text = f"[{segment['speaker']}]: {segment_text}"
|
||||||
yield segment_start, segment_end, segment_text
|
yield segment_start, segment_end, segment_text
|
||||||
|
|
||||||
def format_timestamp(self, seconds: float):
|
def format_timestamp(self, seconds: float):
|
||||||
|
Reference in New Issue
Block a user