mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
Compare commits
1 Commits
v3.3.3
...
improve-co
Author | SHA1 | Date | |
---|---|---|---|
88939b9e8a |
22
.github/workflows/build-and-release.yml
vendored
22
.github/workflows/build-and-release.yml
vendored
@ -11,21 +11,25 @@ jobs:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v5
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
version: "0.5.14"
|
||||
python-version: "3.9"
|
||||
|
||||
- name: Build package
|
||||
run: uv build
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install build
|
||||
|
||||
- name: Build wheels
|
||||
run: python -m build --wheel
|
||||
|
||||
- name: Release to Github
|
||||
uses: softprops/action-gh-release@v2
|
||||
with:
|
||||
files: dist/*.whl
|
||||
files: dist/*
|
||||
|
||||
- name: Publish package to PyPi
|
||||
run: uv publish
|
||||
env:
|
||||
UV_PUBLISH_TOKEN: ${{ secrets.PYPI_API_TOKEN }}
|
||||
uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
|
||||
with:
|
||||
user: __token__
|
||||
password: ${{ secrets.PYPI_API_TOKEN }}
|
||||
|
15
.github/workflows/python-compatibility.yml
vendored
15
.github/workflows/python-compatibility.yml
vendored
@ -5,7 +5,7 @@ on:
|
||||
branches: [main]
|
||||
pull_request:
|
||||
branches: [main]
|
||||
workflow_dispatch: # Allows manual triggering from GitHub UI
|
||||
workflow_dispatch: # Allows manual triggering from GitHub UI
|
||||
|
||||
jobs:
|
||||
test:
|
||||
@ -17,15 +17,16 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v5
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
version: "0.5.14"
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Install the project
|
||||
run: uv sync --all-extras
|
||||
- name: Install package
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install .
|
||||
|
||||
- name: Test import
|
||||
run: |
|
||||
uv run python -c "import whisperx; print('Successfully imported whisperx')"
|
||||
python -c "import whisperx; print('Successfully imported whisperx')"
|
||||
|
35
.github/workflows/tmp.yml
vendored
Normal file
35
.github/workflows/tmp.yml
vendored
Normal file
@ -0,0 +1,35 @@
|
||||
name: Python Compatibility Test (PyPi)
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main]
|
||||
pull_request:
|
||||
branches: [main]
|
||||
workflow_dispatch: # Allows manual triggering from GitHub UI
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.9", "3.10", "3.11", "3.12"]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Install package
|
||||
run: |
|
||||
pip install whisperx
|
||||
|
||||
- name: Print packages
|
||||
run: |
|
||||
pip list
|
||||
|
||||
- name: Test import
|
||||
run: |
|
||||
python -c "import whisperx; print('Successfully imported whisperx')"
|
49
README.md
49
README.md
@ -62,41 +62,54 @@ This repository provides fast automatic speech recognition (70x realtime with la
|
||||
- Paper drop🎓👨🏫! Please see our [ArxiV preprint](https://arxiv.org/abs/2303.00747) for benchmarking and details of WhisperX. We also introduce more efficient batch inference resulting in large-v2 with *60-70x REAL TIME speed.
|
||||
|
||||
<h2 align="left" id="setup">Setup ⚙️</h2>
|
||||
Tested for PyTorch 2.0, Python 3.10 (use other versions at your own risk!)
|
||||
|
||||
### 1. Simple Installation (Recommended)
|
||||
GPU execution requires the NVIDIA libraries cuBLAS 11.x and cuDNN 8.x to be installed on the system. Please refer to the [CTranslate2 documentation](https://opennmt.net/CTranslate2/installation.html).
|
||||
|
||||
The easiest way to install WhisperX is through PyPi:
|
||||
|
||||
### 1. Create Python3.10 environment
|
||||
|
||||
`conda create --name whisperx python=3.10`
|
||||
|
||||
`conda activate whisperx`
|
||||
|
||||
|
||||
### 2. Install PyTorch, e.g. for Linux and Windows CUDA11.8:
|
||||
|
||||
`conda install pytorch==2.0.0 torchaudio==2.0.0 pytorch-cuda=11.8 -c pytorch -c nvidia`
|
||||
|
||||
See other methods [here.](https://pytorch.org/get-started/previous-versions/#v200)
|
||||
|
||||
### 3. Install WhisperX
|
||||
|
||||
You have several installation options:
|
||||
|
||||
#### Option A: Stable Release (recommended)
|
||||
Install the latest stable version from PyPI:
|
||||
|
||||
```bash
|
||||
pip install whisperx
|
||||
```
|
||||
|
||||
Or if using [uvx](https://docs.astral.sh/uv/guides/tools/#running-tools):
|
||||
#### Option B: Development Version
|
||||
Install the latest development version directly from GitHub (may be unstable):
|
||||
|
||||
```bash
|
||||
uvx whisperx
|
||||
pip install git+https://github.com/m-bain/whisperx.git
|
||||
```
|
||||
|
||||
### 2. Advanced Installation Options
|
||||
|
||||
These installation methods are for developers or users with specific needs. If you're not sure, stick with the simple installation above.
|
||||
|
||||
#### Option A: Install from GitHub
|
||||
|
||||
To install directly from the GitHub repository:
|
||||
If already installed, update to the most recent commit:
|
||||
|
||||
```bash
|
||||
uvx git+https://github.com/m-bain/whisperX.git
|
||||
pip install git+https://github.com/m-bain/whisperx.git --upgrade
|
||||
```
|
||||
|
||||
#### Option B: Developer Installation
|
||||
|
||||
If you want to modify the code or contribute to the project:
|
||||
|
||||
#### Option C: Development Mode
|
||||
If you wish to modify the package, clone and install in editable mode:
|
||||
```bash
|
||||
git clone https://github.com/m-bain/whisperX.git
|
||||
cd whisperX
|
||||
uv sync --all-extras --dev
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
> **Note**: The development version may contain experimental features and bugs. Use the stable PyPI release for production environments.
|
||||
@ -104,12 +117,12 @@ uv sync --all-extras --dev
|
||||
You may also need to install ffmpeg, rust etc. Follow openAI instructions here https://github.com/openai/whisper#setup.
|
||||
|
||||
### Speaker Diarization
|
||||
|
||||
To **enable Speaker Diarization**, include your Hugging Face access token (read) that you can generate from [Here](https://huggingface.co/settings/tokens) after the `--hf_token` argument and accept the user agreement for the following models: [Segmentation](https://huggingface.co/pyannote/segmentation-3.0) and [Speaker-Diarization-3.1](https://huggingface.co/pyannote/speaker-diarization-3.1) (if you choose to use Speaker-Diarization 2.x, follow requirements [here](https://huggingface.co/pyannote/speaker-diarization) instead.)
|
||||
|
||||
> **Note**<br>
|
||||
> As of Oct 11, 2023, there is a known issue regarding slow performance with pyannote/Speaker-Diarization-3.0 in whisperX. It is due to dependency conflicts between faster-whisper and pyannote-audio 3.0.0. Please see [this issue](https://github.com/m-bain/whisperX/issues/499) for more details and potential workarounds.
|
||||
|
||||
|
||||
<h2 align="left" id="example">Usage 💬 (command line)</h2>
|
||||
|
||||
### English
|
||||
|
@ -1,36 +0,0 @@
|
||||
[project]
|
||||
urls = { repository = "https://github.com/m-bain/whisperx" }
|
||||
authors = [{ name = "Max Bain" }]
|
||||
name = "whisperx"
|
||||
version = "3.3.3"
|
||||
description = "Time-Accurate Automatic Speech Recognition using Whisper."
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.9, <3.13"
|
||||
license = { text = "BSD-2-Clause" }
|
||||
|
||||
dependencies = [
|
||||
"ctranslate2<4.5.0",
|
||||
"faster-whisper>=1.1.1",
|
||||
"nltk>=3.9.1",
|
||||
"numpy>=2.0.2",
|
||||
"onnxruntime>=1.19",
|
||||
"pandas>=2.2.3",
|
||||
"pyannote-audio>=3.3.2",
|
||||
"torch>=2.5.1",
|
||||
"torchaudio>=2.5.1",
|
||||
"transformers>=4.48.0",
|
||||
]
|
||||
|
||||
|
||||
[project.scripts]
|
||||
whisperx = "whisperx.transcribe:cli"
|
||||
|
||||
[build-system]
|
||||
requires = ["setuptools"]
|
||||
|
||||
[tool.setuptools]
|
||||
include-package-data = true
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["."]
|
||||
include = ["whisperx*"]
|
8
requirements.txt
Normal file
8
requirements.txt
Normal file
@ -0,0 +1,8 @@
|
||||
torch>=2
|
||||
torchaudio>=2
|
||||
faster-whisper==1.1.0
|
||||
ctranslate2<4.5.0
|
||||
transformers
|
||||
pandas
|
||||
setuptools>=65
|
||||
nltk
|
33
setup.py
Normal file
33
setup.py
Normal file
@ -0,0 +1,33 @@
|
||||
import os
|
||||
|
||||
import pkg_resources
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
with open("README.md", "r", encoding="utf-8") as f:
|
||||
long_description = f.read()
|
||||
|
||||
setup(
|
||||
name="whisperx",
|
||||
py_modules=["whisperx"],
|
||||
version="3.3.1",
|
||||
description="Time-Accurate Automatic Speech Recognition using Whisper.",
|
||||
long_description=long_description,
|
||||
long_description_content_type="text/markdown",
|
||||
python_requires=">=3.9, <3.13",
|
||||
author="Max Bain",
|
||||
url="https://github.com/m-bain/whisperx",
|
||||
license="BSD-2-Clause",
|
||||
packages=find_packages(exclude=["tests*"]),
|
||||
install_requires=[
|
||||
str(r)
|
||||
for r in pkg_resources.parse_requirements(
|
||||
open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
|
||||
)
|
||||
]
|
||||
+ [f"pyannote.audio==3.3.2"],
|
||||
entry_points={
|
||||
"console_scripts": ["whisperx=whisperx.transcribe:cli"],
|
||||
},
|
||||
include_package_data=True,
|
||||
extras_require={"dev": ["pytest"]},
|
||||
)
|
@ -1,5 +1,6 @@
|
||||
import math
|
||||
from whisperx.conjunctions import get_conjunctions, get_comma
|
||||
from .conjunctions import get_conjunctions, get_comma
|
||||
from typing import TextIO
|
||||
|
||||
def normal_round(n):
|
||||
if n - math.floor(n) < 0.5:
|
||||
|
@ -1,7 +1,4 @@
|
||||
from whisperx.alignment import load_align_model as load_align_model, align as align
|
||||
from whisperx.asr import load_model as load_model
|
||||
from whisperx.audio import load_audio as load_audio
|
||||
from whisperx.diarize import (
|
||||
assign_word_speakers as assign_word_speakers,
|
||||
DiarizationPipeline as DiarizationPipeline,
|
||||
)
|
||||
from .alignment import load_align_model, align
|
||||
from .audio import load_audio
|
||||
from .diarize import assign_word_speakers, DiarizationPipeline
|
||||
from .asr import load_model
|
||||
|
@ -1,4 +1,4 @@
|
||||
from whisperx.transcribe import cli
|
||||
from .transcribe import cli
|
||||
|
||||
|
||||
cli()
|
||||
|
@ -13,9 +13,9 @@ import torch
|
||||
import torchaudio
|
||||
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
|
||||
|
||||
from whisperx.audio import SAMPLE_RATE, load_audio
|
||||
from whisperx.utils import interpolate_nans
|
||||
from whisperx.types import (
|
||||
from .audio import SAMPLE_RATE, load_audio
|
||||
from .utils import interpolate_nans
|
||||
from .types import (
|
||||
AlignedTranscriptionResult,
|
||||
SingleSegment,
|
||||
SingleAlignedSegment,
|
||||
|
144
whisperx/asr.py
144
whisperx/asr.py
@ -11,12 +11,14 @@ from faster_whisper.transcribe import TranscriptionOptions, get_ctranslate2_stor
|
||||
from transformers import Pipeline
|
||||
from transformers.pipelines.pt_utils import PipelineIterator
|
||||
|
||||
from whisperx.audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram
|
||||
from whisperx.types import SingleSegment, TranscriptionResult
|
||||
from whisperx.vads import Vad, Silero, Pyannote
|
||||
|
||||
from .audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram
|
||||
from .types import SingleSegment, TranscriptionResult
|
||||
from .vads import Vad, Silero, Pyannote
|
||||
|
||||
def find_numeral_symbol_tokens(tokenizer):
|
||||
"""
|
||||
Finds tokens that represent numeral and symbols.
|
||||
"""
|
||||
numeral_symbol_tokens = []
|
||||
for i in range(tokenizer.eot):
|
||||
token = tokenizer.decode([i]).removeprefix(" ")
|
||||
@ -26,10 +28,10 @@ def find_numeral_symbol_tokens(tokenizer):
|
||||
return numeral_symbol_tokens
|
||||
|
||||
class WhisperModel(faster_whisper.WhisperModel):
|
||||
'''
|
||||
FasterWhisperModel provides batched inference for faster-whisper.
|
||||
Currently only works in non-timestamp mode and fixed prompt for all samples in batch.
|
||||
'''
|
||||
"""
|
||||
Wrapper around faster-whisper's WhisperModel to enable batched inference.
|
||||
Currently, it only supports non-timestamp mode and a fixed prompt for all samples in a batch.
|
||||
"""
|
||||
|
||||
def generate_segment_batched(
|
||||
self,
|
||||
@ -38,133 +40,87 @@ class WhisperModel(faster_whisper.WhisperModel):
|
||||
options: TranscriptionOptions,
|
||||
encoder_output=None,
|
||||
):
|
||||
"""
|
||||
Generates transcription for a batch of audio segments.
|
||||
|
||||
Args:
|
||||
features: The input audio features.
|
||||
tokenizer: The tokenizer used to decode the generated tokens.
|
||||
options: Transcription options.
|
||||
encoder_output: Output from the encoder model.
|
||||
|
||||
Returns:
|
||||
The decoded transcription text.
|
||||
"""
|
||||
batch_size = features.shape[0]
|
||||
# Initialize tokens and prompt for the generation process.
|
||||
all_tokens = []
|
||||
prompt_reset_since = 0
|
||||
# Check if an initial prompt is provided and handle it.
|
||||
if options.initial_prompt is not None:
|
||||
initial_prompt = " " + options.initial_prompt.strip()
|
||||
initial_prompt_tokens = tokenizer.encode(initial_prompt)
|
||||
all_tokens.extend(initial_prompt_tokens)
|
||||
# Prepare the prompt for the current batch.
|
||||
previous_tokens = all_tokens[prompt_reset_since:]
|
||||
prompt = self.get_prompt(
|
||||
tokenizer,
|
||||
previous_tokens,
|
||||
without_timestamps=options.without_timestamps,
|
||||
prefix=options.prefix,
|
||||
hotwords=options.hotwords
|
||||
)
|
||||
|
||||
|
||||
# Encode the features to obtain the encoder output.
|
||||
encoder_output = self.encode(features)
|
||||
|
||||
# Determine the maximum initial timestamp index based on the options.
|
||||
max_initial_timestamp_index = int(
|
||||
round(options.max_initial_timestamp / self.time_precision)
|
||||
)
|
||||
|
||||
# Generate the transcription result for the batch.
|
||||
result = self.model.generate(
|
||||
encoder_output,
|
||||
[prompt] * batch_size,
|
||||
beam_size=options.beam_size,
|
||||
patience=options.patience,
|
||||
length_penalty=options.length_penalty,
|
||||
max_length=self.max_length,
|
||||
suppress_blank=options.suppress_blank,
|
||||
suppress_tokens=options.suppress_tokens,
|
||||
)
|
||||
encoder_output,
|
||||
[prompt] * batch_size,
|
||||
beam_size=options.beam_size,
|
||||
patience=options.patience,
|
||||
length_penalty=options.length_penalty,
|
||||
max_length=self.max_length,
|
||||
suppress_blank=options.suppress_blank,
|
||||
suppress_tokens=options.suppress_tokens,
|
||||
)
|
||||
|
||||
# Extract the token sequences from the result.
|
||||
tokens_batch = [x.sequences_ids[0] for x in result]
|
||||
|
||||
# Define an inner function to decode the tokens for each batch.
|
||||
def decode_batch(tokens: List[List[int]]) -> str:
|
||||
res = []
|
||||
for tk in tokens:
|
||||
res.append([token for token in tk if token < tokenizer.eot])
|
||||
# text_tokens = [token for token in tokens if token < self.eot]
|
||||
return tokenizer.tokenizer.decode_batch(res)
|
||||
|
||||
# Decode the tokens to get the transcription text.
|
||||
text = decode_batch(tokens_batch)
|
||||
|
||||
return text
|
||||
|
||||
def encode(self, features: np.ndarray) -> ctranslate2.StorageView:
|
||||
# When the model is running on multiple GPUs, the encoder output should be moved
|
||||
# to the CPU since we don't know which GPU will handle the next job.
|
||||
"""
|
||||
Encodes the audio features using the CTranslate2 storage.
|
||||
|
||||
When the model is running on multiple GPUs, the encoder output should be moved
|
||||
to the CPU since we don't know which GPU will handle the next job.
|
||||
"""
|
||||
# When the model is running on multiple GPUs, the encoder output should be moved to the CPU.
|
||||
to_cpu = self.model.device == "cuda" and len(self.model.device_index) > 1
|
||||
# unsqueeze if batch size = 1
|
||||
# If the batch size is 1, unsqueeze the features to ensure it is a 3D array.
|
||||
if len(features.shape) == 2:
|
||||
features = np.expand_dims(features, 0)
|
||||
features = get_ctranslate2_storage(features)
|
||||
|
||||
# call the model
|
||||
return self.model.encode(features, to_cpu=to_cpu)
|
||||
|
||||
class FasterWhisperPipeline(Pipeline):
|
||||
"""
|
||||
Huggingface Pipeline wrapper for FasterWhisperModel.
|
||||
"""
|
||||
# TODO:
|
||||
# - add support for timestamp mode
|
||||
# - add support for custom inference kwargs
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: WhisperModel,
|
||||
vad,
|
||||
vad_params: dict,
|
||||
options: TranscriptionOptions,
|
||||
tokenizer: Optional[Tokenizer] = None,
|
||||
device: Union[int, str, "torch.device"] = -1,
|
||||
framework="pt",
|
||||
language: Optional[str] = None,
|
||||
suppress_numerals: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.options = options
|
||||
self.preset_language = language
|
||||
self.suppress_numerals = suppress_numerals
|
||||
self._batch_size = kwargs.pop("batch_size", None)
|
||||
self._num_workers = 1
|
||||
self._preprocess_params, self._forward_params, self._postprocess_params = self._sanitize_parameters(**kwargs)
|
||||
self.call_count = 0
|
||||
self.framework = framework
|
||||
if self.framework == "pt":
|
||||
if isinstance(device, torch.device):
|
||||
self.device = device
|
||||
elif isinstance(device, str):
|
||||
self.device = torch.device(device)
|
||||
elif device < 0:
|
||||
self.device = torch.device("cpu")
|
||||
else:
|
||||
self.device = torch.device(f"cuda:{device}")
|
||||
else:
|
||||
self.device = device
|
||||
|
||||
super(Pipeline, self).__init__()
|
||||
self.vad_model = vad
|
||||
self._vad_params = vad_params
|
||||
|
||||
def _sanitize_parameters(self, **kwargs):
|
||||
preprocess_kwargs = {}
|
||||
if "tokenizer" in kwargs:
|
||||
preprocess_kwargs["maybe_arg"] = kwargs["maybe_arg"]
|
||||
return preprocess_kwargs, {}, {}
|
||||
|
||||
def preprocess(self, audio):
|
||||
audio = audio['inputs']
|
||||
model_n_mels = self.model.feat_kwargs.get("feature_size")
|
||||
features = log_mel_spectrogram(
|
||||
audio,
|
||||
n_mels=model_n_mels if model_n_mels is not None else 80,
|
||||
padding=N_SAMPLES - audio.shape[0],
|
||||
)
|
||||
return {'inputs': features}
|
||||
|
||||
def _forward(self, model_inputs):
|
||||
outputs = self.model.generate_segment_batched(model_inputs['inputs'], self.tokenizer, self.options)
|
||||
return {'text': outputs}
|
||||
|
||||
def postprocess(self, model_outputs):
|
||||
return model_outputs
|
||||
|
||||
def get_iterator(
|
||||
self,
|
||||
inputs,
|
||||
|
@ -7,7 +7,7 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from whisperx.utils import exact_div
|
||||
from .utils import exact_div
|
||||
|
||||
# hard-coded audio hyperparameters
|
||||
SAMPLE_RATE = 16000
|
||||
|
@ -4,8 +4,8 @@ from pyannote.audio import Pipeline
|
||||
from typing import Optional, Union
|
||||
import torch
|
||||
|
||||
from whisperx.audio import load_audio, SAMPLE_RATE
|
||||
from whisperx.types import TranscriptionResult, AlignedTranscriptionResult
|
||||
from .audio import load_audio, SAMPLE_RATE
|
||||
from .types import TranscriptionResult, AlignedTranscriptionResult
|
||||
|
||||
|
||||
class DiarizationPipeline:
|
||||
|
@ -1,20 +1,17 @@
|
||||
import argparse
|
||||
import gc
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
import importlib.metadata
|
||||
import platform
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from whisperx.alignment import align, load_align_model
|
||||
from whisperx.asr import load_model
|
||||
from whisperx.audio import load_audio
|
||||
from whisperx.diarize import DiarizationPipeline, assign_word_speakers
|
||||
from whisperx.types import AlignedTranscriptionResult, TranscriptionResult
|
||||
from whisperx.utils import (
|
||||
from .alignment import align, load_align_model
|
||||
from .asr import load_model
|
||||
from .audio import load_audio
|
||||
from .diarize import DiarizationPipeline, assign_word_speakers
|
||||
from .types import AlignedTranscriptionResult, TranscriptionResult
|
||||
from .utils import (
|
||||
LANGUAGES,
|
||||
TO_LANGUAGE_CODE,
|
||||
get_writer,
|
||||
@ -88,8 +85,6 @@ def cli():
|
||||
parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face Access Token to access PyAnnote gated models")
|
||||
|
||||
parser.add_argument("--print_progress", type=str2bool, default = False, help = "if True, progress will be printed in transcribe() and align() methods.")
|
||||
parser.add_argument("--version", "-V", action="version", version=f"%(prog)s {importlib.metadata.version('whisperx')}",help="Show whisperx version information and exit")
|
||||
parser.add_argument("--python-version", "-P", action="version", version=f"Python {platform.python_version()} ({platform.python_implementation()})",help="Show python version information and exit")
|
||||
# fmt: on
|
||||
|
||||
args = parser.parse_args().__dict__
|
||||
@ -143,9 +138,7 @@ def cli():
|
||||
f"{model_name} is an English-only model but received '{args['language']}'; using English instead."
|
||||
)
|
||||
args["language"] = "en"
|
||||
align_language = (
|
||||
args["language"] if args["language"] is not None else "en"
|
||||
) # default to loading english if not specified
|
||||
align_language = args["language"] if args["language"] is not None else "en" # default to loading english if not specified
|
||||
|
||||
temperature = args.pop("temperature")
|
||||
if (increment := args.pop("temperature_increment_on_fallback")) is not None:
|
||||
@ -181,29 +174,12 @@ def cli():
|
||||
if args["max_line_count"] and not args["max_line_width"]:
|
||||
warnings.warn("--max_line_count has no effect without --max_line_width")
|
||||
writer_args = {arg: args.pop(arg) for arg in word_options}
|
||||
|
||||
|
||||
# Part 1: VAD & ASR Loop
|
||||
results = []
|
||||
tmp_results = []
|
||||
# model = load_model(model_name, device=device, download_root=model_dir)
|
||||
model = load_model(
|
||||
model_name,
|
||||
device=device,
|
||||
device_index=device_index,
|
||||
download_root=model_dir,
|
||||
compute_type=compute_type,
|
||||
language=args["language"],
|
||||
asr_options=asr_options,
|
||||
vad_method=vad_method,
|
||||
vad_options={
|
||||
"chunk_size": chunk_size,
|
||||
"vad_onset": vad_onset,
|
||||
"vad_offset": vad_offset,
|
||||
},
|
||||
task=task,
|
||||
local_files_only=model_cache_only,
|
||||
threads=faster_whisper_threads,
|
||||
)
|
||||
model = load_model(model_name, device=device, device_index=device_index, download_root=model_dir, compute_type=compute_type, language=args['language'], asr_options=asr_options, vad_method=vad_method, vad_options={"chunk_size":chunk_size, "vad_onset": vad_onset, "vad_offset": vad_offset}, task=task, local_files_only=model_cache_only, threads=faster_whisper_threads)
|
||||
|
||||
for audio_path in args.pop("audio"):
|
||||
audio = load_audio(audio_path)
|
||||
@ -227,9 +203,7 @@ def cli():
|
||||
if not no_align:
|
||||
tmp_results = results
|
||||
results = []
|
||||
align_model, align_metadata = load_align_model(
|
||||
align_language, device, model_name=align_model
|
||||
)
|
||||
align_model, align_metadata = load_align_model(align_language, device, model_name=align_model)
|
||||
for result, audio_path in tmp_results:
|
||||
# >> Align
|
||||
if len(tmp_results) > 1:
|
||||
@ -241,12 +215,8 @@ def cli():
|
||||
if align_model is not None and len(result["segments"]) > 0:
|
||||
if result.get("language", "en") != align_metadata["language"]:
|
||||
# load new language
|
||||
print(
|
||||
f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language..."
|
||||
)
|
||||
align_model, align_metadata = load_align_model(
|
||||
result["language"], device
|
||||
)
|
||||
print(f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language...")
|
||||
align_model, align_metadata = load_align_model(result["language"], device)
|
||||
print(">>Performing alignment...")
|
||||
result: AlignedTranscriptionResult = align(
|
||||
result["segments"],
|
||||
@ -269,17 +239,13 @@ def cli():
|
||||
# >> Diarize
|
||||
if diarize:
|
||||
if hf_token is None:
|
||||
print(
|
||||
"Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model..."
|
||||
)
|
||||
print("Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model...")
|
||||
tmp_results = results
|
||||
print(">>Performing diarization...")
|
||||
results = []
|
||||
diarize_model = DiarizationPipeline(use_auth_token=hf_token, device=device)
|
||||
for result, input_audio_path in tmp_results:
|
||||
diarize_segments = diarize_model(
|
||||
input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers
|
||||
)
|
||||
diarize_segments = diarize_model(input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers)
|
||||
result = assign_word_speakers(diarize_segments, result)
|
||||
results.append((result, input_audio_path))
|
||||
# >> Write
|
||||
@ -287,6 +253,5 @@ def cli():
|
||||
result["language"] = align_language
|
||||
writer(result, audio_path, writer_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
|
@ -106,6 +106,7 @@ LANGUAGES = {
|
||||
"jw": "javanese",
|
||||
"su": "sundanese",
|
||||
"yue": "cantonese",
|
||||
"lv": "latvian",
|
||||
}
|
||||
|
||||
# language code lookup by name, with a few language aliases
|
||||
|
@ -1,3 +1,3 @@
|
||||
from whisperx.vads.pyannote import Pyannote as Pyannote
|
||||
from whisperx.vads.silero import Silero as Silero
|
||||
from whisperx.vads.vad import Vad as Vad
|
||||
from whisperx.vads.pyannote import Pyannote
|
||||
from whisperx.vads.silero import Silero
|
||||
from whisperx.vads.vad import Vad
|
@ -1,4 +1,6 @@
|
||||
import hashlib
|
||||
import os
|
||||
import urllib
|
||||
from typing import Callable, Text, Union
|
||||
from typing import Optional
|
||||
|
||||
@ -10,11 +12,11 @@ from pyannote.audio.pipelines import VoiceActivityDetection
|
||||
from pyannote.audio.pipelines.utils import PipelineModel
|
||||
from pyannote.core import Annotation, SlidingWindowFeature
|
||||
from pyannote.core import Segment
|
||||
from tqdm import tqdm
|
||||
|
||||
from whisperx.diarize import Segment as SegmentX
|
||||
from whisperx.vads.vad import Vad
|
||||
|
||||
|
||||
def load_vad_model(device, vad_onset=0.500, vad_offset=0.363, use_auth_token=None, model_fp=None):
|
||||
model_dir = torch.hub._get_torch_home()
|
||||
|
||||
|
Reference in New Issue
Block a user