mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
Compare commits
13 Commits
v3.3.3
...
b93e9b6f57
Author | SHA1 | Date | |
---|---|---|---|
b93e9b6f57 | |||
844736e4e4 | |||
220fec9aea | |||
1631c3040f | |||
d700b56c9c | |||
b343241253 | |||
6fe0a8784a | |||
5012650d0f | |||
108bd0c400 | |||
b2d50a027b | |||
36d552cad3 | |||
7d36b832f9 | |||
d2a493e910 |
3
.github/workflows/build-and-release.yml
vendored
3
.github/workflows/build-and-release.yml
vendored
@ -17,6 +17,9 @@ jobs:
|
||||
version: "0.5.14"
|
||||
python-version: "3.9"
|
||||
|
||||
- name: Check if lockfile is up to date
|
||||
run: uv lock --check
|
||||
|
||||
- name: Build package
|
||||
run: uv build
|
||||
|
||||
|
3
.github/workflows/python-compatibility.yml
vendored
3
.github/workflows/python-compatibility.yml
vendored
@ -23,6 +23,9 @@ jobs:
|
||||
version: "0.5.14"
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Check if lockfile is up to date
|
||||
run: uv lock --check
|
||||
|
||||
- name: Install the project
|
||||
run: uv sync --all-extras
|
||||
|
||||
|
105
README.md
105
README.md
@ -22,26 +22,20 @@
|
||||
</a>
|
||||
</p>
|
||||
|
||||
|
||||
<img width="1216" align="center" alt="whisperx-arch" src="https://raw.githubusercontent.com/m-bain/whisperX/refs/heads/main/figures/pipeline.png">
|
||||
|
||||
|
||||
<!-- <p align="left">Whisper-Based Automatic Speech Recognition (ASR) with improved timestamp accuracy + quality via forced phoneme alignment and voice-activity based batching for fast inference.</p> -->
|
||||
|
||||
|
||||
<!-- <h2 align="left", id="what-is-it">What is it 🔎</h2> -->
|
||||
|
||||
|
||||
This repository provides fast automatic speech recognition (70x realtime with large-v2) with word-level timestamps and speaker diarization.
|
||||
|
||||
- ⚡️ Batched inference for 70x realtime transcription using whisper large-v2
|
||||
- 🪶 [faster-whisper](https://github.com/guillaumekln/faster-whisper) backend, requires <8GB gpu memory for large-v2 with beam_size=5
|
||||
- 🎯 Accurate word-level timestamps using wav2vec2 alignment
|
||||
- 👯♂️ Multispeaker ASR using speaker diarization from [pyannote-audio](https://github.com/pyannote/pyannote-audio) (speaker ID labels)
|
||||
- 👯♂️ Multispeaker ASR using speaker diarization from [pyannote-audio](https://github.com/pyannote/pyannote-audio) (speaker ID labels)
|
||||
- 🗣️ VAD preprocessing, reduces hallucination & batching with no WER degradation
|
||||
|
||||
|
||||
|
||||
**Whisper** is an ASR model [developed by OpenAI](https://github.com/openai/whisper), trained on a large dataset of diverse audio. Whilst it does produces highly accurate transcriptions, the corresponding timestamps are at the utterance-level, not per word, and can be inaccurate by several seconds. OpenAI's whisper does not natively support batching.
|
||||
|
||||
**Phoneme-Based ASR** A suite of models finetuned to recognise the smallest unit of speech distinguishing one word from another, e.g. the element p in "tap". A popular example model is [wav2vec2.0](https://huggingface.co/facebook/wav2vec2-large-960h-lv60-self).
|
||||
@ -54,12 +48,12 @@ This repository provides fast automatic speech recognition (70x realtime with la
|
||||
|
||||
<h2 align="left", id="highlights">New🚨</h2>
|
||||
|
||||
- 1st place at [Ego4d transcription challenge](https://eval.ai/web/challenges/challenge-page/1637/leaderboard/3931/WER) 🏆
|
||||
- _WhisperX_ accepted at INTERSPEECH 2023
|
||||
- 1st place at [Ego4d transcription challenge](https://eval.ai/web/challenges/challenge-page/1637/leaderboard/3931/WER) 🏆
|
||||
- _WhisperX_ accepted at INTERSPEECH 2023
|
||||
- v3 transcript segment-per-sentence: using nltk sent_tokenize for better subtitlting & better diarization
|
||||
- v3 released, 70x speed-up open-sourced. Using batched whisper with [faster-whisper](https://github.com/guillaumekln/faster-whisper) backend!
|
||||
- v2 released, code cleanup, imports whisper library VAD filtering is now turned on by default, as in the paper.
|
||||
- Paper drop🎓👨🏫! Please see our [ArxiV preprint](https://arxiv.org/abs/2303.00747) for benchmarking and details of WhisperX. We also introduce more efficient batch inference resulting in large-v2 with *60-70x REAL TIME speed.
|
||||
- Paper drop🎓👨🏫! Please see our [ArxiV preprint](https://arxiv.org/abs/2303.00747) for benchmarking and details of WhisperX. We also introduce more efficient batch inference resulting in large-v2 with \*60-70x REAL TIME speed.
|
||||
|
||||
<h2 align="left" id="setup">Setup ⚙️</h2>
|
||||
|
||||
@ -103,6 +97,25 @@ uv sync --all-extras --dev
|
||||
|
||||
You may also need to install ffmpeg, rust etc. Follow openAI instructions here https://github.com/openai/whisper#setup.
|
||||
|
||||
### Common Issues & Troubleshooting 🔧
|
||||
|
||||
#### libcudnn Dependencies (GPU Users)
|
||||
|
||||
If you're using WhisperX with GPU support and encounter errors like:
|
||||
|
||||
- `Could not load library libcudnn_ops_infer.so.8`
|
||||
- `Unable to load any of {libcudnn_cnn.so.9.1.0, libcudnn_cnn.so.9.1, libcudnn_cnn.so.9, libcudnn_cnn.so}`
|
||||
- `libcudnn_ops_infer.so.8: cannot open shared object file: No such file or directory`
|
||||
|
||||
This means your system is missing the CUDA Deep Neural Network library (cuDNN). This library is needed for GPU acceleration but isn't always installed by default.
|
||||
|
||||
**Install cuDNN (example for apt based systems):**
|
||||
|
||||
```bash
|
||||
sudo apt update
|
||||
sudo apt install libcudnn8 libcudnn8-dev -y
|
||||
```
|
||||
|
||||
### Speaker Diarization
|
||||
|
||||
To **enable Speaker Diarization**, include your Hugging Face access token (read) that you can generate from [Here](https://huggingface.co/settings/tokens) after the `--hf_token` argument and accept the user agreement for the following models: [Segmentation](https://huggingface.co/pyannote/segmentation-3.0) and [Speaker-Diarization-3.1](https://huggingface.co/pyannote/speaker-diarization-3.1) (if you choose to use Speaker-Diarization 2.x, follow requirements [here](https://huggingface.co/pyannote/speaker-diarization) instead.)
|
||||
@ -118,8 +131,7 @@ Run whisper on example segment (using default params, whisper small) add `--high
|
||||
|
||||
whisperx path/to/audio.wav
|
||||
|
||||
|
||||
Result using *WhisperX* with forced alignment to wav2vec2.0 large:
|
||||
Result using _WhisperX_ with forced alignment to wav2vec2.0 large:
|
||||
|
||||
https://user-images.githubusercontent.com/36994049/208253969-7e35fe2a-7541-434a-ae91-8e919540555d.mp4
|
||||
|
||||
@ -127,12 +139,10 @@ Compare this to original whisper out the box, where many transcriptions are out
|
||||
|
||||
https://user-images.githubusercontent.com/36994049/207743923-b4f0d537-29ae-4be2-b404-bb941db73652.mov
|
||||
|
||||
|
||||
For increased timestamp accuracy, at the cost of higher gpu mem, use bigger models (bigger alignment model not found to be that helpful, see paper) e.g.
|
||||
|
||||
whisperx path/to/audio.wav --model large-v2 --align_model WAV2VEC2_ASR_LARGE_LV60K_960H --batch_size 4
|
||||
|
||||
|
||||
To label the transcript with speaker ID's (set number of speakers if known e.g. `--min_speakers 2` `--max_speakers 2`):
|
||||
|
||||
whisperx path/to/audio.wav --model large-v2 --diarize --highlight_words True
|
||||
@ -143,27 +153,26 @@ To run on CPU instead of GPU (and for running on Mac OS X):
|
||||
|
||||
### Other languages
|
||||
|
||||
The phoneme ASR alignment model is *language-specific*, for tested languages these models are [automatically picked from torchaudio pipelines or huggingface](https://github.com/m-bain/whisperX/blob/f2da2f858e99e4211fe4f64b5f2938b007827e17/whisperx/alignment.py#L24-L58).
|
||||
The phoneme ASR alignment model is _language-specific_, for tested languages these models are [automatically picked from torchaudio pipelines or huggingface](https://github.com/m-bain/whisperX/blob/f2da2f858e99e4211fe4f64b5f2938b007827e17/whisperx/alignment.py#L24-L58).
|
||||
Just pass in the `--language` code, and use the whisper `--model large`.
|
||||
|
||||
Currently default models provided for `{en, fr, de, es, it}` via torchaudio pipelines and many other languages via Hugging Face. Please find the list of currently supported languages under `DEFAULT_ALIGN_MODELS_HF` on [alignment.py](https://github.com/m-bain/whisperX/blob/main/whisperx/alignment.py). If the detected language is not in this list, you need to find a phoneme-based ASR model from [huggingface model hub](https://huggingface.co/models) and test it on your data.
|
||||
|
||||
|
||||
#### E.g. German
|
||||
|
||||
whisperx --model large-v2 --language de path/to/audio.wav
|
||||
|
||||
https://user-images.githubusercontent.com/36994049/208298811-e36002ba-3698-4731-97d4-0aebd07e0eb3.mov
|
||||
|
||||
|
||||
See more examples in other languages [here](EXAMPLES.md).
|
||||
|
||||
## Python usage 🐍
|
||||
## Python usage 🐍
|
||||
|
||||
```python
|
||||
import whisperx
|
||||
import gc
|
||||
import gc
|
||||
|
||||
device = "cuda"
|
||||
device = "cuda"
|
||||
audio_file = "audio.mp3"
|
||||
batch_size = 16 # reduce if low on GPU mem
|
||||
compute_type = "float16" # change to "int8" if low on GPU mem (may reduce accuracy)
|
||||
@ -180,7 +189,7 @@ result = model.transcribe(audio, batch_size=batch_size)
|
||||
print(result["segments"]) # before alignment
|
||||
|
||||
# delete model if low on GPU resources
|
||||
# import gc; gc.collect(); torch.cuda.empty_cache(); del model
|
||||
# import gc; import torch; gc.collect(); torch.cuda.empty_cache(); del model
|
||||
|
||||
# 2. Align whisper output
|
||||
model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
|
||||
@ -189,10 +198,10 @@ result = whisperx.align(result["segments"], model_a, metadata, audio, device, re
|
||||
print(result["segments"]) # after alignment
|
||||
|
||||
# delete model if low on GPU resources
|
||||
# import gc; gc.collect(); torch.cuda.empty_cache(); del model_a
|
||||
# import gc; import torch; gc.collect(); torch.cuda.empty_cache(); del model_a
|
||||
|
||||
# 3. Assign speaker labels
|
||||
diarize_model = whisperx.DiarizationPipeline(use_auth_token=YOUR_HF_TOKEN, device=device)
|
||||
diarize_model = whisperx.diarize.DiarizationPipeline(use_auth_token=YOUR_HF_TOKEN, device=device)
|
||||
|
||||
# add min/max number of speakers if known
|
||||
diarize_segments = diarize_model(audio)
|
||||
@ -205,25 +214,27 @@ print(result["segments"]) # segments are now assigned speaker IDs
|
||||
|
||||
## Demos 🚀
|
||||
|
||||
[](https://replicate.com/victor-upmeet/whisperx)
|
||||
[](https://replicate.com/daanelson/whisperx)
|
||||
[](https://replicate.com/carnifexer/whisperx)
|
||||
[](https://replicate.com/victor-upmeet/whisperx)
|
||||
[](https://replicate.com/daanelson/whisperx)
|
||||
[](https://replicate.com/carnifexer/whisperx)
|
||||
|
||||
If you don't have access to your own GPUs, use the links above to try out WhisperX.
|
||||
If you don't have access to your own GPUs, use the links above to try out WhisperX.
|
||||
|
||||
<h2 align="left" id="whisper-mod">Technical Details 👷♂️</h2>
|
||||
|
||||
For specific details on the batching and alignment, the effect of VAD, as well as the chosen alignment model, see the preprint [paper](https://www.robots.ox.ac.uk/~vgg/publications/2023/Bain23/bain23.pdf).
|
||||
|
||||
To reduce GPU memory requirements, try any of the following (2. & 3. can affect quality):
|
||||
|
||||
1. reduce batch size, e.g. `--batch_size 4`
|
||||
2. use a smaller ASR model `--model base`
|
||||
3. Use lighter compute type `--compute_type int8`
|
||||
2. use a smaller ASR model `--model base`
|
||||
3. Use lighter compute type `--compute_type int8`
|
||||
|
||||
Transcription differences from openai's whisper:
|
||||
|
||||
1. Transcription without timestamps. To enable single pass batching, whisper inference is performed `--without_timestamps True`, this ensures 1 forward pass per sample in the batch. However, this can cause discrepancies the default whisper output.
|
||||
2. VAD-based segment transcription, unlike the buffered transcription of openai's. In the WhisperX paper we show this reduces WER, and enables accurate batched inference
|
||||
3. `--condition_on_prev_text` is set to `False` by default (reduces hallucination)
|
||||
3. `--condition_on_prev_text` is set to `False` by default (reduces hallucination)
|
||||
|
||||
<h2 align="left" id="limitations">Limitations ⚠️</h2>
|
||||
|
||||
@ -232,7 +243,6 @@ Transcription differences from openai's whisper:
|
||||
- Diarization is far from perfect
|
||||
- Language specific wav2vec2 model is needed
|
||||
|
||||
|
||||
<h2 align="left" id="contribute">Contribute 🧑🏫</h2>
|
||||
|
||||
If you are multilingual, a major way you can contribute to this project is to find phoneme models on huggingface (or train your own) and test them on speech for the target language. If the results look good send a pull request and some examples showing its success.
|
||||
@ -241,43 +251,40 @@ Bug finding and pull requests are also highly appreciated to keep this project g
|
||||
|
||||
<h2 align="left" id="coming-soon">TODO 🗓</h2>
|
||||
|
||||
* [x] Multilingual init
|
||||
- [x] Multilingual init
|
||||
|
||||
* [x] Automatic align model selection based on language detection
|
||||
- [x] Automatic align model selection based on language detection
|
||||
|
||||
* [x] Python usage
|
||||
- [x] Python usage
|
||||
|
||||
* [x] Incorporating speaker diarization
|
||||
- [x] Incorporating speaker diarization
|
||||
|
||||
* [x] Model flush, for low gpu mem resources
|
||||
- [x] Model flush, for low gpu mem resources
|
||||
|
||||
* [x] Faster-whisper backend
|
||||
- [x] Faster-whisper backend
|
||||
|
||||
* [x] Add max-line etc. see (openai's whisper utils.py)
|
||||
- [x] Add max-line etc. see (openai's whisper utils.py)
|
||||
|
||||
* [x] Sentence-level segments (nltk toolbox)
|
||||
- [x] Sentence-level segments (nltk toolbox)
|
||||
|
||||
* [x] Improve alignment logic
|
||||
- [x] Improve alignment logic
|
||||
|
||||
* [ ] update examples with diarization and word highlighting
|
||||
- [ ] update examples with diarization and word highlighting
|
||||
|
||||
* [ ] Subtitle .ass output <- bring this back (removed in v3)
|
||||
- [ ] Subtitle .ass output <- bring this back (removed in v3)
|
||||
|
||||
* [ ] Add benchmarking code (TEDLIUM for spd/WER & word segmentation)
|
||||
- [ ] Add benchmarking code (TEDLIUM for spd/WER & word segmentation)
|
||||
|
||||
* [x] Allow silero-vad as alternative VAD option
|
||||
|
||||
* [ ] Improve diarization (word level). *Harder than first thought...*
|
||||
- [x] Allow silero-vad as alternative VAD option
|
||||
|
||||
- [ ] Improve diarization (word level). _Harder than first thought..._
|
||||
|
||||
<h2 align="left" id="contact">Contact/Support 📇</h2>
|
||||
|
||||
|
||||
Contact maxhbain@gmail.com for queries.
|
||||
|
||||
<a href="https://www.buymeacoffee.com/maxhbain" target="_blank"><img src="https://cdn.buymeacoffee.com/buttons/default-orange.png" alt="Buy Me A Coffee" height="41" width="174"></a>
|
||||
|
||||
|
||||
<h2 align="left" id="acks">Acknowledgements 🙏</h2>
|
||||
|
||||
This work, and my PhD, is supported by the [VGG (Visual Geometry Group)](https://www.robots.ox.ac.uk/~vgg/) and the University of Oxford.
|
||||
@ -286,8 +293,8 @@ Of course, this is builds on [openAI's whisper](https://github.com/openai/whispe
|
||||
Borrows important alignment code from [PyTorch tutorial on forced alignment](https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html)
|
||||
And uses the wonderful pyannote VAD / Diarization https://github.com/pyannote/pyannote-audio
|
||||
|
||||
|
||||
Valuable VAD & Diarization Models from:
|
||||
|
||||
- [pyannote audio][https://github.com/pyannote/pyannote-audio]
|
||||
- [silero vad][https://github.com/snakers4/silero-vad]
|
||||
|
||||
|
@ -2,7 +2,7 @@
|
||||
urls = { repository = "https://github.com/m-bain/whisperx" }
|
||||
authors = [{ name = "Max Bain" }]
|
||||
name = "whisperx"
|
||||
version = "3.3.3"
|
||||
version = "3.4.0"
|
||||
description = "Time-Accurate Automatic Speech Recognition using Whisper."
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.9, <3.13"
|
||||
@ -23,7 +23,7 @@ dependencies = [
|
||||
|
||||
|
||||
[project.scripts]
|
||||
whisperx = "whisperx.transcribe:cli"
|
||||
whisperx = "whisperx.__main__:cli"
|
||||
|
||||
[build-system]
|
||||
requires = ["setuptools"]
|
||||
@ -33,4 +33,4 @@ include-package-data = true
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["."]
|
||||
include = ["whisperx*"]
|
||||
include = ["whisperx*"]
|
||||
|
@ -1,7 +1,31 @@
|
||||
from whisperx.alignment import load_align_model as load_align_model, align as align
|
||||
from whisperx.asr import load_model as load_model
|
||||
from whisperx.audio import load_audio as load_audio
|
||||
from whisperx.diarize import (
|
||||
assign_word_speakers as assign_word_speakers,
|
||||
DiarizationPipeline as DiarizationPipeline,
|
||||
)
|
||||
import importlib
|
||||
|
||||
|
||||
def _lazy_import(name):
|
||||
module = importlib.import_module(f"whisperx.{name}")
|
||||
return module
|
||||
|
||||
|
||||
def load_align_model(*args, **kwargs):
|
||||
alignment = _lazy_import("alignment")
|
||||
return alignment.load_align_model(*args, **kwargs)
|
||||
|
||||
|
||||
def align(*args, **kwargs):
|
||||
alignment = _lazy_import("alignment")
|
||||
return alignment.align(*args, **kwargs)
|
||||
|
||||
|
||||
def load_model(*args, **kwargs):
|
||||
asr = _lazy_import("asr")
|
||||
return asr.load_model(*args, **kwargs)
|
||||
|
||||
|
||||
def load_audio(*args, **kwargs):
|
||||
audio = _lazy_import("audio")
|
||||
return audio.load_audio(*args, **kwargs)
|
||||
|
||||
|
||||
def assign_word_speakers(*args, **kwargs):
|
||||
diarize = _lazy_import("diarize")
|
||||
return diarize.assign_word_speakers(*args, **kwargs)
|
||||
|
@ -1,4 +1,89 @@
|
||||
from whisperx.transcribe import cli
|
||||
import argparse
|
||||
import importlib.metadata
|
||||
import platform
|
||||
|
||||
import torch
|
||||
|
||||
from whisperx.utils import (LANGUAGES, TO_LANGUAGE_CODE, optional_float,
|
||||
optional_int, str2bool)
|
||||
|
||||
|
||||
cli()
|
||||
def cli():
|
||||
# fmt: off
|
||||
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
|
||||
parser.add_argument("--model", default="small", help="name of the Whisper model to use")
|
||||
parser.add_argument("--model_cache_only", type=str2bool, default=False, help="If True, will not attempt to download models, instead using cached models from --model_dir")
|
||||
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
|
||||
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
|
||||
parser.add_argument("--device_index", default=0, type=int, help="device index to use for FasterWhisper inference")
|
||||
parser.add_argument("--batch_size", default=8, type=int, help="the preferred batch size for inference")
|
||||
parser.add_argument("--compute_type", default="float16", type=str, choices=["float16", "float32", "int8"], help="compute type for computation")
|
||||
|
||||
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
|
||||
parser.add_argument("--output_format", "-f", type=str, default="all", choices=["all", "srt", "vtt", "txt", "tsv", "json", "aud"], help="format of the output file; if not specified, all available formats will be produced")
|
||||
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
|
||||
|
||||
parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
|
||||
parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection")
|
||||
|
||||
# alignment params
|
||||
parser.add_argument("--align_model", default=None, help="Name of phoneme-level ASR model to do alignment")
|
||||
parser.add_argument("--interpolate_method", default="nearest", choices=["nearest", "linear", "ignore"], help="For word .srt, method to assign timestamps to non-aligned words, or merge them into neighbouring.")
|
||||
parser.add_argument("--no_align", action='store_true', help="Do not perform phoneme alignment")
|
||||
parser.add_argument("--return_char_alignments", action='store_true', help="Return character-level alignments in the output json file")
|
||||
|
||||
# vad params
|
||||
parser.add_argument("--vad_method", type=str, default="pyannote", choices=["pyannote", "silero"], help="VAD method to be used")
|
||||
parser.add_argument("--vad_onset", type=float, default=0.500, help="Onset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected")
|
||||
parser.add_argument("--vad_offset", type=float, default=0.363, help="Offset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected.")
|
||||
parser.add_argument("--chunk_size", type=int, default=30, help="Chunk size for merging VAD segments. Default is 30, reduce this if the chunk is too long.")
|
||||
|
||||
# diarization params
|
||||
parser.add_argument("--diarize", action="store_true", help="Apply diarization to assign speaker labels to each segment/word")
|
||||
parser.add_argument("--min_speakers", default=None, type=int, help="Minimum number of speakers to in audio file")
|
||||
parser.add_argument("--max_speakers", default=None, type=int, help="Maximum number of speakers to in audio file")
|
||||
parser.add_argument("--diarize_model", default="pyannote/speaker-diarization-3.1", type=str, help="Name of the speaker diarization model to use")
|
||||
parser.add_argument("--speaker_embeddings", action="store_true", help="Include speaker embeddings in JSON output (only works with --diarize)")
|
||||
|
||||
parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
|
||||
parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
|
||||
parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero")
|
||||
parser.add_argument("--patience", type=float, default=1.0, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
|
||||
parser.add_argument("--length_penalty", type=float, default=1.0, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default")
|
||||
|
||||
parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
|
||||
parser.add_argument("--suppress_numerals", action="store_true", help="whether to suppress numeric symbols and currency symbols during sampling, since wav2vec2 cannot align them correctly")
|
||||
|
||||
parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
|
||||
parser.add_argument("--condition_on_previous_text", type=str2bool, default=False, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
|
||||
parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
|
||||
|
||||
parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
|
||||
parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
|
||||
parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
|
||||
parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
|
||||
|
||||
parser.add_argument("--max_line_width", type=optional_int, default=None, help="(not possible with --no_align) the maximum number of characters in a line before breaking the line")
|
||||
parser.add_argument("--max_line_count", type=optional_int, default=None, help="(not possible with --no_align) the maximum number of lines in a segment")
|
||||
parser.add_argument("--highlight_words", type=str2bool, default=False, help="(not possible with --no_align) underline each word as it is spoken in srt and vtt")
|
||||
parser.add_argument("--segment_resolution", type=str, default="sentence", choices=["sentence", "chunk"], help="(not possible with --no_align) the maximum number of characters in a line before breaking the line")
|
||||
|
||||
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
|
||||
|
||||
parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face Access Token to access PyAnnote gated models")
|
||||
|
||||
parser.add_argument("--print_progress", type=str2bool, default = False, help = "if True, progress will be printed in transcribe() and align() methods.")
|
||||
parser.add_argument("--version", "-V", action="version", version=f"%(prog)s {importlib.metadata.version('whisperx')}",help="Show whisperx version information and exit")
|
||||
parser.add_argument("--python-version", "-P", action="version", version=f"Python {platform.python_version()} ({platform.python_implementation()})",help="Show python version information and exit")
|
||||
# fmt: on
|
||||
|
||||
args = parser.parse_args().__dict__
|
||||
|
||||
from whisperx.transcribe import transcribe_task
|
||||
|
||||
transcribe_task(args, parser)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
|
@ -11,13 +11,14 @@ from whisperx.types import TranscriptionResult, AlignedTranscriptionResult
|
||||
class DiarizationPipeline:
|
||||
def __init__(
|
||||
self,
|
||||
model_name="pyannote/speaker-diarization-3.1",
|
||||
model_name=None,
|
||||
use_auth_token=None,
|
||||
device: Optional[Union[str, torch.device]] = "cpu",
|
||||
):
|
||||
if isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token).to(device)
|
||||
model_config = model_name or "pyannote/speaker-diarization-3.1"
|
||||
self.model = Pipeline.from_pretrained(model_config, use_auth_token=use_auth_token).to(device)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@ -25,25 +26,81 @@ class DiarizationPipeline:
|
||||
num_speakers: Optional[int] = None,
|
||||
min_speakers: Optional[int] = None,
|
||||
max_speakers: Optional[int] = None,
|
||||
):
|
||||
return_embeddings: bool = False,
|
||||
) -> Union[tuple[pd.DataFrame, Optional[dict[str, list[float]]]], pd.DataFrame]:
|
||||
"""
|
||||
Perform speaker diarization on audio.
|
||||
|
||||
Args:
|
||||
audio: Path to audio file or audio array
|
||||
num_speakers: Exact number of speakers (if known)
|
||||
min_speakers: Minimum number of speakers to detect
|
||||
max_speakers: Maximum number of speakers to detect
|
||||
return_embeddings: Whether to return speaker embeddings
|
||||
|
||||
Returns:
|
||||
If return_embeddings is True:
|
||||
Tuple of (diarization dataframe, speaker embeddings dictionary)
|
||||
Otherwise:
|
||||
Just the diarization dataframe
|
||||
"""
|
||||
if isinstance(audio, str):
|
||||
audio = load_audio(audio)
|
||||
audio_data = {
|
||||
'waveform': torch.from_numpy(audio[None, :]),
|
||||
'sample_rate': SAMPLE_RATE
|
||||
}
|
||||
segments = self.model(audio_data, num_speakers = num_speakers, min_speakers=min_speakers, max_speakers=max_speakers)
|
||||
diarize_df = pd.DataFrame(segments.itertracks(yield_label=True), columns=['segment', 'label', 'speaker'])
|
||||
|
||||
if return_embeddings:
|
||||
diarization, embeddings = self.model(
|
||||
audio_data,
|
||||
num_speakers=num_speakers,
|
||||
min_speakers=min_speakers,
|
||||
max_speakers=max_speakers,
|
||||
return_embeddings=True,
|
||||
)
|
||||
else:
|
||||
diarization = self.model(
|
||||
audio_data,
|
||||
num_speakers=num_speakers,
|
||||
min_speakers=min_speakers,
|
||||
max_speakers=max_speakers,
|
||||
)
|
||||
embeddings = None
|
||||
|
||||
diarize_df = pd.DataFrame(diarization.itertracks(yield_label=True), columns=['segment', 'label', 'speaker'])
|
||||
diarize_df['start'] = diarize_df['segment'].apply(lambda x: x.start)
|
||||
diarize_df['end'] = diarize_df['segment'].apply(lambda x: x.end)
|
||||
return diarize_df
|
||||
|
||||
if return_embeddings and embeddings is not None:
|
||||
speaker_embeddings = {speaker: embeddings[s].tolist() for s, speaker in enumerate(diarization.labels())}
|
||||
return diarize_df, speaker_embeddings
|
||||
|
||||
# For backwards compatibility
|
||||
if return_embeddings:
|
||||
return diarize_df, None
|
||||
else:
|
||||
return diarize_df
|
||||
|
||||
|
||||
def assign_word_speakers(
|
||||
diarize_df: pd.DataFrame,
|
||||
transcript_result: Union[AlignedTranscriptionResult, TranscriptionResult],
|
||||
fill_nearest=False,
|
||||
) -> dict:
|
||||
speaker_embeddings: Optional[dict[str, list[float]]] = None,
|
||||
fill_nearest: bool = False,
|
||||
) -> Union[AlignedTranscriptionResult, TranscriptionResult]:
|
||||
"""
|
||||
Assign speakers to words and segments in the transcript.
|
||||
|
||||
Args:
|
||||
diarize_df: Diarization dataframe from DiarizationPipeline
|
||||
transcript_result: Transcription result to augment with speaker labels
|
||||
speaker_embeddings: Optional dictionary mapping speaker IDs to embedding vectors
|
||||
fill_nearest: If True, assign speakers even when there's no direct time overlap
|
||||
|
||||
Returns:
|
||||
Updated transcript_result with speaker assignments and optionally embeddings
|
||||
"""
|
||||
transcript_segments = transcript_result["segments"]
|
||||
for seg in transcript_segments:
|
||||
# assign speaker to segment (if any)
|
||||
@ -74,8 +131,12 @@ def assign_word_speakers(
|
||||
# sum over speakers
|
||||
speaker = dia_tmp.groupby("speaker")["intersection"].sum().sort_values(ascending=False).index[0]
|
||||
word["speaker"] = speaker
|
||||
|
||||
return transcript_result
|
||||
|
||||
# Add speaker embeddings to the result if provided
|
||||
if speaker_embeddings is not None:
|
||||
transcript_result["speaker_embeddings"] = speaker_embeddings
|
||||
|
||||
return transcript_result
|
||||
|
||||
|
||||
class Segment:
|
||||
|
@ -1,10 +1,7 @@
|
||||
import argparse
|
||||
import gc
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
import importlib.metadata
|
||||
import platform
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -14,85 +11,18 @@ from whisperx.asr import load_model
|
||||
from whisperx.audio import load_audio
|
||||
from whisperx.diarize import DiarizationPipeline, assign_word_speakers
|
||||
from whisperx.types import AlignedTranscriptionResult, TranscriptionResult
|
||||
from whisperx.utils import (
|
||||
LANGUAGES,
|
||||
TO_LANGUAGE_CODE,
|
||||
get_writer,
|
||||
optional_float,
|
||||
optional_int,
|
||||
str2bool,
|
||||
)
|
||||
from whisperx.utils import LANGUAGES, TO_LANGUAGE_CODE, get_writer
|
||||
|
||||
|
||||
def cli():
|
||||
def transcribe_task(args: dict, parser: argparse.ArgumentParser):
|
||||
"""Transcription task to be called from CLI.
|
||||
|
||||
Args:
|
||||
args: Dictionary of command-line arguments.
|
||||
parser: argparse.ArgumentParser object.
|
||||
"""
|
||||
# fmt: off
|
||||
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
|
||||
parser.add_argument("--model", default="small", help="name of the Whisper model to use")
|
||||
parser.add_argument("--model_cache_only", type=str2bool, default=False, help="If True, will not attempt to download models, instead using cached models from --model_dir")
|
||||
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
|
||||
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
|
||||
parser.add_argument("--device_index", default=0, type=int, help="device index to use for FasterWhisper inference")
|
||||
parser.add_argument("--batch_size", default=8, type=int, help="the preferred batch size for inference")
|
||||
parser.add_argument("--compute_type", default="float16", type=str, choices=["float16", "float32", "int8"], help="compute type for computation")
|
||||
|
||||
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
|
||||
parser.add_argument("--output_format", "-f", type=str, default="all", choices=["all", "srt", "vtt", "txt", "tsv", "json", "aud"], help="format of the output file; if not specified, all available formats will be produced")
|
||||
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
|
||||
|
||||
parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
|
||||
parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection")
|
||||
|
||||
# alignment params
|
||||
parser.add_argument("--align_model", default=None, help="Name of phoneme-level ASR model to do alignment")
|
||||
parser.add_argument("--interpolate_method", default="nearest", choices=["nearest", "linear", "ignore"], help="For word .srt, method to assign timestamps to non-aligned words, or merge them into neighbouring.")
|
||||
parser.add_argument("--no_align", action='store_true', help="Do not perform phoneme alignment")
|
||||
parser.add_argument("--return_char_alignments", action='store_true', help="Return character-level alignments in the output json file")
|
||||
|
||||
# vad params
|
||||
parser.add_argument("--vad_method", type=str, default="pyannote", choices=["pyannote", "silero"], help="VAD method to be used")
|
||||
parser.add_argument("--vad_onset", type=float, default=0.500, help="Onset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected")
|
||||
parser.add_argument("--vad_offset", type=float, default=0.363, help="Offset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected.")
|
||||
parser.add_argument("--chunk_size", type=int, default=30, help="Chunk size for merging VAD segments. Default is 30, reduce this if the chunk is too long.")
|
||||
|
||||
# diarization params
|
||||
parser.add_argument("--diarize", action="store_true", help="Apply diarization to assign speaker labels to each segment/word")
|
||||
parser.add_argument("--min_speakers", default=None, type=int, help="Minimum number of speakers to in audio file")
|
||||
parser.add_argument("--max_speakers", default=None, type=int, help="Maximum number of speakers to in audio file")
|
||||
|
||||
parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
|
||||
parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
|
||||
parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero")
|
||||
parser.add_argument("--patience", type=float, default=1.0, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
|
||||
parser.add_argument("--length_penalty", type=float, default=1.0, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default")
|
||||
|
||||
parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
|
||||
parser.add_argument("--suppress_numerals", action="store_true", help="whether to suppress numeric symbols and currency symbols during sampling, since wav2vec2 cannot align them correctly")
|
||||
|
||||
parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
|
||||
parser.add_argument("--condition_on_previous_text", type=str2bool, default=False, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
|
||||
parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
|
||||
|
||||
parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
|
||||
parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
|
||||
parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
|
||||
parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
|
||||
|
||||
parser.add_argument("--max_line_width", type=optional_int, default=None, help="(not possible with --no_align) the maximum number of characters in a line before breaking the line")
|
||||
parser.add_argument("--max_line_count", type=optional_int, default=None, help="(not possible with --no_align) the maximum number of lines in a segment")
|
||||
parser.add_argument("--highlight_words", type=str2bool, default=False, help="(not possible with --no_align) underline each word as it is spoken in srt and vtt")
|
||||
parser.add_argument("--segment_resolution", type=str, default="sentence", choices=["sentence", "chunk"], help="(not possible with --no_align) the maximum number of characters in a line before breaking the line")
|
||||
|
||||
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
|
||||
|
||||
parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face Access Token to access PyAnnote gated models")
|
||||
|
||||
parser.add_argument("--print_progress", type=str2bool, default = False, help = "if True, progress will be printed in transcribe() and align() methods.")
|
||||
parser.add_argument("--version", "-V", action="version", version=f"%(prog)s {importlib.metadata.version('whisperx')}",help="Show whisperx version information and exit")
|
||||
parser.add_argument("--python-version", "-P", action="version", version=f"Python {platform.python_version()} ({platform.python_implementation()})",help="Show python version information and exit")
|
||||
# fmt: on
|
||||
|
||||
args = parser.parse_args().__dict__
|
||||
model_name: str = args.pop("model")
|
||||
batch_size: int = args.pop("batch_size")
|
||||
model_dir: str = args.pop("model_dir")
|
||||
@ -127,7 +57,12 @@ def cli():
|
||||
diarize: bool = args.pop("diarize")
|
||||
min_speakers: int = args.pop("min_speakers")
|
||||
max_speakers: int = args.pop("max_speakers")
|
||||
diarize_model_name: str = args.pop("diarize_model")
|
||||
print_progress: bool = args.pop("print_progress")
|
||||
return_speaker_embeddings: bool = args.pop("speaker_embeddings")
|
||||
|
||||
if return_speaker_embeddings and not diarize:
|
||||
warnings.warn("--speaker_embeddings has no effect without --diarize")
|
||||
|
||||
if args["language"] is not None:
|
||||
args["language"] = args["language"].lower()
|
||||
@ -274,19 +209,19 @@ def cli():
|
||||
)
|
||||
tmp_results = results
|
||||
print(">>Performing diarization...")
|
||||
print(">>Using model:", diarize_model_name)
|
||||
results = []
|
||||
diarize_model = DiarizationPipeline(use_auth_token=hf_token, device=device)
|
||||
diarize_model = DiarizationPipeline(model_name=diarize_model_name, use_auth_token=hf_token, device=device)
|
||||
for result, input_audio_path in tmp_results:
|
||||
diarize_segments = diarize_model(
|
||||
input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers
|
||||
diarize_segments, speaker_embeddings = diarize_model(
|
||||
input_audio_path,
|
||||
min_speakers=min_speakers,
|
||||
max_speakers=max_speakers,
|
||||
return_embeddings=return_speaker_embeddings
|
||||
)
|
||||
result = assign_word_speakers(diarize_segments, result)
|
||||
result = assign_word_speakers(diarize_segments, result, speaker_embeddings)
|
||||
results.append((result, input_audio_path))
|
||||
# >> Write
|
||||
for result, audio_path in results:
|
||||
result["language"] = align_language
|
||||
writer(result, audio_path, writer_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
|
Reference in New Issue
Block a user