2 Commits

22 changed files with 335 additions and 3507 deletions

View File

@ -10,22 +10,28 @@ jobs:
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v4
- name: Install uv
uses: astral-sh/setup-uv@v5
with: with:
version: "0.5.14" ref: ${{ github.ref_name }}
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.9" python-version: "3.9"
- name: Build package - name: Install dependencies
run: uv build run: |
python -m pip install build
- name: Build wheels
run: python -m build --wheel
- name: Release to Github - name: Release to Github
uses: softprops/action-gh-release@v2 uses: softprops/action-gh-release@v2
with: with:
files: dist/*.whl files: dist/*
- name: Publish package to PyPi - name: Publish package to PyPi
run: uv publish uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
env: with:
UV_PUBLISH_TOKEN: ${{ secrets.PYPI_API_TOKEN }} user: __token__
password: ${{ secrets.PYPI_API_TOKEN }}

View File

@ -17,15 +17,16 @@ jobs:
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- name: Install uv - name: Set up Python ${{ matrix.python-version }}
uses: astral-sh/setup-uv@v5 uses: actions/setup-python@v5
with: with:
version: "0.5.14"
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}
- name: Install the project - name: Install package
run: uv sync --all-extras run: |
python -m pip install --upgrade pip
pip install .
- name: Test import - name: Test import
run: | run: |
uv run python -c "import whisperx; print('Successfully imported whisperx')" python -c "import whisperx; print('Successfully imported whisperx')"

35
.github/workflows/tmp.yml vendored Normal file
View File

@ -0,0 +1,35 @@
name: Python Compatibility Test (PyPi)
on:
push:
branches: [main]
pull_request:
branches: [main]
workflow_dispatch: # Allows manual triggering from GitHub UI
jobs:
test:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.9", "3.10", "3.11", "3.12"]
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install package
run: |
pip install whisperx
- name: Print packages
run: |
pip list
- name: Test import
run: |
python -c "import whisperx; print('Successfully imported whisperx')"

View File

@ -62,41 +62,54 @@ This repository provides fast automatic speech recognition (70x realtime with la
- Paper drop🎓👨🏫! Please see our [ArxiV preprint](https://arxiv.org/abs/2303.00747) for benchmarking and details of WhisperX. We also introduce more efficient batch inference resulting in large-v2 with *60-70x REAL TIME speed. - Paper drop🎓👨🏫! Please see our [ArxiV preprint](https://arxiv.org/abs/2303.00747) for benchmarking and details of WhisperX. We also introduce more efficient batch inference resulting in large-v2 with *60-70x REAL TIME speed.
<h2 align="left" id="setup">Setup ⚙️</h2> <h2 align="left" id="setup">Setup ⚙️</h2>
Tested for PyTorch 2.0, Python 3.10 (use other versions at your own risk!)
### 1. Simple Installation (Recommended) GPU execution requires the NVIDIA libraries cuBLAS 11.x and cuDNN 8.x to be installed on the system. Please refer to the [CTranslate2 documentation](https://opennmt.net/CTranslate2/installation.html).
The easiest way to install WhisperX is through PyPi:
### 1. Create Python3.10 environment
`conda create --name whisperx python=3.10`
`conda activate whisperx`
### 2. Install PyTorch, e.g. for Linux and Windows CUDA11.8:
`conda install pytorch==2.0.0 torchaudio==2.0.0 pytorch-cuda=11.8 -c pytorch -c nvidia`
See other methods [here.](https://pytorch.org/get-started/previous-versions/#v200)
### 3. Install WhisperX
You have several installation options:
#### Option A: Stable Release (recommended)
Install the latest stable version from PyPI:
```bash ```bash
pip install whisperx pip install whisperx
``` ```
Or if using [uvx](https://docs.astral.sh/uv/guides/tools/#running-tools): #### Option B: Development Version
Install the latest development version directly from GitHub (may be unstable):
```bash ```bash
uvx whisperx pip install git+https://github.com/m-bain/whisperx.git
``` ```
### 2. Advanced Installation Options If already installed, update to the most recent commit:
These installation methods are for developers or users with specific needs. If you're not sure, stick with the simple installation above.
#### Option A: Install from GitHub
To install directly from the GitHub repository:
```bash ```bash
uvx git+https://github.com/m-bain/whisperX.git pip install git+https://github.com/m-bain/whisperx.git --upgrade
``` ```
#### Option B: Developer Installation #### Option C: Development Mode
If you wish to modify the package, clone and install in editable mode:
If you want to modify the code or contribute to the project:
```bash ```bash
git clone https://github.com/m-bain/whisperX.git git clone https://github.com/m-bain/whisperX.git
cd whisperX cd whisperX
uv sync --all-extras --dev pip install -e .
``` ```
> **Note**: The development version may contain experimental features and bugs. Use the stable PyPI release for production environments. > **Note**: The development version may contain experimental features and bugs. Use the stable PyPI release for production environments.
@ -104,19 +117,19 @@ uv sync --all-extras --dev
You may also need to install ffmpeg, rust etc. Follow openAI instructions here https://github.com/openai/whisper#setup. You may also need to install ffmpeg, rust etc. Follow openAI instructions here https://github.com/openai/whisper#setup.
### Speaker Diarization ### Speaker Diarization
To **enable Speaker Diarization**, include your Hugging Face access token (read) that you can generate from [Here](https://huggingface.co/settings/tokens) after the `--hf_token` argument and accept the user agreement for the following models: [Segmentation](https://huggingface.co/pyannote/segmentation-3.0) and [Speaker-Diarization-3.1](https://huggingface.co/pyannote/speaker-diarization-3.1) (if you choose to use Speaker-Diarization 2.x, follow requirements [here](https://huggingface.co/pyannote/speaker-diarization) instead.) To **enable Speaker Diarization**, include your Hugging Face access token (read) that you can generate from [Here](https://huggingface.co/settings/tokens) after the `--hf_token` argument and accept the user agreement for the following models: [Segmentation](https://huggingface.co/pyannote/segmentation-3.0) and [Speaker-Diarization-3.1](https://huggingface.co/pyannote/speaker-diarization-3.1) (if you choose to use Speaker-Diarization 2.x, follow requirements [here](https://huggingface.co/pyannote/speaker-diarization) instead.)
> **Note**<br> > **Note**<br>
> As of Oct 11, 2023, there is a known issue regarding slow performance with pyannote/Speaker-Diarization-3.0 in whisperX. It is due to dependency conflicts between faster-whisper and pyannote-audio 3.0.0. Please see [this issue](https://github.com/m-bain/whisperX/issues/499) for more details and potential workarounds. > As of Oct 11, 2023, there is a known issue regarding slow performance with pyannote/Speaker-Diarization-3.0 in whisperX. It is due to dependency conflicts between faster-whisper and pyannote-audio 3.0.0. Please see [this issue](https://github.com/m-bain/whisperX/issues/499) for more details and potential workarounds.
<h2 align="left" id="example">Usage 💬 (command line)</h2> <h2 align="left" id="example">Usage 💬 (command line)</h2>
### English ### English
Run whisper on example segment (using default params, whisper small) add `--highlight_words True` to visualise word timings in the .srt file. Run whisper on example segment (using default params, whisper small) add `--highlight_words True` to visualise word timings in the .srt file.
whisperx path/to/audio.wav whisperx examples/sample01.wav
Result using *WhisperX* with forced alignment to wav2vec2.0 large: Result using *WhisperX* with forced alignment to wav2vec2.0 large:
@ -130,27 +143,27 @@ https://user-images.githubusercontent.com/36994049/207743923-b4f0d537-29ae-4be2-
For increased timestamp accuracy, at the cost of higher gpu mem, use bigger models (bigger alignment model not found to be that helpful, see paper) e.g. For increased timestamp accuracy, at the cost of higher gpu mem, use bigger models (bigger alignment model not found to be that helpful, see paper) e.g.
whisperx path/to/audio.wav --model large-v2 --align_model WAV2VEC2_ASR_LARGE_LV60K_960H --batch_size 4 whisperx examples/sample01.wav --model large-v2 --align_model WAV2VEC2_ASR_LARGE_LV60K_960H --batch_size 4
To label the transcript with speaker ID's (set number of speakers if known e.g. `--min_speakers 2` `--max_speakers 2`): To label the transcript with speaker ID's (set number of speakers if known e.g. `--min_speakers 2` `--max_speakers 2`):
whisperx path/to/audio.wav --model large-v2 --diarize --highlight_words True whisperx examples/sample01.wav --model large-v2 --diarize --highlight_words True
To run on CPU instead of GPU (and for running on Mac OS X): To run on CPU instead of GPU (and for running on Mac OS X):
whisperx path/to/audio.wav --compute_type int8 whisperx examples/sample01.wav --compute_type int8
### Other languages ### Other languages
The phoneme ASR alignment model is *language-specific*, for tested languages these models are [automatically picked from torchaudio pipelines or huggingface](https://github.com/m-bain/whisperX/blob/f2da2f858e99e4211fe4f64b5f2938b007827e17/whisperx/alignment.py#L24-L58). The phoneme ASR alignment model is *language-specific*, for tested languages these models are [automatically picked from torchaudio pipelines or huggingface](https://github.com/m-bain/whisperX/blob/e909f2f766b23b2000f2d95df41f9b844ac53e49/whisperx/transcribe.py#L22).
Just pass in the `--language` code, and use the whisper `--model large`. Just pass in the `--language` code, and use the whisper `--model large`.
Currently default models provided for `{en, fr, de, es, it}` via torchaudio pipelines and many other languages via Hugging Face. Please find the list of currently supported languages under `DEFAULT_ALIGN_MODELS_HF` on [alignment.py](https://github.com/m-bain/whisperX/blob/main/whisperx/alignment.py). If the detected language is not in this list, you need to find a phoneme-based ASR model from [huggingface model hub](https://huggingface.co/models) and test it on your data. Currently default models provided for `{en, fr, de, es, it, ja, zh, nl, uk, pt}`. If the detected language is not in this list, you need to find a phoneme-based ASR model from [huggingface model hub](https://huggingface.co/models) and test it on your data.
#### E.g. German #### E.g. German
whisperx --model large-v2 --language de path/to/audio.wav whisperx --model large-v2 --language de examples/sample_de_01.wav
https://user-images.githubusercontent.com/36994049/208298811-e36002ba-3698-4731-97d4-0aebd07e0eb3.mov https://user-images.githubusercontent.com/36994049/208298811-e36002ba-3698-4731-97d4-0aebd07e0eb3.mov
@ -265,7 +278,7 @@ Bug finding and pull requests are also highly appreciated to keep this project g
* [ ] Add benchmarking code (TEDLIUM for spd/WER & word segmentation) * [ ] Add benchmarking code (TEDLIUM for spd/WER & word segmentation)
* [x] Allow silero-vad as alternative VAD option * [ ] Allow silero-vad as alternative VAD option
* [ ] Improve diarization (word level). *Harder than first thought...* * [ ] Improve diarization (word level). *Harder than first thought...*
@ -287,9 +300,7 @@ Borrows important alignment code from [PyTorch tutorial on forced alignment](htt
And uses the wonderful pyannote VAD / Diarization https://github.com/pyannote/pyannote-audio And uses the wonderful pyannote VAD / Diarization https://github.com/pyannote/pyannote-audio
Valuable VAD & Diarization Models from: Valuable VAD & Diarization Models from [pyannote audio](https://github.com/pyannote/pyannote-audio)
- [pyannote audio][https://github.com/pyannote/pyannote-audio]
- [silero vad][https://github.com/snakers4/silero-vad]
Great backend from [faster-whisper](https://github.com/guillaumekln/faster-whisper) and [CTranslate2](https://github.com/OpenNMT/CTranslate2) Great backend from [faster-whisper](https://github.com/guillaumekln/faster-whisper) and [CTranslate2](https://github.com/OpenNMT/CTranslate2)

View File

@ -1,36 +0,0 @@
[project]
urls = { repository = "https://github.com/m-bain/whisperx" }
authors = [{ name = "Max Bain" }]
name = "whisperx"
version = "3.3.3"
description = "Time-Accurate Automatic Speech Recognition using Whisper."
readme = "README.md"
requires-python = ">=3.9, <3.13"
license = { text = "BSD-2-Clause" }
dependencies = [
"ctranslate2<4.5.0",
"faster-whisper>=1.1.1",
"nltk>=3.9.1",
"numpy>=2.0.2",
"onnxruntime>=1.19",
"pandas>=2.2.3",
"pyannote-audio>=3.3.2",
"torch>=2.5.1",
"torchaudio>=2.5.1",
"transformers>=4.48.0",
]
[project.scripts]
whisperx = "whisperx.transcribe:cli"
[build-system]
requires = ["setuptools"]
[tool.setuptools]
include-package-data = true
[tool.setuptools.packages.find]
where = ["."]
include = ["whisperx*"]

8
requirements.txt Normal file
View File

@ -0,0 +1,8 @@
torch>=2
torchaudio>=2
faster-whisper==1.1.0
ctranslate2>=4.5.0
transformers
pandas
setuptools>=65
nltk

33
setup.py Normal file
View File

@ -0,0 +1,33 @@
import os
import pkg_resources
from setuptools import find_packages, setup
with open("README.md", "r", encoding="utf-8") as f:
long_description = f.read()
setup(
name="whisperx",
py_modules=["whisperx"],
version="3.3.2",
description="Time-Accurate Automatic Speech Recognition using Whisper.",
long_description=long_description,
long_description_content_type="text/markdown",
python_requires=">=3.9, <3.13",
author="Max Bain",
url="https://github.com/m-bain/whisperx",
license="BSD-2-Clause",
packages=find_packages(exclude=["tests*"]),
install_requires=[
str(r)
for r in pkg_resources.parse_requirements(
open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
)
]
+ [f"pyannote.audio==3.3.2"],
entry_points={
"console_scripts": ["whisperx=whisperx.transcribe:cli"],
},
include_package_data=True,
extras_require={"dev": ["pytest"]},
)

2905
uv.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -1,5 +1,6 @@
import math import math
from whisperx.conjunctions import get_conjunctions, get_comma from .conjunctions import get_conjunctions, get_comma
from typing import TextIO
def normal_round(n): def normal_round(n):
if n - math.floor(n) < 0.5: if n - math.floor(n) < 0.5:

View File

@ -1,7 +1,4 @@
from whisperx.alignment import load_align_model as load_align_model, align as align from .transcribe import load_model
from whisperx.asr import load_model as load_model from .alignment import load_align_model, align
from whisperx.audio import load_audio as load_audio from .audio import load_audio
from whisperx.diarize import ( from .diarize import assign_word_speakers, DiarizationPipeline
assign_word_speakers as assign_word_speakers,
DiarizationPipeline as DiarizationPipeline,
)

View File

@ -1,4 +1,4 @@
from whisperx.transcribe import cli from .transcribe import cli
cli() cli()

View File

@ -1,9 +1,7 @@
""" """"
Forced Alignment with Whisper Forced Alignment with Whisper
C. Max Bain C. Max Bain
""" """
import math
from dataclasses import dataclass from dataclasses import dataclass
from typing import Iterable, Optional, Union, List from typing import Iterable, Optional, Union, List
@ -13,15 +11,10 @@ import torch
import torchaudio import torchaudio
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from whisperx.audio import SAMPLE_RATE, load_audio from .audio import SAMPLE_RATE, load_audio
from whisperx.utils import interpolate_nans from .utils import interpolate_nans
from whisperx.types import ( from .types import AlignedTranscriptionResult, SingleSegment, SingleAlignedSegment, SingleWordSegment
AlignedTranscriptionResult, import nltk
SingleSegment,
SingleAlignedSegment,
SingleWordSegment,
SegmentData,
)
from nltk.tokenize.punkt import PunktSentenceTokenizer, PunktParameters from nltk.tokenize.punkt import PunktSentenceTokenizer, PunktParameters
PUNKT_ABBREVIATIONS = ['dr', 'vs', 'mr', 'mrs', 'prof'] PUNKT_ABBREVIATIONS = ['dr', 'vs', 'mr', 'mrs', 'prof']
@ -69,8 +62,6 @@ DEFAULT_ALIGN_MODELS_HF = {
"eu": "stefan-it/wav2vec2-large-xlsr-53-basque", "eu": "stefan-it/wav2vec2-large-xlsr-53-basque",
"gl": "ifrz/wav2vec2-large-xlsr-galician", "gl": "ifrz/wav2vec2-large-xlsr-galician",
"ka": "xsway/wav2vec2-large-xlsr-georgian", "ka": "xsway/wav2vec2-large-xlsr-georgian",
"lv": "jimregan/wav2vec2-large-xlsr-latvian-cv",
"tl": "Khalsuu/filipino-wav2vec2-l-xls-r-300m-official",
} }
@ -140,8 +131,6 @@ def align(
# 1. Preprocess to keep only characters in dictionary # 1. Preprocess to keep only characters in dictionary
total_segments = len(transcript) total_segments = len(transcript)
# Store temporary processing values
segment_data: dict[int, SegmentData] = {}
for sdx, segment in enumerate(transcript): for sdx, segment in enumerate(transcript):
# strip spaces at beginning / end, but keep track of the amount. # strip spaces at beginning / end, but keep track of the amount.
if print_progress: if print_progress:
@ -174,17 +163,10 @@ def align(
elif char_ in model_dictionary.keys(): elif char_ in model_dictionary.keys():
clean_char.append(char_) clean_char.append(char_)
clean_cdx.append(cdx) clean_cdx.append(cdx)
else:
# add placeholder
clean_char.append('*')
clean_cdx.append(cdx)
clean_wdx = [] clean_wdx = []
for wdx, wrd in enumerate(per_word): for wdx, wrd in enumerate(per_word):
if any([c in model_dictionary.keys() for c in wrd.lower()]): if any([c in model_dictionary.keys() for c in wrd]):
clean_wdx.append(wdx)
else:
# index for placeholder
clean_wdx.append(wdx) clean_wdx.append(wdx)
@ -193,12 +175,10 @@ def align(
sentence_splitter = PunktSentenceTokenizer(punkt_param) sentence_splitter = PunktSentenceTokenizer(punkt_param)
sentence_spans = list(sentence_splitter.span_tokenize(text)) sentence_spans = list(sentence_splitter.span_tokenize(text))
segment_data[sdx] = { segment["clean_char"] = clean_char
"clean_char": clean_char, segment["clean_cdx"] = clean_cdx
"clean_cdx": clean_cdx, segment["clean_wdx"] = clean_wdx
"clean_wdx": clean_wdx, segment["sentence_spans"] = sentence_spans
"sentence_spans": sentence_spans
}
aligned_segments: List[SingleAlignedSegment] = [] aligned_segments: List[SingleAlignedSegment] = []
@ -214,14 +194,13 @@ def align(
"end": t2, "end": t2,
"text": text, "text": text,
"words": [], "words": [],
"chars": None,
} }
if return_char_alignments: if return_char_alignments:
aligned_seg["chars"] = [] aligned_seg["chars"] = []
# check we can align # check we can align
if len(segment_data[sdx]["clean_char"]) == 0: if len(segment["clean_char"]) == 0:
print(f'Failed to align segment ("{segment["text"]}"): no characters in this segment found in model dictionary, resorting to original...') print(f'Failed to align segment ("{segment["text"]}"): no characters in this segment found in model dictionary, resorting to original...')
aligned_segments.append(aligned_seg) aligned_segments.append(aligned_seg)
continue continue
@ -231,8 +210,8 @@ def align(
aligned_segments.append(aligned_seg) aligned_segments.append(aligned_seg)
continue continue
text_clean = "".join(segment_data[sdx]["clean_char"]) text_clean = "".join(segment["clean_char"])
tokens = [model_dictionary.get(c, -1) for c in text_clean] tokens = [model_dictionary[c] for c in text_clean]
f1 = int(t1 * SAMPLE_RATE) f1 = int(t1 * SAMPLE_RATE)
f2 = int(t2 * SAMPLE_RATE) f2 = int(t2 * SAMPLE_RATE)
@ -265,8 +244,7 @@ def align(
blank_id = code blank_id = code
trellis = get_trellis(emission, tokens, blank_id) trellis = get_trellis(emission, tokens, blank_id)
# path = backtrack(trellis, emission, tokens, blank_id) path = backtrack(trellis, emission, tokens, blank_id)
path = backtrack_beam(trellis, emission, tokens, blank_id, beam_width=2)
if path is None: if path is None:
print(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original...') print(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original...')
@ -275,7 +253,7 @@ def align(
char_segments = merge_repeats(path, text_clean) char_segments = merge_repeats(path, text_clean)
duration = t2 - t1 duration = t2 -t1
ratio = duration * waveform_segment.size(0) / (trellis.size(0) - 1) ratio = duration * waveform_segment.size(0) / (trellis.size(0) - 1)
# assign timestamps to aligned characters # assign timestamps to aligned characters
@ -283,8 +261,8 @@ def align(
word_idx = 0 word_idx = 0
for cdx, char in enumerate(text): for cdx, char in enumerate(text):
start, end, score = None, None, None start, end, score = None, None, None
if cdx in segment_data[sdx]["clean_cdx"]: if cdx in segment["clean_cdx"]:
char_seg = char_segments[segment_data[sdx]["clean_cdx"].index(cdx)] char_seg = char_segments[segment["clean_cdx"].index(cdx)]
start = round(char_seg.start * ratio + t1, 3) start = round(char_seg.start * ratio + t1, 3)
end = round(char_seg.end * ratio + t1, 3) end = round(char_seg.end * ratio + t1, 3)
score = round(char_seg.score, 3) score = round(char_seg.score, 3)
@ -310,9 +288,9 @@ def align(
aligned_subsegments = [] aligned_subsegments = []
# assign sentence_idx to each character index # assign sentence_idx to each character index
char_segments_arr["sentence-idx"] = None char_segments_arr["sentence-idx"] = None
for sdx2, (sstart, send) in enumerate(segment_data[sdx]["sentence_spans"]): for sdx, (sstart, send) in enumerate(segment["sentence_spans"]):
curr_chars = char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send)] curr_chars = char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send)]
char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send), "sentence-idx"] = sdx2 char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send), "sentence-idx"] = sdx
sentence_text = text[sstart:send] sentence_text = text[sstart:send]
sentence_start = curr_chars["start"].min() sentence_start = curr_chars["start"].min()
@ -382,202 +360,69 @@ def align(
""" """
source: https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html source: https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html
""" """
def get_trellis(emission, tokens, blank_id=0): def get_trellis(emission, tokens, blank_id=0):
num_frame = emission.size(0) num_frame = emission.size(0)
num_tokens = len(tokens) num_tokens = len(tokens)
trellis = torch.zeros((num_frame, num_tokens)) # Trellis has extra diemsions for both time axis and tokens.
trellis[1:, 0] = torch.cumsum(emission[1:, blank_id], 0) # The extra dim for tokens represents <SoS> (start-of-sentence)
trellis[0, 1:] = -float("inf") # The extra dim for time axis is for simplification of the code.
trellis[-num_tokens + 1:, 0] = float("inf") trellis = torch.empty((num_frame + 1, num_tokens + 1))
trellis[0, 0] = 0
trellis[1:, 0] = torch.cumsum(emission[:, 0], 0)
trellis[0, -num_tokens:] = -float("inf")
trellis[-num_tokens:, 0] = float("inf")
for t in range(num_frame - 1): for t in range(num_frame):
trellis[t + 1, 1:] = torch.maximum( trellis[t + 1, 1:] = torch.maximum(
# Score for staying at the same token # Score for staying at the same token
trellis[t, 1:] + emission[t, blank_id], trellis[t, 1:] + emission[t, blank_id],
# Score for changing to the next token # Score for changing to the next token
# trellis[t, :-1] + emission[t, tokens[1:]], trellis[t, :-1] + emission[t, tokens],
trellis[t, :-1] + get_wildcard_emission(emission[t], tokens[1:], blank_id),
) )
return trellis return trellis
def get_wildcard_emission(frame_emission, tokens, blank_id):
"""Processing token emission scores containing wildcards (vectorized version)
Args:
frame_emission: Emission probability vector for the current frame
tokens: List of token indices
blank_id: ID of the blank token
Returns:
tensor: Maximum probability score for each token position
"""
assert 0 <= blank_id < len(frame_emission)
# Convert tokens to a tensor if they are not already
tokens = torch.tensor(tokens) if not isinstance(tokens, torch.Tensor) else tokens
# Create a mask to identify wildcard positions
wildcard_mask = (tokens == -1)
# Get scores for non-wildcard positions
regular_scores = frame_emission[tokens.clamp(min=0)] # clamp to avoid -1 index
# Create a mask and compute the maximum value without modifying frame_emission
max_valid_score = frame_emission.clone() # Create a copy
max_valid_score[blank_id] = float('-inf') # Modify the copy to exclude the blank token
max_valid_score = max_valid_score.max()
# Use where operation to combine results
result = torch.where(wildcard_mask, max_valid_score, regular_scores)
return result
@dataclass @dataclass
class Point: class Point:
token_index: int token_index: int
time_index: int time_index: int
score: float score: float
def backtrack(trellis, emission, tokens, blank_id=0): def backtrack(trellis, emission, tokens, blank_id=0):
t, j = trellis.size(0) - 1, trellis.size(1) - 1 # Note:
# j and t are indices for trellis, which has extra dimensions
path = [Point(j, t, emission[t, blank_id].exp().item())] # for time and tokens at the beginning.
while j > 0: # When referring to time frame index `T` in trellis,
# Should not happen but just in case # the corresponding index in emission is `T-1`.
assert t > 0 # Similarly, when referring to token index `J` in trellis,
# the corresponding index in transcript is `J-1`.
j = trellis.size(1) - 1
t_start = torch.argmax(trellis[:, j]).item()
path = []
for t in range(t_start, 0, -1):
# 1. Figure out if the current position was stay or change # 1. Figure out if the current position was stay or change
# Frame-wise score of stay vs change # Note (again):
p_stay = emission[t - 1, blank_id] # `emission[J-1]` is the emission at time frame `J` of trellis dimension.
# p_change = emission[t - 1, tokens[j]] # Score for token staying the same from time frame J-1 to T.
p_change = get_wildcard_emission(emission[t - 1], [tokens[j]], blank_id)[0] stayed = trellis[t - 1, j] + emission[t - 1, blank_id]
# Score for token changing from C-1 at T-1 to J at T.
changed = trellis[t - 1, j - 1] + emission[t - 1, tokens[j - 1]]
# Context-aware score for stay vs change # 2. Store the path with frame-wise probability.
stayed = trellis[t - 1, j] + p_stay prob = emission[t - 1, tokens[j - 1] if changed > stayed else 0].exp().item()
changed = trellis[t - 1, j - 1] + p_change # Return token index and time index in non-trellis coordinate.
path.append(Point(j - 1, t - 1, prob))
# Update position # 3. Update the token
t -= 1
if changed > stayed: if changed > stayed:
j -= 1 j -= 1
if j == 0:
# Store the path with frame-wise probability.
prob = (p_change if changed > stayed else p_stay).exp().item()
path.append(Point(j, t, prob))
# Now j == 0, which means, it reached the SoS.
# Fill up the rest for the sake of visualization
while t > 0:
prob = emission[t - 1, blank_id].exp().item()
path.append(Point(j, t - 1, prob))
t -= 1
return path[::-1]
@dataclass
class Path:
points: List[Point]
score: float
@dataclass
class BeamState:
"""State in beam search."""
token_index: int # Current token position
time_index: int # Current time step
score: float # Cumulative score
path: List[Point] # Path history
def backtrack_beam(trellis, emission, tokens, blank_id=0, beam_width=5):
"""Standard CTC beam search backtracking implementation.
Args:
trellis (torch.Tensor): The trellis (or lattice) of shape (T, N), where T is the number of time steps
and N is the number of tokens (including the blank token).
emission (torch.Tensor): The emission probabilities of shape (T, N).
tokens (List[int]): List of token indices (excluding the blank token).
blank_id (int, optional): The ID of the blank token. Defaults to 0.
beam_width (int, optional): The number of top paths to keep during beam search. Defaults to 5.
Returns:
List[Point]: the best path
"""
T, J = trellis.size(0) - 1, trellis.size(1) - 1
init_state = BeamState(
token_index=J,
time_index=T,
score=trellis[T, J],
path=[Point(J, T, emission[T, blank_id].exp().item())]
)
beams = [init_state]
while beams and beams[0].token_index > 0:
next_beams = []
for beam in beams:
t, j = beam.time_index, beam.token_index
if t <= 0:
continue
p_stay = emission[t - 1, blank_id]
p_change = get_wildcard_emission(emission[t - 1], [tokens[j]], blank_id)[0]
stay_score = trellis[t - 1, j]
change_score = trellis[t - 1, j - 1] if j > 0 else float('-inf')
# Stay
if not math.isinf(stay_score):
new_path = beam.path.copy()
new_path.append(Point(j, t - 1, p_stay.exp().item()))
next_beams.append(BeamState(
token_index=j,
time_index=t - 1,
score=stay_score,
path=new_path
))
# Change
if j > 0 and not math.isinf(change_score):
new_path = beam.path.copy()
new_path.append(Point(j - 1, t - 1, p_change.exp().item()))
next_beams.append(BeamState(
token_index=j - 1,
time_index=t - 1,
score=change_score,
path=new_path
))
# sort by score
beams = sorted(next_beams, key=lambda x: x.score, reverse=True)[:beam_width]
if not beams:
break break
else:
if not beams: # failed
return None return None
return path[::-1]
best_beam = beams[0]
t = best_beam.time_index
j = best_beam.token_index
while t > 0:
prob = emission[t - 1, blank_id].exp().item()
best_beam.path.append(Point(j, t - 1, prob))
t -= 1
return best_beam.path[::-1]
# Merge the labels # Merge the labels
@dataclass @dataclass

View File

@ -1,5 +1,6 @@
import os import os
from typing import List, Optional, Union import warnings
from typing import List, NamedTuple, Optional, Union
from dataclasses import replace from dataclasses import replace
import ctranslate2 import ctranslate2
@ -11,9 +12,9 @@ from faster_whisper.transcribe import TranscriptionOptions, get_ctranslate2_stor
from transformers import Pipeline from transformers import Pipeline
from transformers.pipelines.pt_utils import PipelineIterator from transformers.pipelines.pt_utils import PipelineIterator
from whisperx.audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram from .audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram
from whisperx.types import SingleSegment, TranscriptionResult from .types import SingleSegment, TranscriptionResult
from whisperx.vads import Vad, Silero, Pyannote from .vad import VoiceActivitySegmentation, load_vad_model, merge_chunks
def find_numeral_symbol_tokens(tokenizer): def find_numeral_symbol_tokens(tokenizer):
@ -51,7 +52,6 @@ class WhisperModel(faster_whisper.WhisperModel):
previous_tokens, previous_tokens,
without_timestamps=options.without_timestamps, without_timestamps=options.without_timestamps,
prefix=options.prefix, prefix=options.prefix,
hotwords=options.hotwords
) )
encoder_output = self.encode(features) encoder_output = self.encode(features)
@ -106,7 +106,7 @@ class FasterWhisperPipeline(Pipeline):
def __init__( def __init__(
self, self,
model: WhisperModel, model: WhisperModel,
vad, vad: VoiceActivitySegmentation,
vad_params: dict, vad_params: dict,
options: TranscriptionOptions, options: TranscriptionOptions,
tokenizer: Optional[Tokenizer] = None, tokenizer: Optional[Tokenizer] = None,
@ -208,16 +208,7 @@ class FasterWhisperPipeline(Pipeline):
# print(f2-f1) # print(f2-f1)
yield {'inputs': audio[f1:f2]} yield {'inputs': audio[f1:f2]}
# Pre-process audio and merge chunks as defined by the respective VAD child class vad_segments = self.vad_model({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": SAMPLE_RATE})
# In case vad_model is manually assigned (see 'load_model') follow the functionality of pyannote toolkit
if issubclass(type(self.vad_model), Vad):
waveform = self.vad_model.preprocess_audio(audio)
merge_chunks = self.vad_model.merge_chunks
else:
waveform = Pyannote.preprocess_audio(audio)
merge_chunks = Pyannote.merge_chunks
vad_segments = self.vad_model({"waveform": waveform, "sample_rate": SAMPLE_RATE})
vad_segments = merge_chunks( vad_segments = merge_chunks(
vad_segments, vad_segments,
chunk_size, chunk_size,
@ -305,8 +296,7 @@ def load_model(
compute_type="float16", compute_type="float16",
asr_options: Optional[dict] = None, asr_options: Optional[dict] = None,
language: Optional[str] = None, language: Optional[str] = None,
vad_model: Optional[Vad]= None, vad_model: Optional[VoiceActivitySegmentation] = None,
vad_method: Optional[str] = "pyannote",
vad_options: Optional[dict] = None, vad_options: Optional[dict] = None,
model: Optional[WhisperModel] = None, model: Optional[WhisperModel] = None,
task="transcribe", task="transcribe",
@ -319,7 +309,6 @@ def load_model(
whisper_arch - The name of the Whisper model to load. whisper_arch - The name of the Whisper model to load.
device - The device to load the model on. device - The device to load the model on.
compute_type - The compute type to use for the model. compute_type - The compute type to use for the model.
vad_method - The vad method to use. vad_model has higher priority if is not None.
options - A dictionary of options to use for the model. options - A dictionary of options to use for the model.
language - The language of the model. (use English for now) language - The language of the model. (use English for now)
model - The WhisperModel instance to use. model - The WhisperModel instance to use.
@ -385,7 +374,6 @@ def load_model(
default_asr_options = TranscriptionOptions(**default_asr_options) default_asr_options = TranscriptionOptions(**default_asr_options)
default_vad_options = { default_vad_options = {
"chunk_size": 30, # needed by silero since binarization happens before merge_chunks
"vad_onset": 0.500, "vad_onset": 0.500,
"vad_offset": 0.363 "vad_offset": 0.363
} }
@ -393,17 +381,10 @@ def load_model(
if vad_options is not None: if vad_options is not None:
default_vad_options.update(vad_options) default_vad_options.update(vad_options)
# Note: manually assigned vad_model has higher priority than vad_method!
if vad_model is not None: if vad_model is not None:
print("Use manually assigned vad_model. vad_method is ignored.")
vad_model = vad_model vad_model = vad_model
else: else:
if vad_method == "silero": vad_model = load_vad_model(torch.device(device), use_auth_token=None, **default_vad_options)
vad_model = Silero(**default_vad_options)
elif vad_method == "pyannote":
vad_model = Pyannote(torch.device(device), use_auth_token=None, **default_vad_options)
else:
raise ValueError(f"Invalid vad_method: {vad_method}")
return FasterWhisperPipeline( return FasterWhisperPipeline(
model=model, model=model,

View File

@ -7,7 +7,7 @@ import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from whisperx.utils import exact_div from .utils import exact_div
# hard-coded audio hyperparameters # hard-coded audio hyperparameters
SAMPLE_RATE = 16000 SAMPLE_RATE = 16000

View File

@ -4,8 +4,8 @@ from pyannote.audio import Pipeline
from typing import Optional, Union from typing import Optional, Union
import torch import torch
from whisperx.audio import load_audio, SAMPLE_RATE from .audio import load_audio, SAMPLE_RATE
from whisperx.types import TranscriptionResult, AlignedTranscriptionResult from .types import TranscriptionResult, AlignedTranscriptionResult
class DiarizationPipeline: class DiarizationPipeline:
@ -79,7 +79,7 @@ def assign_word_speakers(
class Segment: class Segment:
def __init__(self, start:int, end:int, speaker:Optional[str]=None): def __init__(self, start, end, speaker=None):
self.start = start self.start = start
self.end = end self.end = end
self.speaker = speaker self.speaker = speaker

View File

@ -1,20 +1,17 @@
import argparse import argparse
import gc import gc
import os import os
import sys
import warnings import warnings
import importlib.metadata
import platform
import numpy as np import numpy as np
import torch import torch
from whisperx.alignment import align, load_align_model from .alignment import align, load_align_model
from whisperx.asr import load_model from .asr import load_model
from whisperx.audio import load_audio from .audio import load_audio
from whisperx.diarize import DiarizationPipeline, assign_word_speakers from .diarize import DiarizationPipeline, assign_word_speakers
from whisperx.types import AlignedTranscriptionResult, TranscriptionResult from .types import AlignedTranscriptionResult, TranscriptionResult
from whisperx.utils import ( from .utils import (
LANGUAGES, LANGUAGES,
TO_LANGUAGE_CODE, TO_LANGUAGE_CODE,
get_writer, get_writer,
@ -29,7 +26,6 @@ def cli():
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe") parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
parser.add_argument("--model", default="small", help="name of the Whisper model to use") parser.add_argument("--model", default="small", help="name of the Whisper model to use")
parser.add_argument("--model_cache_only", type=str2bool, default=False, help="If True, will not attempt to download models, instead using cached models from --model_dir")
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default") parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference") parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
parser.add_argument("--device_index", default=0, type=int, help="device index to use for FasterWhisper inference") parser.add_argument("--device_index", default=0, type=int, help="device index to use for FasterWhisper inference")
@ -50,7 +46,6 @@ def cli():
parser.add_argument("--return_char_alignments", action='store_true', help="Return character-level alignments in the output json file") parser.add_argument("--return_char_alignments", action='store_true', help="Return character-level alignments in the output json file")
# vad params # vad params
parser.add_argument("--vad_method", type=str, default="pyannote", choices=["pyannote", "silero"], help="VAD method to be used")
parser.add_argument("--vad_onset", type=float, default=0.500, help="Onset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected") parser.add_argument("--vad_onset", type=float, default=0.500, help="Onset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected")
parser.add_argument("--vad_offset", type=float, default=0.363, help="Offset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected.") parser.add_argument("--vad_offset", type=float, default=0.363, help="Offset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected.")
parser.add_argument("--chunk_size", type=int, default=30, help="Chunk size for merging VAD segments. Default is 30, reduce this if the chunk is too long.") parser.add_argument("--chunk_size", type=int, default=30, help="Chunk size for merging VAD segments. Default is 30, reduce this if the chunk is too long.")
@ -88,15 +83,12 @@ def cli():
parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face Access Token to access PyAnnote gated models") parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face Access Token to access PyAnnote gated models")
parser.add_argument("--print_progress", type=str2bool, default = False, help = "if True, progress will be printed in transcribe() and align() methods.") parser.add_argument("--print_progress", type=str2bool, default = False, help = "if True, progress will be printed in transcribe() and align() methods.")
parser.add_argument("--version", "-V", action="version", version=f"%(prog)s {importlib.metadata.version('whisperx')}",help="Show whisperx version information and exit")
parser.add_argument("--python-version", "-P", action="version", version=f"Python {platform.python_version()} ({platform.python_implementation()})",help="Show python version information and exit")
# fmt: on # fmt: on
args = parser.parse_args().__dict__ args = parser.parse_args().__dict__
model_name: str = args.pop("model") model_name: str = args.pop("model")
batch_size: int = args.pop("batch_size") batch_size: int = args.pop("batch_size")
model_dir: str = args.pop("model_dir") model_dir: str = args.pop("model_dir")
model_cache_only: bool = args.pop("model_cache_only")
output_dir: str = args.pop("output_dir") output_dir: str = args.pop("output_dir")
output_format: str = args.pop("output_format") output_format: str = args.pop("output_format")
device: str = args.pop("device") device: str = args.pop("device")
@ -118,7 +110,6 @@ def cli():
return_char_alignments: bool = args.pop("return_char_alignments") return_char_alignments: bool = args.pop("return_char_alignments")
hf_token: str = args.pop("hf_token") hf_token: str = args.pop("hf_token")
vad_method: str = args.pop("vad_method")
vad_onset: float = args.pop("vad_onset") vad_onset: float = args.pop("vad_onset")
vad_offset: float = args.pop("vad_offset") vad_offset: float = args.pop("vad_offset")
@ -143,9 +134,7 @@ def cli():
f"{model_name} is an English-only model but received '{args['language']}'; using English instead." f"{model_name} is an English-only model but received '{args['language']}'; using English instead."
) )
args["language"] = "en" args["language"] = "en"
align_language = ( align_language = args["language"] if args["language"] is not None else "en" # default to loading english if not specified
args["language"] if args["language"] is not None else "en"
) # default to loading english if not specified
temperature = args.pop("temperature") temperature = args.pop("temperature")
if (increment := args.pop("temperature_increment_on_fallback")) is not None: if (increment := args.pop("temperature_increment_on_fallback")) is not None:
@ -186,24 +175,7 @@ def cli():
results = [] results = []
tmp_results = [] tmp_results = []
# model = load_model(model_name, device=device, download_root=model_dir) # model = load_model(model_name, device=device, download_root=model_dir)
model = load_model( model = load_model(model_name, device=device, device_index=device_index, download_root=model_dir, compute_type=compute_type, language=args['language'], asr_options=asr_options, vad_options={"vad_onset": vad_onset, "vad_offset": vad_offset}, task=task, threads=faster_whisper_threads)
model_name,
device=device,
device_index=device_index,
download_root=model_dir,
compute_type=compute_type,
language=args["language"],
asr_options=asr_options,
vad_method=vad_method,
vad_options={
"chunk_size": chunk_size,
"vad_onset": vad_onset,
"vad_offset": vad_offset,
},
task=task,
local_files_only=model_cache_only,
threads=faster_whisper_threads,
)
for audio_path in args.pop("audio"): for audio_path in args.pop("audio"):
audio = load_audio(audio_path) audio = load_audio(audio_path)
@ -227,9 +199,7 @@ def cli():
if not no_align: if not no_align:
tmp_results = results tmp_results = results
results = [] results = []
align_model, align_metadata = load_align_model( align_model, align_metadata = load_align_model(align_language, device, model_name=align_model)
align_language, device, model_name=align_model
)
for result, audio_path in tmp_results: for result, audio_path in tmp_results:
# >> Align # >> Align
if len(tmp_results) > 1: if len(tmp_results) > 1:
@ -241,12 +211,8 @@ def cli():
if align_model is not None and len(result["segments"]) > 0: if align_model is not None and len(result["segments"]) > 0:
if result.get("language", "en") != align_metadata["language"]: if result.get("language", "en") != align_metadata["language"]:
# load new language # load new language
print( print(f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language...")
f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language..." align_model, align_metadata = load_align_model(result["language"], device)
)
align_model, align_metadata = load_align_model(
result["language"], device
)
print(">>Performing alignment...") print(">>Performing alignment...")
result: AlignedTranscriptionResult = align( result: AlignedTranscriptionResult = align(
result["segments"], result["segments"],
@ -269,17 +235,13 @@ def cli():
# >> Diarize # >> Diarize
if diarize: if diarize:
if hf_token is None: if hf_token is None:
print( print("Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model...")
"Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model..."
)
tmp_results = results tmp_results = results
print(">>Performing diarization...") print(">>Performing diarization...")
results = [] results = []
diarize_model = DiarizationPipeline(use_auth_token=hf_token, device=device) diarize_model = DiarizationPipeline(use_auth_token=hf_token, device=device)
for result, input_audio_path in tmp_results: for result, input_audio_path in tmp_results:
diarize_segments = diarize_model( diarize_segments = diarize_model(input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers)
input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers
)
result = assign_word_speakers(diarize_segments, result) result = assign_word_speakers(diarize_segments, result)
results.append((result, input_audio_path)) results.append((result, input_audio_path))
# >> Write # >> Write
@ -287,6 +249,5 @@ def cli():
result["language"] = align_language result["language"] = align_language
writer(result, audio_path, writer_args) writer(result, audio_path, writer_args)
if __name__ == "__main__": if __name__ == "__main__":
cli() cli()

View File

@ -1,4 +1,4 @@
from typing import TypedDict, Optional, List, Tuple from typing import TypedDict, Optional, List
class SingleWordSegment(TypedDict): class SingleWordSegment(TypedDict):
@ -30,17 +30,6 @@ class SingleSegment(TypedDict):
text: str text: str
class SegmentData(TypedDict):
"""
Temporary processing data used during alignment.
Contains cleaned and preprocessed data for each segment.
"""
clean_char: List[str] # Cleaned characters that exist in model dictionary
clean_cdx: List[int] # Original indices of cleaned characters
clean_wdx: List[int] # Indices of words containing valid characters
sentence_spans: List[Tuple[int, int]] # Start and end indices of sentences
class SingleAlignedSegment(TypedDict): class SingleAlignedSegment(TypedDict):
""" """
A single segment (up to multiple sentences) of a speech with word alignment. A single segment (up to multiple sentences) of a speech with word alignment.

View File

@ -241,7 +241,7 @@ class SubtitlesWriter(ResultWriter):
line_count = 1 line_count = 1
# the next subtitle to yield (a list of word timings with whitespace) # the next subtitle to yield (a list of word timings with whitespace)
subtitle: list[dict] = [] subtitle: list[dict] = []
times: list[tuple] = [] times = []
last = result["segments"][0]["start"] last = result["segments"][0]["start"]
for segment in result["segments"]: for segment in result["segments"]:
for i, original_timing in enumerate(segment["words"]): for i, original_timing in enumerate(segment["words"]):

View File

@ -1,29 +1,32 @@
import hashlib
import os import os
from typing import Callable, Text, Union import urllib
from typing import Optional from typing import Callable, Optional, Text, Union
import numpy as np import numpy as np
import pandas as pd
import torch import torch
from pyannote.audio import Model from pyannote.audio import Model
from pyannote.audio.core.io import AudioFile from pyannote.audio.core.io import AudioFile
from pyannote.audio.pipelines import VoiceActivityDetection from pyannote.audio.pipelines import VoiceActivityDetection
from pyannote.audio.pipelines.utils import PipelineModel from pyannote.audio.pipelines.utils import PipelineModel
from pyannote.core import Annotation, SlidingWindowFeature from pyannote.core import Annotation, Segment, SlidingWindowFeature
from pyannote.core import Segment from tqdm import tqdm
from whisperx.diarize import Segment as SegmentX from .diarize import Segment as SegmentX
from whisperx.vads.vad import Vad
# deprecated
VAD_SEGMENTATION_URL = "https://whisperx.s3.eu-west-2.amazonaws.com/model_weights/segmentation/0b5b3216d60a2d32fc086b47ea8c67589aaeb26b7e07fcbe620d6d0b83e209ea/pytorch_model.bin"
def load_vad_model(device, vad_onset=0.500, vad_offset=0.363, use_auth_token=None, model_fp=None): def load_vad_model(device, vad_onset=0.500, vad_offset=0.363, use_auth_token=None, model_fp=None):
model_dir = torch.hub._get_torch_home() model_dir = torch.hub._get_torch_home()
main_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) vad_dir = os.path.dirname(os.path.abspath(__file__))
os.makedirs(model_dir, exist_ok = True) os.makedirs(model_dir, exist_ok = True)
if model_fp is None: if model_fp is None:
# Dynamically resolve the path to the model file # Dynamically resolve the path to the model file
model_fp = os.path.join(main_dir, "assets", "pytorch_model.bin") model_fp = os.path.join(vad_dir, "assets", "pytorch_model.bin")
model_fp = os.path.abspath(model_fp) # Ensure the path is absolute model_fp = os.path.abspath(model_fp) # Ensure the path is absolute
else: else:
model_fp = os.path.abspath(model_fp) # Ensure any provided path is absolute model_fp = os.path.abspath(model_fp) # Ensure any provided path is absolute
@ -36,6 +39,10 @@ def load_vad_model(device, vad_onset=0.500, vad_offset=0.363, use_auth_token=Non
raise RuntimeError(f"{model_fp} exists and is not a regular file") raise RuntimeError(f"{model_fp} exists and is not a regular file")
model_bytes = open(model_fp, "rb").read() model_bytes = open(model_fp, "rb").read()
if hashlib.sha256(model_bytes).hexdigest() != VAD_SEGMENTATION_URL.split('/')[-2]:
raise RuntimeError(
"Model has been downloaded but the SHA256 checksum does not match. Please retry loading the model."
)
vad_model = Model.from_pretrained(model_fp, use_auth_token=use_auth_token) vad_model = Model.from_pretrained(model_fp, use_auth_token=use_auth_token)
hyperparameters = {"onset": vad_onset, hyperparameters = {"onset": vad_onset,
@ -144,8 +151,8 @@ class Binarize:
region = Segment(start - self.pad_onset, min_score_t + self.pad_offset) region = Segment(start - self.pad_onset, min_score_t + self.pad_offset)
active[region, k] = label active[region, k] = label
start = curr_timestamps[min_score_div_idx] start = curr_timestamps[min_score_div_idx]
curr_scores = curr_scores[min_score_div_idx + 1:] curr_scores = curr_scores[min_score_div_idx+1:]
curr_timestamps = curr_timestamps[min_score_div_idx + 1:] curr_timestamps = curr_timestamps[min_score_div_idx+1:]
# switching from active to inactive # switching from active to inactive
elif y < self.offset: elif y < self.offset:
region = Segment(start - self.pad_onset, t + self.pad_offset) region = Segment(start - self.pad_onset, t + self.pad_offset)
@ -229,26 +236,41 @@ class VoiceActivitySegmentation(VoiceActivityDetection):
return segmentations return segmentations
class Pyannote(Vad): def merge_vad(vad_arr, pad_onset=0.0, pad_offset=0.0, min_duration_off=0.0, min_duration_on=0.0):
def __init__(self, device, use_auth_token=None, model_fp=None, **kwargs): active = Annotation()
print(">>Performing voice activity detection using Pyannote...") for k, vad_t in enumerate(vad_arr):
super().__init__(kwargs['vad_onset']) region = Segment(vad_t[0] - pad_onset, vad_t[1] + pad_offset)
self.vad_pipeline = load_vad_model(device, use_auth_token=use_auth_token, model_fp=model_fp) active[region, k] = 1
def __call__(self, audio: AudioFile, **kwargs):
return self.vad_pipeline(audio)
@staticmethod if pad_offset > 0.0 or pad_onset > 0.0 or min_duration_off > 0.0:
def preprocess_audio(audio): active = active.support(collar=min_duration_off)
return torch.from_numpy(audio).unsqueeze(0)
@staticmethod # remove tracks shorter than min_duration_on
def merge_chunks(segments, if min_duration_on > 0:
for segment, track in list(active.itertracks()):
if segment.duration < min_duration_on:
del active[segment, track]
active = active.for_json()
active_segs = pd.DataFrame([x['segment'] for x in active['content']])
return active_segs
def merge_chunks(
segments,
chunk_size, chunk_size,
onset: float = 0.5, onset: float = 0.5,
offset: Optional[float] = None, offset: Optional[float] = None,
): ):
"""
Merge operation described in paper
"""
curr_end = 0
merged_segments = []
seg_idxs = []
speaker_idxs = []
assert chunk_size > 0 assert chunk_size > 0
binarize = Binarize(max_duration=chunk_size, onset=onset, offset=offset) binarize = Binarize(max_duration=chunk_size, onset=onset, offset=offset)
segments = binarize(segments) segments = binarize(segments)
@ -259,5 +281,27 @@ class Pyannote(Vad):
if len(segments_list) == 0: if len(segments_list) == 0:
print("No active speech found in audio") print("No active speech found in audio")
return [] return []
assert segments_list, "segments_list is empty." # assert segments_list, "segments_list is empty."
return Vad.merge_chunks(segments_list, chunk_size, onset, offset) # Make sur the starting point is the start of the segment.
curr_start = segments_list[0].start
for seg in segments_list:
if seg.end - curr_start > chunk_size and curr_end-curr_start > 0:
merged_segments.append({
"start": curr_start,
"end": curr_end,
"segments": seg_idxs,
})
curr_start = seg.start
seg_idxs = []
speaker_idxs = []
curr_end = seg.end
seg_idxs.append((seg.start, seg.end))
speaker_idxs.append(seg.speaker)
# add final
merged_segments.append({
"start": curr_start,
"end": curr_end,
"segments": seg_idxs,
})
return merged_segments

View File

@ -1,3 +0,0 @@
from whisperx.vads.pyannote import Pyannote as Pyannote
from whisperx.vads.silero import Silero as Silero
from whisperx.vads.vad import Vad as Vad

View File

@ -1,66 +0,0 @@
from io import IOBase
from pathlib import Path
from typing import Mapping, Text
from typing import Optional
from typing import Union
import torch
from whisperx.diarize import Segment as SegmentX
from whisperx.vads.vad import Vad
AudioFile = Union[Text, Path, IOBase, Mapping]
class Silero(Vad):
# check again default values
def __init__(self, **kwargs):
print(">>Performing voice activity detection using Silero...")
super().__init__(kwargs['vad_onset'])
self.vad_onset = kwargs['vad_onset']
self.chunk_size = kwargs['chunk_size']
self.vad_pipeline, vad_utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',
model='silero_vad',
force_reload=False,
onnx=False,
trust_repo=True)
(self.get_speech_timestamps, _, self.read_audio, _, _) = vad_utils
def __call__(self, audio: AudioFile, **kwargs):
"""use silero to get segments of speech"""
# Only accept 16000 Hz for now.
# Note: Silero models support both 8000 and 16000 Hz. Although other values are not directly supported,
# multiples of 16000 (e.g. 32000 or 48000) are cast to 16000 inside of the JIT model!
sample_rate = audio["sample_rate"]
if sample_rate != 16000:
raise ValueError("Only 16000Hz sample rate is allowed")
timestamps = self.get_speech_timestamps(audio["waveform"],
model=self.vad_pipeline,
sampling_rate=sample_rate,
max_speech_duration_s=self.chunk_size,
threshold=self.vad_onset
# min_silence_duration_ms = self.min_duration_off/1000
# min_speech_duration_ms = self.min_duration_on/1000
# ...
# See silero documentation for full option list
)
return [SegmentX(i['start'] / sample_rate, i['end'] / sample_rate, "UNKNOWN") for i in timestamps]
@staticmethod
def preprocess_audio(audio):
return audio
@staticmethod
def merge_chunks(segments_list,
chunk_size,
onset: float = 0.5,
offset: Optional[float] = None,
):
assert chunk_size > 0
if len(segments_list) == 0:
print("No active speech found in audio")
return []
assert segments_list, "segments_list is empty."
return Vad.merge_chunks(segments_list, chunk_size, onset, offset)

View File

@ -1,74 +0,0 @@
from typing import Optional
import pandas as pd
from pyannote.core import Annotation, Segment
class Vad:
def __init__(self, vad_onset):
if not (0 < vad_onset < 1):
raise ValueError(
"vad_onset is a decimal value between 0 and 1."
)
@staticmethod
def preprocess_audio(audio):
pass
# keep merge_chunks as static so it can be also used by manually assigned vad_model (see 'load_model')
@staticmethod
def merge_chunks(segments,
chunk_size,
onset: float,
offset: Optional[float]):
"""
Merge operation described in paper
"""
curr_end = 0
merged_segments = []
seg_idxs: list[tuple]= []
speaker_idxs: list[Optional[str]] = []
curr_start = segments[0].start
for seg in segments:
if seg.end - curr_start > chunk_size and curr_end - curr_start > 0:
merged_segments.append({
"start": curr_start,
"end": curr_end,
"segments": seg_idxs,
})
curr_start = seg.start
seg_idxs = []
speaker_idxs = []
curr_end = seg.end
seg_idxs.append((seg.start, seg.end))
speaker_idxs.append(seg.speaker)
# add final
merged_segments.append({
"start": curr_start,
"end": curr_end,
"segments": seg_idxs,
})
return merged_segments
# Unused function
@staticmethod
def merge_vad(vad_arr, pad_onset=0.0, pad_offset=0.0, min_duration_off=0.0, min_duration_on=0.0):
active = Annotation()
for k, vad_t in enumerate(vad_arr):
region = Segment(vad_t[0] - pad_onset, vad_t[1] + pad_offset)
active[region, k] = 1
if pad_offset > 0.0 or pad_onset > 0.0 or min_duration_off > 0.0:
active = active.support(collar=min_duration_off)
# remove tracks shorter than min_duration_on
if min_duration_on > 0:
for segment, track in list(active.itertracks()):
if segment.duration < min_duration_on:
del active[segment, track]
active = active.for_json()
active_segs = pd.DataFrame([x['segment'] for x in active['content']])
return active_segs