mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
Compare commits
70 Commits
v3.3.0
...
3dfe6c6ea0
Author | SHA1 | Date | |
---|---|---|---|
3dfe6c6ea0 | |||
d700b56c9c | |||
b343241253 | |||
6fe0a8784a | |||
5012650d0f | |||
108bd0c400 | |||
b2d50a027b | |||
36d552cad3 | |||
7d36b832f9 | |||
d2a493e910 | |||
f5b40b5366 | |||
ac0c8bd79a | |||
cd59f21d1a | |||
0aed874589 | |||
f10dbf6ab1 | |||
a7564c2ad6 | |||
e7712f496e | |||
8e53866704 | |||
3205436d58 | |||
8c58c54635 | |||
0d9807adc5 | |||
4db839018c | |||
f8d11df727 | |||
d2f0e53f71 | |||
7489ebf876 | |||
90256cc481 | |||
b41ebd4871 | |||
63bc1903c1 | |||
272714e07d | |||
44e8bf5bb6 | |||
7b3c9ce629 | |||
36d2622e27 | |||
8bfa12193b | |||
acbeba6057 | |||
fca563a782 | |||
2117909bf6 | |||
de0d8fe313 | |||
355f8e06f7 | |||
86e2b3ee74 | |||
70c639cdb5 | |||
235536e28d | |||
12604a48ea | |||
ffbc73664c | |||
289eadfc76 | |||
22a93f2932 | |||
1027367b79 | |||
5e54b872a9 | |||
6be02cccfa | |||
2f93e029c7 | |||
024bc8481b | |||
f286e7f3de | |||
73e644559d | |||
1ec527375a | |||
6695426a85 | |||
7a98456321 | |||
aaddb83aa5 | |||
c288f4812a | |||
4ebfb078c5 | |||
65b2332e13 | |||
69281f3a29 | |||
734084cdf6 | |||
9395b0de18 | |||
d57f9dc54c | |||
a90bd1ce3f | |||
79eb8fa53d | |||
10b05fc43f | |||
26d9b46888 | |||
9a8967f27e | |||
0f7f9f9f83 | |||
c60594fa3b |
23
.github/workflows/build-and-release.yml
vendored
23
.github/workflows/build-and-release.yml
vendored
@ -11,25 +11,24 @@ jobs:
|
|||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Set up Python
|
- name: Install uv
|
||||||
uses: actions/setup-python@v5
|
uses: astral-sh/setup-uv@v5
|
||||||
with:
|
with:
|
||||||
|
version: "0.5.14"
|
||||||
python-version: "3.9"
|
python-version: "3.9"
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Check if lockfile is up to date
|
||||||
run: |
|
run: uv lock --check
|
||||||
python -m pip install build
|
|
||||||
|
|
||||||
- name: Build wheels
|
- name: Build package
|
||||||
run: python -m build --wheel
|
run: uv build
|
||||||
|
|
||||||
- name: Release to Github
|
- name: Release to Github
|
||||||
uses: softprops/action-gh-release@v2
|
uses: softprops/action-gh-release@v2
|
||||||
with:
|
with:
|
||||||
files: dist/*
|
files: dist/*.whl
|
||||||
|
|
||||||
- name: Publish package to PyPi
|
- name: Publish package to PyPi
|
||||||
uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
|
run: uv publish
|
||||||
with:
|
env:
|
||||||
user: __token__
|
UV_PUBLISH_TOKEN: ${{ secrets.PYPI_API_TOKEN }}
|
||||||
password: ${{ secrets.PYPI_API_TOKEN }}
|
|
||||||
|
18
.github/workflows/python-compatibility.yml
vendored
18
.github/workflows/python-compatibility.yml
vendored
@ -5,7 +5,7 @@ on:
|
|||||||
branches: [main]
|
branches: [main]
|
||||||
pull_request:
|
pull_request:
|
||||||
branches: [main]
|
branches: [main]
|
||||||
workflow_dispatch: # Allows manual triggering from GitHub UI
|
workflow_dispatch: # Allows manual triggering from GitHub UI
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
test:
|
test:
|
||||||
@ -17,16 +17,18 @@ jobs:
|
|||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Set up Python ${{ matrix.python-version }}
|
- name: Install uv
|
||||||
uses: actions/setup-python@v5
|
uses: astral-sh/setup-uv@v5
|
||||||
with:
|
with:
|
||||||
|
version: "0.5.14"
|
||||||
python-version: ${{ matrix.python-version }}
|
python-version: ${{ matrix.python-version }}
|
||||||
|
|
||||||
- name: Install package
|
- name: Check if lockfile is up to date
|
||||||
run: |
|
run: uv lock --check
|
||||||
python -m pip install --upgrade pip
|
|
||||||
pip install .
|
- name: Install the project
|
||||||
|
run: uv sync --all-extras
|
||||||
|
|
||||||
- name: Test import
|
- name: Test import
|
||||||
run: |
|
run: |
|
||||||
python -c "import whisperx; print('Successfully imported whisperx')"
|
uv run python -c "import whisperx; print('Successfully imported whisperx')"
|
||||||
|
180
README.md
180
README.md
@ -22,26 +22,20 @@
|
|||||||
</a>
|
</a>
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
|
|
||||||
<img width="1216" align="center" alt="whisperx-arch" src="https://raw.githubusercontent.com/m-bain/whisperX/refs/heads/main/figures/pipeline.png">
|
<img width="1216" align="center" alt="whisperx-arch" src="https://raw.githubusercontent.com/m-bain/whisperX/refs/heads/main/figures/pipeline.png">
|
||||||
|
|
||||||
|
|
||||||
<!-- <p align="left">Whisper-Based Automatic Speech Recognition (ASR) with improved timestamp accuracy + quality via forced phoneme alignment and voice-activity based batching for fast inference.</p> -->
|
<!-- <p align="left">Whisper-Based Automatic Speech Recognition (ASR) with improved timestamp accuracy + quality via forced phoneme alignment and voice-activity based batching for fast inference.</p> -->
|
||||||
|
|
||||||
|
|
||||||
<!-- <h2 align="left", id="what-is-it">What is it 🔎</h2> -->
|
<!-- <h2 align="left", id="what-is-it">What is it 🔎</h2> -->
|
||||||
|
|
||||||
|
|
||||||
This repository provides fast automatic speech recognition (70x realtime with large-v2) with word-level timestamps and speaker diarization.
|
This repository provides fast automatic speech recognition (70x realtime with large-v2) with word-level timestamps and speaker diarization.
|
||||||
|
|
||||||
- ⚡️ Batched inference for 70x realtime transcription using whisper large-v2
|
- ⚡️ Batched inference for 70x realtime transcription using whisper large-v2
|
||||||
- 🪶 [faster-whisper](https://github.com/guillaumekln/faster-whisper) backend, requires <8GB gpu memory for large-v2 with beam_size=5
|
- 🪶 [faster-whisper](https://github.com/guillaumekln/faster-whisper) backend, requires <8GB gpu memory for large-v2 with beam_size=5
|
||||||
- 🎯 Accurate word-level timestamps using wav2vec2 alignment
|
- 🎯 Accurate word-level timestamps using wav2vec2 alignment
|
||||||
- 👯♂️ Multispeaker ASR using speaker diarization from [pyannote-audio](https://github.com/pyannote/pyannote-audio) (speaker ID labels)
|
- 👯♂️ Multispeaker ASR using speaker diarization from [pyannote-audio](https://github.com/pyannote/pyannote-audio) (speaker ID labels)
|
||||||
- 🗣️ VAD preprocessing, reduces hallucination & batching with no WER degradation
|
- 🗣️ VAD preprocessing, reduces hallucination & batching with no WER degradation
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
**Whisper** is an ASR model [developed by OpenAI](https://github.com/openai/whisper), trained on a large dataset of diverse audio. Whilst it does produces highly accurate transcriptions, the corresponding timestamps are at the utterance-level, not per word, and can be inaccurate by several seconds. OpenAI's whisper does not natively support batching.
|
**Whisper** is an ASR model [developed by OpenAI](https://github.com/openai/whisper), trained on a large dataset of diverse audio. Whilst it does produces highly accurate transcriptions, the corresponding timestamps are at the utterance-level, not per word, and can be inaccurate by several seconds. OpenAI's whisper does not natively support batching.
|
||||||
|
|
||||||
**Phoneme-Based ASR** A suite of models finetuned to recognise the smallest unit of speech distinguishing one word from another, e.g. the element p in "tap". A popular example model is [wav2vec2.0](https://huggingface.co/facebook/wav2vec2-large-960h-lv60-self).
|
**Phoneme-Based ASR** A suite of models finetuned to recognise the smallest unit of speech distinguishing one word from another, e.g. the element p in "tap". A popular example model is [wav2vec2.0](https://huggingface.co/facebook/wav2vec2-large-960h-lv60-self).
|
||||||
@ -54,85 +48,102 @@ This repository provides fast automatic speech recognition (70x realtime with la
|
|||||||
|
|
||||||
<h2 align="left", id="highlights">New🚨</h2>
|
<h2 align="left", id="highlights">New🚨</h2>
|
||||||
|
|
||||||
- 1st place at [Ego4d transcription challenge](https://eval.ai/web/challenges/challenge-page/1637/leaderboard/3931/WER) 🏆
|
- 1st place at [Ego4d transcription challenge](https://eval.ai/web/challenges/challenge-page/1637/leaderboard/3931/WER) 🏆
|
||||||
- _WhisperX_ accepted at INTERSPEECH 2023
|
- _WhisperX_ accepted at INTERSPEECH 2023
|
||||||
- v3 transcript segment-per-sentence: using nltk sent_tokenize for better subtitlting & better diarization
|
- v3 transcript segment-per-sentence: using nltk sent_tokenize for better subtitlting & better diarization
|
||||||
- v3 released, 70x speed-up open-sourced. Using batched whisper with [faster-whisper](https://github.com/guillaumekln/faster-whisper) backend!
|
- v3 released, 70x speed-up open-sourced. Using batched whisper with [faster-whisper](https://github.com/guillaumekln/faster-whisper) backend!
|
||||||
- v2 released, code cleanup, imports whisper library VAD filtering is now turned on by default, as in the paper.
|
- v2 released, code cleanup, imports whisper library VAD filtering is now turned on by default, as in the paper.
|
||||||
- Paper drop🎓👨🏫! Please see our [ArxiV preprint](https://arxiv.org/abs/2303.00747) for benchmarking and details of WhisperX. We also introduce more efficient batch inference resulting in large-v2 with *60-70x REAL TIME speed.
|
- Paper drop🎓👨🏫! Please see our [ArxiV preprint](https://arxiv.org/abs/2303.00747) for benchmarking and details of WhisperX. We also introduce more efficient batch inference resulting in large-v2 with \*60-70x REAL TIME speed.
|
||||||
|
|
||||||
<h2 align="left" id="setup">Setup ⚙️</h2>
|
<h2 align="left" id="setup">Setup ⚙️</h2>
|
||||||
Tested for PyTorch 2.0, Python 3.10 (use other versions at your own risk!)
|
|
||||||
|
|
||||||
GPU execution requires the NVIDIA libraries cuBLAS 11.x and cuDNN 8.x to be installed on the system. Please refer to the [CTranslate2 documentation](https://opennmt.net/CTranslate2/installation.html).
|
### 1. Simple Installation (Recommended)
|
||||||
|
|
||||||
|
The easiest way to install WhisperX is through PyPi:
|
||||||
### 1. Create Python3.10 environment
|
|
||||||
|
|
||||||
`conda create --name whisperx python=3.10`
|
|
||||||
|
|
||||||
`conda activate whisperx`
|
|
||||||
|
|
||||||
|
|
||||||
### 2. Install PyTorch, e.g. for Linux and Windows CUDA11.8:
|
|
||||||
|
|
||||||
`conda install pytorch==2.0.0 torchaudio==2.0.0 pytorch-cuda=11.8 -c pytorch -c nvidia`
|
|
||||||
|
|
||||||
See other methods [here.](https://pytorch.org/get-started/previous-versions/#v200)
|
|
||||||
|
|
||||||
### 3. Install WhisperX
|
|
||||||
|
|
||||||
You have several installation options:
|
|
||||||
|
|
||||||
#### Option A: Stable Release (recommended)
|
|
||||||
Install the latest stable version from PyPI:
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install whisperx
|
pip install whisperx
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Option B: Development Version
|
Or if using [uvx](https://docs.astral.sh/uv/guides/tools/#running-tools):
|
||||||
Install the latest development version directly from GitHub (may be unstable):
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install git+https://github.com/m-bain/whisperx.git
|
uvx whisperx
|
||||||
```
|
```
|
||||||
|
|
||||||
If already installed, update to the most recent commit:
|
### 2. Advanced Installation Options
|
||||||
|
|
||||||
|
These installation methods are for developers or users with specific needs. If you're not sure, stick with the simple installation above.
|
||||||
|
|
||||||
|
#### Option A: Install from GitHub
|
||||||
|
|
||||||
|
To install directly from the GitHub repository:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install git+https://github.com/m-bain/whisperx.git --upgrade
|
uvx git+https://github.com/m-bain/whisperX.git
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Option C: Development Mode
|
#### Option B: Developer Installation
|
||||||
If you wish to modify the package, clone and install in editable mode:
|
|
||||||
|
If you want to modify the code or contribute to the project:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
git clone https://github.com/m-bain/whisperX.git
|
git clone https://github.com/m-bain/whisperX.git
|
||||||
cd whisperX
|
cd whisperX
|
||||||
pip install -e .
|
uv sync --all-extras --dev
|
||||||
```
|
```
|
||||||
|
|
||||||
> **Note**: The development version may contain experimental features and bugs. Use the stable PyPI release for production environments.
|
> **Note**: The development version may contain experimental features and bugs. Use the stable PyPI release for production environments.
|
||||||
|
|
||||||
You may also need to install ffmpeg, rust etc. Follow openAI instructions here https://github.com/openai/whisper#setup.
|
You may also need to install ffmpeg, rust etc. Follow openAI instructions here https://github.com/openai/whisper#setup.
|
||||||
|
|
||||||
|
### 3. Docker Images
|
||||||
|
|
||||||
|
Execute pre-built WhisperX container images:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker run --gpus all -it -v ".:/app" ghcr.io/jim60105/whisperx:base-en -- --output_format srt audio.mp3
|
||||||
|
docker run --gpus all -it -v ".:/app" ghcr.io/jim60105/whisperx:large-v3-ja -- --output_format srt audio.mp3
|
||||||
|
docker run --gpus all -it -v ".:/app" ghcr.io/jim60105/whisperx:no_model -- --model tiny --language en --output_format srt audio.mp3
|
||||||
|
```
|
||||||
|
|
||||||
|
Review the tag lists in this repository: [jim60105/docker-whisperX](https://github.com/jim60105/docker-whisperX)
|
||||||
|
|
||||||
|
### Common Issues & Troubleshooting 🔧
|
||||||
|
|
||||||
|
#### libcudnn Dependencies (GPU Users)
|
||||||
|
|
||||||
|
If you're using WhisperX with GPU support and encounter errors like:
|
||||||
|
|
||||||
|
- `Could not load library libcudnn_ops_infer.so.8`
|
||||||
|
- `Unable to load any of {libcudnn_cnn.so.9.1.0, libcudnn_cnn.so.9.1, libcudnn_cnn.so.9, libcudnn_cnn.so}`
|
||||||
|
- `libcudnn_ops_infer.so.8: cannot open shared object file: No such file or directory`
|
||||||
|
|
||||||
|
This means your system is missing the CUDA Deep Neural Network library (cuDNN). This library is needed for GPU acceleration but isn't always installed by default.
|
||||||
|
|
||||||
|
**Install cuDNN (example for apt based systems):**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
sudo apt update
|
||||||
|
sudo apt install libcudnn8 libcudnn8-dev -y
|
||||||
|
```
|
||||||
|
|
||||||
### Speaker Diarization
|
### Speaker Diarization
|
||||||
|
|
||||||
To **enable Speaker Diarization**, include your Hugging Face access token (read) that you can generate from [Here](https://huggingface.co/settings/tokens) after the `--hf_token` argument and accept the user agreement for the following models: [Segmentation](https://huggingface.co/pyannote/segmentation-3.0) and [Speaker-Diarization-3.1](https://huggingface.co/pyannote/speaker-diarization-3.1) (if you choose to use Speaker-Diarization 2.x, follow requirements [here](https://huggingface.co/pyannote/speaker-diarization) instead.)
|
To **enable Speaker Diarization**, include your Hugging Face access token (read) that you can generate from [Here](https://huggingface.co/settings/tokens) after the `--hf_token` argument and accept the user agreement for the following models: [Segmentation](https://huggingface.co/pyannote/segmentation-3.0) and [Speaker-Diarization-3.1](https://huggingface.co/pyannote/speaker-diarization-3.1) (if you choose to use Speaker-Diarization 2.x, follow requirements [here](https://huggingface.co/pyannote/speaker-diarization) instead.)
|
||||||
|
|
||||||
> **Note**<br>
|
> **Note**<br>
|
||||||
> As of Oct 11, 2023, there is a known issue regarding slow performance with pyannote/Speaker-Diarization-3.0 in whisperX. It is due to dependency conflicts between faster-whisper and pyannote-audio 3.0.0. Please see [this issue](https://github.com/m-bain/whisperX/issues/499) for more details and potential workarounds.
|
> As of Oct 11, 2023, there is a known issue regarding slow performance with pyannote/Speaker-Diarization-3.0 in whisperX. It is due to dependency conflicts between faster-whisper and pyannote-audio 3.0.0. Please see [this issue](https://github.com/m-bain/whisperX/issues/499) for more details and potential workarounds.
|
||||||
|
|
||||||
|
|
||||||
<h2 align="left" id="example">Usage 💬 (command line)</h2>
|
<h2 align="left" id="example">Usage 💬 (command line)</h2>
|
||||||
|
|
||||||
### English
|
### English
|
||||||
|
|
||||||
Run whisper on example segment (using default params, whisper small) add `--highlight_words True` to visualise word timings in the .srt file.
|
Run whisper on example segment (using default params, whisper small) add `--highlight_words True` to visualise word timings in the .srt file.
|
||||||
|
|
||||||
whisperx examples/sample01.wav
|
whisperx path/to/audio.wav
|
||||||
|
|
||||||
|
Result using _WhisperX_ with forced alignment to wav2vec2.0 large:
|
||||||
Result using *WhisperX* with forced alignment to wav2vec2.0 large:
|
|
||||||
|
|
||||||
https://user-images.githubusercontent.com/36994049/208253969-7e35fe2a-7541-434a-ae91-8e919540555d.mp4
|
https://user-images.githubusercontent.com/36994049/208253969-7e35fe2a-7541-434a-ae91-8e919540555d.mp4
|
||||||
|
|
||||||
@ -140,43 +151,40 @@ Compare this to original whisper out the box, where many transcriptions are out
|
|||||||
|
|
||||||
https://user-images.githubusercontent.com/36994049/207743923-b4f0d537-29ae-4be2-b404-bb941db73652.mov
|
https://user-images.githubusercontent.com/36994049/207743923-b4f0d537-29ae-4be2-b404-bb941db73652.mov
|
||||||
|
|
||||||
|
|
||||||
For increased timestamp accuracy, at the cost of higher gpu mem, use bigger models (bigger alignment model not found to be that helpful, see paper) e.g.
|
For increased timestamp accuracy, at the cost of higher gpu mem, use bigger models (bigger alignment model not found to be that helpful, see paper) e.g.
|
||||||
|
|
||||||
whisperx examples/sample01.wav --model large-v2 --align_model WAV2VEC2_ASR_LARGE_LV60K_960H --batch_size 4
|
whisperx path/to/audio.wav --model large-v2 --align_model WAV2VEC2_ASR_LARGE_LV60K_960H --batch_size 4
|
||||||
|
|
||||||
|
|
||||||
To label the transcript with speaker ID's (set number of speakers if known e.g. `--min_speakers 2` `--max_speakers 2`):
|
To label the transcript with speaker ID's (set number of speakers if known e.g. `--min_speakers 2` `--max_speakers 2`):
|
||||||
|
|
||||||
whisperx examples/sample01.wav --model large-v2 --diarize --highlight_words True
|
whisperx path/to/audio.wav --model large-v2 --diarize --highlight_words True
|
||||||
|
|
||||||
To run on CPU instead of GPU (and for running on Mac OS X):
|
To run on CPU instead of GPU (and for running on Mac OS X):
|
||||||
|
|
||||||
whisperx examples/sample01.wav --compute_type int8
|
whisperx path/to/audio.wav --compute_type int8
|
||||||
|
|
||||||
### Other languages
|
### Other languages
|
||||||
|
|
||||||
The phoneme ASR alignment model is *language-specific*, for tested languages these models are [automatically picked from torchaudio pipelines or huggingface](https://github.com/m-bain/whisperX/blob/e909f2f766b23b2000f2d95df41f9b844ac53e49/whisperx/transcribe.py#L22).
|
The phoneme ASR alignment model is _language-specific_, for tested languages these models are [automatically picked from torchaudio pipelines or huggingface](https://github.com/m-bain/whisperX/blob/f2da2f858e99e4211fe4f64b5f2938b007827e17/whisperx/alignment.py#L24-L58).
|
||||||
Just pass in the `--language` code, and use the whisper `--model large`.
|
Just pass in the `--language` code, and use the whisper `--model large`.
|
||||||
|
|
||||||
Currently default models provided for `{en, fr, de, es, it, ja, zh, nl, uk, pt}`. If the detected language is not in this list, you need to find a phoneme-based ASR model from [huggingface model hub](https://huggingface.co/models) and test it on your data.
|
Currently default models provided for `{en, fr, de, es, it}` via torchaudio pipelines and many other languages via Hugging Face. Please find the list of currently supported languages under `DEFAULT_ALIGN_MODELS_HF` on [alignment.py](https://github.com/m-bain/whisperX/blob/main/whisperx/alignment.py). If the detected language is not in this list, you need to find a phoneme-based ASR model from [huggingface model hub](https://huggingface.co/models) and test it on your data.
|
||||||
|
|
||||||
|
|
||||||
#### E.g. German
|
#### E.g. German
|
||||||
whisperx --model large-v2 --language de examples/sample_de_01.wav
|
|
||||||
|
whisperx --model large-v2 --language de path/to/audio.wav
|
||||||
|
|
||||||
https://user-images.githubusercontent.com/36994049/208298811-e36002ba-3698-4731-97d4-0aebd07e0eb3.mov
|
https://user-images.githubusercontent.com/36994049/208298811-e36002ba-3698-4731-97d4-0aebd07e0eb3.mov
|
||||||
|
|
||||||
|
|
||||||
See more examples in other languages [here](EXAMPLES.md).
|
See more examples in other languages [here](EXAMPLES.md).
|
||||||
|
|
||||||
## Python usage 🐍
|
## Python usage 🐍
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import whisperx
|
import whisperx
|
||||||
import gc
|
import gc
|
||||||
|
|
||||||
device = "cuda"
|
device = "cuda"
|
||||||
audio_file = "audio.mp3"
|
audio_file = "audio.mp3"
|
||||||
batch_size = 16 # reduce if low on GPU mem
|
batch_size = 16 # reduce if low on GPU mem
|
||||||
compute_type = "float16" # change to "int8" if low on GPU mem (may reduce accuracy)
|
compute_type = "float16" # change to "int8" if low on GPU mem (may reduce accuracy)
|
||||||
@ -193,7 +201,7 @@ result = model.transcribe(audio, batch_size=batch_size)
|
|||||||
print(result["segments"]) # before alignment
|
print(result["segments"]) # before alignment
|
||||||
|
|
||||||
# delete model if low on GPU resources
|
# delete model if low on GPU resources
|
||||||
# import gc; gc.collect(); torch.cuda.empty_cache(); del model
|
# import gc; import torch; gc.collect(); torch.cuda.empty_cache(); del model
|
||||||
|
|
||||||
# 2. Align whisper output
|
# 2. Align whisper output
|
||||||
model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
|
model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
|
||||||
@ -202,10 +210,10 @@ result = whisperx.align(result["segments"], model_a, metadata, audio, device, re
|
|||||||
print(result["segments"]) # after alignment
|
print(result["segments"]) # after alignment
|
||||||
|
|
||||||
# delete model if low on GPU resources
|
# delete model if low on GPU resources
|
||||||
# import gc; gc.collect(); torch.cuda.empty_cache(); del model_a
|
# import gc; import torch; gc.collect(); torch.cuda.empty_cache(); del model_a
|
||||||
|
|
||||||
# 3. Assign speaker labels
|
# 3. Assign speaker labels
|
||||||
diarize_model = whisperx.DiarizationPipeline(use_auth_token=YOUR_HF_TOKEN, device=device)
|
diarize_model = whisperx.diarize.DiarizationPipeline(use_auth_token=YOUR_HF_TOKEN, device=device)
|
||||||
|
|
||||||
# add min/max number of speakers if known
|
# add min/max number of speakers if known
|
||||||
diarize_segments = diarize_model(audio)
|
diarize_segments = diarize_model(audio)
|
||||||
@ -218,25 +226,27 @@ print(result["segments"]) # segments are now assigned speaker IDs
|
|||||||
|
|
||||||
## Demos 🚀
|
## Demos 🚀
|
||||||
|
|
||||||
[](https://replicate.com/victor-upmeet/whisperx)
|
[](https://replicate.com/victor-upmeet/whisperx)
|
||||||
[](https://replicate.com/daanelson/whisperx)
|
[](https://replicate.com/daanelson/whisperx)
|
||||||
[](https://replicate.com/carnifexer/whisperx)
|
[](https://replicate.com/carnifexer/whisperx)
|
||||||
|
|
||||||
If you don't have access to your own GPUs, use the links above to try out WhisperX.
|
If you don't have access to your own GPUs, use the links above to try out WhisperX.
|
||||||
|
|
||||||
<h2 align="left" id="whisper-mod">Technical Details 👷♂️</h2>
|
<h2 align="left" id="whisper-mod">Technical Details 👷♂️</h2>
|
||||||
|
|
||||||
For specific details on the batching and alignment, the effect of VAD, as well as the chosen alignment model, see the preprint [paper](https://www.robots.ox.ac.uk/~vgg/publications/2023/Bain23/bain23.pdf).
|
For specific details on the batching and alignment, the effect of VAD, as well as the chosen alignment model, see the preprint [paper](https://www.robots.ox.ac.uk/~vgg/publications/2023/Bain23/bain23.pdf).
|
||||||
|
|
||||||
To reduce GPU memory requirements, try any of the following (2. & 3. can affect quality):
|
To reduce GPU memory requirements, try any of the following (2. & 3. can affect quality):
|
||||||
|
|
||||||
1. reduce batch size, e.g. `--batch_size 4`
|
1. reduce batch size, e.g. `--batch_size 4`
|
||||||
2. use a smaller ASR model `--model base`
|
2. use a smaller ASR model `--model base`
|
||||||
3. Use lighter compute type `--compute_type int8`
|
3. Use lighter compute type `--compute_type int8`
|
||||||
|
|
||||||
Transcription differences from openai's whisper:
|
Transcription differences from openai's whisper:
|
||||||
|
|
||||||
1. Transcription without timestamps. To enable single pass batching, whisper inference is performed `--without_timestamps True`, this ensures 1 forward pass per sample in the batch. However, this can cause discrepancies the default whisper output.
|
1. Transcription without timestamps. To enable single pass batching, whisper inference is performed `--without_timestamps True`, this ensures 1 forward pass per sample in the batch. However, this can cause discrepancies the default whisper output.
|
||||||
2. VAD-based segment transcription, unlike the buffered transcription of openai's. In the WhisperX paper we show this reduces WER, and enables accurate batched inference
|
2. VAD-based segment transcription, unlike the buffered transcription of openai's. In the WhisperX paper we show this reduces WER, and enables accurate batched inference
|
||||||
3. `--condition_on_prev_text` is set to `False` by default (reduces hallucination)
|
3. `--condition_on_prev_text` is set to `False` by default (reduces hallucination)
|
||||||
|
|
||||||
<h2 align="left" id="limitations">Limitations ⚠️</h2>
|
<h2 align="left" id="limitations">Limitations ⚠️</h2>
|
||||||
|
|
||||||
@ -245,7 +255,6 @@ Transcription differences from openai's whisper:
|
|||||||
- Diarization is far from perfect
|
- Diarization is far from perfect
|
||||||
- Language specific wav2vec2 model is needed
|
- Language specific wav2vec2 model is needed
|
||||||
|
|
||||||
|
|
||||||
<h2 align="left" id="contribute">Contribute 🧑🏫</h2>
|
<h2 align="left" id="contribute">Contribute 🧑🏫</h2>
|
||||||
|
|
||||||
If you are multilingual, a major way you can contribute to this project is to find phoneme models on huggingface (or train your own) and test them on speech for the target language. If the results look good send a pull request and some examples showing its success.
|
If you are multilingual, a major way you can contribute to this project is to find phoneme models on huggingface (or train your own) and test them on speech for the target language. If the results look good send a pull request and some examples showing its success.
|
||||||
@ -254,43 +263,40 @@ Bug finding and pull requests are also highly appreciated to keep this project g
|
|||||||
|
|
||||||
<h2 align="left" id="coming-soon">TODO 🗓</h2>
|
<h2 align="left" id="coming-soon">TODO 🗓</h2>
|
||||||
|
|
||||||
* [x] Multilingual init
|
- [x] Multilingual init
|
||||||
|
|
||||||
* [x] Automatic align model selection based on language detection
|
- [x] Automatic align model selection based on language detection
|
||||||
|
|
||||||
* [x] Python usage
|
- [x] Python usage
|
||||||
|
|
||||||
* [x] Incorporating speaker diarization
|
- [x] Incorporating speaker diarization
|
||||||
|
|
||||||
* [x] Model flush, for low gpu mem resources
|
- [x] Model flush, for low gpu mem resources
|
||||||
|
|
||||||
* [x] Faster-whisper backend
|
- [x] Faster-whisper backend
|
||||||
|
|
||||||
* [x] Add max-line etc. see (openai's whisper utils.py)
|
- [x] Add max-line etc. see (openai's whisper utils.py)
|
||||||
|
|
||||||
* [x] Sentence-level segments (nltk toolbox)
|
- [x] Sentence-level segments (nltk toolbox)
|
||||||
|
|
||||||
* [x] Improve alignment logic
|
- [x] Improve alignment logic
|
||||||
|
|
||||||
* [ ] update examples with diarization and word highlighting
|
- [ ] update examples with diarization and word highlighting
|
||||||
|
|
||||||
* [ ] Subtitle .ass output <- bring this back (removed in v3)
|
- [ ] Subtitle .ass output <- bring this back (removed in v3)
|
||||||
|
|
||||||
* [ ] Add benchmarking code (TEDLIUM for spd/WER & word segmentation)
|
- [ ] Add benchmarking code (TEDLIUM for spd/WER & word segmentation)
|
||||||
|
|
||||||
* [ ] Allow silero-vad as alternative VAD option
|
- [x] Allow silero-vad as alternative VAD option
|
||||||
|
|
||||||
* [ ] Improve diarization (word level). *Harder than first thought...*
|
|
||||||
|
|
||||||
|
- [ ] Improve diarization (word level). _Harder than first thought..._
|
||||||
|
|
||||||
<h2 align="left" id="contact">Contact/Support 📇</h2>
|
<h2 align="left" id="contact">Contact/Support 📇</h2>
|
||||||
|
|
||||||
|
|
||||||
Contact maxhbain@gmail.com for queries.
|
Contact maxhbain@gmail.com for queries.
|
||||||
|
|
||||||
<a href="https://www.buymeacoffee.com/maxhbain" target="_blank"><img src="https://cdn.buymeacoffee.com/buttons/default-orange.png" alt="Buy Me A Coffee" height="41" width="174"></a>
|
<a href="https://www.buymeacoffee.com/maxhbain" target="_blank"><img src="https://cdn.buymeacoffee.com/buttons/default-orange.png" alt="Buy Me A Coffee" height="41" width="174"></a>
|
||||||
|
|
||||||
|
|
||||||
<h2 align="left" id="acks">Acknowledgements 🙏</h2>
|
<h2 align="left" id="acks">Acknowledgements 🙏</h2>
|
||||||
|
|
||||||
This work, and my PhD, is supported by the [VGG (Visual Geometry Group)](https://www.robots.ox.ac.uk/~vgg/) and the University of Oxford.
|
This work, and my PhD, is supported by the [VGG (Visual Geometry Group)](https://www.robots.ox.ac.uk/~vgg/) and the University of Oxford.
|
||||||
@ -299,8 +305,10 @@ Of course, this is builds on [openAI's whisper](https://github.com/openai/whispe
|
|||||||
Borrows important alignment code from [PyTorch tutorial on forced alignment](https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html)
|
Borrows important alignment code from [PyTorch tutorial on forced alignment](https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html)
|
||||||
And uses the wonderful pyannote VAD / Diarization https://github.com/pyannote/pyannote-audio
|
And uses the wonderful pyannote VAD / Diarization https://github.com/pyannote/pyannote-audio
|
||||||
|
|
||||||
|
Valuable VAD & Diarization Models from:
|
||||||
|
|
||||||
Valuable VAD & Diarization Models from [pyannote audio](https://github.com/pyannote/pyannote-audio)
|
- [pyannote audio][https://github.com/pyannote/pyannote-audio]
|
||||||
|
- [silero vad][https://github.com/snakers4/silero-vad]
|
||||||
|
|
||||||
Great backend from [faster-whisper](https://github.com/guillaumekln/faster-whisper) and [CTranslate2](https://github.com/OpenNMT/CTranslate2)
|
Great backend from [faster-whisper](https://github.com/guillaumekln/faster-whisper) and [CTranslate2](https://github.com/OpenNMT/CTranslate2)
|
||||||
|
|
||||||
|
36
pyproject.toml
Normal file
36
pyproject.toml
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
[project]
|
||||||
|
urls = { repository = "https://github.com/m-bain/whisperx" }
|
||||||
|
authors = [{ name = "Max Bain" }]
|
||||||
|
name = "whisperx"
|
||||||
|
version = "3.3.4"
|
||||||
|
description = "Time-Accurate Automatic Speech Recognition using Whisper."
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.9, <3.13"
|
||||||
|
license = { text = "BSD-2-Clause" }
|
||||||
|
|
||||||
|
dependencies = [
|
||||||
|
"ctranslate2<4.5.0",
|
||||||
|
"faster-whisper>=1.1.1",
|
||||||
|
"nltk>=3.9.1",
|
||||||
|
"numpy>=2.0.2",
|
||||||
|
"onnxruntime>=1.19",
|
||||||
|
"pandas>=2.2.3",
|
||||||
|
"pyannote-audio>=3.3.2",
|
||||||
|
"torch>=2.5.1",
|
||||||
|
"torchaudio>=2.5.1",
|
||||||
|
"transformers>=4.48.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
[project.scripts]
|
||||||
|
whisperx = "whisperx.__main__:cli"
|
||||||
|
|
||||||
|
[build-system]
|
||||||
|
requires = ["setuptools"]
|
||||||
|
|
||||||
|
[tool.setuptools]
|
||||||
|
include-package-data = true
|
||||||
|
|
||||||
|
[tool.setuptools.packages.find]
|
||||||
|
where = ["."]
|
||||||
|
include = ["whisperx*"]
|
@ -1,8 +0,0 @@
|
|||||||
torch>=2
|
|
||||||
torchaudio>=2
|
|
||||||
faster-whisper==1.1.0
|
|
||||||
ctranslate2<4.5.0
|
|
||||||
transformers
|
|
||||||
pandas
|
|
||||||
setuptools>=65
|
|
||||||
nltk
|
|
33
setup.py
33
setup.py
@ -1,33 +0,0 @@
|
|||||||
import os
|
|
||||||
|
|
||||||
import pkg_resources
|
|
||||||
from setuptools import find_packages, setup
|
|
||||||
|
|
||||||
with open("README.md", "r", encoding="utf-8") as f:
|
|
||||||
long_description = f.read()
|
|
||||||
|
|
||||||
setup(
|
|
||||||
name="whisperx",
|
|
||||||
py_modules=["whisperx"],
|
|
||||||
version="3.3.0",
|
|
||||||
description="Time-Accurate Automatic Speech Recognition using Whisper.",
|
|
||||||
long_description=long_description,
|
|
||||||
long_description_content_type="text/markdown",
|
|
||||||
python_requires=">=3.9, <3.13",
|
|
||||||
author="Max Bain",
|
|
||||||
url="https://github.com/m-bain/whisperx",
|
|
||||||
license="BSD-2-Clause",
|
|
||||||
packages=find_packages(exclude=["tests*"]),
|
|
||||||
install_requires=[
|
|
||||||
str(r)
|
|
||||||
for r in pkg_resources.parse_requirements(
|
|
||||||
open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
|
|
||||||
)
|
|
||||||
]
|
|
||||||
+ [f"pyannote.audio==3.3.2"],
|
|
||||||
entry_points={
|
|
||||||
"console_scripts": ["whisperx=whisperx.transcribe:cli"],
|
|
||||||
},
|
|
||||||
include_package_data=True,
|
|
||||||
extras_require={"dev": ["pytest"]},
|
|
||||||
)
|
|
@ -1,6 +1,5 @@
|
|||||||
import math
|
import math
|
||||||
from conjunctions import get_conjunctions, get_comma
|
from whisperx.conjunctions import get_conjunctions, get_comma
|
||||||
from typing import TextIO
|
|
||||||
|
|
||||||
def normal_round(n):
|
def normal_round(n):
|
||||||
if n - math.floor(n) < 0.5:
|
if n - math.floor(n) < 0.5:
|
||||||
|
@ -1,4 +1,31 @@
|
|||||||
from .transcribe import load_model
|
import importlib
|
||||||
from .alignment import load_align_model, align
|
|
||||||
from .audio import load_audio
|
|
||||||
from .diarize import assign_word_speakers, DiarizationPipeline
|
def _lazy_import(name):
|
||||||
|
module = importlib.import_module(f"whisperx.{name}")
|
||||||
|
return module
|
||||||
|
|
||||||
|
|
||||||
|
def load_align_model(*args, **kwargs):
|
||||||
|
alignment = _lazy_import("alignment")
|
||||||
|
return alignment.load_align_model(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def align(*args, **kwargs):
|
||||||
|
alignment = _lazy_import("alignment")
|
||||||
|
return alignment.align(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(*args, **kwargs):
|
||||||
|
asr = _lazy_import("asr")
|
||||||
|
return asr.load_model(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def load_audio(*args, **kwargs):
|
||||||
|
audio = _lazy_import("audio")
|
||||||
|
return audio.load_audio(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def assign_word_speakers(*args, **kwargs):
|
||||||
|
diarize = _lazy_import("diarize")
|
||||||
|
return diarize.assign_word_speakers(*args, **kwargs)
|
||||||
|
@ -1,4 +1,88 @@
|
|||||||
from .transcribe import cli
|
import argparse
|
||||||
|
import importlib.metadata
|
||||||
|
import platform
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from whisperx.utils import (LANGUAGES, TO_LANGUAGE_CODE, optional_float,
|
||||||
|
optional_int, str2bool)
|
||||||
|
|
||||||
|
|
||||||
cli()
|
def cli():
|
||||||
|
# fmt: off
|
||||||
|
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||||
|
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
|
||||||
|
parser.add_argument("--model", default="small", help="name of the Whisper model to use")
|
||||||
|
parser.add_argument("--model_cache_only", type=str2bool, default=False, help="If True, will not attempt to download models, instead using cached models from --model_dir")
|
||||||
|
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
|
||||||
|
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
|
||||||
|
parser.add_argument("--device_index", default=0, type=int, help="device index to use for FasterWhisper inference")
|
||||||
|
parser.add_argument("--batch_size", default=8, type=int, help="the preferred batch size for inference")
|
||||||
|
parser.add_argument("--compute_type", default="float16", type=str, choices=["float16", "float32", "int8"], help="compute type for computation")
|
||||||
|
|
||||||
|
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
|
||||||
|
parser.add_argument("--output_format", "-f", type=str, default="all", choices=["all", "srt", "vtt", "txt", "tsv", "json", "aud"], help="format of the output file; if not specified, all available formats will be produced")
|
||||||
|
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
|
||||||
|
|
||||||
|
parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
|
||||||
|
parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection")
|
||||||
|
|
||||||
|
# alignment params
|
||||||
|
parser.add_argument("--align_model", default=None, help="Name of phoneme-level ASR model to do alignment")
|
||||||
|
parser.add_argument("--interpolate_method", default="nearest", choices=["nearest", "linear", "ignore"], help="For word .srt, method to assign timestamps to non-aligned words, or merge them into neighbouring.")
|
||||||
|
parser.add_argument("--no_align", action='store_true', help="Do not perform phoneme alignment")
|
||||||
|
parser.add_argument("--return_char_alignments", action='store_true', help="Return character-level alignments in the output json file")
|
||||||
|
|
||||||
|
# vad params
|
||||||
|
parser.add_argument("--vad_method", type=str, default="pyannote", choices=["pyannote", "silero"], help="VAD method to be used")
|
||||||
|
parser.add_argument("--vad_onset", type=float, default=0.500, help="Onset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected")
|
||||||
|
parser.add_argument("--vad_offset", type=float, default=0.363, help="Offset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected.")
|
||||||
|
parser.add_argument("--chunk_size", type=int, default=30, help="Chunk size for merging VAD segments. Default is 30, reduce this if the chunk is too long.")
|
||||||
|
|
||||||
|
# diarization params
|
||||||
|
parser.add_argument("--diarize", action="store_true", help="Apply diarization to assign speaker labels to each segment/word")
|
||||||
|
parser.add_argument("--min_speakers", default=None, type=int, help="Minimum number of speakers to in audio file")
|
||||||
|
parser.add_argument("--max_speakers", default=None, type=int, help="Maximum number of speakers to in audio file")
|
||||||
|
parser.add_argument("--diarize_model", default="pyannote/speaker-diarization-3.1", type=str, help="Name of the speaker diarization model to use")
|
||||||
|
|
||||||
|
parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
|
||||||
|
parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
|
||||||
|
parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero")
|
||||||
|
parser.add_argument("--patience", type=float, default=1.0, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
|
||||||
|
parser.add_argument("--length_penalty", type=float, default=1.0, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default")
|
||||||
|
|
||||||
|
parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
|
||||||
|
parser.add_argument("--suppress_numerals", action="store_true", help="whether to suppress numeric symbols and currency symbols during sampling, since wav2vec2 cannot align them correctly")
|
||||||
|
|
||||||
|
parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
|
||||||
|
parser.add_argument("--condition_on_previous_text", type=str2bool, default=False, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
|
||||||
|
parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
|
||||||
|
|
||||||
|
parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
|
||||||
|
parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
|
||||||
|
parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
|
||||||
|
parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
|
||||||
|
|
||||||
|
parser.add_argument("--max_line_width", type=optional_int, default=None, help="(not possible with --no_align) the maximum number of characters in a line before breaking the line")
|
||||||
|
parser.add_argument("--max_line_count", type=optional_int, default=None, help="(not possible with --no_align) the maximum number of lines in a segment")
|
||||||
|
parser.add_argument("--highlight_words", type=str2bool, default=False, help="(not possible with --no_align) underline each word as it is spoken in srt and vtt")
|
||||||
|
parser.add_argument("--segment_resolution", type=str, default="sentence", choices=["sentence", "chunk"], help="(not possible with --no_align) the maximum number of characters in a line before breaking the line")
|
||||||
|
|
||||||
|
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
|
||||||
|
|
||||||
|
parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face Access Token to access PyAnnote gated models")
|
||||||
|
|
||||||
|
parser.add_argument("--print_progress", type=str2bool, default = False, help = "if True, progress will be printed in transcribe() and align() methods.")
|
||||||
|
parser.add_argument("--version", "-V", action="version", version=f"%(prog)s {importlib.metadata.version('whisperx')}",help="Show whisperx version information and exit")
|
||||||
|
parser.add_argument("--python-version", "-P", action="version", version=f"Python {platform.python_version()} ({platform.python_implementation()})",help="Show python version information and exit")
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
args = parser.parse_args().__dict__
|
||||||
|
|
||||||
|
from whisperx.transcribe import transcribe_task
|
||||||
|
|
||||||
|
transcribe_task(args, parser)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
cli()
|
||||||
|
@ -1,9 +1,11 @@
|
|||||||
""""
|
"""
|
||||||
Forced Alignment with Whisper
|
Forced Alignment with Whisper
|
||||||
C. Max Bain
|
C. Max Bain
|
||||||
"""
|
"""
|
||||||
|
import math
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Iterable, Union, List
|
from typing import Iterable, Optional, Union, List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
@ -11,10 +13,15 @@ import torch
|
|||||||
import torchaudio
|
import torchaudio
|
||||||
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
|
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
|
||||||
|
|
||||||
from .audio import SAMPLE_RATE, load_audio
|
from whisperx.audio import SAMPLE_RATE, load_audio
|
||||||
from .utils import interpolate_nans
|
from whisperx.utils import interpolate_nans
|
||||||
from .types import AlignedTranscriptionResult, SingleSegment, SingleAlignedSegment, SingleWordSegment
|
from whisperx.types import (
|
||||||
import nltk
|
AlignedTranscriptionResult,
|
||||||
|
SingleSegment,
|
||||||
|
SingleAlignedSegment,
|
||||||
|
SingleWordSegment,
|
||||||
|
SegmentData,
|
||||||
|
)
|
||||||
from nltk.tokenize.punkt import PunktSentenceTokenizer, PunktParameters
|
from nltk.tokenize.punkt import PunktSentenceTokenizer, PunktParameters
|
||||||
|
|
||||||
PUNKT_ABBREVIATIONS = ['dr', 'vs', 'mr', 'mrs', 'prof']
|
PUNKT_ABBREVIATIONS = ['dr', 'vs', 'mr', 'mrs', 'prof']
|
||||||
@ -62,10 +69,12 @@ DEFAULT_ALIGN_MODELS_HF = {
|
|||||||
"eu": "stefan-it/wav2vec2-large-xlsr-53-basque",
|
"eu": "stefan-it/wav2vec2-large-xlsr-53-basque",
|
||||||
"gl": "ifrz/wav2vec2-large-xlsr-galician",
|
"gl": "ifrz/wav2vec2-large-xlsr-galician",
|
||||||
"ka": "xsway/wav2vec2-large-xlsr-georgian",
|
"ka": "xsway/wav2vec2-large-xlsr-georgian",
|
||||||
|
"lv": "jimregan/wav2vec2-large-xlsr-latvian-cv",
|
||||||
|
"tl": "Khalsuu/filipino-wav2vec2-l-xls-r-300m-official",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def load_align_model(language_code, device, model_name=None, model_dir=None):
|
def load_align_model(language_code: str, device: str, model_name: Optional[str] = None, model_dir=None):
|
||||||
if model_name is None:
|
if model_name is None:
|
||||||
# use default model
|
# use default model
|
||||||
if language_code in DEFAULT_ALIGN_MODELS_TORCH:
|
if language_code in DEFAULT_ALIGN_MODELS_TORCH:
|
||||||
@ -131,6 +140,8 @@ def align(
|
|||||||
|
|
||||||
# 1. Preprocess to keep only characters in dictionary
|
# 1. Preprocess to keep only characters in dictionary
|
||||||
total_segments = len(transcript)
|
total_segments = len(transcript)
|
||||||
|
# Store temporary processing values
|
||||||
|
segment_data: dict[int, SegmentData] = {}
|
||||||
for sdx, segment in enumerate(transcript):
|
for sdx, segment in enumerate(transcript):
|
||||||
# strip spaces at beginning / end, but keep track of the amount.
|
# strip spaces at beginning / end, but keep track of the amount.
|
||||||
if print_progress:
|
if print_progress:
|
||||||
@ -163,10 +174,17 @@ def align(
|
|||||||
elif char_ in model_dictionary.keys():
|
elif char_ in model_dictionary.keys():
|
||||||
clean_char.append(char_)
|
clean_char.append(char_)
|
||||||
clean_cdx.append(cdx)
|
clean_cdx.append(cdx)
|
||||||
|
else:
|
||||||
|
# add placeholder
|
||||||
|
clean_char.append('*')
|
||||||
|
clean_cdx.append(cdx)
|
||||||
|
|
||||||
clean_wdx = []
|
clean_wdx = []
|
||||||
for wdx, wrd in enumerate(per_word):
|
for wdx, wrd in enumerate(per_word):
|
||||||
if any([c in model_dictionary.keys() for c in wrd]):
|
if any([c in model_dictionary.keys() for c in wrd.lower()]):
|
||||||
|
clean_wdx.append(wdx)
|
||||||
|
else:
|
||||||
|
# index for placeholder
|
||||||
clean_wdx.append(wdx)
|
clean_wdx.append(wdx)
|
||||||
|
|
||||||
|
|
||||||
@ -175,11 +193,13 @@ def align(
|
|||||||
sentence_splitter = PunktSentenceTokenizer(punkt_param)
|
sentence_splitter = PunktSentenceTokenizer(punkt_param)
|
||||||
sentence_spans = list(sentence_splitter.span_tokenize(text))
|
sentence_spans = list(sentence_splitter.span_tokenize(text))
|
||||||
|
|
||||||
segment["clean_char"] = clean_char
|
segment_data[sdx] = {
|
||||||
segment["clean_cdx"] = clean_cdx
|
"clean_char": clean_char,
|
||||||
segment["clean_wdx"] = clean_wdx
|
"clean_cdx": clean_cdx,
|
||||||
segment["sentence_spans"] = sentence_spans
|
"clean_wdx": clean_wdx,
|
||||||
|
"sentence_spans": sentence_spans
|
||||||
|
}
|
||||||
|
|
||||||
aligned_segments: List[SingleAlignedSegment] = []
|
aligned_segments: List[SingleAlignedSegment] = []
|
||||||
|
|
||||||
# 2. Get prediction matrix from alignment model & align
|
# 2. Get prediction matrix from alignment model & align
|
||||||
@ -194,13 +214,14 @@ def align(
|
|||||||
"end": t2,
|
"end": t2,
|
||||||
"text": text,
|
"text": text,
|
||||||
"words": [],
|
"words": [],
|
||||||
|
"chars": None,
|
||||||
}
|
}
|
||||||
|
|
||||||
if return_char_alignments:
|
if return_char_alignments:
|
||||||
aligned_seg["chars"] = []
|
aligned_seg["chars"] = []
|
||||||
|
|
||||||
# check we can align
|
# check we can align
|
||||||
if len(segment["clean_char"]) == 0:
|
if len(segment_data[sdx]["clean_char"]) == 0:
|
||||||
print(f'Failed to align segment ("{segment["text"]}"): no characters in this segment found in model dictionary, resorting to original...')
|
print(f'Failed to align segment ("{segment["text"]}"): no characters in this segment found in model dictionary, resorting to original...')
|
||||||
aligned_segments.append(aligned_seg)
|
aligned_segments.append(aligned_seg)
|
||||||
continue
|
continue
|
||||||
@ -210,8 +231,8 @@ def align(
|
|||||||
aligned_segments.append(aligned_seg)
|
aligned_segments.append(aligned_seg)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
text_clean = "".join(segment["clean_char"])
|
text_clean = "".join(segment_data[sdx]["clean_char"])
|
||||||
tokens = [model_dictionary[c] for c in text_clean]
|
tokens = [model_dictionary.get(c, -1) for c in text_clean]
|
||||||
|
|
||||||
f1 = int(t1 * SAMPLE_RATE)
|
f1 = int(t1 * SAMPLE_RATE)
|
||||||
f2 = int(t2 * SAMPLE_RATE)
|
f2 = int(t2 * SAMPLE_RATE)
|
||||||
@ -244,7 +265,8 @@ def align(
|
|||||||
blank_id = code
|
blank_id = code
|
||||||
|
|
||||||
trellis = get_trellis(emission, tokens, blank_id)
|
trellis = get_trellis(emission, tokens, blank_id)
|
||||||
path = backtrack(trellis, emission, tokens, blank_id)
|
# path = backtrack(trellis, emission, tokens, blank_id)
|
||||||
|
path = backtrack_beam(trellis, emission, tokens, blank_id, beam_width=2)
|
||||||
|
|
||||||
if path is None:
|
if path is None:
|
||||||
print(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original...')
|
print(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original...')
|
||||||
@ -253,7 +275,7 @@ def align(
|
|||||||
|
|
||||||
char_segments = merge_repeats(path, text_clean)
|
char_segments = merge_repeats(path, text_clean)
|
||||||
|
|
||||||
duration = t2 -t1
|
duration = t2 - t1
|
||||||
ratio = duration * waveform_segment.size(0) / (trellis.size(0) - 1)
|
ratio = duration * waveform_segment.size(0) / (trellis.size(0) - 1)
|
||||||
|
|
||||||
# assign timestamps to aligned characters
|
# assign timestamps to aligned characters
|
||||||
@ -261,8 +283,8 @@ def align(
|
|||||||
word_idx = 0
|
word_idx = 0
|
||||||
for cdx, char in enumerate(text):
|
for cdx, char in enumerate(text):
|
||||||
start, end, score = None, None, None
|
start, end, score = None, None, None
|
||||||
if cdx in segment["clean_cdx"]:
|
if cdx in segment_data[sdx]["clean_cdx"]:
|
||||||
char_seg = char_segments[segment["clean_cdx"].index(cdx)]
|
char_seg = char_segments[segment_data[sdx]["clean_cdx"].index(cdx)]
|
||||||
start = round(char_seg.start * ratio + t1, 3)
|
start = round(char_seg.start * ratio + t1, 3)
|
||||||
end = round(char_seg.end * ratio + t1, 3)
|
end = round(char_seg.end * ratio + t1, 3)
|
||||||
score = round(char_seg.score, 3)
|
score = round(char_seg.score, 3)
|
||||||
@ -288,10 +310,10 @@ def align(
|
|||||||
aligned_subsegments = []
|
aligned_subsegments = []
|
||||||
# assign sentence_idx to each character index
|
# assign sentence_idx to each character index
|
||||||
char_segments_arr["sentence-idx"] = None
|
char_segments_arr["sentence-idx"] = None
|
||||||
for sdx, (sstart, send) in enumerate(segment["sentence_spans"]):
|
for sdx2, (sstart, send) in enumerate(segment_data[sdx]["sentence_spans"]):
|
||||||
curr_chars = char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send)]
|
curr_chars = char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send)]
|
||||||
char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send), "sentence-idx"] = sdx
|
char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send), "sentence-idx"] = sdx2
|
||||||
|
|
||||||
sentence_text = text[sstart:send]
|
sentence_text = text[sstart:send]
|
||||||
sentence_start = curr_chars["start"].min()
|
sentence_start = curr_chars["start"].min()
|
||||||
end_chars = curr_chars[curr_chars["char"] != ' ']
|
end_chars = curr_chars[curr_chars["char"] != ' ']
|
||||||
@ -360,70 +382,203 @@ def align(
|
|||||||
"""
|
"""
|
||||||
source: https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html
|
source: https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def get_trellis(emission, tokens, blank_id=0):
|
def get_trellis(emission, tokens, blank_id=0):
|
||||||
num_frame = emission.size(0)
|
num_frame = emission.size(0)
|
||||||
num_tokens = len(tokens)
|
num_tokens = len(tokens)
|
||||||
|
|
||||||
# Trellis has extra diemsions for both time axis and tokens.
|
trellis = torch.zeros((num_frame, num_tokens))
|
||||||
# The extra dim for tokens represents <SoS> (start-of-sentence)
|
trellis[1:, 0] = torch.cumsum(emission[1:, blank_id], 0)
|
||||||
# The extra dim for time axis is for simplification of the code.
|
trellis[0, 1:] = -float("inf")
|
||||||
trellis = torch.empty((num_frame + 1, num_tokens + 1))
|
trellis[-num_tokens + 1:, 0] = float("inf")
|
||||||
trellis[0, 0] = 0
|
|
||||||
trellis[1:, 0] = torch.cumsum(emission[:, 0], 0)
|
|
||||||
trellis[0, -num_tokens:] = -float("inf")
|
|
||||||
trellis[-num_tokens:, 0] = float("inf")
|
|
||||||
|
|
||||||
for t in range(num_frame):
|
for t in range(num_frame - 1):
|
||||||
trellis[t + 1, 1:] = torch.maximum(
|
trellis[t + 1, 1:] = torch.maximum(
|
||||||
# Score for staying at the same token
|
# Score for staying at the same token
|
||||||
trellis[t, 1:] + emission[t, blank_id],
|
trellis[t, 1:] + emission[t, blank_id],
|
||||||
# Score for changing to the next token
|
# Score for changing to the next token
|
||||||
trellis[t, :-1] + emission[t, tokens],
|
# trellis[t, :-1] + emission[t, tokens[1:]],
|
||||||
|
trellis[t, :-1] + get_wildcard_emission(emission[t], tokens[1:], blank_id),
|
||||||
)
|
)
|
||||||
return trellis
|
return trellis
|
||||||
|
|
||||||
|
|
||||||
|
def get_wildcard_emission(frame_emission, tokens, blank_id):
|
||||||
|
"""Processing token emission scores containing wildcards (vectorized version)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
frame_emission: Emission probability vector for the current frame
|
||||||
|
tokens: List of token indices
|
||||||
|
blank_id: ID of the blank token
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tensor: Maximum probability score for each token position
|
||||||
|
"""
|
||||||
|
assert 0 <= blank_id < len(frame_emission)
|
||||||
|
|
||||||
|
# Convert tokens to a tensor if they are not already
|
||||||
|
tokens = torch.tensor(tokens) if not isinstance(tokens, torch.Tensor) else tokens
|
||||||
|
|
||||||
|
# Create a mask to identify wildcard positions
|
||||||
|
wildcard_mask = (tokens == -1)
|
||||||
|
|
||||||
|
# Get scores for non-wildcard positions
|
||||||
|
regular_scores = frame_emission[tokens.clamp(min=0)] # clamp to avoid -1 index
|
||||||
|
|
||||||
|
# Create a mask and compute the maximum value without modifying frame_emission
|
||||||
|
max_valid_score = frame_emission.clone() # Create a copy
|
||||||
|
max_valid_score[blank_id] = float('-inf') # Modify the copy to exclude the blank token
|
||||||
|
max_valid_score = max_valid_score.max()
|
||||||
|
|
||||||
|
# Use where operation to combine results
|
||||||
|
result = torch.where(wildcard_mask, max_valid_score, regular_scores)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Point:
|
class Point:
|
||||||
token_index: int
|
token_index: int
|
||||||
time_index: int
|
time_index: int
|
||||||
score: float
|
score: float
|
||||||
|
|
||||||
|
|
||||||
def backtrack(trellis, emission, tokens, blank_id=0):
|
def backtrack(trellis, emission, tokens, blank_id=0):
|
||||||
# Note:
|
t, j = trellis.size(0) - 1, trellis.size(1) - 1
|
||||||
# j and t are indices for trellis, which has extra dimensions
|
|
||||||
# for time and tokens at the beginning.
|
path = [Point(j, t, emission[t, blank_id].exp().item())]
|
||||||
# When referring to time frame index `T` in trellis,
|
while j > 0:
|
||||||
# the corresponding index in emission is `T-1`.
|
# Should not happen but just in case
|
||||||
# Similarly, when referring to token index `J` in trellis,
|
assert t > 0
|
||||||
# the corresponding index in transcript is `J-1`.
|
|
||||||
j = trellis.size(1) - 1
|
|
||||||
t_start = torch.argmax(trellis[:, j]).item()
|
|
||||||
|
|
||||||
path = []
|
|
||||||
for t in range(t_start, 0, -1):
|
|
||||||
# 1. Figure out if the current position was stay or change
|
# 1. Figure out if the current position was stay or change
|
||||||
# Note (again):
|
# Frame-wise score of stay vs change
|
||||||
# `emission[J-1]` is the emission at time frame `J` of trellis dimension.
|
p_stay = emission[t - 1, blank_id]
|
||||||
# Score for token staying the same from time frame J-1 to T.
|
# p_change = emission[t - 1, tokens[j]]
|
||||||
stayed = trellis[t - 1, j] + emission[t - 1, blank_id]
|
p_change = get_wildcard_emission(emission[t - 1], [tokens[j]], blank_id)[0]
|
||||||
# Score for token changing from C-1 at T-1 to J at T.
|
|
||||||
changed = trellis[t - 1, j - 1] + emission[t - 1, tokens[j - 1]]
|
|
||||||
|
|
||||||
# 2. Store the path with frame-wise probability.
|
# Context-aware score for stay vs change
|
||||||
prob = emission[t - 1, tokens[j - 1] if changed > stayed else 0].exp().item()
|
stayed = trellis[t - 1, j] + p_stay
|
||||||
# Return token index and time index in non-trellis coordinate.
|
changed = trellis[t - 1, j - 1] + p_change
|
||||||
path.append(Point(j - 1, t - 1, prob))
|
|
||||||
|
|
||||||
# 3. Update the token
|
# Update position
|
||||||
|
t -= 1
|
||||||
if changed > stayed:
|
if changed > stayed:
|
||||||
j -= 1
|
j -= 1
|
||||||
if j == 0:
|
|
||||||
break
|
# Store the path with frame-wise probability.
|
||||||
else:
|
prob = (p_change if changed > stayed else p_stay).exp().item()
|
||||||
# failed
|
path.append(Point(j, t, prob))
|
||||||
return None
|
|
||||||
|
# Now j == 0, which means, it reached the SoS.
|
||||||
|
# Fill up the rest for the sake of visualization
|
||||||
|
while t > 0:
|
||||||
|
prob = emission[t - 1, blank_id].exp().item()
|
||||||
|
path.append(Point(j, t - 1, prob))
|
||||||
|
t -= 1
|
||||||
|
|
||||||
return path[::-1]
|
return path[::-1]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Path:
|
||||||
|
points: List[Point]
|
||||||
|
score: float
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BeamState:
|
||||||
|
"""State in beam search."""
|
||||||
|
token_index: int # Current token position
|
||||||
|
time_index: int # Current time step
|
||||||
|
score: float # Cumulative score
|
||||||
|
path: List[Point] # Path history
|
||||||
|
|
||||||
|
|
||||||
|
def backtrack_beam(trellis, emission, tokens, blank_id=0, beam_width=5):
|
||||||
|
"""Standard CTC beam search backtracking implementation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
trellis (torch.Tensor): The trellis (or lattice) of shape (T, N), where T is the number of time steps
|
||||||
|
and N is the number of tokens (including the blank token).
|
||||||
|
emission (torch.Tensor): The emission probabilities of shape (T, N).
|
||||||
|
tokens (List[int]): List of token indices (excluding the blank token).
|
||||||
|
blank_id (int, optional): The ID of the blank token. Defaults to 0.
|
||||||
|
beam_width (int, optional): The number of top paths to keep during beam search. Defaults to 5.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Point]: the best path
|
||||||
|
"""
|
||||||
|
T, J = trellis.size(0) - 1, trellis.size(1) - 1
|
||||||
|
|
||||||
|
init_state = BeamState(
|
||||||
|
token_index=J,
|
||||||
|
time_index=T,
|
||||||
|
score=trellis[T, J],
|
||||||
|
path=[Point(J, T, emission[T, blank_id].exp().item())]
|
||||||
|
)
|
||||||
|
|
||||||
|
beams = [init_state]
|
||||||
|
|
||||||
|
while beams and beams[0].token_index > 0:
|
||||||
|
next_beams = []
|
||||||
|
|
||||||
|
for beam in beams:
|
||||||
|
t, j = beam.time_index, beam.token_index
|
||||||
|
|
||||||
|
if t <= 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
p_stay = emission[t - 1, blank_id]
|
||||||
|
p_change = get_wildcard_emission(emission[t - 1], [tokens[j]], blank_id)[0]
|
||||||
|
|
||||||
|
stay_score = trellis[t - 1, j]
|
||||||
|
change_score = trellis[t - 1, j - 1] if j > 0 else float('-inf')
|
||||||
|
|
||||||
|
# Stay
|
||||||
|
if not math.isinf(stay_score):
|
||||||
|
new_path = beam.path.copy()
|
||||||
|
new_path.append(Point(j, t - 1, p_stay.exp().item()))
|
||||||
|
next_beams.append(BeamState(
|
||||||
|
token_index=j,
|
||||||
|
time_index=t - 1,
|
||||||
|
score=stay_score,
|
||||||
|
path=new_path
|
||||||
|
))
|
||||||
|
|
||||||
|
# Change
|
||||||
|
if j > 0 and not math.isinf(change_score):
|
||||||
|
new_path = beam.path.copy()
|
||||||
|
new_path.append(Point(j - 1, t - 1, p_change.exp().item()))
|
||||||
|
next_beams.append(BeamState(
|
||||||
|
token_index=j - 1,
|
||||||
|
time_index=t - 1,
|
||||||
|
score=change_score,
|
||||||
|
path=new_path
|
||||||
|
))
|
||||||
|
|
||||||
|
# sort by score
|
||||||
|
beams = sorted(next_beams, key=lambda x: x.score, reverse=True)[:beam_width]
|
||||||
|
|
||||||
|
if not beams:
|
||||||
|
break
|
||||||
|
|
||||||
|
if not beams:
|
||||||
|
return None
|
||||||
|
|
||||||
|
best_beam = beams[0]
|
||||||
|
t = best_beam.time_index
|
||||||
|
j = best_beam.token_index
|
||||||
|
while t > 0:
|
||||||
|
prob = emission[t - 1, blank_id].exp().item()
|
||||||
|
best_beam.path.append(Point(j, t - 1, prob))
|
||||||
|
t -= 1
|
||||||
|
|
||||||
|
return best_beam.path[::-1]
|
||||||
|
|
||||||
|
|
||||||
# Merge the labels
|
# Merge the labels
|
||||||
@dataclass
|
@dataclass
|
||||||
class Segment:
|
class Segment:
|
||||||
|
172
whisperx/asr.py
172
whisperx/asr.py
@ -1,17 +1,20 @@
|
|||||||
import os
|
import os
|
||||||
import warnings
|
from typing import List, Optional, Union
|
||||||
from typing import List, Union, Optional, NamedTuple
|
from dataclasses import replace
|
||||||
|
|
||||||
import ctranslate2
|
import ctranslate2
|
||||||
import faster_whisper
|
import faster_whisper
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from faster_whisper.tokenizer import Tokenizer
|
||||||
|
from faster_whisper.transcribe import TranscriptionOptions, get_ctranslate2_storage
|
||||||
from transformers import Pipeline
|
from transformers import Pipeline
|
||||||
from transformers.pipelines.pt_utils import PipelineIterator
|
from transformers.pipelines.pt_utils import PipelineIterator
|
||||||
|
|
||||||
from .audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram
|
from whisperx.audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram
|
||||||
from .vad import load_vad_model, merge_chunks
|
from whisperx.types import SingleSegment, TranscriptionResult
|
||||||
from .types import TranscriptionResult, SingleSegment
|
from whisperx.vads import Vad, Silero, Pyannote
|
||||||
|
|
||||||
|
|
||||||
def find_numeral_symbol_tokens(tokenizer):
|
def find_numeral_symbol_tokens(tokenizer):
|
||||||
numeral_symbol_tokens = []
|
numeral_symbol_tokens = []
|
||||||
@ -28,7 +31,13 @@ class WhisperModel(faster_whisper.WhisperModel):
|
|||||||
Currently only works in non-timestamp mode and fixed prompt for all samples in batch.
|
Currently only works in non-timestamp mode and fixed prompt for all samples in batch.
|
||||||
'''
|
'''
|
||||||
|
|
||||||
def generate_segment_batched(self, features: np.ndarray, tokenizer: faster_whisper.tokenizer.Tokenizer, options: faster_whisper.transcribe.TranscriptionOptions, encoder_output = None):
|
def generate_segment_batched(
|
||||||
|
self,
|
||||||
|
features: np.ndarray,
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
options: TranscriptionOptions,
|
||||||
|
encoder_output=None,
|
||||||
|
):
|
||||||
batch_size = features.shape[0]
|
batch_size = features.shape[0]
|
||||||
all_tokens = []
|
all_tokens = []
|
||||||
prompt_reset_since = 0
|
prompt_reset_since = 0
|
||||||
@ -42,6 +51,7 @@ class WhisperModel(faster_whisper.WhisperModel):
|
|||||||
previous_tokens,
|
previous_tokens,
|
||||||
without_timestamps=options.without_timestamps,
|
without_timestamps=options.without_timestamps,
|
||||||
prefix=options.prefix,
|
prefix=options.prefix,
|
||||||
|
hotwords=options.hotwords
|
||||||
)
|
)
|
||||||
|
|
||||||
encoder_output = self.encode(features)
|
encoder_output = self.encode(features)
|
||||||
@ -81,7 +91,7 @@ class WhisperModel(faster_whisper.WhisperModel):
|
|||||||
# unsqueeze if batch size = 1
|
# unsqueeze if batch size = 1
|
||||||
if len(features.shape) == 2:
|
if len(features.shape) == 2:
|
||||||
features = np.expand_dims(features, 0)
|
features = np.expand_dims(features, 0)
|
||||||
features = faster_whisper.transcribe.get_ctranslate2_storage(features)
|
features = get_ctranslate2_storage(features)
|
||||||
|
|
||||||
return self.model.encode(features, to_cpu=to_cpu)
|
return self.model.encode(features, to_cpu=to_cpu)
|
||||||
|
|
||||||
@ -94,17 +104,17 @@ class FasterWhisperPipeline(Pipeline):
|
|||||||
# - add support for custom inference kwargs
|
# - add support for custom inference kwargs
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model,
|
model: WhisperModel,
|
||||||
vad,
|
vad,
|
||||||
vad_params: dict,
|
vad_params: dict,
|
||||||
options : NamedTuple,
|
options: TranscriptionOptions,
|
||||||
tokenizer=None,
|
tokenizer: Optional[Tokenizer] = None,
|
||||||
device: Union[int, str, "torch.device"] = -1,
|
device: Union[int, str, "torch.device"] = -1,
|
||||||
framework = "pt",
|
framework="pt",
|
||||||
language : Optional[str] = None,
|
language: Optional[str] = None,
|
||||||
suppress_numerals: bool = False,
|
suppress_numerals: bool = False,
|
||||||
**kwargs
|
**kwargs,
|
||||||
):
|
):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
@ -156,7 +166,13 @@ class FasterWhisperPipeline(Pipeline):
|
|||||||
return model_outputs
|
return model_outputs
|
||||||
|
|
||||||
def get_iterator(
|
def get_iterator(
|
||||||
self, inputs, num_workers: int, batch_size: int, preprocess_params, forward_params, postprocess_params
|
self,
|
||||||
|
inputs,
|
||||||
|
num_workers: int,
|
||||||
|
batch_size: int,
|
||||||
|
preprocess_params: dict,
|
||||||
|
forward_params: dict,
|
||||||
|
postprocess_params: dict,
|
||||||
):
|
):
|
||||||
dataset = PipelineIterator(inputs, self.preprocess, preprocess_params)
|
dataset = PipelineIterator(inputs, self.preprocess, preprocess_params)
|
||||||
if "TOKENIZERS_PARALLELISM" not in os.environ:
|
if "TOKENIZERS_PARALLELISM" not in os.environ:
|
||||||
@ -171,7 +187,16 @@ class FasterWhisperPipeline(Pipeline):
|
|||||||
return final_iterator
|
return final_iterator
|
||||||
|
|
||||||
def transcribe(
|
def transcribe(
|
||||||
self, audio: Union[str, np.ndarray], batch_size=None, num_workers=0, language=None, task=None, chunk_size=30, print_progress = False, combined_progress=False, verbose=False
|
self,
|
||||||
|
audio: Union[str, np.ndarray],
|
||||||
|
batch_size: Optional[int] = None,
|
||||||
|
num_workers=0,
|
||||||
|
language: Optional[str] = None,
|
||||||
|
task: Optional[str] = None,
|
||||||
|
chunk_size=30,
|
||||||
|
print_progress=False,
|
||||||
|
combined_progress=False,
|
||||||
|
verbose=False,
|
||||||
) -> TranscriptionResult:
|
) -> TranscriptionResult:
|
||||||
if isinstance(audio, str):
|
if isinstance(audio, str):
|
||||||
audio = load_audio(audio)
|
audio = load_audio(audio)
|
||||||
@ -183,7 +208,16 @@ class FasterWhisperPipeline(Pipeline):
|
|||||||
# print(f2-f1)
|
# print(f2-f1)
|
||||||
yield {'inputs': audio[f1:f2]}
|
yield {'inputs': audio[f1:f2]}
|
||||||
|
|
||||||
vad_segments = self.vad_model({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": SAMPLE_RATE})
|
# Pre-process audio and merge chunks as defined by the respective VAD child class
|
||||||
|
# In case vad_model is manually assigned (see 'load_model') follow the functionality of pyannote toolkit
|
||||||
|
if issubclass(type(self.vad_model), Vad):
|
||||||
|
waveform = self.vad_model.preprocess_audio(audio)
|
||||||
|
merge_chunks = self.vad_model.merge_chunks
|
||||||
|
else:
|
||||||
|
waveform = Pyannote.preprocess_audio(audio)
|
||||||
|
merge_chunks = Pyannote.merge_chunks
|
||||||
|
|
||||||
|
vad_segments = self.vad_model({"waveform": waveform, "sample_rate": SAMPLE_RATE})
|
||||||
vad_segments = merge_chunks(
|
vad_segments = merge_chunks(
|
||||||
vad_segments,
|
vad_segments,
|
||||||
chunk_size,
|
chunk_size,
|
||||||
@ -193,24 +227,30 @@ class FasterWhisperPipeline(Pipeline):
|
|||||||
if self.tokenizer is None:
|
if self.tokenizer is None:
|
||||||
language = language or self.detect_language(audio)
|
language = language or self.detect_language(audio)
|
||||||
task = task or "transcribe"
|
task = task or "transcribe"
|
||||||
self.tokenizer = faster_whisper.tokenizer.Tokenizer(self.model.hf_tokenizer,
|
self.tokenizer = Tokenizer(
|
||||||
self.model.model.is_multilingual, task=task,
|
self.model.hf_tokenizer,
|
||||||
language=language)
|
self.model.model.is_multilingual,
|
||||||
|
task=task,
|
||||||
|
language=language,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
language = language or self.tokenizer.language_code
|
language = language or self.tokenizer.language_code
|
||||||
task = task or self.tokenizer.task
|
task = task or self.tokenizer.task
|
||||||
if task != self.tokenizer.task or language != self.tokenizer.language_code:
|
if task != self.tokenizer.task or language != self.tokenizer.language_code:
|
||||||
self.tokenizer = faster_whisper.tokenizer.Tokenizer(self.model.hf_tokenizer,
|
self.tokenizer = Tokenizer(
|
||||||
self.model.model.is_multilingual, task=task,
|
self.model.hf_tokenizer,
|
||||||
language=language)
|
self.model.model.is_multilingual,
|
||||||
|
task=task,
|
||||||
|
language=language,
|
||||||
|
)
|
||||||
|
|
||||||
if self.suppress_numerals:
|
if self.suppress_numerals:
|
||||||
previous_suppress_tokens = self.options.suppress_tokens
|
previous_suppress_tokens = self.options.suppress_tokens
|
||||||
numeral_symbol_tokens = find_numeral_symbol_tokens(self.tokenizer)
|
numeral_symbol_tokens = find_numeral_symbol_tokens(self.tokenizer)
|
||||||
print(f"Suppressing numeral and symbol tokens")
|
print(f"Suppressing numeral and symbol tokens")
|
||||||
new_suppressed_tokens = numeral_symbol_tokens + self.options.suppress_tokens
|
new_suppressed_tokens = numeral_symbol_tokens + self.options.suppress_tokens
|
||||||
new_suppressed_tokens = list(set(new_suppressed_tokens))
|
new_suppressed_tokens = list(set(new_suppressed_tokens))
|
||||||
self.options = self.options._replace(suppress_tokens=new_suppressed_tokens)
|
self.options = replace(self.options, suppress_tokens=new_suppressed_tokens)
|
||||||
|
|
||||||
segments: List[SingleSegment] = []
|
segments: List[SingleSegment] = []
|
||||||
batch_size = batch_size or self._batch_size
|
batch_size = batch_size or self._batch_size
|
||||||
@ -239,12 +279,11 @@ class FasterWhisperPipeline(Pipeline):
|
|||||||
|
|
||||||
# revert suppressed tokens if suppress_numerals is enabled
|
# revert suppressed tokens if suppress_numerals is enabled
|
||||||
if self.suppress_numerals:
|
if self.suppress_numerals:
|
||||||
self.options = self.options._replace(suppress_tokens=previous_suppress_tokens)
|
self.options = replace(self.options, suppress_tokens=previous_suppress_tokens)
|
||||||
|
|
||||||
return {"segments": segments, "language": language}
|
return {"segments": segments, "language": language}
|
||||||
|
|
||||||
|
def detect_language(self, audio: np.ndarray) -> str:
|
||||||
def detect_language(self, audio: np.ndarray):
|
|
||||||
if audio.shape[0] < N_SAMPLES:
|
if audio.shape[0] < N_SAMPLES:
|
||||||
print("Warning: audio is shorter than 30s, language detection may be inaccurate.")
|
print("Warning: audio is shorter than 30s, language detection may be inaccurate.")
|
||||||
model_n_mels = self.model.feat_kwargs.get("feature_size")
|
model_n_mels = self.model.feat_kwargs.get("feature_size")
|
||||||
@ -258,33 +297,38 @@ class FasterWhisperPipeline(Pipeline):
|
|||||||
print(f"Detected language: {language} ({language_probability:.2f}) in first 30s of audio...")
|
print(f"Detected language: {language} ({language_probability:.2f}) in first 30s of audio...")
|
||||||
return language
|
return language
|
||||||
|
|
||||||
def load_model(whisper_arch,
|
|
||||||
device,
|
def load_model(
|
||||||
device_index=0,
|
whisper_arch: str,
|
||||||
compute_type="float16",
|
device: str,
|
||||||
asr_options=None,
|
device_index=0,
|
||||||
language : Optional[str] = None,
|
compute_type="float16",
|
||||||
vad_model=None,
|
asr_options: Optional[dict] = None,
|
||||||
vad_options=None,
|
language: Optional[str] = None,
|
||||||
model : Optional[WhisperModel] = None,
|
vad_model: Optional[Vad]= None,
|
||||||
task="transcribe",
|
vad_method: Optional[str] = "pyannote",
|
||||||
download_root=None,
|
vad_options: Optional[dict] = None,
|
||||||
local_files_only=False,
|
model: Optional[WhisperModel] = None,
|
||||||
threads=4):
|
task="transcribe",
|
||||||
'''Load a Whisper model for inference.
|
download_root: Optional[str] = None,
|
||||||
|
local_files_only=False,
|
||||||
|
threads=4,
|
||||||
|
) -> FasterWhisperPipeline:
|
||||||
|
"""Load a Whisper model for inference.
|
||||||
Args:
|
Args:
|
||||||
whisper_arch: str - The name of the Whisper model to load.
|
whisper_arch - The name of the Whisper model to load.
|
||||||
device: str - The device to load the model on.
|
device - The device to load the model on.
|
||||||
compute_type: str - The compute type to use for the model.
|
compute_type - The compute type to use for the model.
|
||||||
options: dict - A dictionary of options to use for the model.
|
vad_method - The vad method to use. vad_model has higher priority if is not None.
|
||||||
language: str - The language of the model. (use English for now)
|
options - A dictionary of options to use for the model.
|
||||||
model: Optional[WhisperModel] - The WhisperModel instance to use.
|
language - The language of the model. (use English for now)
|
||||||
download_root: Optional[str] - The root directory to download the model to.
|
model - The WhisperModel instance to use.
|
||||||
local_files_only: bool - If `True`, avoid downloading the file and return the path to the local cached file if it exists.
|
download_root - The root directory to download the model to.
|
||||||
threads: int - The number of cpu threads to use per worker, e.g. will be multiplied by num workers.
|
local_files_only - If `True`, avoid downloading the file and return the path to the local cached file if it exists.
|
||||||
|
threads - The number of cpu threads to use per worker, e.g. will be multiplied by num workers.
|
||||||
Returns:
|
Returns:
|
||||||
A Whisper pipeline.
|
A Whisper pipeline.
|
||||||
'''
|
"""
|
||||||
|
|
||||||
if whisper_arch.endswith(".en"):
|
if whisper_arch.endswith(".en"):
|
||||||
language = "en"
|
language = "en"
|
||||||
@ -297,7 +341,7 @@ def load_model(whisper_arch,
|
|||||||
local_files_only=local_files_only,
|
local_files_only=local_files_only,
|
||||||
cpu_threads=threads)
|
cpu_threads=threads)
|
||||||
if language is not None:
|
if language is not None:
|
||||||
tokenizer = faster_whisper.tokenizer.Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task=task, language=language)
|
tokenizer = Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task=task, language=language)
|
||||||
else:
|
else:
|
||||||
print("No language specified, language will be first be detected for each audio file (increases inference time).")
|
print("No language specified, language will be first be detected for each audio file (increases inference time).")
|
||||||
tokenizer = None
|
tokenizer = None
|
||||||
@ -338,9 +382,10 @@ def load_model(whisper_arch,
|
|||||||
suppress_numerals = default_asr_options["suppress_numerals"]
|
suppress_numerals = default_asr_options["suppress_numerals"]
|
||||||
del default_asr_options["suppress_numerals"]
|
del default_asr_options["suppress_numerals"]
|
||||||
|
|
||||||
default_asr_options = faster_whisper.transcribe.TranscriptionOptions(**default_asr_options)
|
default_asr_options = TranscriptionOptions(**default_asr_options)
|
||||||
|
|
||||||
default_vad_options = {
|
default_vad_options = {
|
||||||
|
"chunk_size": 30, # needed by silero since binarization happens before merge_chunks
|
||||||
"vad_onset": 0.500,
|
"vad_onset": 0.500,
|
||||||
"vad_offset": 0.363
|
"vad_offset": 0.363
|
||||||
}
|
}
|
||||||
@ -348,10 +393,17 @@ def load_model(whisper_arch,
|
|||||||
if vad_options is not None:
|
if vad_options is not None:
|
||||||
default_vad_options.update(vad_options)
|
default_vad_options.update(vad_options)
|
||||||
|
|
||||||
|
# Note: manually assigned vad_model has higher priority than vad_method!
|
||||||
if vad_model is not None:
|
if vad_model is not None:
|
||||||
|
print("Use manually assigned vad_model. vad_method is ignored.")
|
||||||
vad_model = vad_model
|
vad_model = vad_model
|
||||||
else:
|
else:
|
||||||
vad_model = load_vad_model(torch.device(device), use_auth_token=None, **default_vad_options)
|
if vad_method == "silero":
|
||||||
|
vad_model = Silero(**default_vad_options)
|
||||||
|
elif vad_method == "pyannote":
|
||||||
|
vad_model = Pyannote(torch.device(device), use_auth_token=None, **default_vad_options)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid vad_method: {vad_method}")
|
||||||
|
|
||||||
return FasterWhisperPipeline(
|
return FasterWhisperPipeline(
|
||||||
model=model,
|
model=model,
|
||||||
@ -361,4 +413,4 @@ def load_model(whisper_arch,
|
|||||||
language=language,
|
language=language,
|
||||||
suppress_numerals=suppress_numerals,
|
suppress_numerals=suppress_numerals,
|
||||||
vad_params=default_vad_options,
|
vad_params=default_vad_options,
|
||||||
)
|
)
|
||||||
|
@ -7,7 +7,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from .utils import exact_div
|
from whisperx.utils import exact_div
|
||||||
|
|
||||||
# hard-coded audio hyperparameters
|
# hard-coded audio hyperparameters
|
||||||
SAMPLE_RATE = 16000
|
SAMPLE_RATE = 16000
|
||||||
@ -22,7 +22,7 @@ FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame
|
|||||||
TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token
|
TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token
|
||||||
|
|
||||||
|
|
||||||
def load_audio(file: str, sr: int = SAMPLE_RATE):
|
def load_audio(file: str, sr: int = SAMPLE_RATE) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Open an audio file and read as mono waveform, resampling as necessary
|
Open an audio file and read as mono waveform, resampling as necessary
|
||||||
|
|
||||||
|
@ -1,5 +1,8 @@
|
|||||||
# conjunctions.py
|
# conjunctions.py
|
||||||
|
|
||||||
|
from typing import Set
|
||||||
|
|
||||||
|
|
||||||
conjunctions_by_language = {
|
conjunctions_by_language = {
|
||||||
'en': {'and', 'whether', 'or', 'as', 'but', 'so', 'for', 'nor', 'which', 'yet', 'although', 'since', 'unless', 'when', 'while', 'because', 'if', 'how', 'that', 'than', 'who', 'where', 'what', 'near', 'before', 'after', 'across', 'through', 'until', 'once', 'whereas', 'even', 'both', 'either', 'neither', 'though'},
|
'en': {'and', 'whether', 'or', 'as', 'but', 'so', 'for', 'nor', 'which', 'yet', 'although', 'since', 'unless', 'when', 'while', 'because', 'if', 'how', 'that', 'than', 'who', 'where', 'what', 'near', 'before', 'after', 'across', 'through', 'until', 'once', 'whereas', 'even', 'both', 'either', 'neither', 'though'},
|
||||||
'fr': {'et', 'ou', 'mais', 'parce', 'bien', 'pendant', 'quand', 'où', 'comme', 'si', 'que', 'avant', 'après', 'aussitôt', 'jusqu’à', 'à', 'malgré', 'donc', 'tant', 'puisque', 'ni', 'soit', 'bien', 'encore', 'dès', 'lorsque'},
|
'fr': {'et', 'ou', 'mais', 'parce', 'bien', 'pendant', 'quand', 'où', 'comme', 'si', 'que', 'avant', 'après', 'aussitôt', 'jusqu’à', 'à', 'malgré', 'donc', 'tant', 'puisque', 'ni', 'soit', 'bien', 'encore', 'dès', 'lorsque'},
|
||||||
@ -36,8 +39,9 @@ commas_by_language = {
|
|||||||
'ur': '،'
|
'ur': '،'
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_conjunctions(lang_code):
|
def get_conjunctions(lang_code: str) -> Set[str]:
|
||||||
return conjunctions_by_language.get(lang_code, set())
|
return conjunctions_by_language.get(lang_code, set())
|
||||||
|
|
||||||
def get_comma(lang_code):
|
|
||||||
return commas_by_language.get(lang_code, ',')
|
def get_comma(lang_code: str) -> str:
|
||||||
|
return commas_by_language.get(lang_code, ",")
|
||||||
|
@ -4,21 +4,29 @@ from pyannote.audio import Pipeline
|
|||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .audio import load_audio, SAMPLE_RATE
|
from whisperx.audio import load_audio, SAMPLE_RATE
|
||||||
|
from whisperx.types import TranscriptionResult, AlignedTranscriptionResult
|
||||||
|
|
||||||
|
|
||||||
class DiarizationPipeline:
|
class DiarizationPipeline:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_name="pyannote/speaker-diarization-3.1",
|
model_name=None,
|
||||||
use_auth_token=None,
|
use_auth_token=None,
|
||||||
device: Optional[Union[str, torch.device]] = "cpu",
|
device: Optional[Union[str, torch.device]] = "cpu",
|
||||||
):
|
):
|
||||||
if isinstance(device, str):
|
if isinstance(device, str):
|
||||||
device = torch.device(device)
|
device = torch.device(device)
|
||||||
self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token).to(device)
|
model_config = model_name or "pyannote/speaker-diarization-3.1"
|
||||||
|
self.model = Pipeline.from_pretrained(model_config, use_auth_token=use_auth_token).to(device)
|
||||||
|
|
||||||
def __call__(self, audio: Union[str, np.ndarray], num_speakers=None, min_speakers=None, max_speakers=None):
|
def __call__(
|
||||||
|
self,
|
||||||
|
audio: Union[str, np.ndarray],
|
||||||
|
num_speakers: Optional[int] = None,
|
||||||
|
min_speakers: Optional[int] = None,
|
||||||
|
max_speakers: Optional[int] = None,
|
||||||
|
):
|
||||||
if isinstance(audio, str):
|
if isinstance(audio, str):
|
||||||
audio = load_audio(audio)
|
audio = load_audio(audio)
|
||||||
audio_data = {
|
audio_data = {
|
||||||
@ -32,7 +40,11 @@ class DiarizationPipeline:
|
|||||||
return diarize_df
|
return diarize_df
|
||||||
|
|
||||||
|
|
||||||
def assign_word_speakers(diarize_df, transcript_result, fill_nearest=False):
|
def assign_word_speakers(
|
||||||
|
diarize_df: pd.DataFrame,
|
||||||
|
transcript_result: Union[AlignedTranscriptionResult, TranscriptionResult],
|
||||||
|
fill_nearest=False,
|
||||||
|
) -> dict:
|
||||||
transcript_segments = transcript_result["segments"]
|
transcript_segments = transcript_result["segments"]
|
||||||
for seg in transcript_segments:
|
for seg in transcript_segments:
|
||||||
# assign speaker to segment (if any)
|
# assign speaker to segment (if any)
|
||||||
@ -68,7 +80,7 @@ def assign_word_speakers(diarize_df, transcript_result, fill_nearest=False):
|
|||||||
|
|
||||||
|
|
||||||
class Segment:
|
class Segment:
|
||||||
def __init__(self, start, end, speaker=None):
|
def __init__(self, start:int, end:int, speaker:Optional[str]=None):
|
||||||
self.start = start
|
self.start = start
|
||||||
self.end = end
|
self.end = end
|
||||||
self.speaker = speaker
|
self.speaker = speaker
|
||||||
|
@ -6,82 +6,27 @@ import warnings
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .alignment import align, load_align_model
|
from whisperx.alignment import align, load_align_model
|
||||||
from .asr import load_model
|
from whisperx.asr import load_model
|
||||||
from .audio import load_audio
|
from whisperx.audio import load_audio
|
||||||
from .diarize import DiarizationPipeline, assign_word_speakers
|
from whisperx.diarize import DiarizationPipeline, assign_word_speakers
|
||||||
from .utils import (LANGUAGES, TO_LANGUAGE_CODE, get_writer, optional_float,
|
from whisperx.types import AlignedTranscriptionResult, TranscriptionResult
|
||||||
optional_int, str2bool)
|
from whisperx.utils import LANGUAGES, TO_LANGUAGE_CODE, get_writer
|
||||||
|
|
||||||
|
|
||||||
def cli():
|
def transcribe_task(args: dict, parser: argparse.ArgumentParser):
|
||||||
|
"""Transcription task to be called from CLI.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args: Dictionary of command-line arguments.
|
||||||
|
parser: argparse.ArgumentParser object.
|
||||||
|
"""
|
||||||
# fmt: off
|
# fmt: off
|
||||||
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
|
||||||
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
|
|
||||||
parser.add_argument("--model", default="small", help="name of the Whisper model to use")
|
|
||||||
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
|
|
||||||
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
|
|
||||||
parser.add_argument("--device_index", default=0, type=int, help="device index to use for FasterWhisper inference")
|
|
||||||
parser.add_argument("--batch_size", default=8, type=int, help="the preferred batch size for inference")
|
|
||||||
parser.add_argument("--compute_type", default="float16", type=str, choices=["float16", "float32", "int8"], help="compute type for computation")
|
|
||||||
|
|
||||||
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
|
|
||||||
parser.add_argument("--output_format", "-f", type=str, default="all", choices=["all", "srt", "vtt", "txt", "tsv", "json", "aud"], help="format of the output file; if not specified, all available formats will be produced")
|
|
||||||
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
|
|
||||||
|
|
||||||
parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
|
|
||||||
parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection")
|
|
||||||
|
|
||||||
# alignment params
|
|
||||||
parser.add_argument("--align_model", default=None, help="Name of phoneme-level ASR model to do alignment")
|
|
||||||
parser.add_argument("--interpolate_method", default="nearest", choices=["nearest", "linear", "ignore"], help="For word .srt, method to assign timestamps to non-aligned words, or merge them into neighbouring.")
|
|
||||||
parser.add_argument("--no_align", action='store_true', help="Do not perform phoneme alignment")
|
|
||||||
parser.add_argument("--return_char_alignments", action='store_true', help="Return character-level alignments in the output json file")
|
|
||||||
|
|
||||||
# vad params
|
|
||||||
parser.add_argument("--vad_onset", type=float, default=0.500, help="Onset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected")
|
|
||||||
parser.add_argument("--vad_offset", type=float, default=0.363, help="Offset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected.")
|
|
||||||
parser.add_argument("--chunk_size", type=int, default=30, help="Chunk size for merging VAD segments. Default is 30, reduce this if the chunk is too long.")
|
|
||||||
|
|
||||||
# diarization params
|
|
||||||
parser.add_argument("--diarize", action="store_true", help="Apply diarization to assign speaker labels to each segment/word")
|
|
||||||
parser.add_argument("--min_speakers", default=None, type=int, help="Minimum number of speakers to in audio file")
|
|
||||||
parser.add_argument("--max_speakers", default=None, type=int, help="Maximum number of speakers to in audio file")
|
|
||||||
|
|
||||||
parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
|
|
||||||
parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
|
|
||||||
parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero")
|
|
||||||
parser.add_argument("--patience", type=float, default=1.0, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
|
|
||||||
parser.add_argument("--length_penalty", type=float, default=1.0, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default")
|
|
||||||
|
|
||||||
parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
|
|
||||||
parser.add_argument("--suppress_numerals", action="store_true", help="whether to suppress numeric symbols and currency symbols during sampling, since wav2vec2 cannot align them correctly")
|
|
||||||
|
|
||||||
parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
|
|
||||||
parser.add_argument("--condition_on_previous_text", type=str2bool, default=False, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
|
|
||||||
parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
|
|
||||||
|
|
||||||
parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
|
|
||||||
parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
|
|
||||||
parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
|
|
||||||
parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
|
|
||||||
|
|
||||||
parser.add_argument("--max_line_width", type=optional_int, default=None, help="(not possible with --no_align) the maximum number of characters in a line before breaking the line")
|
|
||||||
parser.add_argument("--max_line_count", type=optional_int, default=None, help="(not possible with --no_align) the maximum number of lines in a segment")
|
|
||||||
parser.add_argument("--highlight_words", type=str2bool, default=False, help="(not possible with --no_align) underline each word as it is spoken in srt and vtt")
|
|
||||||
parser.add_argument("--segment_resolution", type=str, default="sentence", choices=["sentence", "chunk"], help="(not possible with --no_align) the maximum number of characters in a line before breaking the line")
|
|
||||||
|
|
||||||
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
|
|
||||||
|
|
||||||
parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face Access Token to access PyAnnote gated models")
|
|
||||||
|
|
||||||
parser.add_argument("--print_progress", type=str2bool, default = False, help = "if True, progress will be printed in transcribe() and align() methods.")
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
args = parser.parse_args().__dict__
|
|
||||||
model_name: str = args.pop("model")
|
model_name: str = args.pop("model")
|
||||||
batch_size: int = args.pop("batch_size")
|
batch_size: int = args.pop("batch_size")
|
||||||
model_dir: str = args.pop("model_dir")
|
model_dir: str = args.pop("model_dir")
|
||||||
|
model_cache_only: bool = args.pop("model_cache_only")
|
||||||
output_dir: str = args.pop("output_dir")
|
output_dir: str = args.pop("output_dir")
|
||||||
output_format: str = args.pop("output_format")
|
output_format: str = args.pop("output_format")
|
||||||
device: str = args.pop("device")
|
device: str = args.pop("device")
|
||||||
@ -95,7 +40,7 @@ def cli():
|
|||||||
align_model: str = args.pop("align_model")
|
align_model: str = args.pop("align_model")
|
||||||
interpolate_method: str = args.pop("interpolate_method")
|
interpolate_method: str = args.pop("interpolate_method")
|
||||||
no_align: bool = args.pop("no_align")
|
no_align: bool = args.pop("no_align")
|
||||||
task : str = args.pop("task")
|
task: str = args.pop("task")
|
||||||
if task == "translate":
|
if task == "translate":
|
||||||
# translation cannot be aligned
|
# translation cannot be aligned
|
||||||
no_align = True
|
no_align = True
|
||||||
@ -103,6 +48,7 @@ def cli():
|
|||||||
return_char_alignments: bool = args.pop("return_char_alignments")
|
return_char_alignments: bool = args.pop("return_char_alignments")
|
||||||
|
|
||||||
hf_token: str = args.pop("hf_token")
|
hf_token: str = args.pop("hf_token")
|
||||||
|
vad_method: str = args.pop("vad_method")
|
||||||
vad_onset: float = args.pop("vad_onset")
|
vad_onset: float = args.pop("vad_onset")
|
||||||
vad_offset: float = args.pop("vad_offset")
|
vad_offset: float = args.pop("vad_offset")
|
||||||
|
|
||||||
@ -111,6 +57,7 @@ def cli():
|
|||||||
diarize: bool = args.pop("diarize")
|
diarize: bool = args.pop("diarize")
|
||||||
min_speakers: int = args.pop("min_speakers")
|
min_speakers: int = args.pop("min_speakers")
|
||||||
max_speakers: int = args.pop("max_speakers")
|
max_speakers: int = args.pop("max_speakers")
|
||||||
|
diarize_model_name: str = args.pop("diarize_model")
|
||||||
print_progress: bool = args.pop("print_progress")
|
print_progress: bool = args.pop("print_progress")
|
||||||
|
|
||||||
if args["language"] is not None:
|
if args["language"] is not None:
|
||||||
@ -127,7 +74,9 @@ def cli():
|
|||||||
f"{model_name} is an English-only model but received '{args['language']}'; using English instead."
|
f"{model_name} is an English-only model but received '{args['language']}'; using English instead."
|
||||||
)
|
)
|
||||||
args["language"] = "en"
|
args["language"] = "en"
|
||||||
align_language = args["language"] if args["language"] is not None else "en" # default to loading english if not specified
|
align_language = (
|
||||||
|
args["language"] if args["language"] is not None else "en"
|
||||||
|
) # default to loading english if not specified
|
||||||
|
|
||||||
temperature = args.pop("temperature")
|
temperature = args.pop("temperature")
|
||||||
if (increment := args.pop("temperature_increment_on_fallback")) is not None:
|
if (increment := args.pop("temperature_increment_on_fallback")) is not None:
|
||||||
@ -163,18 +112,41 @@ def cli():
|
|||||||
if args["max_line_count"] and not args["max_line_width"]:
|
if args["max_line_count"] and not args["max_line_width"]:
|
||||||
warnings.warn("--max_line_count has no effect without --max_line_width")
|
warnings.warn("--max_line_count has no effect without --max_line_width")
|
||||||
writer_args = {arg: args.pop(arg) for arg in word_options}
|
writer_args = {arg: args.pop(arg) for arg in word_options}
|
||||||
|
|
||||||
# Part 1: VAD & ASR Loop
|
# Part 1: VAD & ASR Loop
|
||||||
results = []
|
results = []
|
||||||
tmp_results = []
|
tmp_results = []
|
||||||
# model = load_model(model_name, device=device, download_root=model_dir)
|
# model = load_model(model_name, device=device, download_root=model_dir)
|
||||||
model = load_model(model_name, device=device, device_index=device_index, download_root=model_dir, compute_type=compute_type, language=args['language'], asr_options=asr_options, vad_options={"vad_onset": vad_onset, "vad_offset": vad_offset}, task=task, threads=faster_whisper_threads)
|
model = load_model(
|
||||||
|
model_name,
|
||||||
|
device=device,
|
||||||
|
device_index=device_index,
|
||||||
|
download_root=model_dir,
|
||||||
|
compute_type=compute_type,
|
||||||
|
language=args["language"],
|
||||||
|
asr_options=asr_options,
|
||||||
|
vad_method=vad_method,
|
||||||
|
vad_options={
|
||||||
|
"chunk_size": chunk_size,
|
||||||
|
"vad_onset": vad_onset,
|
||||||
|
"vad_offset": vad_offset,
|
||||||
|
},
|
||||||
|
task=task,
|
||||||
|
local_files_only=model_cache_only,
|
||||||
|
threads=faster_whisper_threads,
|
||||||
|
)
|
||||||
|
|
||||||
for audio_path in args.pop("audio"):
|
for audio_path in args.pop("audio"):
|
||||||
audio = load_audio(audio_path)
|
audio = load_audio(audio_path)
|
||||||
# >> VAD & ASR
|
# >> VAD & ASR
|
||||||
print(">>Performing transcription...")
|
print(">>Performing transcription...")
|
||||||
result = model.transcribe(audio, batch_size=batch_size, chunk_size=chunk_size, print_progress=print_progress, verbose=verbose)
|
result: TranscriptionResult = model.transcribe(
|
||||||
|
audio,
|
||||||
|
batch_size=batch_size,
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
print_progress=print_progress,
|
||||||
|
verbose=verbose,
|
||||||
|
)
|
||||||
results.append((result, audio_path))
|
results.append((result, audio_path))
|
||||||
|
|
||||||
# Unload Whisper and VAD
|
# Unload Whisper and VAD
|
||||||
@ -186,7 +158,9 @@ def cli():
|
|||||||
if not no_align:
|
if not no_align:
|
||||||
tmp_results = results
|
tmp_results = results
|
||||||
results = []
|
results = []
|
||||||
align_model, align_metadata = load_align_model(align_language, device, model_name=align_model)
|
align_model, align_metadata = load_align_model(
|
||||||
|
align_language, device, model_name=align_model
|
||||||
|
)
|
||||||
for result, audio_path in tmp_results:
|
for result, audio_path in tmp_results:
|
||||||
# >> Align
|
# >> Align
|
||||||
if len(tmp_results) > 1:
|
if len(tmp_results) > 1:
|
||||||
@ -198,10 +172,23 @@ def cli():
|
|||||||
if align_model is not None and len(result["segments"]) > 0:
|
if align_model is not None and len(result["segments"]) > 0:
|
||||||
if result.get("language", "en") != align_metadata["language"]:
|
if result.get("language", "en") != align_metadata["language"]:
|
||||||
# load new language
|
# load new language
|
||||||
print(f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language...")
|
print(
|
||||||
align_model, align_metadata = load_align_model(result["language"], device)
|
f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language..."
|
||||||
|
)
|
||||||
|
align_model, align_metadata = load_align_model(
|
||||||
|
result["language"], device
|
||||||
|
)
|
||||||
print(">>Performing alignment...")
|
print(">>Performing alignment...")
|
||||||
result = align(result["segments"], align_model, align_metadata, input_audio, device, interpolate_method=interpolate_method, return_char_alignments=return_char_alignments, print_progress=print_progress)
|
result: AlignedTranscriptionResult = align(
|
||||||
|
result["segments"],
|
||||||
|
align_model,
|
||||||
|
align_metadata,
|
||||||
|
input_audio,
|
||||||
|
device,
|
||||||
|
interpolate_method=interpolate_method,
|
||||||
|
return_char_alignments=return_char_alignments,
|
||||||
|
print_progress=print_progress,
|
||||||
|
)
|
||||||
|
|
||||||
results.append((result, audio_path))
|
results.append((result, audio_path))
|
||||||
|
|
||||||
@ -213,19 +200,21 @@ def cli():
|
|||||||
# >> Diarize
|
# >> Diarize
|
||||||
if diarize:
|
if diarize:
|
||||||
if hf_token is None:
|
if hf_token is None:
|
||||||
print("Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model...")
|
print(
|
||||||
|
"Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model..."
|
||||||
|
)
|
||||||
tmp_results = results
|
tmp_results = results
|
||||||
print(">>Performing diarization...")
|
print(">>Performing diarization...")
|
||||||
|
print(">>Using model:", diarize_model_name)
|
||||||
results = []
|
results = []
|
||||||
diarize_model = DiarizationPipeline(use_auth_token=hf_token, device=device)
|
diarize_model = DiarizationPipeline(model_name=diarize_model_name, use_auth_token=hf_token, device=device)
|
||||||
for result, input_audio_path in tmp_results:
|
for result, input_audio_path in tmp_results:
|
||||||
diarize_segments = diarize_model(input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers)
|
diarize_segments = diarize_model(
|
||||||
|
input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers
|
||||||
|
)
|
||||||
result = assign_word_speakers(diarize_segments, result)
|
result = assign_word_speakers(diarize_segments, result)
|
||||||
results.append((result, input_audio_path))
|
results.append((result, input_audio_path))
|
||||||
# >> Write
|
# >> Write
|
||||||
for result, audio_path in results:
|
for result, audio_path in results:
|
||||||
result["language"] = align_language
|
result["language"] = align_language
|
||||||
writer(result, audio_path, writer_args)
|
writer(result, audio_path, writer_args)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
cli()
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import TypedDict, Optional, List
|
from typing import TypedDict, Optional, List, Tuple
|
||||||
|
|
||||||
|
|
||||||
class SingleWordSegment(TypedDict):
|
class SingleWordSegment(TypedDict):
|
||||||
@ -30,6 +30,17 @@ class SingleSegment(TypedDict):
|
|||||||
text: str
|
text: str
|
||||||
|
|
||||||
|
|
||||||
|
class SegmentData(TypedDict):
|
||||||
|
"""
|
||||||
|
Temporary processing data used during alignment.
|
||||||
|
Contains cleaned and preprocessed data for each segment.
|
||||||
|
"""
|
||||||
|
clean_char: List[str] # Cleaned characters that exist in model dictionary
|
||||||
|
clean_cdx: List[int] # Original indices of cleaned characters
|
||||||
|
clean_wdx: List[int] # Indices of words containing valid characters
|
||||||
|
sentence_spans: List[Tuple[int, int]] # Start and end indices of sentences
|
||||||
|
|
||||||
|
|
||||||
class SingleAlignedSegment(TypedDict):
|
class SingleAlignedSegment(TypedDict):
|
||||||
"""
|
"""
|
||||||
A single segment (up to multiple sentences) of a speech with word alignment.
|
A single segment (up to multiple sentences) of a speech with word alignment.
|
||||||
|
@ -214,7 +214,12 @@ class WriteTXT(ResultWriter):
|
|||||||
|
|
||||||
def write_result(self, result: dict, file: TextIO, options: dict):
|
def write_result(self, result: dict, file: TextIO, options: dict):
|
||||||
for segment in result["segments"]:
|
for segment in result["segments"]:
|
||||||
print(segment["text"].strip(), file=file, flush=True)
|
speaker = segment.get("speaker")
|
||||||
|
text = segment["text"].strip()
|
||||||
|
if speaker is not None:
|
||||||
|
print(f"[{speaker}]: {text}", file=file, flush=True)
|
||||||
|
else:
|
||||||
|
print(text, file=file, flush=True)
|
||||||
|
|
||||||
|
|
||||||
class SubtitlesWriter(ResultWriter):
|
class SubtitlesWriter(ResultWriter):
|
||||||
@ -236,7 +241,7 @@ class SubtitlesWriter(ResultWriter):
|
|||||||
line_count = 1
|
line_count = 1
|
||||||
# the next subtitle to yield (a list of word timings with whitespace)
|
# the next subtitle to yield (a list of word timings with whitespace)
|
||||||
subtitle: list[dict] = []
|
subtitle: list[dict] = []
|
||||||
times = []
|
times: list[tuple] = []
|
||||||
last = result["segments"][0]["start"]
|
last = result["segments"][0]["start"]
|
||||||
for segment in result["segments"]:
|
for segment in result["segments"]:
|
||||||
for i, original_timing in enumerate(segment["words"]):
|
for i, original_timing in enumerate(segment["words"]):
|
||||||
|
3
whisperx/vads/__init__.py
Normal file
3
whisperx/vads/__init__.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
from whisperx.vads.pyannote import Pyannote as Pyannote
|
||||||
|
from whisperx.vads.silero import Silero as Silero
|
||||||
|
from whisperx.vads.vad import Vad as Vad
|
@ -1,51 +1,44 @@
|
|||||||
import hashlib
|
|
||||||
import os
|
import os
|
||||||
import urllib
|
from typing import Callable, Text, Union
|
||||||
from typing import Callable, Optional, Text, Union
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
|
||||||
import torch
|
import torch
|
||||||
from pyannote.audio import Model
|
from pyannote.audio import Model
|
||||||
from pyannote.audio.core.io import AudioFile
|
from pyannote.audio.core.io import AudioFile
|
||||||
from pyannote.audio.pipelines import VoiceActivityDetection
|
from pyannote.audio.pipelines import VoiceActivityDetection
|
||||||
from pyannote.audio.pipelines.utils import PipelineModel
|
from pyannote.audio.pipelines.utils import PipelineModel
|
||||||
from pyannote.core import Annotation, Segment, SlidingWindowFeature
|
from pyannote.core import Annotation, SlidingWindowFeature
|
||||||
from tqdm import tqdm
|
from pyannote.core import Segment
|
||||||
|
|
||||||
from .diarize import Segment as SegmentX
|
from whisperx.diarize import Segment as SegmentX
|
||||||
|
from whisperx.vads.vad import Vad
|
||||||
|
|
||||||
# deprecated
|
|
||||||
VAD_SEGMENTATION_URL = "https://whisperx.s3.eu-west-2.amazonaws.com/model_weights/segmentation/0b5b3216d60a2d32fc086b47ea8c67589aaeb26b7e07fcbe620d6d0b83e209ea/pytorch_model.bin"
|
|
||||||
|
|
||||||
def load_vad_model(device, vad_onset=0.500, vad_offset=0.363, use_auth_token=None, model_fp=None):
|
def load_vad_model(device, vad_onset=0.500, vad_offset=0.363, use_auth_token=None, model_fp=None):
|
||||||
model_dir = torch.hub._get_torch_home()
|
model_dir = torch.hub._get_torch_home()
|
||||||
|
|
||||||
vad_dir = os.path.dirname(os.path.abspath(__file__))
|
main_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
os.makedirs(model_dir, exist_ok = True)
|
os.makedirs(model_dir, exist_ok = True)
|
||||||
if model_fp is None:
|
if model_fp is None:
|
||||||
# Dynamically resolve the path to the model file
|
# Dynamically resolve the path to the model file
|
||||||
model_fp = os.path.join(vad_dir, "assets", "pytorch_model.bin")
|
model_fp = os.path.join(main_dir, "assets", "pytorch_model.bin")
|
||||||
model_fp = os.path.abspath(model_fp) # Ensure the path is absolute
|
model_fp = os.path.abspath(model_fp) # Ensure the path is absolute
|
||||||
else:
|
else:
|
||||||
model_fp = os.path.abspath(model_fp) # Ensure any provided path is absolute
|
model_fp = os.path.abspath(model_fp) # Ensure any provided path is absolute
|
||||||
|
|
||||||
# Check if the resolved model file exists
|
# Check if the resolved model file exists
|
||||||
if not os.path.exists(model_fp):
|
if not os.path.exists(model_fp):
|
||||||
raise FileNotFoundError(f"Model file not found at {model_fp}")
|
raise FileNotFoundError(f"Model file not found at {model_fp}")
|
||||||
|
|
||||||
if os.path.exists(model_fp) and not os.path.isfile(model_fp):
|
if os.path.exists(model_fp) and not os.path.isfile(model_fp):
|
||||||
raise RuntimeError(f"{model_fp} exists and is not a regular file")
|
raise RuntimeError(f"{model_fp} exists and is not a regular file")
|
||||||
|
|
||||||
model_bytes = open(model_fp, "rb").read()
|
model_bytes = open(model_fp, "rb").read()
|
||||||
if hashlib.sha256(model_bytes).hexdigest() != VAD_SEGMENTATION_URL.split('/')[-2]:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Model has been downloaded but the SHA256 checksum does not match. Please retry loading the model."
|
|
||||||
)
|
|
||||||
|
|
||||||
vad_model = Model.from_pretrained(model_fp, use_auth_token=use_auth_token)
|
vad_model = Model.from_pretrained(model_fp, use_auth_token=use_auth_token)
|
||||||
hyperparameters = {"onset": vad_onset,
|
hyperparameters = {"onset": vad_onset,
|
||||||
"offset": vad_offset,
|
"offset": vad_offset,
|
||||||
"min_duration_on": 0.1,
|
"min_duration_on": 0.1,
|
||||||
"min_duration_off": 0.1}
|
"min_duration_off": 0.1}
|
||||||
@ -81,21 +74,21 @@ class Binarize:
|
|||||||
Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of
|
Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of
|
||||||
RNN-based Voice Activity Detection", InterSpeech 2015.
|
RNN-based Voice Activity Detection", InterSpeech 2015.
|
||||||
|
|
||||||
Modified by Max Bain to include WhisperX's min-cut operation
|
Modified by Max Bain to include WhisperX's min-cut operation
|
||||||
https://arxiv.org/abs/2303.00747
|
https://arxiv.org/abs/2303.00747
|
||||||
|
|
||||||
Pyannote-audio
|
Pyannote-audio
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
onset: float = 0.5,
|
onset: float = 0.5,
|
||||||
offset: Optional[float] = None,
|
offset: Optional[float] = None,
|
||||||
min_duration_on: float = 0.0,
|
min_duration_on: float = 0.0,
|
||||||
min_duration_off: float = 0.0,
|
min_duration_off: float = 0.0,
|
||||||
pad_onset: float = 0.0,
|
pad_onset: float = 0.0,
|
||||||
pad_offset: float = 0.0,
|
pad_offset: float = 0.0,
|
||||||
max_duration: float = float('inf')
|
max_duration: float = float('inf')
|
||||||
):
|
):
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -141,7 +134,7 @@ class Binarize:
|
|||||||
t = start
|
t = start
|
||||||
for t, y in zip(timestamps[1:], k_scores[1:]):
|
for t, y in zip(timestamps[1:], k_scores[1:]):
|
||||||
# currently active
|
# currently active
|
||||||
if is_active:
|
if is_active:
|
||||||
curr_duration = t - start
|
curr_duration = t - start
|
||||||
if curr_duration > self.max_duration:
|
if curr_duration > self.max_duration:
|
||||||
search_after = len(curr_scores) // 2
|
search_after = len(curr_scores) // 2
|
||||||
@ -151,8 +144,8 @@ class Binarize:
|
|||||||
region = Segment(start - self.pad_onset, min_score_t + self.pad_offset)
|
region = Segment(start - self.pad_onset, min_score_t + self.pad_offset)
|
||||||
active[region, k] = label
|
active[region, k] = label
|
||||||
start = curr_timestamps[min_score_div_idx]
|
start = curr_timestamps[min_score_div_idx]
|
||||||
curr_scores = curr_scores[min_score_div_idx+1:]
|
curr_scores = curr_scores[min_score_div_idx + 1:]
|
||||||
curr_timestamps = curr_timestamps[min_score_div_idx+1:]
|
curr_timestamps = curr_timestamps[min_score_div_idx + 1:]
|
||||||
# switching from active to inactive
|
# switching from active to inactive
|
||||||
elif y < self.offset:
|
elif y < self.offset:
|
||||||
region = Segment(start - self.pad_onset, t + self.pad_offset)
|
region = Segment(start - self.pad_onset, t + self.pad_offset)
|
||||||
@ -193,11 +186,11 @@ class Binarize:
|
|||||||
|
|
||||||
class VoiceActivitySegmentation(VoiceActivityDetection):
|
class VoiceActivitySegmentation(VoiceActivityDetection):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
segmentation: PipelineModel = "pyannote/segmentation",
|
segmentation: PipelineModel = "pyannote/segmentation",
|
||||||
fscore: bool = False,
|
fscore: bool = False,
|
||||||
use_auth_token: Union[Text, None] = None,
|
use_auth_token: Union[Text, None] = None,
|
||||||
**inference_kwargs,
|
**inference_kwargs,
|
||||||
):
|
):
|
||||||
|
|
||||||
super().__init__(segmentation=segmentation, fscore=fscore, use_auth_token=use_auth_token, **inference_kwargs)
|
super().__init__(segmentation=segmentation, fscore=fscore, use_auth_token=use_auth_token, **inference_kwargs)
|
||||||
@ -236,72 +229,35 @@ class VoiceActivitySegmentation(VoiceActivityDetection):
|
|||||||
return segmentations
|
return segmentations
|
||||||
|
|
||||||
|
|
||||||
def merge_vad(vad_arr, pad_onset=0.0, pad_offset=0.0, min_duration_off=0.0, min_duration_on=0.0):
|
class Pyannote(Vad):
|
||||||
|
|
||||||
active = Annotation()
|
def __init__(self, device, use_auth_token=None, model_fp=None, **kwargs):
|
||||||
for k, vad_t in enumerate(vad_arr):
|
print(">>Performing voice activity detection using Pyannote...")
|
||||||
region = Segment(vad_t[0] - pad_onset, vad_t[1] + pad_offset)
|
super().__init__(kwargs['vad_onset'])
|
||||||
active[region, k] = 1
|
self.vad_pipeline = load_vad_model(device, use_auth_token=use_auth_token, model_fp=model_fp)
|
||||||
|
|
||||||
|
def __call__(self, audio: AudioFile, **kwargs):
|
||||||
|
return self.vad_pipeline(audio)
|
||||||
|
|
||||||
if pad_offset > 0.0 or pad_onset > 0.0 or min_duration_off > 0.0:
|
@staticmethod
|
||||||
active = active.support(collar=min_duration_off)
|
def preprocess_audio(audio):
|
||||||
|
return torch.from_numpy(audio).unsqueeze(0)
|
||||||
# remove tracks shorter than min_duration_on
|
|
||||||
if min_duration_on > 0:
|
|
||||||
for segment, track in list(active.itertracks()):
|
|
||||||
if segment.duration < min_duration_on:
|
|
||||||
del active[segment, track]
|
|
||||||
|
|
||||||
active = active.for_json()
|
|
||||||
active_segs = pd.DataFrame([x['segment'] for x in active['content']])
|
|
||||||
return active_segs
|
|
||||||
|
|
||||||
def merge_chunks(
|
@staticmethod
|
||||||
segments,
|
def merge_chunks(segments,
|
||||||
chunk_size,
|
chunk_size,
|
||||||
onset: float = 0.5,
|
onset: float = 0.5,
|
||||||
offset: Optional[float] = None,
|
offset: Optional[float] = None,
|
||||||
):
|
):
|
||||||
"""
|
assert chunk_size > 0
|
||||||
Merge operation described in paper
|
binarize = Binarize(max_duration=chunk_size, onset=onset, offset=offset)
|
||||||
"""
|
segments = binarize(segments)
|
||||||
curr_end = 0
|
segments_list = []
|
||||||
merged_segments = []
|
for speech_turn in segments.get_timeline():
|
||||||
seg_idxs = []
|
segments_list.append(SegmentX(speech_turn.start, speech_turn.end, "UNKNOWN"))
|
||||||
speaker_idxs = []
|
|
||||||
|
|
||||||
assert chunk_size > 0
|
if len(segments_list) == 0:
|
||||||
binarize = Binarize(max_duration=chunk_size, onset=onset, offset=offset)
|
print("No active speech found in audio")
|
||||||
segments = binarize(segments)
|
return []
|
||||||
segments_list = []
|
assert segments_list, "segments_list is empty."
|
||||||
for speech_turn in segments.get_timeline():
|
return Vad.merge_chunks(segments_list, chunk_size, onset, offset)
|
||||||
segments_list.append(SegmentX(speech_turn.start, speech_turn.end, "UNKNOWN"))
|
|
||||||
|
|
||||||
if len(segments_list) == 0:
|
|
||||||
print("No active speech found in audio")
|
|
||||||
return []
|
|
||||||
# assert segments_list, "segments_list is empty."
|
|
||||||
# Make sur the starting point is the start of the segment.
|
|
||||||
curr_start = segments_list[0].start
|
|
||||||
|
|
||||||
for seg in segments_list:
|
|
||||||
if seg.end - curr_start > chunk_size and curr_end-curr_start > 0:
|
|
||||||
merged_segments.append({
|
|
||||||
"start": curr_start,
|
|
||||||
"end": curr_end,
|
|
||||||
"segments": seg_idxs,
|
|
||||||
})
|
|
||||||
curr_start = seg.start
|
|
||||||
seg_idxs = []
|
|
||||||
speaker_idxs = []
|
|
||||||
curr_end = seg.end
|
|
||||||
seg_idxs.append((seg.start, seg.end))
|
|
||||||
speaker_idxs.append(seg.speaker)
|
|
||||||
# add final
|
|
||||||
merged_segments.append({
|
|
||||||
"start": curr_start,
|
|
||||||
"end": curr_end,
|
|
||||||
"segments": seg_idxs,
|
|
||||||
})
|
|
||||||
return merged_segments
|
|
66
whisperx/vads/silero.py
Normal file
66
whisperx/vads/silero.py
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
from io import IOBase
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Mapping, Text
|
||||||
|
from typing import Optional
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from whisperx.diarize import Segment as SegmentX
|
||||||
|
from whisperx.vads.vad import Vad
|
||||||
|
|
||||||
|
AudioFile = Union[Text, Path, IOBase, Mapping]
|
||||||
|
|
||||||
|
|
||||||
|
class Silero(Vad):
|
||||||
|
# check again default values
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
print(">>Performing voice activity detection using Silero...")
|
||||||
|
super().__init__(kwargs['vad_onset'])
|
||||||
|
|
||||||
|
self.vad_onset = kwargs['vad_onset']
|
||||||
|
self.chunk_size = kwargs['chunk_size']
|
||||||
|
self.vad_pipeline, vad_utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',
|
||||||
|
model='silero_vad',
|
||||||
|
force_reload=False,
|
||||||
|
onnx=False,
|
||||||
|
trust_repo=True)
|
||||||
|
(self.get_speech_timestamps, _, self.read_audio, _, _) = vad_utils
|
||||||
|
|
||||||
|
def __call__(self, audio: AudioFile, **kwargs):
|
||||||
|
"""use silero to get segments of speech"""
|
||||||
|
# Only accept 16000 Hz for now.
|
||||||
|
# Note: Silero models support both 8000 and 16000 Hz. Although other values are not directly supported,
|
||||||
|
# multiples of 16000 (e.g. 32000 or 48000) are cast to 16000 inside of the JIT model!
|
||||||
|
sample_rate = audio["sample_rate"]
|
||||||
|
if sample_rate != 16000:
|
||||||
|
raise ValueError("Only 16000Hz sample rate is allowed")
|
||||||
|
|
||||||
|
timestamps = self.get_speech_timestamps(audio["waveform"],
|
||||||
|
model=self.vad_pipeline,
|
||||||
|
sampling_rate=sample_rate,
|
||||||
|
max_speech_duration_s=self.chunk_size,
|
||||||
|
threshold=self.vad_onset
|
||||||
|
# min_silence_duration_ms = self.min_duration_off/1000
|
||||||
|
# min_speech_duration_ms = self.min_duration_on/1000
|
||||||
|
# ...
|
||||||
|
# See silero documentation for full option list
|
||||||
|
)
|
||||||
|
return [SegmentX(i['start'] / sample_rate, i['end'] / sample_rate, "UNKNOWN") for i in timestamps]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def preprocess_audio(audio):
|
||||||
|
return audio
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def merge_chunks(segments_list,
|
||||||
|
chunk_size,
|
||||||
|
onset: float = 0.5,
|
||||||
|
offset: Optional[float] = None,
|
||||||
|
):
|
||||||
|
assert chunk_size > 0
|
||||||
|
if len(segments_list) == 0:
|
||||||
|
print("No active speech found in audio")
|
||||||
|
return []
|
||||||
|
assert segments_list, "segments_list is empty."
|
||||||
|
return Vad.merge_chunks(segments_list, chunk_size, onset, offset)
|
74
whisperx/vads/vad.py
Normal file
74
whisperx/vads/vad.py
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
from pyannote.core import Annotation, Segment
|
||||||
|
|
||||||
|
|
||||||
|
class Vad:
|
||||||
|
def __init__(self, vad_onset):
|
||||||
|
if not (0 < vad_onset < 1):
|
||||||
|
raise ValueError(
|
||||||
|
"vad_onset is a decimal value between 0 and 1."
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def preprocess_audio(audio):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# keep merge_chunks as static so it can be also used by manually assigned vad_model (see 'load_model')
|
||||||
|
@staticmethod
|
||||||
|
def merge_chunks(segments,
|
||||||
|
chunk_size,
|
||||||
|
onset: float,
|
||||||
|
offset: Optional[float]):
|
||||||
|
"""
|
||||||
|
Merge operation described in paper
|
||||||
|
"""
|
||||||
|
curr_end = 0
|
||||||
|
merged_segments = []
|
||||||
|
seg_idxs: list[tuple]= []
|
||||||
|
speaker_idxs: list[Optional[str]] = []
|
||||||
|
|
||||||
|
curr_start = segments[0].start
|
||||||
|
for seg in segments:
|
||||||
|
if seg.end - curr_start > chunk_size and curr_end - curr_start > 0:
|
||||||
|
merged_segments.append({
|
||||||
|
"start": curr_start,
|
||||||
|
"end": curr_end,
|
||||||
|
"segments": seg_idxs,
|
||||||
|
})
|
||||||
|
curr_start = seg.start
|
||||||
|
seg_idxs = []
|
||||||
|
speaker_idxs = []
|
||||||
|
curr_end = seg.end
|
||||||
|
seg_idxs.append((seg.start, seg.end))
|
||||||
|
speaker_idxs.append(seg.speaker)
|
||||||
|
# add final
|
||||||
|
merged_segments.append({
|
||||||
|
"start": curr_start,
|
||||||
|
"end": curr_end,
|
||||||
|
"segments": seg_idxs,
|
||||||
|
})
|
||||||
|
|
||||||
|
return merged_segments
|
||||||
|
|
||||||
|
# Unused function
|
||||||
|
@staticmethod
|
||||||
|
def merge_vad(vad_arr, pad_onset=0.0, pad_offset=0.0, min_duration_off=0.0, min_duration_on=0.0):
|
||||||
|
active = Annotation()
|
||||||
|
for k, vad_t in enumerate(vad_arr):
|
||||||
|
region = Segment(vad_t[0] - pad_onset, vad_t[1] + pad_offset)
|
||||||
|
active[region, k] = 1
|
||||||
|
|
||||||
|
if pad_offset > 0.0 or pad_onset > 0.0 or min_duration_off > 0.0:
|
||||||
|
active = active.support(collar=min_duration_off)
|
||||||
|
|
||||||
|
# remove tracks shorter than min_duration_on
|
||||||
|
if min_duration_on > 0:
|
||||||
|
for segment, track in list(active.itertracks()):
|
||||||
|
if segment.duration < min_duration_on:
|
||||||
|
del active[segment, track]
|
||||||
|
|
||||||
|
active = active.for_json()
|
||||||
|
active_segs = pd.DataFrame([x['segment'] for x in active['content']])
|
||||||
|
return active_segs
|
Reference in New Issue
Block a user