15 Commits

Author SHA1 Message Date
f5b40b5366 chore: update version to 3.3.3 in pyproject.toml and uv.lock 2025-05-01 11:08:54 +02:00
ac0c8bd79a feat: add version and Python version arguments to CLI 2025-05-01 11:08:54 +02:00
cd59f21d1a fix: downgrade ctranslate2 dependency version 2025-05-01 11:08:54 +02:00
0aed874589 Remove duplicated item
"lv": "latvian"
2025-04-12 11:08:15 +02:00
f10dbf6ab1 fix: update setuptools configuration to include package discovery for whisperx 2025-03-25 18:49:44 +01:00
a7564c2ad6 docs: update installation instructions 2025-03-25 17:02:41 +01:00
e7712f496e refactor: update import statements to use explicit module paths across multiple files 2025-03-25 16:24:21 +01:00
8e53866704 feat: pass hotwords argument to get_prompt (#1073)
Co-authored-by: Jade Moillic <jade.moillic@radiofrance.com>
2025-03-24 10:47:47 +01:00
3205436d58 Merge pull request #1002 from Barabazs/feat/uv 2025-03-23 12:59:46 +00:00
d2f0e53f71 chore: remove tmp workflow 2025-02-12 08:23:23 +01:00
7489ebf876 feat: update build and release workflow to use uv for package installation and publishing 2025-02-12 08:23:23 +01:00
90256cc481 feat: use uv recommended setup 2025-02-12 08:23:23 +01:00
b41ebd4871 chore: add numpy to deps 2025-02-12 08:23:23 +01:00
63bc1903c1 feat: update Python compatibility workflow to use uv 2025-02-12 08:23:23 +01:00
272714e07d feat: use uv for building package 2025-02-12 08:23:23 +01:00
19 changed files with 3137 additions and 212 deletions

View File

@ -11,25 +11,21 @@ jobs:
- name: Checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v4
- name: Set up Python - name: Install uv
uses: actions/setup-python@v5 uses: astral-sh/setup-uv@v5
with: with:
version: "0.5.14"
python-version: "3.9" python-version: "3.9"
- name: Install dependencies - name: Build package
run: | run: uv build
python -m pip install build
- name: Build wheels
run: python -m build --wheel
- name: Release to Github - name: Release to Github
uses: softprops/action-gh-release@v2 uses: softprops/action-gh-release@v2
with: with:
files: dist/* files: dist/*.whl
- name: Publish package to PyPi - name: Publish package to PyPi
uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 run: uv publish
with: env:
user: __token__ UV_PUBLISH_TOKEN: ${{ secrets.PYPI_API_TOKEN }}
password: ${{ secrets.PYPI_API_TOKEN }}

View File

@ -17,16 +17,15 @@ jobs:
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }} - name: Install uv
uses: actions/setup-python@v5 uses: astral-sh/setup-uv@v5
with: with:
version: "0.5.14"
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}
- name: Install package - name: Install the project
run: | run: uv sync --all-extras
python -m pip install --upgrade pip
pip install .
- name: Test import - name: Test import
run: | run: |
python -c "import whisperx; print('Successfully imported whisperx')" uv run python -c "import whisperx; print('Successfully imported whisperx')"

View File

@ -1,35 +0,0 @@
name: Python Compatibility Test (PyPi)
on:
push:
branches: [main]
pull_request:
branches: [main]
workflow_dispatch: # Allows manual triggering from GitHub UI
jobs:
test:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.9", "3.10", "3.11", "3.12"]
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install package
run: |
pip install whisperx
- name: Print packages
run: |
pip list
- name: Test import
run: |
python -c "import whisperx; print('Successfully imported whisperx')"

View File

@ -62,54 +62,41 @@ This repository provides fast automatic speech recognition (70x realtime with la
- Paper drop🎓👨🏫! Please see our [ArxiV preprint](https://arxiv.org/abs/2303.00747) for benchmarking and details of WhisperX. We also introduce more efficient batch inference resulting in large-v2 with *60-70x REAL TIME speed. - Paper drop🎓👨🏫! Please see our [ArxiV preprint](https://arxiv.org/abs/2303.00747) for benchmarking and details of WhisperX. We also introduce more efficient batch inference resulting in large-v2 with *60-70x REAL TIME speed.
<h2 align="left" id="setup">Setup ⚙️</h2> <h2 align="left" id="setup">Setup ⚙️</h2>
Tested for PyTorch 2.0, Python 3.10 (use other versions at your own risk!)
GPU execution requires the NVIDIA libraries cuBLAS 11.x and cuDNN 8.x to be installed on the system. Please refer to the [CTranslate2 documentation](https://opennmt.net/CTranslate2/installation.html). ### 1. Simple Installation (Recommended)
The easiest way to install WhisperX is through PyPi:
### 1. Create Python3.10 environment
`conda create --name whisperx python=3.10`
`conda activate whisperx`
### 2. Install PyTorch, e.g. for Linux and Windows CUDA11.8:
`conda install pytorch==2.0.0 torchaudio==2.0.0 pytorch-cuda=11.8 -c pytorch -c nvidia`
See other methods [here.](https://pytorch.org/get-started/previous-versions/#v200)
### 3. Install WhisperX
You have several installation options:
#### Option A: Stable Release (recommended)
Install the latest stable version from PyPI:
```bash ```bash
pip install whisperx pip install whisperx
``` ```
#### Option B: Development Version Or if using [uvx](https://docs.astral.sh/uv/guides/tools/#running-tools):
Install the latest development version directly from GitHub (may be unstable):
```bash ```bash
pip install git+https://github.com/m-bain/whisperx.git uvx whisperx
``` ```
If already installed, update to the most recent commit: ### 2. Advanced Installation Options
These installation methods are for developers or users with specific needs. If you're not sure, stick with the simple installation above.
#### Option A: Install from GitHub
To install directly from the GitHub repository:
```bash ```bash
pip install git+https://github.com/m-bain/whisperx.git --upgrade uvx git+https://github.com/m-bain/whisperX.git
``` ```
#### Option C: Development Mode #### Option B: Developer Installation
If you wish to modify the package, clone and install in editable mode:
If you want to modify the code or contribute to the project:
```bash ```bash
git clone https://github.com/m-bain/whisperX.git git clone https://github.com/m-bain/whisperX.git
cd whisperX cd whisperX
pip install -e . uv sync --all-extras --dev
``` ```
> **Note**: The development version may contain experimental features and bugs. Use the stable PyPI release for production environments. > **Note**: The development version may contain experimental features and bugs. Use the stable PyPI release for production environments.
@ -117,12 +104,12 @@ pip install -e .
You may also need to install ffmpeg, rust etc. Follow openAI instructions here https://github.com/openai/whisper#setup. You may also need to install ffmpeg, rust etc. Follow openAI instructions here https://github.com/openai/whisper#setup.
### Speaker Diarization ### Speaker Diarization
To **enable Speaker Diarization**, include your Hugging Face access token (read) that you can generate from [Here](https://huggingface.co/settings/tokens) after the `--hf_token` argument and accept the user agreement for the following models: [Segmentation](https://huggingface.co/pyannote/segmentation-3.0) and [Speaker-Diarization-3.1](https://huggingface.co/pyannote/speaker-diarization-3.1) (if you choose to use Speaker-Diarization 2.x, follow requirements [here](https://huggingface.co/pyannote/speaker-diarization) instead.) To **enable Speaker Diarization**, include your Hugging Face access token (read) that you can generate from [Here](https://huggingface.co/settings/tokens) after the `--hf_token` argument and accept the user agreement for the following models: [Segmentation](https://huggingface.co/pyannote/segmentation-3.0) and [Speaker-Diarization-3.1](https://huggingface.co/pyannote/speaker-diarization-3.1) (if you choose to use Speaker-Diarization 2.x, follow requirements [here](https://huggingface.co/pyannote/speaker-diarization) instead.)
> **Note**<br> > **Note**<br>
> As of Oct 11, 2023, there is a known issue regarding slow performance with pyannote/Speaker-Diarization-3.0 in whisperX. It is due to dependency conflicts between faster-whisper and pyannote-audio 3.0.0. Please see [this issue](https://github.com/m-bain/whisperX/issues/499) for more details and potential workarounds. > As of Oct 11, 2023, there is a known issue regarding slow performance with pyannote/Speaker-Diarization-3.0 in whisperX. It is due to dependency conflicts between faster-whisper and pyannote-audio 3.0.0. Please see [this issue](https://github.com/m-bain/whisperX/issues/499) for more details and potential workarounds.
<h2 align="left" id="example">Usage 💬 (command line)</h2> <h2 align="left" id="example">Usage 💬 (command line)</h2>
### English ### English

36
pyproject.toml Normal file
View File

@ -0,0 +1,36 @@
[project]
urls = { repository = "https://github.com/m-bain/whisperx" }
authors = [{ name = "Max Bain" }]
name = "whisperx"
version = "3.3.3"
description = "Time-Accurate Automatic Speech Recognition using Whisper."
readme = "README.md"
requires-python = ">=3.9, <3.13"
license = { text = "BSD-2-Clause" }
dependencies = [
"ctranslate2<4.5.0",
"faster-whisper>=1.1.1",
"nltk>=3.9.1",
"numpy>=2.0.2",
"onnxruntime>=1.19",
"pandas>=2.2.3",
"pyannote-audio>=3.3.2",
"torch>=2.5.1",
"torchaudio>=2.5.1",
"transformers>=4.48.0",
]
[project.scripts]
whisperx = "whisperx.transcribe:cli"
[build-system]
requires = ["setuptools"]
[tool.setuptools]
include-package-data = true
[tool.setuptools.packages.find]
where = ["."]
include = ["whisperx*"]

View File

@ -1,8 +0,0 @@
torch>=2
torchaudio>=2
faster-whisper==1.1.0
ctranslate2<4.5.0
transformers
pandas
setuptools>=65
nltk

View File

@ -1,33 +0,0 @@
import os
import pkg_resources
from setuptools import find_packages, setup
with open("README.md", "r", encoding="utf-8") as f:
long_description = f.read()
setup(
name="whisperx",
py_modules=["whisperx"],
version="3.3.1",
description="Time-Accurate Automatic Speech Recognition using Whisper.",
long_description=long_description,
long_description_content_type="text/markdown",
python_requires=">=3.9, <3.13",
author="Max Bain",
url="https://github.com/m-bain/whisperx",
license="BSD-2-Clause",
packages=find_packages(exclude=["tests*"]),
install_requires=[
str(r)
for r in pkg_resources.parse_requirements(
open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
)
]
+ [f"pyannote.audio==3.3.2"],
entry_points={
"console_scripts": ["whisperx=whisperx.transcribe:cli"],
},
include_package_data=True,
extras_require={"dev": ["pytest"]},
)

2905
uv.lock generated Normal file

File diff suppressed because it is too large Load Diff

View File

@ -1,6 +1,5 @@
import math import math
from .conjunctions import get_conjunctions, get_comma from whisperx.conjunctions import get_conjunctions, get_comma
from typing import TextIO
def normal_round(n): def normal_round(n):
if n - math.floor(n) < 0.5: if n - math.floor(n) < 0.5:

View File

@ -1,4 +1,7 @@
from .alignment import load_align_model, align from whisperx.alignment import load_align_model as load_align_model, align as align
from .audio import load_audio from whisperx.asr import load_model as load_model
from .diarize import assign_word_speakers, DiarizationPipeline from whisperx.audio import load_audio as load_audio
from .asr import load_model from whisperx.diarize import (
assign_word_speakers as assign_word_speakers,
DiarizationPipeline as DiarizationPipeline,
)

View File

@ -1,4 +1,4 @@
from .transcribe import cli from whisperx.transcribe import cli
cli() cli()

View File

@ -13,9 +13,9 @@ import torch
import torchaudio import torchaudio
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from .audio import SAMPLE_RATE, load_audio from whisperx.audio import SAMPLE_RATE, load_audio
from .utils import interpolate_nans from whisperx.utils import interpolate_nans
from .types import ( from whisperx.types import (
AlignedTranscriptionResult, AlignedTranscriptionResult,
SingleSegment, SingleSegment,
SingleAlignedSegment, SingleAlignedSegment,

View File

@ -11,14 +11,12 @@ from faster_whisper.transcribe import TranscriptionOptions, get_ctranslate2_stor
from transformers import Pipeline from transformers import Pipeline
from transformers.pipelines.pt_utils import PipelineIterator from transformers.pipelines.pt_utils import PipelineIterator
from .audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram from whisperx.audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram
from .types import SingleSegment, TranscriptionResult from whisperx.types import SingleSegment, TranscriptionResult
from .vads import Vad, Silero, Pyannote from whisperx.vads import Vad, Silero, Pyannote
def find_numeral_symbol_tokens(tokenizer): def find_numeral_symbol_tokens(tokenizer):
"""
Finds tokens that represent numeral and symbols.
"""
numeral_symbol_tokens = [] numeral_symbol_tokens = []
for i in range(tokenizer.eot): for i in range(tokenizer.eot):
token = tokenizer.decode([i]).removeprefix(" ") token = tokenizer.decode([i]).removeprefix(" ")
@ -28,10 +26,10 @@ def find_numeral_symbol_tokens(tokenizer):
return numeral_symbol_tokens return numeral_symbol_tokens
class WhisperModel(faster_whisper.WhisperModel): class WhisperModel(faster_whisper.WhisperModel):
""" '''
Wrapper around faster-whisper's WhisperModel to enable batched inference. FasterWhisperModel provides batched inference for faster-whisper.
Currently, it only supports non-timestamp mode and a fixed prompt for all samples in a batch. Currently only works in non-timestamp mode and fixed prompt for all samples in batch.
""" '''
def generate_segment_batched( def generate_segment_batched(
self, self,
@ -40,45 +38,28 @@ class WhisperModel(faster_whisper.WhisperModel):
options: TranscriptionOptions, options: TranscriptionOptions,
encoder_output=None, encoder_output=None,
): ):
"""
Generates transcription for a batch of audio segments.
Args:
features: The input audio features.
tokenizer: The tokenizer used to decode the generated tokens.
options: Transcription options.
encoder_output: Output from the encoder model.
Returns:
The decoded transcription text.
"""
batch_size = features.shape[0] batch_size = features.shape[0]
# Initialize tokens and prompt for the generation process.
all_tokens = [] all_tokens = []
prompt_reset_since = 0 prompt_reset_since = 0
# Check if an initial prompt is provided and handle it.
if options.initial_prompt is not None: if options.initial_prompt is not None:
initial_prompt = " " + options.initial_prompt.strip() initial_prompt = " " + options.initial_prompt.strip()
initial_prompt_tokens = tokenizer.encode(initial_prompt) initial_prompt_tokens = tokenizer.encode(initial_prompt)
all_tokens.extend(initial_prompt_tokens) all_tokens.extend(initial_prompt_tokens)
# Prepare the prompt for the current batch.
previous_tokens = all_tokens[prompt_reset_since:] previous_tokens = all_tokens[prompt_reset_since:]
prompt = self.get_prompt( prompt = self.get_prompt(
tokenizer, tokenizer,
previous_tokens, previous_tokens,
without_timestamps=options.without_timestamps, without_timestamps=options.without_timestamps,
prefix=options.prefix, prefix=options.prefix,
hotwords=options.hotwords
) )
# Encode the features to obtain the encoder output.
encoder_output = self.encode(features) encoder_output = self.encode(features)
# Determine the maximum initial timestamp index based on the options.
max_initial_timestamp_index = int( max_initial_timestamp_index = int(
round(options.max_initial_timestamp / self.time_precision) round(options.max_initial_timestamp / self.time_precision)
) )
# Generate the transcription result for the batch.
result = self.model.generate( result = self.model.generate(
encoder_output, encoder_output,
[prompt] * batch_size, [prompt] * batch_size,
@ -90,37 +71,100 @@ class WhisperModel(faster_whisper.WhisperModel):
suppress_tokens=options.suppress_tokens, suppress_tokens=options.suppress_tokens,
) )
# Extract the token sequences from the result.
tokens_batch = [x.sequences_ids[0] for x in result] tokens_batch = [x.sequences_ids[0] for x in result]
# Define an inner function to decode the tokens for each batch.
def decode_batch(tokens: List[List[int]]) -> str: def decode_batch(tokens: List[List[int]]) -> str:
res = [] res = []
for tk in tokens: for tk in tokens:
res.append([token for token in tk if token < tokenizer.eot]) res.append([token for token in tk if token < tokenizer.eot])
# text_tokens = [token for token in tokens if token < self.eot]
return tokenizer.tokenizer.decode_batch(res) return tokenizer.tokenizer.decode_batch(res)
# Decode the tokens to get the transcription text.
text = decode_batch(tokens_batch) text = decode_batch(tokens_batch)
return text return text
def encode(self, features: np.ndarray) -> ctranslate2.StorageView: def encode(self, features: np.ndarray) -> ctranslate2.StorageView:
""" # When the model is running on multiple GPUs, the encoder output should be moved
Encodes the audio features using the CTranslate2 storage. # to the CPU since we don't know which GPU will handle the next job.
When the model is running on multiple GPUs, the encoder output should be moved
to the CPU since we don't know which GPU will handle the next job.
"""
# When the model is running on multiple GPUs, the encoder output should be moved to the CPU.
to_cpu = self.model.device == "cuda" and len(self.model.device_index) > 1 to_cpu = self.model.device == "cuda" and len(self.model.device_index) > 1
# If the batch size is 1, unsqueeze the features to ensure it is a 3D array. # unsqueeze if batch size = 1
if len(features.shape) == 2: if len(features.shape) == 2:
features = np.expand_dims(features, 0) features = np.expand_dims(features, 0)
features = get_ctranslate2_storage(features) features = get_ctranslate2_storage(features)
# call the model
return self.model.encode(features, to_cpu=to_cpu) return self.model.encode(features, to_cpu=to_cpu)
class FasterWhisperPipeline(Pipeline):
"""
Huggingface Pipeline wrapper for FasterWhisperModel.
"""
# TODO:
# - add support for timestamp mode
# - add support for custom inference kwargs
def __init__(
self,
model: WhisperModel,
vad,
vad_params: dict,
options: TranscriptionOptions,
tokenizer: Optional[Tokenizer] = None,
device: Union[int, str, "torch.device"] = -1,
framework="pt",
language: Optional[str] = None,
suppress_numerals: bool = False,
**kwargs,
):
self.model = model
self.tokenizer = tokenizer
self.options = options
self.preset_language = language
self.suppress_numerals = suppress_numerals
self._batch_size = kwargs.pop("batch_size", None)
self._num_workers = 1
self._preprocess_params, self._forward_params, self._postprocess_params = self._sanitize_parameters(**kwargs)
self.call_count = 0
self.framework = framework
if self.framework == "pt":
if isinstance(device, torch.device):
self.device = device
elif isinstance(device, str):
self.device = torch.device(device)
elif device < 0:
self.device = torch.device("cpu")
else:
self.device = torch.device(f"cuda:{device}")
else:
self.device = device
super(Pipeline, self).__init__()
self.vad_model = vad
self._vad_params = vad_params
def _sanitize_parameters(self, **kwargs):
preprocess_kwargs = {}
if "tokenizer" in kwargs:
preprocess_kwargs["maybe_arg"] = kwargs["maybe_arg"]
return preprocess_kwargs, {}, {}
def preprocess(self, audio):
audio = audio['inputs']
model_n_mels = self.model.feat_kwargs.get("feature_size")
features = log_mel_spectrogram(
audio,
n_mels=model_n_mels if model_n_mels is not None else 80,
padding=N_SAMPLES - audio.shape[0],
)
return {'inputs': features}
def _forward(self, model_inputs):
outputs = self.model.generate_segment_batched(model_inputs['inputs'], self.tokenizer, self.options)
return {'text': outputs}
def postprocess(self, model_outputs):
return model_outputs
def get_iterator( def get_iterator(
self, self,
inputs, inputs,

View File

@ -7,7 +7,7 @@ import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from .utils import exact_div from whisperx.utils import exact_div
# hard-coded audio hyperparameters # hard-coded audio hyperparameters
SAMPLE_RATE = 16000 SAMPLE_RATE = 16000

View File

@ -4,8 +4,8 @@ from pyannote.audio import Pipeline
from typing import Optional, Union from typing import Optional, Union
import torch import torch
from .audio import load_audio, SAMPLE_RATE from whisperx.audio import load_audio, SAMPLE_RATE
from .types import TranscriptionResult, AlignedTranscriptionResult from whisperx.types import TranscriptionResult, AlignedTranscriptionResult
class DiarizationPipeline: class DiarizationPipeline:

View File

@ -1,17 +1,20 @@
import argparse import argparse
import gc import gc
import os import os
import sys
import warnings import warnings
import importlib.metadata
import platform
import numpy as np import numpy as np
import torch import torch
from .alignment import align, load_align_model from whisperx.alignment import align, load_align_model
from .asr import load_model from whisperx.asr import load_model
from .audio import load_audio from whisperx.audio import load_audio
from .diarize import DiarizationPipeline, assign_word_speakers from whisperx.diarize import DiarizationPipeline, assign_word_speakers
from .types import AlignedTranscriptionResult, TranscriptionResult from whisperx.types import AlignedTranscriptionResult, TranscriptionResult
from .utils import ( from whisperx.utils import (
LANGUAGES, LANGUAGES,
TO_LANGUAGE_CODE, TO_LANGUAGE_CODE,
get_writer, get_writer,
@ -85,6 +88,8 @@ def cli():
parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face Access Token to access PyAnnote gated models") parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face Access Token to access PyAnnote gated models")
parser.add_argument("--print_progress", type=str2bool, default = False, help = "if True, progress will be printed in transcribe() and align() methods.") parser.add_argument("--print_progress", type=str2bool, default = False, help = "if True, progress will be printed in transcribe() and align() methods.")
parser.add_argument("--version", "-V", action="version", version=f"%(prog)s {importlib.metadata.version('whisperx')}",help="Show whisperx version information and exit")
parser.add_argument("--python-version", "-P", action="version", version=f"Python {platform.python_version()} ({platform.python_implementation()})",help="Show python version information and exit")
# fmt: on # fmt: on
args = parser.parse_args().__dict__ args = parser.parse_args().__dict__
@ -138,7 +143,9 @@ def cli():
f"{model_name} is an English-only model but received '{args['language']}'; using English instead." f"{model_name} is an English-only model but received '{args['language']}'; using English instead."
) )
args["language"] = "en" args["language"] = "en"
align_language = args["language"] if args["language"] is not None else "en" # default to loading english if not specified align_language = (
args["language"] if args["language"] is not None else "en"
) # default to loading english if not specified
temperature = args.pop("temperature") temperature = args.pop("temperature")
if (increment := args.pop("temperature_increment_on_fallback")) is not None: if (increment := args.pop("temperature_increment_on_fallback")) is not None:
@ -179,7 +186,24 @@ def cli():
results = [] results = []
tmp_results = [] tmp_results = []
# model = load_model(model_name, device=device, download_root=model_dir) # model = load_model(model_name, device=device, download_root=model_dir)
model = load_model(model_name, device=device, device_index=device_index, download_root=model_dir, compute_type=compute_type, language=args['language'], asr_options=asr_options, vad_method=vad_method, vad_options={"chunk_size":chunk_size, "vad_onset": vad_onset, "vad_offset": vad_offset}, task=task, local_files_only=model_cache_only, threads=faster_whisper_threads) model = load_model(
model_name,
device=device,
device_index=device_index,
download_root=model_dir,
compute_type=compute_type,
language=args["language"],
asr_options=asr_options,
vad_method=vad_method,
vad_options={
"chunk_size": chunk_size,
"vad_onset": vad_onset,
"vad_offset": vad_offset,
},
task=task,
local_files_only=model_cache_only,
threads=faster_whisper_threads,
)
for audio_path in args.pop("audio"): for audio_path in args.pop("audio"):
audio = load_audio(audio_path) audio = load_audio(audio_path)
@ -203,7 +227,9 @@ def cli():
if not no_align: if not no_align:
tmp_results = results tmp_results = results
results = [] results = []
align_model, align_metadata = load_align_model(align_language, device, model_name=align_model) align_model, align_metadata = load_align_model(
align_language, device, model_name=align_model
)
for result, audio_path in tmp_results: for result, audio_path in tmp_results:
# >> Align # >> Align
if len(tmp_results) > 1: if len(tmp_results) > 1:
@ -215,8 +241,12 @@ def cli():
if align_model is not None and len(result["segments"]) > 0: if align_model is not None and len(result["segments"]) > 0:
if result.get("language", "en") != align_metadata["language"]: if result.get("language", "en") != align_metadata["language"]:
# load new language # load new language
print(f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language...") print(
align_model, align_metadata = load_align_model(result["language"], device) f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language..."
)
align_model, align_metadata = load_align_model(
result["language"], device
)
print(">>Performing alignment...") print(">>Performing alignment...")
result: AlignedTranscriptionResult = align( result: AlignedTranscriptionResult = align(
result["segments"], result["segments"],
@ -239,13 +269,17 @@ def cli():
# >> Diarize # >> Diarize
if diarize: if diarize:
if hf_token is None: if hf_token is None:
print("Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model...") print(
"Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model..."
)
tmp_results = results tmp_results = results
print(">>Performing diarization...") print(">>Performing diarization...")
results = [] results = []
diarize_model = DiarizationPipeline(use_auth_token=hf_token, device=device) diarize_model = DiarizationPipeline(use_auth_token=hf_token, device=device)
for result, input_audio_path in tmp_results: for result, input_audio_path in tmp_results:
diarize_segments = diarize_model(input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers) diarize_segments = diarize_model(
input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers
)
result = assign_word_speakers(diarize_segments, result) result = assign_word_speakers(diarize_segments, result)
results.append((result, input_audio_path)) results.append((result, input_audio_path))
# >> Write # >> Write
@ -253,5 +287,6 @@ def cli():
result["language"] = align_language result["language"] = align_language
writer(result, audio_path, writer_args) writer(result, audio_path, writer_args)
if __name__ == "__main__": if __name__ == "__main__":
cli() cli()

View File

@ -106,7 +106,6 @@ LANGUAGES = {
"jw": "javanese", "jw": "javanese",
"su": "sundanese", "su": "sundanese",
"yue": "cantonese", "yue": "cantonese",
"lv": "latvian",
} }
# language code lookup by name, with a few language aliases # language code lookup by name, with a few language aliases

View File

@ -1,3 +1,3 @@
from whisperx.vads.pyannote import Pyannote from whisperx.vads.pyannote import Pyannote as Pyannote
from whisperx.vads.silero import Silero from whisperx.vads.silero import Silero as Silero
from whisperx.vads.vad import Vad from whisperx.vads.vad import Vad as Vad

View File

@ -1,6 +1,4 @@
import hashlib
import os import os
import urllib
from typing import Callable, Text, Union from typing import Callable, Text, Union
from typing import Optional from typing import Optional
@ -12,11 +10,11 @@ from pyannote.audio.pipelines import VoiceActivityDetection
from pyannote.audio.pipelines.utils import PipelineModel from pyannote.audio.pipelines.utils import PipelineModel
from pyannote.core import Annotation, SlidingWindowFeature from pyannote.core import Annotation, SlidingWindowFeature
from pyannote.core import Segment from pyannote.core import Segment
from tqdm import tqdm
from whisperx.diarize import Segment as SegmentX from whisperx.diarize import Segment as SegmentX
from whisperx.vads.vad import Vad from whisperx.vads.vad import Vad
def load_vad_model(device, vad_onset=0.500, vad_offset=0.363, use_auth_token=None, model_fp=None): def load_vad_model(device, vad_onset=0.500, vad_offset=0.363, use_auth_token=None, model_fp=None):
model_dir = torch.hub._get_torch_home() model_dir = torch.hub._get_torch_home()