27 Commits

Author SHA1 Message Date
73db39703e chore: update GitHub Actions workflow to use dynamic ref for checkout 2025-04-10 09:34:24 +02:00
db1750fa48 chore: update ctranslate2 version requirement to >=4.5.0 2025-04-10 09:23:17 +02:00
734084cdf6 bump: update version to 3.3.1 2025-01-08 18:00:34 +01:00
9395b0de18 Update tmp.yml 2025-01-08 17:59:28 +01:00
d57f9dc54c Create tmp.yml 2025-01-08 17:59:28 +01:00
a90bd1ce3f dataclasses replace method 2025-01-08 17:59:13 +01:00
10b05fc43f refactor: replace NamedTuple with TranscriptionOptions in FasterWhisperPipeline 2025-01-05 18:56:19 +01:00
26d9b46888 feat: include speaker information in WriteTXT when diarizing 2025-01-05 18:21:34 +01:00
9a8967f27e refactor: add type hints 2025-01-05 11:48:24 +01:00
0f7f9f9f83 refactor: simplify imports for better type inference 2025-01-05 11:48:24 +01:00
c60594fa3b fix: update import statement for conjunctions module 2025-01-05 11:48:24 +01:00
4916192246 chore: bump whisperX to 3.3.0 2025-01-02 14:09:10 +01:00
cbdac53e87 chore: update ctranslate2 version to restrict <4.5.0 2025-01-02 14:09:10 +01:00
940a223219 fix: add UTF-8 encoding when reading README.md 2025-01-02 12:43:59 +01:00
a0eb31019b chore: update license in setup.py 2025-01-02 08:41:04 +01:00
b08ad67a72 docs: update installation instructions in README 2025-01-02 08:35:45 +01:00
c18f9f979b fix: update README image source and enhance setup.py for long description 2025-01-02 08:30:04 +01:00
948b3e368b chore: update gitignore 2025-01-01 18:47:40 +01:00
e9ac5b63bc chore: clean up MANIFEST.in by removing unnecessary asset inclusions 2025-01-01 18:47:40 +01:00
90b45459d9 feat: add build and release workflow 2025-01-01 18:47:40 +01:00
81c4af96a6 feat: add Python compatibility testing workflow
feat: restrict Python versions to 3.9 - 3.12
2025-01-01 15:29:03 +01:00
1c6d9327bc feat: use model_dir as cache_dir for wav2vec2 (#681) 2025-01-01 13:22:27 +01:00
0fdb55d317 feat: add local_files_only option on whisperx.load_model for offline mode (#867)
Adds the parameter local_files_only (default False for consistency) to whisperx.load_model so that the user can avoid downloading the file and return the path to the local cached file if it exists.

---------

Co-authored-by: Barabazs <31799121+Barabazs@users.noreply.github.com>
2025-01-01 13:16:45 +01:00
51da22771f feat: add verbose output (#759)
---------

Co-authored-by: Abhishek Sharma <abhishek@zipteams.com>
Co-authored-by: Barabazs <31799121+Barabazs@users.noreply.github.com>
2025-01-01 13:07:52 +01:00
15ad5bf7df feat: update versions for pyannote:3.3.2 and faster-whisper:1.1.0 (#936)
* chore: bump faster-whisper to 1.1.0

* chore: bump pyannote to 3.3.2

* feat: add multilingual option in load_model function

---------

Co-authored-by: Barabazs <31799121+Barabazs@users.noreply.github.com>
2024-12-31 10:41:09 +01:00
7fdbd21fe3 feat: add support for faster-whisper 1.0.3 (#875)
---------

Co-authored-by: Barabazs <31799121+Barabazs@users.noreply.github.com>
2024-12-31 10:07:42 +01:00
3ff625c561 feat: update faster-whisper to 1.0.2 (#814)
* Update faster-whisper to 1.0.2 to enable model distil-large-v3

* feat: add hotwords option to default_asr_options

---------

Co-authored-by: Barabazs <31799121+Barabazs@users.noreply.github.com>
2024-12-31 09:41:22 +01:00
16 changed files with 468 additions and 94 deletions

37
.github/workflows/build-and-release.yml vendored Normal file
View File

@ -0,0 +1,37 @@
name: Build and release
on:
release:
types: [published]
jobs:
build:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
with:
ref: ${{ github.ref_name }}
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.9"
- name: Install dependencies
run: |
python -m pip install build
- name: Build wheels
run: python -m build --wheel
- name: Release to Github
uses: softprops/action-gh-release@v2
with:
files: dist/*
- name: Publish package to PyPi
uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
with:
user: __token__
password: ${{ secrets.PYPI_API_TOKEN }}

View File

@ -0,0 +1,32 @@
name: Python Compatibility Test
on:
push:
branches: [main]
pull_request:
branches: [main]
workflow_dispatch: # Allows manual triggering from GitHub UI
jobs:
test:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.9", "3.10", "3.11", "3.12"]
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install package
run: |
python -m pip install --upgrade pip
pip install .
- name: Test import
run: |
python -c "import whisperx; print('Successfully imported whisperx')"

35
.github/workflows/tmp.yml vendored Normal file
View File

@ -0,0 +1,35 @@
name: Python Compatibility Test (PyPi)
on:
push:
branches: [main]
pull_request:
branches: [main]
workflow_dispatch: # Allows manual triggering from GitHub UI
jobs:
test:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.9", "3.10", "3.11", "3.12"]
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install package
run: |
pip install whisperx
- name: Print packages
run: |
pip list
- name: Test import
run: |
python -c "import whisperx; print('Successfully imported whisperx')"

172
.gitignore vendored
View File

@ -1,3 +1,171 @@
whisperx.egg-info/ # Byte-compiled / optimized / DLL files
**/__pycache__/ __pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints .ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# UV
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
#uv.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
.pdm.toml
.pdm-python
.pdm-build/
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
# PyPI configuration file
.pypirc

View File

@ -1,6 +1,3 @@
include whisperx/assets/* include whisperx/assets/*
include whisperx/assets/gpt2/*
include whisperx/assets/multilingual/*
include whisperx/normalizers/english.json
include LICENSE include LICENSE
include requirements.txt include requirements.txt

View File

@ -23,7 +23,7 @@
</p> </p>
<img width="1216" align="center" alt="whisperx-arch" src="figures/pipeline.png"> <img width="1216" align="center" alt="whisperx-arch" src="https://raw.githubusercontent.com/m-bain/whisperX/refs/heads/main/figures/pipeline.png">
<!-- <p align="left">Whisper-Based Automatic Speech Recognition (ASR) with improved timestamp accuracy + quality via forced phoneme alignment and voice-activity based batching for fast inference.</p> --> <!-- <p align="left">Whisper-Based Automatic Speech Recognition (ASR) with improved timestamp accuracy + quality via forced phoneme alignment and voice-activity based batching for fast inference.</p> -->
@ -80,21 +80,40 @@ GPU execution requires the NVIDIA libraries cuBLAS 11.x and cuDNN 8.x to be inst
See other methods [here.](https://pytorch.org/get-started/previous-versions/#v200) See other methods [here.](https://pytorch.org/get-started/previous-versions/#v200)
### 3. Install this repo ### 3. Install WhisperX
`pip install git+https://github.com/m-bain/whisperx.git` You have several installation options:
If already installed, update package to most recent commit #### Option A: Stable Release (recommended)
Install the latest stable version from PyPI:
`pip install git+https://github.com/m-bain/whisperx.git --upgrade` ```bash
pip install whisperx
If wishing to modify this package, clone and install in editable mode:
``` ```
$ git clone https://github.com/m-bain/whisperX.git
$ cd whisperX #### Option B: Development Version
$ pip install -e . Install the latest development version directly from GitHub (may be unstable):
```bash
pip install git+https://github.com/m-bain/whisperx.git
``` ```
If already installed, update to the most recent commit:
```bash
pip install git+https://github.com/m-bain/whisperx.git --upgrade
```
#### Option C: Development Mode
If you wish to modify the package, clone and install in editable mode:
```bash
git clone https://github.com/m-bain/whisperX.git
cd whisperX
pip install -e .
```
> **Note**: The development version may contain experimental features and bugs. Use the stable PyPI release for production environments.
You may also need to install ffmpeg, rust etc. Follow openAI instructions here https://github.com/openai/whisper#setup. You may also need to install ffmpeg, rust etc. Follow openAI instructions here https://github.com/openai/whisper#setup.
### Speaker Diarization ### Speaker Diarization

View File

@ -1,7 +1,7 @@
torch>=2 torch>=2
torchaudio>=2 torchaudio>=2
faster-whisper==1.0.0 faster-whisper==1.1.0
ctranslate2==4.4.0 ctranslate2>=4.5.0
transformers transformers
pandas pandas
setuptools>=65 setuptools>=65

View File

@ -1,19 +1,22 @@
import os import os
import platform
import pkg_resources import pkg_resources
from setuptools import find_packages, setup from setuptools import find_packages, setup
with open("README.md", "r", encoding="utf-8") as f:
long_description = f.read()
setup( setup(
name="whisperx", name="whisperx",
py_modules=["whisperx"], py_modules=["whisperx"],
version="3.2.0", version="3.3.2",
description="Time-Accurate Automatic Speech Recognition using Whisper.", description="Time-Accurate Automatic Speech Recognition using Whisper.",
readme="README.md", long_description=long_description,
python_requires=">=3.8", long_description_content_type="text/markdown",
python_requires=">=3.9, <3.13",
author="Max Bain", author="Max Bain",
url="https://github.com/m-bain/whisperx", url="https://github.com/m-bain/whisperx",
license="MIT", license="BSD-2-Clause",
packages=find_packages(exclude=["tests*"]), packages=find_packages(exclude=["tests*"]),
install_requires=[ install_requires=[
str(r) str(r)
@ -21,7 +24,7 @@ setup(
open(os.path.join(os.path.dirname(__file__), "requirements.txt")) open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
) )
] ]
+ [f"pyannote.audio==3.1.1"], + [f"pyannote.audio==3.3.2"],
entry_points={ entry_points={
"console_scripts": ["whisperx=whisperx.transcribe:cli"], "console_scripts": ["whisperx=whisperx.transcribe:cli"],
}, },

View File

@ -1,5 +1,5 @@
import math import math
from conjunctions import get_conjunctions, get_comma from .conjunctions import get_conjunctions, get_comma
from typing import TextIO from typing import TextIO
def normal_round(n): def normal_round(n):

View File

@ -3,7 +3,7 @@ Forced Alignment with Whisper
C. Max Bain C. Max Bain
""" """
from dataclasses import dataclass from dataclasses import dataclass
from typing import Iterable, Union, List from typing import Iterable, Optional, Union, List
import numpy as np import numpy as np
import pandas as pd import pandas as pd
@ -65,7 +65,7 @@ DEFAULT_ALIGN_MODELS_HF = {
} }
def load_align_model(language_code, device, model_name=None, model_dir=None): def load_align_model(language_code: str, device: str, model_name: Optional[str] = None, model_dir=None):
if model_name is None: if model_name is None:
# use default model # use default model
if language_code in DEFAULT_ALIGN_MODELS_TORCH: if language_code in DEFAULT_ALIGN_MODELS_TORCH:
@ -85,8 +85,8 @@ def load_align_model(language_code, device, model_name=None, model_dir=None):
align_dictionary = {c.lower(): i for i, c in enumerate(labels)} align_dictionary = {c.lower(): i for i, c in enumerate(labels)}
else: else:
try: try:
processor = Wav2Vec2Processor.from_pretrained(model_name) processor = Wav2Vec2Processor.from_pretrained(model_name, cache_dir=model_dir)
align_model = Wav2Vec2ForCTC.from_pretrained(model_name) align_model = Wav2Vec2ForCTC.from_pretrained(model_name, cache_dir=model_dir)
except Exception as e: except Exception as e:
print(e) print(e)
print(f"Error loading model from huggingface, check https://huggingface.co/models for finetuned wav2vec2.0 models") print(f"Error loading model from huggingface, check https://huggingface.co/models for finetuned wav2vec2.0 models")

View File

@ -1,17 +1,21 @@
import os import os
import warnings import warnings
from typing import List, Union, Optional, NamedTuple from typing import List, NamedTuple, Optional, Union
from dataclasses import replace
import ctranslate2 import ctranslate2
import faster_whisper import faster_whisper
import numpy as np import numpy as np
import torch import torch
from faster_whisper.tokenizer import Tokenizer
from faster_whisper.transcribe import TranscriptionOptions, get_ctranslate2_storage
from transformers import Pipeline from transformers import Pipeline
from transformers.pipelines.pt_utils import PipelineIterator from transformers.pipelines.pt_utils import PipelineIterator
from .audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram from .audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram
from .vad import load_vad_model, merge_chunks from .types import SingleSegment, TranscriptionResult
from .types import TranscriptionResult, SingleSegment from .vad import VoiceActivitySegmentation, load_vad_model, merge_chunks
def find_numeral_symbol_tokens(tokenizer): def find_numeral_symbol_tokens(tokenizer):
numeral_symbol_tokens = [] numeral_symbol_tokens = []
@ -28,7 +32,13 @@ class WhisperModel(faster_whisper.WhisperModel):
Currently only works in non-timestamp mode and fixed prompt for all samples in batch. Currently only works in non-timestamp mode and fixed prompt for all samples in batch.
''' '''
def generate_segment_batched(self, features: np.ndarray, tokenizer: faster_whisper.tokenizer.Tokenizer, options: faster_whisper.transcribe.TranscriptionOptions, encoder_output = None): def generate_segment_batched(
self,
features: np.ndarray,
tokenizer: Tokenizer,
options: TranscriptionOptions,
encoder_output=None,
):
batch_size = features.shape[0] batch_size = features.shape[0]
all_tokens = [] all_tokens = []
prompt_reset_since = 0 prompt_reset_since = 0
@ -81,7 +91,7 @@ class WhisperModel(faster_whisper.WhisperModel):
# unsqueeze if batch size = 1 # unsqueeze if batch size = 1
if len(features.shape) == 2: if len(features.shape) == 2:
features = np.expand_dims(features, 0) features = np.expand_dims(features, 0)
features = faster_whisper.transcribe.get_ctranslate2_storage(features) features = get_ctranslate2_storage(features)
return self.model.encode(features, to_cpu=to_cpu) return self.model.encode(features, to_cpu=to_cpu)
@ -94,17 +104,17 @@ class FasterWhisperPipeline(Pipeline):
# - add support for custom inference kwargs # - add support for custom inference kwargs
def __init__( def __init__(
self, self,
model, model: WhisperModel,
vad, vad: VoiceActivitySegmentation,
vad_params: dict, vad_params: dict,
options : NamedTuple, options: TranscriptionOptions,
tokenizer=None, tokenizer: Optional[Tokenizer] = None,
device: Union[int, str, "torch.device"] = -1, device: Union[int, str, "torch.device"] = -1,
framework = "pt", framework="pt",
language : Optional[str] = None, language: Optional[str] = None,
suppress_numerals: bool = False, suppress_numerals: bool = False,
**kwargs **kwargs,
): ):
self.model = model self.model = model
self.tokenizer = tokenizer self.tokenizer = tokenizer
@ -156,7 +166,13 @@ class FasterWhisperPipeline(Pipeline):
return model_outputs return model_outputs
def get_iterator( def get_iterator(
self, inputs, num_workers: int, batch_size: int, preprocess_params, forward_params, postprocess_params self,
inputs,
num_workers: int,
batch_size: int,
preprocess_params: dict,
forward_params: dict,
postprocess_params: dict,
): ):
dataset = PipelineIterator(inputs, self.preprocess, preprocess_params) dataset = PipelineIterator(inputs, self.preprocess, preprocess_params)
if "TOKENIZERS_PARALLELISM" not in os.environ: if "TOKENIZERS_PARALLELISM" not in os.environ:
@ -171,7 +187,16 @@ class FasterWhisperPipeline(Pipeline):
return final_iterator return final_iterator
def transcribe( def transcribe(
self, audio: Union[str, np.ndarray], batch_size=None, num_workers=0, language=None, task=None, chunk_size=30, print_progress = False, combined_progress=False self,
audio: Union[str, np.ndarray],
batch_size: Optional[int] = None,
num_workers=0,
language: Optional[str] = None,
task: Optional[str] = None,
chunk_size=30,
print_progress=False,
combined_progress=False,
verbose=False,
) -> TranscriptionResult: ) -> TranscriptionResult:
if isinstance(audio, str): if isinstance(audio, str):
audio = load_audio(audio) audio = load_audio(audio)
@ -193,16 +218,22 @@ class FasterWhisperPipeline(Pipeline):
if self.tokenizer is None: if self.tokenizer is None:
language = language or self.detect_language(audio) language = language or self.detect_language(audio)
task = task or "transcribe" task = task or "transcribe"
self.tokenizer = faster_whisper.tokenizer.Tokenizer(self.model.hf_tokenizer, self.tokenizer = Tokenizer(
self.model.model.is_multilingual, task=task, self.model.hf_tokenizer,
language=language) self.model.model.is_multilingual,
task=task,
language=language,
)
else: else:
language = language or self.tokenizer.language_code language = language or self.tokenizer.language_code
task = task or self.tokenizer.task task = task or self.tokenizer.task
if task != self.tokenizer.task or language != self.tokenizer.language_code: if task != self.tokenizer.task or language != self.tokenizer.language_code:
self.tokenizer = faster_whisper.tokenizer.Tokenizer(self.model.hf_tokenizer, self.tokenizer = Tokenizer(
self.model.model.is_multilingual, task=task, self.model.hf_tokenizer,
language=language) self.model.model.is_multilingual,
task=task,
language=language,
)
if self.suppress_numerals: if self.suppress_numerals:
previous_suppress_tokens = self.options.suppress_tokens previous_suppress_tokens = self.options.suppress_tokens
@ -210,7 +241,7 @@ class FasterWhisperPipeline(Pipeline):
print(f"Suppressing numeral and symbol tokens") print(f"Suppressing numeral and symbol tokens")
new_suppressed_tokens = numeral_symbol_tokens + self.options.suppress_tokens new_suppressed_tokens = numeral_symbol_tokens + self.options.suppress_tokens
new_suppressed_tokens = list(set(new_suppressed_tokens)) new_suppressed_tokens = list(set(new_suppressed_tokens))
self.options = self.options._replace(suppress_tokens=new_suppressed_tokens) self.options = replace(self.options, suppress_tokens=new_suppressed_tokens)
segments: List[SingleSegment] = [] segments: List[SingleSegment] = []
batch_size = batch_size or self._batch_size batch_size = batch_size or self._batch_size
@ -223,6 +254,8 @@ class FasterWhisperPipeline(Pipeline):
text = out['text'] text = out['text']
if batch_size in [0, 1, None]: if batch_size in [0, 1, None]:
text = text[0] text = text[0]
if verbose:
print(f"Transcript: [{round(vad_segments[idx]['start'], 3)} --> {round(vad_segments[idx]['end'], 3)}] {text}")
segments.append( segments.append(
{ {
"text": text, "text": text,
@ -237,12 +270,11 @@ class FasterWhisperPipeline(Pipeline):
# revert suppressed tokens if suppress_numerals is enabled # revert suppressed tokens if suppress_numerals is enabled
if self.suppress_numerals: if self.suppress_numerals:
self.options = self.options._replace(suppress_tokens=previous_suppress_tokens) self.options = replace(self.options, suppress_tokens=previous_suppress_tokens)
return {"segments": segments, "language": language} return {"segments": segments, "language": language}
def detect_language(self, audio: np.ndarray) -> str:
def detect_language(self, audio: np.ndarray):
if audio.shape[0] < N_SAMPLES: if audio.shape[0] < N_SAMPLES:
print("Warning: audio is shorter than 30s, language detection may be inaccurate.") print("Warning: audio is shorter than 30s, language detection may be inaccurate.")
model_n_mels = self.model.feat_kwargs.get("feature_size") model_n_mels = self.model.feat_kwargs.get("feature_size")
@ -256,31 +288,36 @@ class FasterWhisperPipeline(Pipeline):
print(f"Detected language: {language} ({language_probability:.2f}) in first 30s of audio...") print(f"Detected language: {language} ({language_probability:.2f}) in first 30s of audio...")
return language return language
def load_model(whisper_arch,
device, def load_model(
device_index=0, whisper_arch: str,
compute_type="float16", device: str,
asr_options=None, device_index=0,
language : Optional[str] = None, compute_type="float16",
vad_model=None, asr_options: Optional[dict] = None,
vad_options=None, language: Optional[str] = None,
model : Optional[WhisperModel] = None, vad_model: Optional[VoiceActivitySegmentation] = None,
task="transcribe", vad_options: Optional[dict] = None,
download_root=None, model: Optional[WhisperModel] = None,
threads=4): task="transcribe",
'''Load a Whisper model for inference. download_root: Optional[str] = None,
local_files_only=False,
threads=4,
) -> FasterWhisperPipeline:
"""Load a Whisper model for inference.
Args: Args:
whisper_arch: str - The name of the Whisper model to load. whisper_arch - The name of the Whisper model to load.
device: str - The device to load the model on. device - The device to load the model on.
compute_type: str - The compute type to use for the model. compute_type - The compute type to use for the model.
options: dict - A dictionary of options to use for the model. options - A dictionary of options to use for the model.
language: str - The language of the model. (use English for now) language - The language of the model. (use English for now)
model: Optional[WhisperModel] - The WhisperModel instance to use. model - The WhisperModel instance to use.
download_root: Optional[str] - The root directory to download the model to. download_root - The root directory to download the model to.
threads: int - The number of cpu threads to use per worker, e.g. will be multiplied by num workers. local_files_only - If `True`, avoid downloading the file and return the path to the local cached file if it exists.
threads - The number of cpu threads to use per worker, e.g. will be multiplied by num workers.
Returns: Returns:
A Whisper pipeline. A Whisper pipeline.
''' """
if whisper_arch.endswith(".en"): if whisper_arch.endswith(".en"):
language = "en" language = "en"
@ -290,9 +327,10 @@ def load_model(whisper_arch,
device_index=device_index, device_index=device_index,
compute_type=compute_type, compute_type=compute_type,
download_root=download_root, download_root=download_root,
local_files_only=local_files_only,
cpu_threads=threads) cpu_threads=threads)
if language is not None: if language is not None:
tokenizer = faster_whisper.tokenizer.Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task=task, language=language) tokenizer = Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task=task, language=language)
else: else:
print("No language specified, language will be first be detected for each audio file (increases inference time).") print("No language specified, language will be first be detected for each audio file (increases inference time).")
tokenizer = None tokenizer = None
@ -319,10 +357,12 @@ def load_model(whisper_arch,
"word_timestamps": False, "word_timestamps": False,
"prepend_punctuations": "\"'“¿([{-", "prepend_punctuations": "\"'“¿([{-",
"append_punctuations": "\"'.。,!?::”)]}、", "append_punctuations": "\"'.。,!?::”)]}、",
"multilingual": model.model.is_multilingual,
"suppress_numerals": False, "suppress_numerals": False,
"max_new_tokens": None, "max_new_tokens": None,
"clip_timestamps": None, "clip_timestamps": None,
"hallucination_silence_threshold": None, "hallucination_silence_threshold": None,
"hotwords": None,
} }
if asr_options is not None: if asr_options is not None:
@ -331,7 +371,7 @@ def load_model(whisper_arch,
suppress_numerals = default_asr_options["suppress_numerals"] suppress_numerals = default_asr_options["suppress_numerals"]
del default_asr_options["suppress_numerals"] del default_asr_options["suppress_numerals"]
default_asr_options = faster_whisper.transcribe.TranscriptionOptions(**default_asr_options) default_asr_options = TranscriptionOptions(**default_asr_options)
default_vad_options = { default_vad_options = {
"vad_onset": 0.500, "vad_onset": 0.500,

View File

@ -22,7 +22,7 @@ FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame
TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token
def load_audio(file: str, sr: int = SAMPLE_RATE): def load_audio(file: str, sr: int = SAMPLE_RATE) -> np.ndarray:
""" """
Open an audio file and read as mono waveform, resampling as necessary Open an audio file and read as mono waveform, resampling as necessary

View File

@ -1,5 +1,8 @@
# conjunctions.py # conjunctions.py
from typing import Set
conjunctions_by_language = { conjunctions_by_language = {
'en': {'and', 'whether', 'or', 'as', 'but', 'so', 'for', 'nor', 'which', 'yet', 'although', 'since', 'unless', 'when', 'while', 'because', 'if', 'how', 'that', 'than', 'who', 'where', 'what', 'near', 'before', 'after', 'across', 'through', 'until', 'once', 'whereas', 'even', 'both', 'either', 'neither', 'though'}, 'en': {'and', 'whether', 'or', 'as', 'but', 'so', 'for', 'nor', 'which', 'yet', 'although', 'since', 'unless', 'when', 'while', 'because', 'if', 'how', 'that', 'than', 'who', 'where', 'what', 'near', 'before', 'after', 'across', 'through', 'until', 'once', 'whereas', 'even', 'both', 'either', 'neither', 'though'},
'fr': {'et', 'ou', 'mais', 'parce', 'bien', 'pendant', 'quand', '', 'comme', 'si', 'que', 'avant', 'après', 'aussitôt', 'jusquà', 'à', 'malgré', 'donc', 'tant', 'puisque', 'ni', 'soit', 'bien', 'encore', 'dès', 'lorsque'}, 'fr': {'et', 'ou', 'mais', 'parce', 'bien', 'pendant', 'quand', '', 'comme', 'si', 'que', 'avant', 'après', 'aussitôt', 'jusquà', 'à', 'malgré', 'donc', 'tant', 'puisque', 'ni', 'soit', 'bien', 'encore', 'dès', 'lorsque'},
@ -36,8 +39,9 @@ commas_by_language = {
'ur': '،' 'ur': '،'
} }
def get_conjunctions(lang_code): def get_conjunctions(lang_code: str) -> Set[str]:
return conjunctions_by_language.get(lang_code, set()) return conjunctions_by_language.get(lang_code, set())
def get_comma(lang_code):
return commas_by_language.get(lang_code, ',') def get_comma(lang_code: str) -> str:
return commas_by_language.get(lang_code, ",")

View File

@ -5,6 +5,7 @@ from typing import Optional, Union
import torch import torch
from .audio import load_audio, SAMPLE_RATE from .audio import load_audio, SAMPLE_RATE
from .types import TranscriptionResult, AlignedTranscriptionResult
class DiarizationPipeline: class DiarizationPipeline:
@ -18,7 +19,13 @@ class DiarizationPipeline:
device = torch.device(device) device = torch.device(device)
self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token).to(device) self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token).to(device)
def __call__(self, audio: Union[str, np.ndarray], num_speakers=None, min_speakers=None, max_speakers=None): def __call__(
self,
audio: Union[str, np.ndarray],
num_speakers: Optional[int] = None,
min_speakers: Optional[int] = None,
max_speakers: Optional[int] = None,
):
if isinstance(audio, str): if isinstance(audio, str):
audio = load_audio(audio) audio = load_audio(audio)
audio_data = { audio_data = {
@ -32,7 +39,11 @@ class DiarizationPipeline:
return diarize_df return diarize_df
def assign_word_speakers(diarize_df, transcript_result, fill_nearest=False): def assign_word_speakers(
diarize_df: pd.DataFrame,
transcript_result: Union[AlignedTranscriptionResult, TranscriptionResult],
fill_nearest=False,
) -> dict:
transcript_segments = transcript_result["segments"] transcript_segments = transcript_result["segments"]
for seg in transcript_segments: for seg in transcript_segments:
# assign speaker to segment (if any) # assign speaker to segment (if any)

View File

@ -10,8 +10,15 @@ from .alignment import align, load_align_model
from .asr import load_model from .asr import load_model
from .audio import load_audio from .audio import load_audio
from .diarize import DiarizationPipeline, assign_word_speakers from .diarize import DiarizationPipeline, assign_word_speakers
from .utils import (LANGUAGES, TO_LANGUAGE_CODE, get_writer, optional_float, from .types import AlignedTranscriptionResult, TranscriptionResult
optional_int, str2bool) from .utils import (
LANGUAGES,
TO_LANGUAGE_CODE,
get_writer,
optional_float,
optional_int,
str2bool,
)
def cli(): def cli():
@ -87,6 +94,7 @@ def cli():
device: str = args.pop("device") device: str = args.pop("device")
device_index: int = args.pop("device_index") device_index: int = args.pop("device_index")
compute_type: str = args.pop("compute_type") compute_type: str = args.pop("compute_type")
verbose: bool = args.pop("verbose")
# model_flush: bool = args.pop("model_flush") # model_flush: bool = args.pop("model_flush")
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
@ -94,7 +102,7 @@ def cli():
align_model: str = args.pop("align_model") align_model: str = args.pop("align_model")
interpolate_method: str = args.pop("interpolate_method") interpolate_method: str = args.pop("interpolate_method")
no_align: bool = args.pop("no_align") no_align: bool = args.pop("no_align")
task : str = args.pop("task") task: str = args.pop("task")
if task == "translate": if task == "translate":
# translation cannot be aligned # translation cannot be aligned
no_align = True no_align = True
@ -173,7 +181,13 @@ def cli():
audio = load_audio(audio_path) audio = load_audio(audio_path)
# >> VAD & ASR # >> VAD & ASR
print(">>Performing transcription...") print(">>Performing transcription...")
result = model.transcribe(audio, batch_size=batch_size, chunk_size=chunk_size, print_progress=print_progress) result: TranscriptionResult = model.transcribe(
audio,
batch_size=batch_size,
chunk_size=chunk_size,
print_progress=print_progress,
verbose=verbose,
)
results.append((result, audio_path)) results.append((result, audio_path))
# Unload Whisper and VAD # Unload Whisper and VAD
@ -200,7 +214,16 @@ def cli():
print(f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language...") print(f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language...")
align_model, align_metadata = load_align_model(result["language"], device) align_model, align_metadata = load_align_model(result["language"], device)
print(">>Performing alignment...") print(">>Performing alignment...")
result = align(result["segments"], align_model, align_metadata, input_audio, device, interpolate_method=interpolate_method, return_char_alignments=return_char_alignments, print_progress=print_progress) result: AlignedTranscriptionResult = align(
result["segments"],
align_model,
align_metadata,
input_audio,
device,
interpolate_method=interpolate_method,
return_char_alignments=return_char_alignments,
print_progress=print_progress,
)
results.append((result, audio_path)) results.append((result, audio_path))

View File

@ -214,7 +214,12 @@ class WriteTXT(ResultWriter):
def write_result(self, result: dict, file: TextIO, options: dict): def write_result(self, result: dict, file: TextIO, options: dict):
for segment in result["segments"]: for segment in result["segments"]:
print(segment["text"].strip(), file=file, flush=True) speaker = segment.get("speaker")
text = segment["text"].strip()
if speaker is not None:
print(f"[{speaker}]: {text}", file=file, flush=True)
else:
print(text, file=file, flush=True)
class SubtitlesWriter(ResultWriter): class SubtitlesWriter(ResultWriter):