11 Commits

Author SHA1 Message Date
036b5b0717 Merge c89b4f898f into d700b56c9c 2025-06-13 15:33:03 +02:00
d700b56c9c docs: add missing torch import to Python usage example in README 2025-06-08 03:34:49 -06:00
bog
b343241253 feat: add diarize_model arg to CLI (#1101) 2025-05-31 13:32:31 +02:00
6fe0a8784a docs: add troubleshooting section for libcudnn dependencies in README 2025-05-31 05:20:06 -06:00
c89b4f898f fix: incorrect type annotation in get_writer return value
The audio_path attribute that the __call__ method of the ResultWriter class takes is a str, not TextIO
2025-05-13 02:45:33 +02:00
5012650d0f chore: update lockfile 2025-05-03 16:25:43 +02:00
108bd0c400 chore: add lockfile check step to CI workflows 2025-05-03 16:25:43 +02:00
b2d50a027b chore: bump version 2025-05-03 11:38:54 +02:00
36d552cad3 fix: remove DiarizationPipeline from public API 2025-05-03 09:25:59 +02:00
7d36b832f9 refactor: update CLI entry point 2025-05-03 09:25:59 +02:00
d2a493e910 refactor: implement lazy loading for module imports in whisperx 2025-05-03 09:25:59 +02:00
10 changed files with 199 additions and 149 deletions

View File

@ -17,6 +17,9 @@ jobs:
version: "0.5.14" version: "0.5.14"
python-version: "3.9" python-version: "3.9"
- name: Check if lockfile is up to date
run: uv lock --check
- name: Build package - name: Build package
run: uv build run: uv build

View File

@ -23,6 +23,9 @@ jobs:
version: "0.5.14" version: "0.5.14"
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}
- name: Check if lockfile is up to date
run: uv lock --check
- name: Install the project - name: Install the project
run: uv sync --all-extras run: uv sync --all-extras

105
README.md
View File

