mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
Compare commits
1 Commits
v3.3.3
...
improve-co
Author | SHA1 | Date | |
---|---|---|---|
88939b9e8a |
22
.github/workflows/build-and-release.yml
vendored
22
.github/workflows/build-and-release.yml
vendored
@ -11,21 +11,25 @@ jobs:
|
|||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Install uv
|
- name: Set up Python
|
||||||
uses: astral-sh/setup-uv@v5
|
uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
version: "0.5.14"
|
|
||||||
python-version: "3.9"
|
python-version: "3.9"
|
||||||
|
|
||||||
- name: Build package
|
- name: Install dependencies
|
||||||
run: uv build
|
run: |
|
||||||
|
python -m pip install build
|
||||||
|
|
||||||
|
- name: Build wheels
|
||||||
|
run: python -m build --wheel
|
||||||
|
|
||||||
- name: Release to Github
|
- name: Release to Github
|
||||||
uses: softprops/action-gh-release@v2
|
uses: softprops/action-gh-release@v2
|
||||||
with:
|
with:
|
||||||
files: dist/*.whl
|
files: dist/*
|
||||||
|
|
||||||
- name: Publish package to PyPi
|
- name: Publish package to PyPi
|
||||||
run: uv publish
|
uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
|
||||||
env:
|
with:
|
||||||
UV_PUBLISH_TOKEN: ${{ secrets.PYPI_API_TOKEN }}
|
user: __token__
|
||||||
|
password: ${{ secrets.PYPI_API_TOKEN }}
|
||||||
|
13
.github/workflows/python-compatibility.yml
vendored
13
.github/workflows/python-compatibility.yml
vendored
@ -17,15 +17,16 @@ jobs:
|
|||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Install uv
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
uses: astral-sh/setup-uv@v5
|
uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
version: "0.5.14"
|
|
||||||
python-version: ${{ matrix.python-version }}
|
python-version: ${{ matrix.python-version }}
|
||||||
|
|
||||||
- name: Install the project
|
- name: Install package
|
||||||
run: uv sync --all-extras
|
run: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
pip install .
|
||||||
|
|
||||||
- name: Test import
|
- name: Test import
|
||||||
run: |
|
run: |
|
||||||
uv run python -c "import whisperx; print('Successfully imported whisperx')"
|
python -c "import whisperx; print('Successfully imported whisperx')"
|
||||||
|
35
.github/workflows/tmp.yml
vendored
Normal file
35
.github/workflows/tmp.yml
vendored
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
name: Python Compatibility Test (PyPi)
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [main]
|
||||||
|
pull_request:
|
||||||
|
branches: [main]
|
||||||
|
workflow_dispatch: # Allows manual triggering from GitHub UI
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
test:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
python-version: ["3.9", "3.10", "3.11", "3.12"]
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
|
||||||
|
- name: Install package
|
||||||
|
run: |
|
||||||
|
pip install whisperx
|
||||||
|
|
||||||
|
- name: Print packages
|
||||||
|
run: |
|
||||||
|
pip list
|
||||||
|
|
||||||
|
- name: Test import
|
||||||
|
run: |
|
||||||
|
python -c "import whisperx; print('Successfully imported whisperx')"
|
49
README.md
49
README.md
@ -62,41 +62,54 @@ This repository provides fast automatic speech recognition (70x realtime with la
|
|||||||
- Paper drop🎓👨🏫! Please see our [ArxiV preprint](https://arxiv.org/abs/2303.00747) for benchmarking and details of WhisperX. We also introduce more efficient batch inference resulting in large-v2 with *60-70x REAL TIME speed.
|
- Paper drop🎓👨🏫! Please see our [ArxiV preprint](https://arxiv.org/abs/2303.00747) for benchmarking and details of WhisperX. We also introduce more efficient batch inference resulting in large-v2 with *60-70x REAL TIME speed.
|
||||||
|
|
||||||
<h2 align="left" id="setup">Setup ⚙️</h2>
|
<h2 align="left" id="setup">Setup ⚙️</h2>
|
||||||
|
Tested for PyTorch 2.0, Python 3.10 (use other versions at your own risk!)
|
||||||
|
|
||||||
### 1. Simple Installation (Recommended)
|
GPU execution requires the NVIDIA libraries cuBLAS 11.x and cuDNN 8.x to be installed on the system. Please refer to the [CTranslate2 documentation](https://opennmt.net/CTranslate2/installation.html).
|
||||||
|
|
||||||
The easiest way to install WhisperX is through PyPi:
|
|
||||||
|
### 1. Create Python3.10 environment
|
||||||
|
|
||||||
|
`conda create --name whisperx python=3.10`
|
||||||
|
|
||||||
|
`conda activate whisperx`
|
||||||
|
|
||||||
|
|
||||||
|
### 2. Install PyTorch, e.g. for Linux and Windows CUDA11.8:
|
||||||
|
|
||||||
|
`conda install pytorch==2.0.0 torchaudio==2.0.0 pytorch-cuda=11.8 -c pytorch -c nvidia`
|
||||||
|
|
||||||
|
See other methods [here.](https://pytorch.org/get-started/previous-versions/#v200)
|
||||||
|
|
||||||
|
### 3. Install WhisperX
|
||||||
|
|
||||||
|
You have several installation options:
|
||||||
|
|
||||||
|
#### Option A: Stable Release (recommended)
|
||||||
|
Install the latest stable version from PyPI:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install whisperx
|
pip install whisperx
|
||||||
```
|
```
|
||||||
|
|
||||||
Or if using [uvx](https://docs.astral.sh/uv/guides/tools/#running-tools):
|
#### Option B: Development Version
|
||||||
|
Install the latest development version directly from GitHub (may be unstable):
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
uvx whisperx
|
pip install git+https://github.com/m-bain/whisperx.git
|
||||||
```
|
```
|
||||||
|
|
||||||
### 2. Advanced Installation Options
|
If already installed, update to the most recent commit:
|
||||||
|
|
||||||
These installation methods are for developers or users with specific needs. If you're not sure, stick with the simple installation above.
|
|
||||||
|
|
||||||
#### Option A: Install from GitHub
|
|
||||||
|
|
||||||
To install directly from the GitHub repository:
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
uvx git+https://github.com/m-bain/whisperX.git
|
pip install git+https://github.com/m-bain/whisperx.git --upgrade
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Option B: Developer Installation
|
#### Option C: Development Mode
|
||||||
|
If you wish to modify the package, clone and install in editable mode:
|
||||||
If you want to modify the code or contribute to the project:
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
git clone https://github.com/m-bain/whisperX.git
|
git clone https://github.com/m-bain/whisperX.git
|
||||||
cd whisperX
|
cd whisperX
|
||||||
uv sync --all-extras --dev
|
pip install -e .
|
||||||
```
|
```
|
||||||
|
|
||||||
> **Note**: The development version may contain experimental features and bugs. Use the stable PyPI release for production environments.
|
> **Note**: The development version may contain experimental features and bugs. Use the stable PyPI release for production environments.
|
||||||
@ -104,12 +117,12 @@ uv sync --all-extras --dev
|
|||||||
You may also need to install ffmpeg, rust etc. Follow openAI instructions here https://github.com/openai/whisper#setup.
|
You may also need to install ffmpeg, rust etc. Follow openAI instructions here https://github.com/openai/whisper#setup.
|
||||||
|
|
||||||
### Speaker Diarization
|
### Speaker Diarization
|
||||||
|
|
||||||
To **enable Speaker Diarization**, include your Hugging Face access token (read) that you can generate from [Here](https://huggingface.co/settings/tokens) after the `--hf_token` argument and accept the user agreement for the following models: [Segmentation](https://huggingface.co/pyannote/segmentation-3.0) and [Speaker-Diarization-3.1](https://huggingface.co/pyannote/speaker-diarization-3.1) (if you choose to use Speaker-Diarization 2.x, follow requirements [here](https://huggingface.co/pyannote/speaker-diarization) instead.)
|
To **enable Speaker Diarization**, include your Hugging Face access token (read) that you can generate from [Here](https://huggingface.co/settings/tokens) after the `--hf_token` argument and accept the user agreement for the following models: [Segmentation](https://huggingface.co/pyannote/segmentation-3.0) and [Speaker-Diarization-3.1](https://huggingface.co/pyannote/speaker-diarization-3.1) (if you choose to use Speaker-Diarization 2.x, follow requirements [here](https://huggingface.co/pyannote/speaker-diarization) instead.)
|
||||||
|
|
||||||
> **Note**<br>
|
> **Note**<br>
|
||||||
> As of Oct 11, 2023, there is a known issue regarding slow performance with pyannote/Speaker-Diarization-3.0 in whisperX. It is due to dependency conflicts between faster-whisper and pyannote-audio 3.0.0. Please see [this issue](https://github.com/m-bain/whisperX/issues/499) for more details and potential workarounds.
|
> As of Oct 11, 2023, there is a known issue regarding slow performance with pyannote/Speaker-Diarization-3.0 in whisperX. It is due to dependency conflicts between faster-whisper and pyannote-audio 3.0.0. Please see [this issue](https://github.com/m-bain/whisperX/issues/499) for more details and potential workarounds.
|
||||||
|
|
||||||
|
|
||||||
<h2 align="left" id="example">Usage 💬 (command line)</h2>
|
<h2 align="left" id="example">Usage 💬 (command line)</h2>
|
||||||
|
|
||||||
### English
|
### English
|
||||||
|
@ -1,36 +0,0 @@
|
|||||||
[project]
|
|
||||||
urls = { repository = "https://github.com/m-bain/whisperx" }
|
|
||||||
authors = [{ name = "Max Bain" }]
|
|
||||||
name = "whisperx"
|
|
||||||
version = "3.3.3"
|
|
||||||
description = "Time-Accurate Automatic Speech Recognition using Whisper."
|
|
||||||
readme = "README.md"
|
|
||||||
requires-python = ">=3.9, <3.13"
|
|
||||||
license = { text = "BSD-2-Clause" }
|
|
||||||
|
|
||||||
dependencies = [
|
|
||||||
"ctranslate2<4.5.0",
|
|
||||||
"faster-whisper>=1.1.1",
|
|
||||||
"nltk>=3.9.1",
|
|
||||||
"numpy>=2.0.2",
|
|
||||||
"onnxruntime>=1.19",
|
|
||||||
"pandas>=2.2.3",
|
|
||||||
"pyannote-audio>=3.3.2",
|
|
||||||
"torch>=2.5.1",
|
|
||||||
"torchaudio>=2.5.1",
|
|
||||||
"transformers>=4.48.0",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
[project.scripts]
|
|
||||||
whisperx = "whisperx.transcribe:cli"
|
|
||||||
|
|
||||||
[build-system]
|
|
||||||
requires = ["setuptools"]
|
|
||||||
|
|
||||||
[tool.setuptools]
|
|
||||||
include-package-data = true
|
|
||||||
|
|
||||||
[tool.setuptools.packages.find]
|
|
||||||
where = ["."]
|
|
||||||
include = ["whisperx*"]
|
|
8
requirements.txt
Normal file
8
requirements.txt
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
torch>=2
|
||||||
|
torchaudio>=2
|
||||||
|
faster-whisper==1.1.0
|
||||||
|
ctranslate2<4.5.0
|
||||||
|
transformers
|
||||||
|
pandas
|
||||||
|
setuptools>=65
|
||||||
|
nltk
|
33
setup.py
Normal file
33
setup.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import pkg_resources
|
||||||
|
from setuptools import find_packages, setup
|
||||||
|
|
||||||
|
with open("README.md", "r", encoding="utf-8") as f:
|
||||||
|
long_description = f.read()
|
||||||
|
|
||||||
|
setup(
|
||||||
|
name="whisperx",
|
||||||
|
py_modules=["whisperx"],
|
||||||
|
version="3.3.1",
|
||||||
|
description="Time-Accurate Automatic Speech Recognition using Whisper.",
|
||||||
|
long_description=long_description,
|
||||||
|
long_description_content_type="text/markdown",
|
||||||
|
python_requires=">=3.9, <3.13",
|
||||||
|
author="Max Bain",
|
||||||
|
url="https://github.com/m-bain/whisperx",
|
||||||
|
license="BSD-2-Clause",
|
||||||
|
packages=find_packages(exclude=["tests*"]),
|
||||||
|
install_requires=[
|
||||||
|
str(r)
|
||||||
|
for r in pkg_resources.parse_requirements(
|
||||||
|
open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
|
||||||
|
)
|
||||||
|
]
|
||||||
|
+ [f"pyannote.audio==3.3.2"],
|
||||||
|
entry_points={
|
||||||
|
"console_scripts": ["whisperx=whisperx.transcribe:cli"],
|
||||||
|
},
|
||||||
|
include_package_data=True,
|
||||||
|
extras_require={"dev": ["pytest"]},
|
||||||
|
)
|
@ -1,5 +1,6 @@
|
|||||||
import math
|
import math
|
||||||
from whisperx.conjunctions import get_conjunctions, get_comma
|
from .conjunctions import get_conjunctions, get_comma
|
||||||
|
from typing import TextIO
|
||||||
|
|
||||||
def normal_round(n):
|
def normal_round(n):
|
||||||
if n - math.floor(n) < 0.5:
|
if n - math.floor(n) < 0.5:
|
||||||
|
@ -1,7 +1,4 @@
|
|||||||
from whisperx.alignment import load_align_model as load_align_model, align as align
|
from .alignment import load_align_model, align
|
||||||
from whisperx.asr import load_model as load_model
|
from .audio import load_audio
|
||||||
from whisperx.audio import load_audio as load_audio
|
from .diarize import assign_word_speakers, DiarizationPipeline
|
||||||
from whisperx.diarize import (
|
from .asr import load_model
|
||||||
assign_word_speakers as assign_word_speakers,
|
|
||||||
DiarizationPipeline as DiarizationPipeline,
|
|
||||||
)
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from whisperx.transcribe import cli
|
from .transcribe import cli
|
||||||
|
|
||||||
|
|
||||||
cli()
|
cli()
|
||||||
|
@ -13,9 +13,9 @@ import torch
|
|||||||
import torchaudio
|
import torchaudio
|
||||||
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
|
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
|
||||||
|
|
||||||
from whisperx.audio import SAMPLE_RATE, load_audio
|
from .audio import SAMPLE_RATE, load_audio
|
||||||
from whisperx.utils import interpolate_nans
|
from .utils import interpolate_nans
|
||||||
from whisperx.types import (
|
from .types import (
|
||||||
AlignedTranscriptionResult,
|
AlignedTranscriptionResult,
|
||||||
SingleSegment,
|
SingleSegment,
|
||||||
SingleAlignedSegment,
|
SingleAlignedSegment,
|
||||||
|
124
whisperx/asr.py
124
whisperx/asr.py
@ -11,12 +11,14 @@ from faster_whisper.transcribe import TranscriptionOptions, get_ctranslate2_stor
|
|||||||
from transformers import Pipeline
|
from transformers import Pipeline
|
||||||
from transformers.pipelines.pt_utils import PipelineIterator
|
from transformers.pipelines.pt_utils import PipelineIterator
|
||||||
|
|
||||||
from whisperx.audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram
|
from .audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram
|
||||||
from whisperx.types import SingleSegment, TranscriptionResult
|
from .types import SingleSegment, TranscriptionResult
|
||||||
from whisperx.vads import Vad, Silero, Pyannote
|
from .vads import Vad, Silero, Pyannote
|
||||||
|
|
||||||
|
|
||||||
def find_numeral_symbol_tokens(tokenizer):
|
def find_numeral_symbol_tokens(tokenizer):
|
||||||
|
"""
|
||||||
|
Finds tokens that represent numeral and symbols.
|
||||||
|
"""
|
||||||
numeral_symbol_tokens = []
|
numeral_symbol_tokens = []
|
||||||
for i in range(tokenizer.eot):
|
for i in range(tokenizer.eot):
|
||||||
token = tokenizer.decode([i]).removeprefix(" ")
|
token = tokenizer.decode([i]).removeprefix(" ")
|
||||||
@ -26,10 +28,10 @@ def find_numeral_symbol_tokens(tokenizer):
|
|||||||
return numeral_symbol_tokens
|
return numeral_symbol_tokens
|
||||||
|
|
||||||
class WhisperModel(faster_whisper.WhisperModel):
|
class WhisperModel(faster_whisper.WhisperModel):
|
||||||
'''
|
"""
|
||||||
FasterWhisperModel provides batched inference for faster-whisper.
|
Wrapper around faster-whisper's WhisperModel to enable batched inference.
|
||||||
Currently only works in non-timestamp mode and fixed prompt for all samples in batch.
|
Currently, it only supports non-timestamp mode and a fixed prompt for all samples in a batch.
|
||||||
'''
|
"""
|
||||||
|
|
||||||
def generate_segment_batched(
|
def generate_segment_batched(
|
||||||
self,
|
self,
|
||||||
@ -38,28 +40,45 @@ class WhisperModel(faster_whisper.WhisperModel):
|
|||||||
options: TranscriptionOptions,
|
options: TranscriptionOptions,
|
||||||
encoder_output=None,
|
encoder_output=None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Generates transcription for a batch of audio segments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
features: The input audio features.
|
||||||
|
tokenizer: The tokenizer used to decode the generated tokens.
|
||||||
|
options: Transcription options.
|
||||||
|
encoder_output: Output from the encoder model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The decoded transcription text.
|
||||||
|
"""
|
||||||
batch_size = features.shape[0]
|
batch_size = features.shape[0]
|
||||||
|
# Initialize tokens and prompt for the generation process.
|
||||||
all_tokens = []
|
all_tokens = []
|
||||||
prompt_reset_since = 0
|
prompt_reset_since = 0
|
||||||
|
# Check if an initial prompt is provided and handle it.
|
||||||
if options.initial_prompt is not None:
|
if options.initial_prompt is not None:
|
||||||
initial_prompt = " " + options.initial_prompt.strip()
|
initial_prompt = " " + options.initial_prompt.strip()
|
||||||
initial_prompt_tokens = tokenizer.encode(initial_prompt)
|
initial_prompt_tokens = tokenizer.encode(initial_prompt)
|
||||||
all_tokens.extend(initial_prompt_tokens)
|
all_tokens.extend(initial_prompt_tokens)
|
||||||
|
# Prepare the prompt for the current batch.
|
||||||
previous_tokens = all_tokens[prompt_reset_since:]
|
previous_tokens = all_tokens[prompt_reset_since:]
|
||||||
prompt = self.get_prompt(
|
prompt = self.get_prompt(
|
||||||
tokenizer,
|
tokenizer,
|
||||||
previous_tokens,
|
previous_tokens,
|
||||||
without_timestamps=options.without_timestamps,
|
without_timestamps=options.without_timestamps,
|
||||||
prefix=options.prefix,
|
prefix=options.prefix,
|
||||||
hotwords=options.hotwords
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Encode the features to obtain the encoder output.
|
||||||
encoder_output = self.encode(features)
|
encoder_output = self.encode(features)
|
||||||
|
|
||||||
|
# Determine the maximum initial timestamp index based on the options.
|
||||||
max_initial_timestamp_index = int(
|
max_initial_timestamp_index = int(
|
||||||
round(options.max_initial_timestamp / self.time_precision)
|
round(options.max_initial_timestamp / self.time_precision)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Generate the transcription result for the batch.
|
||||||
result = self.model.generate(
|
result = self.model.generate(
|
||||||
encoder_output,
|
encoder_output,
|
||||||
[prompt] * batch_size,
|
[prompt] * batch_size,
|
||||||
@ -71,100 +90,37 @@ class WhisperModel(faster_whisper.WhisperModel):
|
|||||||
suppress_tokens=options.suppress_tokens,
|
suppress_tokens=options.suppress_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Extract the token sequences from the result.
|
||||||
tokens_batch = [x.sequences_ids[0] for x in result]
|
tokens_batch = [x.sequences_ids[0] for x in result]
|
||||||
|
|
||||||
|
# Define an inner function to decode the tokens for each batch.
|
||||||
def decode_batch(tokens: List[List[int]]) -> str:
|
def decode_batch(tokens: List[List[int]]) -> str:
|
||||||
res = []
|
res = []
|
||||||
for tk in tokens:
|
for tk in tokens:
|
||||||
res.append([token for token in tk if token < tokenizer.eot])
|
res.append([token for token in tk if token < tokenizer.eot])
|
||||||
# text_tokens = [token for token in tokens if token < self.eot]
|
|
||||||
return tokenizer.tokenizer.decode_batch(res)
|
return tokenizer.tokenizer.decode_batch(res)
|
||||||
|
|
||||||
|
# Decode the tokens to get the transcription text.
|
||||||
text = decode_batch(tokens_batch)
|
text = decode_batch(tokens_batch)
|
||||||
|
|
||||||
return text
|
return text
|
||||||
|
|
||||||
def encode(self, features: np.ndarray) -> ctranslate2.StorageView:
|
def encode(self, features: np.ndarray) -> ctranslate2.StorageView:
|
||||||
# When the model is running on multiple GPUs, the encoder output should be moved
|
"""
|
||||||
# to the CPU since we don't know which GPU will handle the next job.
|
Encodes the audio features using the CTranslate2 storage.
|
||||||
|
|
||||||
|
When the model is running on multiple GPUs, the encoder output should be moved
|
||||||
|
to the CPU since we don't know which GPU will handle the next job.
|
||||||
|
"""
|
||||||
|
# When the model is running on multiple GPUs, the encoder output should be moved to the CPU.
|
||||||
to_cpu = self.model.device == "cuda" and len(self.model.device_index) > 1
|
to_cpu = self.model.device == "cuda" and len(self.model.device_index) > 1
|
||||||
# unsqueeze if batch size = 1
|
# If the batch size is 1, unsqueeze the features to ensure it is a 3D array.
|
||||||
if len(features.shape) == 2:
|
if len(features.shape) == 2:
|
||||||
features = np.expand_dims(features, 0)
|
features = np.expand_dims(features, 0)
|
||||||
features = get_ctranslate2_storage(features)
|
features = get_ctranslate2_storage(features)
|
||||||
|
# call the model
|
||||||
return self.model.encode(features, to_cpu=to_cpu)
|
return self.model.encode(features, to_cpu=to_cpu)
|
||||||
|
|
||||||
class FasterWhisperPipeline(Pipeline):
|
|
||||||
"""
|
|
||||||
Huggingface Pipeline wrapper for FasterWhisperModel.
|
|
||||||
"""
|
|
||||||
# TODO:
|
|
||||||
# - add support for timestamp mode
|
|
||||||
# - add support for custom inference kwargs
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model: WhisperModel,
|
|
||||||
vad,
|
|
||||||
vad_params: dict,
|
|
||||||
options: TranscriptionOptions,
|
|
||||||
tokenizer: Optional[Tokenizer] = None,
|
|
||||||
device: Union[int, str, "torch.device"] = -1,
|
|
||||||
framework="pt",
|
|
||||||
language: Optional[str] = None,
|
|
||||||
suppress_numerals: bool = False,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
self.model = model
|
|
||||||
self.tokenizer = tokenizer
|
|
||||||
self.options = options
|
|
||||||
self.preset_language = language
|
|
||||||
self.suppress_numerals = suppress_numerals
|
|
||||||
self._batch_size = kwargs.pop("batch_size", None)
|
|
||||||
self._num_workers = 1
|
|
||||||
self._preprocess_params, self._forward_params, self._postprocess_params = self._sanitize_parameters(**kwargs)
|
|
||||||
self.call_count = 0
|
|
||||||
self.framework = framework
|
|
||||||
if self.framework == "pt":
|
|
||||||
if isinstance(device, torch.device):
|
|
||||||
self.device = device
|
|
||||||
elif isinstance(device, str):
|
|
||||||
self.device = torch.device(device)
|
|
||||||
elif device < 0:
|
|
||||||
self.device = torch.device("cpu")
|
|
||||||
else:
|
|
||||||
self.device = torch.device(f"cuda:{device}")
|
|
||||||
else:
|
|
||||||
self.device = device
|
|
||||||
|
|
||||||
super(Pipeline, self).__init__()
|
|
||||||
self.vad_model = vad
|
|
||||||
self._vad_params = vad_params
|
|
||||||
|
|
||||||
def _sanitize_parameters(self, **kwargs):
|
|
||||||
preprocess_kwargs = {}
|
|
||||||
if "tokenizer" in kwargs:
|
|
||||||
preprocess_kwargs["maybe_arg"] = kwargs["maybe_arg"]
|
|
||||||
return preprocess_kwargs, {}, {}
|
|
||||||
|
|
||||||
def preprocess(self, audio):
|
|
||||||
audio = audio['inputs']
|
|
||||||
model_n_mels = self.model.feat_kwargs.get("feature_size")
|
|
||||||
features = log_mel_spectrogram(
|
|
||||||
audio,
|
|
||||||
n_mels=model_n_mels if model_n_mels is not None else 80,
|
|
||||||
padding=N_SAMPLES - audio.shape[0],
|
|
||||||
)
|
|
||||||
return {'inputs': features}
|
|
||||||
|
|
||||||
def _forward(self, model_inputs):
|
|
||||||
outputs = self.model.generate_segment_batched(model_inputs['inputs'], self.tokenizer, self.options)
|
|
||||||
return {'text': outputs}
|
|
||||||
|
|
||||||
def postprocess(self, model_outputs):
|
|
||||||
return model_outputs
|
|
||||||
|
|
||||||
def get_iterator(
|
def get_iterator(
|
||||||
self,
|
self,
|
||||||
inputs,
|
inputs,
|
||||||
|
@ -7,7 +7,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from whisperx.utils import exact_div
|
from .utils import exact_div
|
||||||
|
|
||||||
# hard-coded audio hyperparameters
|
# hard-coded audio hyperparameters
|
||||||
SAMPLE_RATE = 16000
|
SAMPLE_RATE = 16000
|
||||||
|
@ -4,8 +4,8 @@ from pyannote.audio import Pipeline
|
|||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from whisperx.audio import load_audio, SAMPLE_RATE
|
from .audio import load_audio, SAMPLE_RATE
|
||||||
from whisperx.types import TranscriptionResult, AlignedTranscriptionResult
|
from .types import TranscriptionResult, AlignedTranscriptionResult
|
||||||
|
|
||||||
|
|
||||||
class DiarizationPipeline:
|
class DiarizationPipeline:
|
||||||
|
@ -1,20 +1,17 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import gc
|
import gc
|
||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
import warnings
|
import warnings
|
||||||
import importlib.metadata
|
|
||||||
import platform
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from whisperx.alignment import align, load_align_model
|
from .alignment import align, load_align_model
|
||||||
from whisperx.asr import load_model
|
from .asr import load_model
|
||||||
from whisperx.audio import load_audio
|
from .audio import load_audio
|
||||||
from whisperx.diarize import DiarizationPipeline, assign_word_speakers
|
from .diarize import DiarizationPipeline, assign_word_speakers
|
||||||
from whisperx.types import AlignedTranscriptionResult, TranscriptionResult
|
from .types import AlignedTranscriptionResult, TranscriptionResult
|
||||||
from whisperx.utils import (
|
from .utils import (
|
||||||
LANGUAGES,
|
LANGUAGES,
|
||||||
TO_LANGUAGE_CODE,
|
TO_LANGUAGE_CODE,
|
||||||
get_writer,
|
get_writer,
|
||||||
@ -88,8 +85,6 @@ def cli():
|
|||||||
parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face Access Token to access PyAnnote gated models")
|
parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face Access Token to access PyAnnote gated models")
|
||||||
|
|
||||||
parser.add_argument("--print_progress", type=str2bool, default = False, help = "if True, progress will be printed in transcribe() and align() methods.")
|
parser.add_argument("--print_progress", type=str2bool, default = False, help = "if True, progress will be printed in transcribe() and align() methods.")
|
||||||
parser.add_argument("--version", "-V", action="version", version=f"%(prog)s {importlib.metadata.version('whisperx')}",help="Show whisperx version information and exit")
|
|
||||||
parser.add_argument("--python-version", "-P", action="version", version=f"Python {platform.python_version()} ({platform.python_implementation()})",help="Show python version information and exit")
|
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
args = parser.parse_args().__dict__
|
args = parser.parse_args().__dict__
|
||||||
@ -143,9 +138,7 @@ def cli():
|
|||||||
f"{model_name} is an English-only model but received '{args['language']}'; using English instead."
|
f"{model_name} is an English-only model but received '{args['language']}'; using English instead."
|
||||||
)
|
)
|
||||||
args["language"] = "en"
|
args["language"] = "en"
|
||||||
align_language = (
|
align_language = args["language"] if args["language"] is not None else "en" # default to loading english if not specified
|
||||||
args["language"] if args["language"] is not None else "en"
|
|
||||||
) # default to loading english if not specified
|
|
||||||
|
|
||||||
temperature = args.pop("temperature")
|
temperature = args.pop("temperature")
|
||||||
if (increment := args.pop("temperature_increment_on_fallback")) is not None:
|
if (increment := args.pop("temperature_increment_on_fallback")) is not None:
|
||||||
@ -186,24 +179,7 @@ def cli():
|
|||||||
results = []
|
results = []
|
||||||
tmp_results = []
|
tmp_results = []
|
||||||
# model = load_model(model_name, device=device, download_root=model_dir)
|
# model = load_model(model_name, device=device, download_root=model_dir)
|
||||||
model = load_model(
|
model = load_model(model_name, device=device, device_index=device_index, download_root=model_dir, compute_type=compute_type, language=args['language'], asr_options=asr_options, vad_method=vad_method, vad_options={"chunk_size":chunk_size, "vad_onset": vad_onset, "vad_offset": vad_offset}, task=task, local_files_only=model_cache_only, threads=faster_whisper_threads)
|
||||||
model_name,
|
|
||||||
device=device,
|
|
||||||
device_index=device_index,
|
|
||||||
download_root=model_dir,
|
|
||||||
compute_type=compute_type,
|
|
||||||
language=args["language"],
|
|
||||||
asr_options=asr_options,
|
|
||||||
vad_method=vad_method,
|
|
||||||
vad_options={
|
|
||||||
"chunk_size": chunk_size,
|
|
||||||
"vad_onset": vad_onset,
|
|
||||||
"vad_offset": vad_offset,
|
|
||||||
},
|
|
||||||
task=task,
|
|
||||||
local_files_only=model_cache_only,
|
|
||||||
threads=faster_whisper_threads,
|
|
||||||
)
|
|
||||||
|
|
||||||
for audio_path in args.pop("audio"):
|
for audio_path in args.pop("audio"):
|
||||||
audio = load_audio(audio_path)
|
audio = load_audio(audio_path)
|
||||||
@ -227,9 +203,7 @@ def cli():
|
|||||||
if not no_align:
|
if not no_align:
|
||||||
tmp_results = results
|
tmp_results = results
|
||||||
results = []
|
results = []
|
||||||
align_model, align_metadata = load_align_model(
|
align_model, align_metadata = load_align_model(align_language, device, model_name=align_model)
|
||||||
align_language, device, model_name=align_model
|
|
||||||
)
|
|
||||||
for result, audio_path in tmp_results:
|
for result, audio_path in tmp_results:
|
||||||
# >> Align
|
# >> Align
|
||||||
if len(tmp_results) > 1:
|
if len(tmp_results) > 1:
|
||||||
@ -241,12 +215,8 @@ def cli():
|
|||||||
if align_model is not None and len(result["segments"]) > 0:
|
if align_model is not None and len(result["segments"]) > 0:
|
||||||
if result.get("language", "en") != align_metadata["language"]:
|
if result.get("language", "en") != align_metadata["language"]:
|
||||||
# load new language
|
# load new language
|
||||||
print(
|
print(f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language...")
|
||||||
f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language..."
|
align_model, align_metadata = load_align_model(result["language"], device)
|
||||||
)
|
|
||||||
align_model, align_metadata = load_align_model(
|
|
||||||
result["language"], device
|
|
||||||
)
|
|
||||||
print(">>Performing alignment...")
|
print(">>Performing alignment...")
|
||||||
result: AlignedTranscriptionResult = align(
|
result: AlignedTranscriptionResult = align(
|
||||||
result["segments"],
|
result["segments"],
|
||||||
@ -269,17 +239,13 @@ def cli():
|
|||||||
# >> Diarize
|
# >> Diarize
|
||||||
if diarize:
|
if diarize:
|
||||||
if hf_token is None:
|
if hf_token is None:
|
||||||
print(
|
print("Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model...")
|
||||||
"Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model..."
|
|
||||||
)
|
|
||||||
tmp_results = results
|
tmp_results = results
|
||||||
print(">>Performing diarization...")
|
print(">>Performing diarization...")
|
||||||
results = []
|
results = []
|
||||||
diarize_model = DiarizationPipeline(use_auth_token=hf_token, device=device)
|
diarize_model = DiarizationPipeline(use_auth_token=hf_token, device=device)
|
||||||
for result, input_audio_path in tmp_results:
|
for result, input_audio_path in tmp_results:
|
||||||
diarize_segments = diarize_model(
|
diarize_segments = diarize_model(input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers)
|
||||||
input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers
|
|
||||||
)
|
|
||||||
result = assign_word_speakers(diarize_segments, result)
|
result = assign_word_speakers(diarize_segments, result)
|
||||||
results.append((result, input_audio_path))
|
results.append((result, input_audio_path))
|
||||||
# >> Write
|
# >> Write
|
||||||
@ -287,6 +253,5 @@ def cli():
|
|||||||
result["language"] = align_language
|
result["language"] = align_language
|
||||||
writer(result, audio_path, writer_args)
|
writer(result, audio_path, writer_args)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
cli()
|
cli()
|
||||||
|
@ -106,6 +106,7 @@ LANGUAGES = {
|
|||||||
"jw": "javanese",
|
"jw": "javanese",
|
||||||
"su": "sundanese",
|
"su": "sundanese",
|
||||||
"yue": "cantonese",
|
"yue": "cantonese",
|
||||||
|
"lv": "latvian",
|
||||||
}
|
}
|
||||||
|
|
||||||
# language code lookup by name, with a few language aliases
|
# language code lookup by name, with a few language aliases
|
||||||
|
@ -1,3 +1,3 @@
|
|||||||
from whisperx.vads.pyannote import Pyannote as Pyannote
|
from whisperx.vads.pyannote import Pyannote
|
||||||
from whisperx.vads.silero import Silero as Silero
|
from whisperx.vads.silero import Silero
|
||||||
from whisperx.vads.vad import Vad as Vad
|
from whisperx.vads.vad import Vad
|
@ -1,4 +1,6 @@
|
|||||||
|
import hashlib
|
||||||
import os
|
import os
|
||||||
|
import urllib
|
||||||
from typing import Callable, Text, Union
|
from typing import Callable, Text, Union
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
@ -10,11 +12,11 @@ from pyannote.audio.pipelines import VoiceActivityDetection
|
|||||||
from pyannote.audio.pipelines.utils import PipelineModel
|
from pyannote.audio.pipelines.utils import PipelineModel
|
||||||
from pyannote.core import Annotation, SlidingWindowFeature
|
from pyannote.core import Annotation, SlidingWindowFeature
|
||||||
from pyannote.core import Segment
|
from pyannote.core import Segment
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
from whisperx.diarize import Segment as SegmentX
|
from whisperx.diarize import Segment as SegmentX
|
||||||
from whisperx.vads.vad import Vad
|
from whisperx.vads.vad import Vad
|
||||||
|
|
||||||
|
|
||||||
def load_vad_model(device, vad_onset=0.500, vad_offset=0.363, use_auth_token=None, model_fp=None):
|
def load_vad_model(device, vad_onset=0.500, vad_offset=0.363, use_auth_token=None, model_fp=None):
|
||||||
model_dir = torch.hub._get_torch_home()
|
model_dir = torch.hub._get_torch_home()
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user