@ -22,26 +22,20 @@
</a> </a>
</p> </p>
<img width="1216" align="center" alt="whisperx-arch" src="https://raw.githubusercontent.com/m-bain/whisperX/refs/heads/main/figures/pipeline.png"> <img width="1216" align="center" alt="whisperx-arch" src="https://raw.githubusercontent.com/m-bain/whisperX/refs/heads/main/figures/pipeline.png">
<!-- <p align="left">Whisper-Based Automatic Speech Recognition (ASR) with improved timestamp accuracy + quality via forced phoneme alignment and voice-activity based batching for fast inference.</p> --> <!-- <p align="left">Whisper-Based Automatic Speech Recognition (ASR) with improved timestamp accuracy + quality via forced phoneme alignment and voice-activity based batching for fast inference.</p> -->
<!-- <h2 align="left", id="what-is-it">What is it 🔎</h2> --> <!-- <h2 align="left", id="what-is-it">What is it 🔎</h2> -->
This repository provides fast automatic speech recognition (70x realtime with large-v2) with word-level timestamps and speaker diarization. This repository provides fast automatic speech recognition (70x realtime with large-v2) with word-level timestamps and speaker diarization.
- ⚡️ Batched inference for 70x realtime transcription using whisper large-v2 - ⚡️ Batched inference for 70x realtime transcription using whisper large-v2
- 🪶 [faster-whisper](https://github.com/guillaumekln/faster-whisper) backend, requires <8GB gpu memory for large-v2 with beam_size=5 - 🪶 [faster-whisper](https://github.com/guillaumekln/faster-whisper) backend, requires <8GB gpu memory for large-v2 with beam_size=5
- 🎯 Accurate word-level timestamps using wav2vec2 alignment - 🎯 Accurate word-level timestamps using wav2vec2 alignment
- 👯 Multispeaker ASR using speaker diarization from [pyannote-audio](https://github.com/pyannote/pyannote-audio) (speaker ID labels) - 👯 Multispeaker ASR using speaker diarization from [pyannote-audio](https://github.com/pyannote/pyannote-audio) (speaker ID labels)
- 🗣 VAD preprocessing, reduces hallucination & batching with no WER degradation - 🗣 VAD preprocessing, reduces hallucination & batching with no WER degradation
**Whisper** is an ASR model [developed by OpenAI](https://github.com/openai/whisper), trained on a large dataset of diverse audio. Whilst it does produces highly accurate transcriptions, the corresponding timestamps are at the utterance-level, not per word, and can be inaccurate by several seconds. OpenAI's whisper does not natively support batching. **Whisper** is an ASR model [developed by OpenAI](https://github.com/openai/whisper), trained on a large dataset of diverse audio. Whilst it does produces highly accurate transcriptions, the corresponding timestamps are at the utterance-level, not per word, and can be inaccurate by several seconds. OpenAI's whisper does not natively support batching.
**Phoneme-Based ASR** A suite of models finetuned to recognise the smallest unit of speech distinguishing one word from another, e.g. the element p in "tap". A popular example model is [wav2vec2.0](https://huggingface.co/facebook/wav2vec2-large-960h-lv60-self). **Phoneme-Based ASR** A suite of models finetuned to recognise the smallest unit of speech distinguishing one word from another, e.g. the element p in "tap". A popular example model is [wav2vec2.0](https://huggingface.co/facebook/wav2vec2-large-960h-lv60-self).
@ -54,12 +48,12 @@ This repository provides fast automatic speech recognition (70x realtime with la
<h2 align="left", id="highlights">New🚨</h2> <h2 align="left", id="highlights">New🚨</h2>
- 1st place at [Ego4d transcription challenge](https://eval.ai/web/challenges/challenge-page/1637/leaderboard/3931/WER) 🏆 - 1st place at [Ego4d transcription challenge](https://eval.ai/web/challenges/challenge-page/1637/leaderboard/3931/WER) 🏆
- _WhisperX_ accepted at INTERSPEECH 2023 - _WhisperX_ accepted at INTERSPEECH 2023
- v3 transcript segment-per-sentence: using nltk sent_tokenize for better subtitlting & better diarization - v3 transcript segment-per-sentence: using nltk sent_tokenize for better subtitlting & better diarization
- v3 released, 70x speed-up open-sourced. Using batched whisper with [faster-whisper](https://github.com/guillaumekln/faster-whisper) backend! - v3 released, 70x speed-up open-sourced. Using batched whisper with [faster-whisper](https://github.com/guillaumekln/faster-whisper) backend!
- v2 released, code cleanup, imports whisper library VAD filtering is now turned on by default, as in the paper. - v2 released, code cleanup, imports whisper library VAD filtering is now turned on by default, as in the paper.
- Paper drop🎓👨🏫! Please see our [ArxiV preprint](https://arxiv.org/abs/2303.00747) for benchmarking and details of WhisperX. We also introduce more efficient batch inference resulting in large-v2 with *60-70x REAL TIME speed. - Paper drop🎓👨🏫! Please see our [ArxiV preprint](https://arxiv.org/abs/2303.00747) for benchmarking and details of WhisperX. We also introduce more efficient batch inference resulting in large-v2 with \*60-70x REAL TIME speed.
<h2 align="left" id="setup">Setup ⚙️</h2> <h2 align="left" id="setup">Setup ⚙️</h2>
@ -103,6 +97,25 @@ uv sync --all-extras --dev
You may also need to install ffmpeg, rust etc. Follow openAI instructions here https://github.com/openai/whisper#setup. You may also need to install ffmpeg, rust etc. Follow openAI instructions here https://github.com/openai/whisper#setup.
### Common Issues & Troubleshooting 🔧
#### libcudnn Dependencies (GPU Users)
If you're using WhisperX with GPU support and encounter errors like:
- `Could not load library libcudnn_ops_infer.so.8`
- `Unable to load any of {libcudnn_cnn.so.9.1.0, libcudnn_cnn.so.9.1, libcudnn_cnn.so.9, libcudnn_cnn.so}`
- `libcudnn_ops_infer.so.8: cannot open shared object file: No such file or directory`
This means your system is missing the CUDA Deep Neural Network library (cuDNN). This library is needed for GPU acceleration but isn't always installed by default.
**Install cuDNN (example for apt based systems):**
```bash
sudo apt update
sudo apt install libcudnn8 libcudnn8-dev -y
```
### Speaker Diarization ### Speaker Diarization
To **enable Speaker Diarization**, include your Hugging Face access token (read) that you can generate from [Here](https://huggingface.co/settings/tokens) after the `--hf_token` argument and accept the user agreement for the following models: [Segmentation](https://huggingface.co/pyannote/segmentation-3.0) and [Speaker-Diarization-3.1](https://huggingface.co/pyannote/speaker-diarization-3.1) (if you choose to use Speaker-Diarization 2.x, follow requirements [here](https://huggingface.co/pyannote/speaker-diarization) instead.) To **enable Speaker Diarization**, include your Hugging Face access token (read) that you can generate from [Here](https://huggingface.co/settings/tokens) after the `--hf_token` argument and accept the user agreement for the following models: [Segmentation](https://huggingface.co/pyannote/segmentation-3.0) and [Speaker-Diarization-3.1](https://huggingface.co/pyannote/speaker-diarization-3.1) (if you choose to use Speaker-Diarization 2.x, follow requirements [here](https://huggingface.co/pyannote/speaker-diarization) instead.)
@ -118,8 +131,7 @@ Run whisper on example segment (using default params, whisper small) add `--high
whisperx path/to/audio.wav whisperx path/to/audio.wav
Result using _WhisperX_ with forced alignment to wav2vec2.0 large:
Result using *WhisperX* with forced alignment to wav2vec2.0 large:
https://user-images.githubusercontent.com/36994049/208253969-7e35fe2a-7541-434a-ae91-8e919540555d.mp4 https://user-images.githubusercontent.com/36994049/208253969-7e35fe2a-7541-434a-ae91-8e919540555d.mp4
@ -127,12 +139,10 @@ Compare this to original whisper out the box, where many transcriptions are out
https://user-images.githubusercontent.com/36994049/207743923-b4f0d537-29ae-4be2-b404-bb941db73652.mov https://user-images.githubusercontent.com/36994049/207743923-b4f0d537-29ae-4be2-b404-bb941db73652.mov
For increased timestamp accuracy, at the cost of higher gpu mem, use bigger models (bigger alignment model not found to be that helpful, see paper) e.g. For increased timestamp accuracy, at the cost of higher gpu mem, use bigger models (bigger alignment model not found to be that helpful, see paper) e.g.
whisperx path/to/audio.wav --model large-v2 --align_model WAV2VEC2_ASR_LARGE_LV60K_960H --batch_size 4 whisperx path/to/audio.wav --model large-v2 --align_model WAV2VEC2_ASR_LARGE_LV60K_960H --batch_size 4
To label the transcript with speaker ID's (set number of speakers if known e.g. `--min_speakers 2` `--max_speakers 2`): To label the transcript with speaker ID's (set number of speakers if known e.g. `--min_speakers 2` `--max_speakers 2`):
whisperx path/to/audio.wav --model large-v2 --diarize --highlight_words True whisperx path/to/audio.wav --model large-v2 --diarize --highlight_words True
@ -143,27 +153,26 @@ To run on CPU instead of GPU (and for running on Mac OS X):
### Other languages ### Other languages
The phoneme ASR alignment model is *language-specific*, for tested languages these models are [automatically picked from torchaudio pipelines or huggingface](https://github.com/m-bain/whisperX/blob/f2da2f858e99e4211fe4f64b5f2938b007827e17/whisperx/alignment.py#L24-L58). The phoneme ASR alignment model is _language-specific_, for tested languages these models are [automatically picked from torchaudio pipelines or huggingface](https://github.com/m-bain/whisperX/blob/f2da2f858e99e4211fe4f64b5f2938b007827e17/whisperx/alignment.py#L24-L58).
Just pass in the `--language` code, and use the whisper `--model large`. Just pass in the `--language` code, and use the whisper `--model large`.
Currently default models provided for `{en, fr, de, es, it}` via torchaudio pipelines and many other languages via Hugging Face. Please find the list of currently supported languages under `DEFAULT_ALIGN_MODELS_HF` on [alignment.py](https://github.com/m-bain/whisperX/blob/main/whisperx/alignment.py). If the detected language is not in this list, you need to find a phoneme-based ASR model from [huggingface model hub](https://huggingface.co/models) and test it on your data. Currently default models provided for `{en, fr, de, es, it}` via torchaudio pipelines and many other languages via Hugging Face. Please find the list of currently supported languages under `DEFAULT_ALIGN_MODELS_HF` on [alignment.py](https://github.com/m-bain/whisperX/blob/main/whisperx/alignment.py). If the detected language is not in this list, you need to find a phoneme-based ASR model from [huggingface model hub](https://huggingface.co/models) and test it on your data.
#### E.g. German #### E.g. German
whisperx --model large-v2 --language de path/to/audio.wav whisperx --model large-v2 --language de path/to/audio.wav
https://user-images.githubusercontent.com/36994049/208298811-e36002ba-3698-4731-97d4-0aebd07e0eb3.mov https://user-images.githubusercontent.com/36994049/208298811-e36002ba-3698-4731-97d4-0aebd07e0eb3.mov
See more examples in other languages [here](EXAMPLES.md). See more examples in other languages [here](EXAMPLES.md).
## Python usage 🐍 ## Python usage 🐍
```python ```python
import whisperx import whisperx
import gc import gc
device = "cuda" device = "cuda"
audio_file = "audio.mp3" audio_file = "audio.mp3"
batch_size = 16 # reduce if low on GPU mem batch_size = 16 # reduce if low on GPU mem
compute_type = "float16" # change to "int8" if low on GPU mem (may reduce accuracy) compute_type = "float16" # change to "int8" if low on GPU mem (may reduce accuracy)
@ -180,7 +189,7 @@ result = model.transcribe(audio, batch_size=batch_size)
print(result["segments"]) # before alignment print(result["segments"]) # before alignment
# delete model if low on GPU resources # delete model if low on GPU resources
# import gc; gc.collect(); torch.cuda.empty_cache(); del model # import gc; import torch; gc.collect(); torch.cuda.empty_cache(); del model
# 2. Align whisper output # 2. Align whisper output
model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device) model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
@ -189,10 +198,10 @@ result = whisperx.align(result["segments"], model_a, metadata, audio, device, re
print(result["segments"]) # after alignment print(result["segments"]) # after alignment
# delete model if low on GPU resources # delete model if low on GPU resources
# import gc; gc.collect(); torch.cuda.empty_cache(); del model_a # import gc; import torch; gc.collect(); torch.cuda.empty_cache(); del model_a
# 3. Assign speaker labels # 3. Assign speaker labels
diarize_model = whisperx.DiarizationPipeline(use_auth_token=YOUR_HF_TOKEN, device=device) diarize_model = whisperx.diarize.DiarizationPipeline(use_auth_token=YOUR_HF_TOKEN, device=device)
# add min/max number of speakers if known # add min/max number of speakers if known
diarize_segments = diarize_model(audio) diarize_segments = diarize_model(audio)
@ -205,25 +214,27 @@ print(result["segments"]) # segments are now assigned speaker IDs
## Demos 🚀 ## Demos 🚀
[![Replicate (large-v3](https://img.shields.io/static/v1?label=Replicate+WhisperX+large-v3&message=Demo+%26+Cloud+API&color=blue)](https://replicate.com/victor-upmeet/whisperx) [![Replicate (large-v3](https://img.shields.io/static/v1?label=Replicate+WhisperX+large-v3&message=Demo+%26+Cloud+API&color=blue)](https://replicate.com/victor-upmeet/whisperx)
[![Replicate (large-v2](https://img.shields.io/static/v1?label=Replicate+WhisperX+large-v2&message=Demo+%26+Cloud+API&color=blue)](https://replicate.com/daanelson/whisperx) [![Replicate (large-v2](https://img.shields.io/static/v1?label=Replicate+WhisperX+large-v2&message=Demo+%26+Cloud+API&color=blue)](https://replicate.com/daanelson/whisperx)
[![Replicate (medium)](https://img.shields.io/static/v1?label=Replicate+WhisperX+medium&message=Demo+%26+Cloud+API&color=blue)](https://replicate.com/carnifexer/whisperx) [![Replicate (medium)](https://img.shields.io/static/v1?label=Replicate+WhisperX+medium&message=Demo+%26+Cloud+API&color=blue)](https://replicate.com/carnifexer/whisperx)
If you don't have access to your own GPUs, use the links above to try out WhisperX. If you don't have access to your own GPUs, use the links above to try out WhisperX.
<h2 align="left" id="whisper-mod">Technical Details 👷‍♂️</h2> <h2 align="left" id="whisper-mod">Technical Details 👷‍♂️</h2>
For specific details on the batching and alignment, the effect of VAD, as well as the chosen alignment model, see the preprint [paper](https://www.robots.ox.ac.uk/~vgg/publications/2023/Bain23/bain23.pdf). For specific details on the batching and alignment, the effect of VAD, as well as the chosen alignment model, see the preprint [paper](https://www.robots.ox.ac.uk/~vgg/publications/2023/Bain23/bain23.pdf).
To reduce GPU memory requirements, try any of the following (2. & 3. can affect quality): To reduce GPU memory requirements, try any of the following (2. & 3. can affect quality):
1. reduce batch size, e.g. `--batch_size 4` 1. reduce batch size, e.g. `--batch_size 4`
2. use a smaller ASR model `--model base` 2. use a smaller ASR model `--model base`
3. Use lighter compute type `--compute_type int8` 3. Use lighter compute type `--compute_type int8`
Transcription differences from openai's whisper: Transcription differences from openai's whisper:
1. Transcription without timestamps. To enable single pass batching, whisper inference is performed `--without_timestamps True`, this ensures 1 forward pass per sample in the batch. However, this can cause discrepancies the default whisper output. 1. Transcription without timestamps. To enable single pass batching, whisper inference is performed `--without_timestamps True`, this ensures 1 forward pass per sample in the batch. However, this can cause discrepancies the default whisper output.
2. VAD-based segment transcription, unlike the buffered transcription of openai's. In the WhisperX paper we show this reduces WER, and enables accurate batched inference 2. VAD-based segment transcription, unlike the buffered transcription of openai's. In the WhisperX paper we show this reduces WER, and enables accurate batched inference
3. `--condition_on_prev_text` is set to `False` by default (reduces hallucination) 3. `--condition_on_prev_text` is set to `False` by default (reduces hallucination)
<h2 align="left" id="limitations">Limitations ⚠️</h2> <h2 align="left" id="limitations">Limitations ⚠️</h2>
@ -232,7 +243,6 @@ Transcription differences from openai's whisper:
- Diarization is far from perfect - Diarization is far from perfect
- Language specific wav2vec2 model is needed - Language specific wav2vec2 model is needed
<h2 align="left" id="contribute">Contribute 🧑‍🏫</h2> <h2 align="left" id="contribute">Contribute 🧑‍🏫</h2>
If you are multilingual, a major way you can contribute to this project is to find phoneme models on huggingface (or train your own) and test them on speech for the target language. If the results look good send a pull request and some examples showing its success. If you are multilingual, a major way you can contribute to this project is to find phoneme models on huggingface (or train your own) and test them on speech for the target language. If the results look good send a pull request and some examples showing its success.
@ -241,43 +251,40 @@ Bug finding and pull requests are also highly appreciated to keep this project g
<h2 align="left" id="coming-soon">TODO 🗓</h2> <h2 align="left" id="coming-soon">TODO 🗓</h2>
* [x] Multilingual init - [x] Multilingual init
* [x] Automatic align model selection based on language detection - [x] Automatic align model selection based on language detection
* [x] Python usage - [x] Python usage
* [x] Incorporating speaker diarization - [x] Incorporating speaker diarization
* [x] Model flush, for low gpu mem resources - [x] Model flush, for low gpu mem resources
* [x] Faster-whisper backend - [x] Faster-whisper backend
* [x] Add max-line etc. see (openai's whisper utils.py) - [x] Add max-line etc. see (openai's whisper utils.py)
* [x] Sentence-level segments (nltk toolbox) - [x] Sentence-level segments (nltk toolbox)
* [x] Improve alignment logic - [x] Improve alignment logic
* [ ] update examples with diarization and word highlighting - [ ] update examples with diarization and word highlighting
* [ ] Subtitle .ass output <- bring this back (removed in v3) - [ ] Subtitle .ass output <- bring this back (removed in v3)
* [ ] Add benchmarking code (TEDLIUM for spd/WER & word segmentation) - [ ] Add benchmarking code (TEDLIUM for spd/WER & word segmentation)
* [x] Allow silero-vad as alternative VAD option - [x] Allow silero-vad as alternative VAD option
* [ ] Improve diarization (word level). *Harder than first thought...*
- [ ] Improve diarization (word level). _Harder than first thought..._
<h2 align="left" id="contact">Contact/Support 📇</h2> <h2 align="left" id="contact">Contact/Support 📇</h2>
Contact maxhbain@gmail.com for queries. Contact maxhbain@gmail.com for queries.
<a href="https://www.buymeacoffee.com/maxhbain" target="_blank"><img src="https://cdn.buymeacoffee.com/buttons/default-orange.png" alt="Buy Me A Coffee" height="41" width="174"></a> <a href="https://www.buymeacoffee.com/maxhbain" target="_blank"><img src="https://cdn.buymeacoffee.com/buttons/default-orange.png" alt="Buy Me A Coffee" height="41" width="174"></a>
<h2 align="left" id="acks">Acknowledgements 🙏</h2> <h2 align="left" id="acks">Acknowledgements 🙏</h2>
This work, and my PhD, is supported by the [VGG (Visual Geometry Group)](https://www.robots.ox.ac.uk/~vgg/) and the University of Oxford. This work, and my PhD, is supported by the [VGG (Visual Geometry Group)](https://www.robots.ox.ac.uk/~vgg/) and the University of Oxford.
@ -286,8 +293,8 @@ Of course, this is builds on [openAI's whisper](https://github.com/openai/whispe
Borrows important alignment code from [PyTorch tutorial on forced alignment](https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html) Borrows important alignment code from [PyTorch tutorial on forced alignment](https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html)
And uses the wonderful pyannote VAD / Diarization https://github.com/pyannote/pyannote-audio And uses the wonderful pyannote VAD / Diarization https://github.com/pyannote/pyannote-audio
Valuable VAD & Diarization Models from: Valuable VAD & Diarization Models from:
- [pyannote audio][https://github.com/pyannote/pyannote-audio] - [pyannote audio][https://github.com/pyannote/pyannote-audio]
- [silero vad][https://github.com/snakers4/silero-vad] - [silero vad][https://github.com/snakers4/silero-vad]

View File

@ -2,7 +2,7 @@
urls = { repository = "https://github.com/m-bain/whisperx" } urls = { repository = "https://github.com/m-bain/whisperx" }
authors = [{ name = "Max Bain" }] authors = [{ name = "Max Bain" }]
name = "whisperx" name = "whisperx"
version = "3.3.3" version = "3.3.4"
description = "Time-Accurate Automatic Speech Recognition using Whisper." description = "Time-Accurate Automatic Speech Recognition using Whisper."
readme = "README.md" readme = "README.md"
requires-python = ">=3.9, <3.13" requires-python = ">=3.9, <3.13"
@ -23,7 +23,7 @@ dependencies = [
[project.scripts] [project.scripts]
whisperx = "whisperx.transcribe:cli" whisperx = "whisperx.__main__:cli"
[build-system] [build-system]
requires = ["setuptools"] requires = ["setuptools"]
@ -33,4 +33,4 @@ include-package-data = true
[tool.setuptools.packages.find] [tool.setuptools.packages.find]
where = ["."] where = ["."]
include = ["whisperx*"] include = ["whisperx*"]

2
uv.lock generated
View File

@ -2787,7 +2787,7 @@ wheels = [
[[package]] [[package]]
name = "whisperx" name = "whisperx"
version = "3.3.3" version = "3.3.4"
source = { editable = "." } source = { editable = "." }
dependencies = [ dependencies = [
{ name = "ctranslate2" }, { name = "ctranslate2" },

View File

@ -1,7 +1,31 @@
from whisperx.alignment import load_align_model as load_align_model, align as align import importlib
from whisperx.asr import load_model as load_model
from whisperx.audio import load_audio as load_audio
from whisperx.diarize import ( def _lazy_import(name):
assign_word_speakers as assign_word_speakers, module = importlib.import_module(f"whisperx.{name}")
DiarizationPipeline as DiarizationPipeline, return module
)
def load_align_model(*args, **kwargs):
alignment = _lazy_import("alignment")
return alignment.load_align_model(*args, **kwargs)
def align(*args, **kwargs):
alignment = _lazy_import("alignment")
return alignment.align(*args, **kwargs)
def load_model(*args, **kwargs):
asr = _lazy_import("asr")
return asr.load_model(*args, **kwargs)
def load_audio(*args, **kwargs):
audio = _lazy_import("audio")
return audio.load_audio(*args, **kwargs)
def assign_word_speakers(*args, **kwargs):
diarize = _lazy_import("diarize")
return diarize.assign_word_speakers(*args, **kwargs)

View File

@ -1,4 +1,88 @@
from whisperx.transcribe import cli import argparse
import importlib.metadata
import platform
import torch
from whisperx.utils import (LANGUAGES, TO_LANGUAGE_CODE, optional_float,
optional_int, str2bool)
cli() def cli():
# fmt: off
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
parser.add_argument("--model", default="small", help="name of the Whisper model to use")
parser.add_argument("--model_cache_only", type=str2bool, default=False, help="If True, will not attempt to download models, instead using cached models from --model_dir")
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
parser.add_argument("--device_index", default=0, type=int, help="device index to use for FasterWhisper inference")
parser.add_argument("--batch_size", default=8, type=int, help="the preferred batch size for inference")
parser.add_argument("--compute_type", default="float16", type=str, choices=["float16", "float32", "int8"], help="compute type for computation")
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
parser.add_argument("--output_format", "-f", type=str, default="all", choices=["all", "srt", "vtt", "txt", "tsv", "json", "aud"], help="format of the output file; if not specified, all available formats will be produced")
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection")
# alignment params
parser.add_argument("--align_model", default=None, help="Name of phoneme-level ASR model to do alignment")
parser.add_argument("--interpolate_method", default="nearest", choices=["nearest", "linear", "ignore"], help="For word .srt, method to assign timestamps to non-aligned words, or merge them into neighbouring.")
parser.add_argument("--no_align", action='store_true', help="Do not perform phoneme alignment")
parser.add_argument("--return_char_alignments", action='store_true', help="Return character-level alignments in the output json file")
# vad params
parser.add_argument("--vad_method", type=str, default="pyannote", choices=["pyannote", "silero"], help="VAD method to be used")
parser.add_argument("--vad_onset", type=float, default=0.500, help="Onset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected")
parser.add_argument("--vad_offset", type=float, default=0.363, help="Offset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected.")
parser.add_argument("--chunk_size", type=int, default=30, help="Chunk size for merging VAD segments. Default is 30, reduce this if the chunk is too long.")
# diarization params
parser.add_argument("--diarize", action="store_true", help="Apply diarization to assign speaker labels to each segment/word")
parser.add_argument("--min_speakers", default=None, type=int, help="Minimum number of speakers to in audio file")
parser.add_argument("--max_speakers", default=None, type=int, help="Maximum number of speakers to in audio file")
parser.add_argument("--diarize_model", default="pyannote/speaker-diarization-3.1", type=str, help="Name of the speaker diarization model to use")
parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero")
parser.add_argument("--patience", type=float, default=1.0, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
parser.add_argument("--length_penalty", type=float, default=1.0, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default")
parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
parser.add_argument("--suppress_numerals", action="store_true", help="whether to suppress numeric symbols and currency symbols during sampling, since wav2vec2 cannot align them correctly")
parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
parser.add_argument("--condition_on_previous_text", type=str2bool, default=False, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
parser.add_argument("--max_line_width", type=optional_int, default=None, help="(not possible with --no_align) the maximum number of characters in a line before breaking the line")
parser.add_argument("--max_line_count", type=optional_int, default=None, help="(not possible with --no_align) the maximum number of lines in a segment")
parser.add_argument("--highlight_words", type=str2bool, default=False, help="(not possible with --no_align) underline each word as it is spoken in srt and vtt")
parser.add_argument("--segment_resolution", type=str, default="sentence", choices=["sentence", "chunk"], help="(not possible with --no_align) the maximum number of characters in a line before breaking the line")
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face Access Token to access PyAnnote gated models")
parser.add_argument("--print_progress", type=str2bool, default = False, help = "if True, progress will be printed in transcribe() and align() methods.")
parser.add_argument("--version", "-V", action="version", version=f"%(prog)s {importlib.metadata.version('whisperx')}",help="Show whisperx version information and exit")
parser.add_argument("--python-version", "-P", action="version", version=f"Python {platform.python_version()} ({platform.python_implementation()})",help="Show python version information and exit")
# fmt: on
args = parser.parse_args().__dict__
from whisperx.transcribe import transcribe_task
transcribe_task(args, parser)
if __name__ == "__main__":
cli()

View File

@ -11,13 +11,14 @@ from whisperx.types import TranscriptionResult, AlignedTranscriptionResult
class DiarizationPipeline: class DiarizationPipeline:
def __init__( def __init__(
self, self,
model_name="pyannote/speaker-diarization-3.1", model_name=None,
use_auth_token=None, use_auth_token=None,
device: Optional[Union[str, torch.device]] = "cpu", device: Optional[Union[str, torch.device]] = "cpu",
): ):
if isinstance(device, str): if isinstance(device, str):
device = torch.device(device) device = torch.device(device)
self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token).to(device) model_config = model_name or "pyannote/speaker-diarization-3.1"
self.model = Pipeline.from_pretrained(model_config, use_auth_token=use_auth_token).to(device)
def __call__( def __call__(
self, self,

View File

@ -1,10 +1,7 @@
import argparse import argparse
import gc import gc
import os import os
import sys
import warnings import warnings
import importlib.metadata
import platform
import numpy as np import numpy as np
import torch import torch
@ -14,85 +11,18 @@ from whisperx.asr import load_model
from whisperx.audio import load_audio from whisperx.audio import load_audio
from whisperx.diarize import DiarizationPipeline, assign_word_speakers from whisperx.diarize import DiarizationPipeline, assign_word_speakers
from whisperx.types import AlignedTranscriptionResult, TranscriptionResult from whisperx.types import AlignedTranscriptionResult, TranscriptionResult
from whisperx.utils import ( from whisperx.utils import LANGUAGES, TO_LANGUAGE_CODE, get_writer
LANGUAGES,
TO_LANGUAGE_CODE,
get_writer,
optional_float,
optional_int,
str2bool,
)
def cli(): def transcribe_task(args: dict, parser: argparse.ArgumentParser):
"""Transcription task to be called from CLI.
Args:
args: Dictionary of command-line arguments.
parser: argparse.ArgumentParser object.
"""
# fmt: off # fmt: off
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
parser.add_argument("--model", default="small", help="name of the Whisper model to use")
parser.add_argument("--model_cache_only", type=str2bool, default=False, help="If True, will not attempt to download models, instead using cached models from --model_dir")
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
parser.add_argument("--device_index", default=0, type=int, help="device index to use for FasterWhisper inference")
parser.add_argument("--batch_size", default=8, type=int, help="the preferred batch size for inference")
parser.add_argument("--compute_type", default="float16", type=str, choices=["float16", "float32", "int8"], help="compute type for computation")
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
parser.add_argument("--output_format", "-f", type=str, default="all", choices=["all", "srt", "vtt", "txt", "tsv", "json", "aud"], help="format of the output file; if not specified, all available formats will be produced")
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection")
# alignment params
parser.add_argument("--align_model", default=None, help="Name of phoneme-level ASR model to do alignment")
parser.add_argument("--interpolate_method", default="nearest", choices=["nearest", "linear", "ignore"], help="For word .srt, method to assign timestamps to non-aligned words, or merge them into neighbouring.")
parser.add_argument("--no_align", action='store_true', help="Do not perform phoneme alignment")
parser.add_argument("--return_char_alignments", action='store_true', help="Return character-level alignments in the output json file")
# vad params
parser.add_argument("--vad_method", type=str, default="pyannote", choices=["pyannote", "silero"], help="VAD method to be used")
parser.add_argument("--vad_onset", type=float, default=0.500, help="Onset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected")
parser.add_argument("--vad_offset", type=float, default=0.363, help="Offset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected.")
parser.add_argument("--chunk_size", type=int, default=30, help="Chunk size for merging VAD segments. Default is 30, reduce this if the chunk is too long.")
# diarization params
parser.add_argument("--diarize", action="store_true", help="Apply diarization to assign speaker labels to each segment/word")
parser.add_argument("--min_speakers", default=None, type=int, help="Minimum number of speakers to in audio file")
parser.add_argument("--max_speakers", default=None, type=int, help="Maximum number of speakers to in audio file")
parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero")
parser.add_argument("--patience", type=float, default=1.0, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
parser.add_argument("--length_penalty", type=float, default=1.0, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default")
parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
parser.add_argument("--suppress_numerals", action="store_true", help="whether to suppress numeric symbols and currency symbols during sampling, since wav2vec2 cannot align them correctly")
parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
parser.add_argument("--condition_on_previous_text", type=str2bool, default=False, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
parser.add_argument("--max_line_width", type=optional_int, default=None, help="(not possible with --no_align) the maximum number of characters in a line before breaking the line")
parser.add_argument("--max_line_count", type=optional_int, default=None, help="(not possible with --no_align) the maximum number of lines in a segment")
parser.add_argument("--highlight_words", type=str2bool, default=False, help="(not possible with --no_align) underline each word as it is spoken in srt and vtt")
parser.add_argument("--segment_resolution", type=str, default="sentence", choices=["sentence", "chunk"], help="(not possible with --no_align) the maximum number of characters in a line before breaking the line")
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face Access Token to access PyAnnote gated models")
parser.add_argument("--print_progress", type=str2bool, default = False, help = "if True, progress will be printed in transcribe() and align() methods.")
parser.add_argument("--version", "-V", action="version", version=f"%(prog)s {importlib.metadata.version('whisperx')}",help="Show whisperx version information and exit")
parser.add_argument("--python-version", "-P", action="version", version=f"Python {platform.python_version()} ({platform.python_implementation()})",help="Show python version information and exit")
# fmt: on
args = parser.parse_args().__dict__
model_name: str = args.pop("model") model_name: str = args.pop("model")
batch_size: int = args.pop("batch_size") batch_size: int = args.pop("batch_size")
model_dir: str = args.pop("model_dir") model_dir: str = args.pop("model_dir")
@ -127,6 +57,7 @@ def cli():
diarize: bool = args.pop("diarize") diarize: bool = args.pop("diarize")
min_speakers: int = args.pop("min_speakers") min_speakers: int = args.pop("min_speakers")
max_speakers: int = args.pop("max_speakers") max_speakers: int = args.pop("max_speakers")
diarize_model_name: str = args.pop("diarize_model")
print_progress: bool = args.pop("print_progress") print_progress: bool = args.pop("print_progress")
if args["language"] is not None: if args["language"] is not None:
@ -274,8 +205,9 @@ def cli():
) )
tmp_results = results tmp_results = results
print(">>Performing diarization...") print(">>Performing diarization...")
print(">>Using model:", diarize_model_name)
results = [] results = []
diarize_model = DiarizationPipeline(use_auth_token=hf_token, device=device) diarize_model = DiarizationPipeline(model_name=diarize_model_name, use_auth_token=hf_token, device=device)
for result, input_audio_path in tmp_results: for result, input_audio_path in tmp_results:
diarize_segments = diarize_model( diarize_segments = diarize_model(
input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers
@ -286,7 +218,3 @@ def cli():
for result, audio_path in results: for result, audio_path in results:
result["language"] = align_language result["language"] = align_language
writer(result, audio_path, writer_args) writer(result, audio_path, writer_args)
if __name__ == "__main__":
cli()

View File

@ -410,7 +410,7 @@ class WriteJSON(ResultWriter):
def get_writer( def get_writer(
output_format: str, output_dir: str output_format: str, output_dir: str
) -> Callable[[dict, TextIO, dict], None]: ) -> Callable[[dict, str, dict], None]:
writers = { writers = {
"txt": WriteTXT, "txt": WriteTXT,
"vtt": WriteVTT, "vtt": WriteVTT,
@ -425,7 +425,7 @@ def get_writer(
if output_format == "all": if output_format == "all":
all_writers = [writer(output_dir) for writer in writers.values()] all_writers = [writer(output_dir) for writer in writers.values()]
def write_all(result: dict, file: TextIO, options: dict): def write_all(result: dict, file: str, options: dict):
for writer in all_writers: for writer in all_writers:
writer(result, file, options) writer(result, file, options)