27 Commits

Author SHA1 Message Date
73db39703e chore: update GitHub Actions workflow to use dynamic ref for checkout 2025-04-10 09:34:24 +02:00
db1750fa48 chore: update ctranslate2 version requirement to >=4.5.0 2025-04-10 09:23:17 +02:00
734084cdf6 bump: update version to 3.3.1 2025-01-08 18:00:34 +01:00
9395b0de18 Update tmp.yml 2025-01-08 17:59:28 +01:00
d57f9dc54c Create tmp.yml 2025-01-08 17:59:28 +01:00
a90bd1ce3f dataclasses replace method 2025-01-08 17:59:13 +01:00
10b05fc43f refactor: replace NamedTuple with TranscriptionOptions in FasterWhisperPipeline 2025-01-05 18:56:19 +01:00
26d9b46888 feat: include speaker information in WriteTXT when diarizing 2025-01-05 18:21:34 +01:00
9a8967f27e refactor: add type hints 2025-01-05 11:48:24 +01:00
0f7f9f9f83 refactor: simplify imports for better type inference 2025-01-05 11:48:24 +01:00
c60594fa3b fix: update import statement for conjunctions module 2025-01-05 11:48:24 +01:00
4916192246 chore: bump whisperX to 3.3.0 2025-01-02 14:09:10 +01:00
cbdac53e87 chore: update ctranslate2 version to restrict <4.5.0 2025-01-02 14:09:10 +01:00
940a223219 fix: add UTF-8 encoding when reading README.md 2025-01-02 12:43:59 +01:00
a0eb31019b chore: update license in setup.py 2025-01-02 08:41:04 +01:00
b08ad67a72 docs: update installation instructions in README 2025-01-02 08:35:45 +01:00
c18f9f979b fix: update README image source and enhance setup.py for long description 2025-01-02 08:30:04 +01:00
948b3e368b chore: update gitignore 2025-01-01 18:47:40 +01:00
e9ac5b63bc chore: clean up MANIFEST.in by removing unnecessary asset inclusions 2025-01-01 18:47:40 +01:00
90b45459d9 feat: add build and release workflow 2025-01-01 18:47:40 +01:00
81c4af96a6 feat: add Python compatibility testing workflow
feat: restrict Python versions to 3.9 - 3.12
2025-01-01 15:29:03 +01:00
1c6d9327bc feat: use model_dir as cache_dir for wav2vec2 (#681) 2025-01-01 13:22:27 +01:00
0fdb55d317 feat: add local_files_only option on whisperx.load_model for offline mode (#867)
Adds the parameter local_files_only (default False for consistency) to whisperx.load_model so that the user can avoid downloading the file and return the path to the local cached file if it exists.

---------

Co-authored-by: Barabazs <31799121+Barabazs@users.noreply.github.com>
2025-01-01 13:16:45 +01:00
51da22771f feat: add verbose output (#759)
---------

Co-authored-by: Abhishek Sharma <abhishek@zipteams.com>
Co-authored-by: Barabazs <31799121+Barabazs@users.noreply.github.com>
2025-01-01 13:07:52 +01:00
15ad5bf7df feat: update versions for pyannote:3.3.2 and faster-whisper:1.1.0 (#936)
* chore: bump faster-whisper to 1.1.0

* chore: bump pyannote to 3.3.2

* feat: add multilingual option in load_model function

---------

Co-authored-by: Barabazs <31799121+Barabazs@users.noreply.github.com>
2024-12-31 10:41:09 +01:00
7fdbd21fe3 feat: add support for faster-whisper 1.0.3 (#875)
---------

Co-authored-by: Barabazs <31799121+Barabazs@users.noreply.github.com>
2024-12-31 10:07:42 +01:00
3ff625c561 feat: update faster-whisper to 1.0.2 (#814)
* Update faster-whisper to 1.0.2 to enable model distil-large-v3

* feat: add hotwords option to default_asr_options

---------

Co-authored-by: Barabazs <31799121+Barabazs@users.noreply.github.com>
2024-12-31 09:41:22 +01:00
16 changed files with 468 additions and 94 deletions

37
.github/workflows/build-and-release.yml vendored Normal file
View File

@ -0,0 +1,37 @@
name: Build and release
on:
release:
types: [published]
jobs:
build:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
with:
ref: ${{ github.ref_name }}
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.9"
- name: Install dependencies
run: |
python -m pip install build
- name: Build wheels
run: python -m build --wheel
- name: Release to Github
uses: softprops/action-gh-release@v2
with:
files: dist/*
- name: Publish package to PyPi
uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
with:
user: __token__
password: ${{ secrets.PYPI_API_TOKEN }}

View File

@ -0,0 +1,32 @@
name: Python Compatibility Test
on:
push:
branches: [main]
pull_request:
branches: [main]
workflow_dispatch: # Allows manual triggering from GitHub UI
jobs:
test:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.9", "3.10", "3.11", "3.12"]
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install package
run: |
python -m pip install --upgrade pip
pip install .
- name: Test import
run: |
python -c "import whisperx; print('Successfully imported whisperx')"

35
.github/workflows/tmp.yml vendored Normal file
View File

@ -0,0 +1,35 @@
name: Python Compatibility Test (PyPi)
on:
push:
branches: [main]
pull_request:
branches: [main]
workflow_dispatch: # Allows manual triggering from GitHub UI
jobs:
test:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.9", "3.10", "3.11", "3.12"]
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install package
run: |
pip install whisperx
- name: Print packages
run: |
pip list
- name: Test import
run: |
python -c "import whisperx; print('Successfully imported whisperx')"

172
.gitignore vendored
View File

@ -1,3 +1,171 @@
whisperx.egg-info/
**/__pycache__/
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# UV
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
#uv.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
.pdm.toml
.pdm-python
.pdm-build/
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
# PyPI configuration file
.pypirc

View File

@ -1,6 +1,3 @@
include whisperx/assets/*
include whisperx/assets/gpt2/*
include whisperx/assets/multilingual/*
include whisperx/normalizers/english.json
include LICENSE
include requirements.txt

View File

@ -23,7 +23,7 @@
</p>
<img width="1216" align="center" alt="whisperx-arch" src="figures/pipeline.png">
<img width="1216" align="center" alt="whisperx-arch" src="https://raw.githubusercontent.com/m-bain/whisperX/refs/heads/main/figures/pipeline.png">
<!-- <p align="left">Whisper-Based Automatic Speech Recognition (ASR) with improved timestamp accuracy + quality via forced phoneme alignment and voice-activity based batching for fast inference.</p> -->
@ -80,21 +80,40 @@ GPU execution requires the NVIDIA libraries cuBLAS 11.x and cuDNN 8.x to be inst
See other methods [here.](https://pytorch.org/get-started/previous-versions/#v200)
### 3. Install this repo
### 3. Install WhisperX
`pip install git+https://github.com/m-bain/whisperx.git`
You have several installation options:
If already installed, update package to most recent commit
#### Option A: Stable Release (recommended)
Install the latest stable version from PyPI:
`pip install git+https://github.com/m-bain/whisperx.git --upgrade`
If wishing to modify this package, clone and install in editable mode:
```bash
pip install whisperx
```
$ git clone https://github.com/m-bain/whisperX.git
$ cd whisperX
$ pip install -e .
#### Option B: Development Version
Install the latest development version directly from GitHub (may be unstable):
```bash
pip install git+https://github.com/m-bain/whisperx.git
```
If already installed, update to the most recent commit:
```bash
pip install git+https://github.com/m-bain/whisperx.git --upgrade
```
#### Option C: Development Mode
If you wish to modify the package, clone and install in editable mode:
```bash
git clone https://github.com/m-bain/whisperX.git
cd whisperX
pip install -e .
```
> **Note**: The development version may contain experimental features and bugs. Use the stable PyPI release for production environments.
You may also need to install ffmpeg, rust etc. Follow openAI instructions here https://github.com/openai/whisper#setup.
### Speaker Diarization

View File

@ -1,7 +1,7 @@
torch>=2
torchaudio>=2
faster-whisper==1.0.0
ctranslate2==4.4.0
faster-whisper==1.1.0
ctranslate2>=4.5.0
transformers
pandas
setuptools>=65

View File

@ -1,19 +1,22 @@
import os
import platform
import pkg_resources
from setuptools import find_packages, setup
with open("README.md", "r", encoding="utf-8") as f:
long_description = f.read()
setup(
name="whisperx",
py_modules=["whisperx"],
version="3.2.0",
version="3.3.2",
description="Time-Accurate Automatic Speech Recognition using Whisper.",
readme="README.md",
python_requires=">=3.8",
long_description=long_description,
long_description_content_type="text/markdown",
python_requires=">=3.9, <3.13",
author="Max Bain",
url="https://github.com/m-bain/whisperx",
license="MIT",
license="BSD-2-Clause",
packages=find_packages(exclude=["tests*"]),
install_requires=[
str(r)
@ -21,7 +24,7 @@ setup(
open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
)
]
+ [f"pyannote.audio==3.1.1"],
+ [f"pyannote.audio==3.3.2"],
entry_points={
"console_scripts": ["whisperx=whisperx.transcribe:cli"],
},

View File

@ -1,5 +1,5 @@
import math
from conjunctions import get_conjunctions, get_comma
from .conjunctions import get_conjunctions, get_comma
from typing import TextIO
def normal_round(n):

View File

@ -3,7 +3,7 @@ Forced Alignment with Whisper
C. Max Bain
"""
from dataclasses import dataclass
from typing import Iterable, Union, List
from typing import Iterable, Optional, Union, List
import numpy as np
import pandas as pd
@ -65,7 +65,7 @@ DEFAULT_ALIGN_MODELS_HF = {
}
def load_align_model(language_code, device, model_name=None, model_dir=None):
def load_align_model(language_code: str, device: str, model_name: Optional[str] = None, model_dir=None):
if model_name is None:
# use default model
if language_code in DEFAULT_ALIGN_MODELS_TORCH:
@ -85,8 +85,8 @@ def load_align_model(language_code, device, model_name=None, model_dir=None):
align_dictionary = {c.lower(): i for i, c in enumerate(labels)}
else:
try:
processor = Wav2Vec2Processor.from_pretrained(model_name)
align_model = Wav2Vec2ForCTC.from_pretrained(model_name)
processor = Wav2Vec2Processor.from_pretrained(model_name, cache_dir=model_dir)
align_model = Wav2Vec2ForCTC.from_pretrained(model_name, cache_dir=model_dir)
except Exception as e:
print(e)
print(f"Error loading model from huggingface, check https://huggingface.co/models for finetuned wav2vec2.0 models")

View File

@ -1,17 +1,21 @@
import os
import warnings
from typing import List, Union, Optional, NamedTuple
from typing import List, NamedTuple, Optional, Union
from dataclasses import replace
import ctranslate2
import faster_whisper
import numpy as np
import torch
from faster_whisper.tokenizer import Tokenizer
from faster_whisper.transcribe import TranscriptionOptions, get_ctranslate2_storage
from transformers import Pipeline
from transformers.pipelines.pt_utils import PipelineIterator
from .audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram
from .vad import load_vad_model, merge_chunks
from .types import TranscriptionResult, SingleSegment
from .types import SingleSegment, TranscriptionResult
from .vad import VoiceActivitySegmentation, load_vad_model, merge_chunks
def find_numeral_symbol_tokens(tokenizer):
numeral_symbol_tokens = []
@ -28,7 +32,13 @@ class WhisperModel(faster_whisper.WhisperModel):
Currently only works in non-timestamp mode and fixed prompt for all samples in batch.
'''
def generate_segment_batched(self, features: np.ndarray, tokenizer: faster_whisper.tokenizer.Tokenizer, options: faster_whisper.transcribe.TranscriptionOptions, encoder_output = None):
def generate_segment_batched(
self,
features: np.ndarray,
tokenizer: Tokenizer,
options: TranscriptionOptions,
encoder_output=None,
):
batch_size = features.shape[0]
all_tokens = []
prompt_reset_since = 0
@ -81,7 +91,7 @@ class WhisperModel(faster_whisper.WhisperModel):
# unsqueeze if batch size = 1
if len(features.shape) == 2:
features = np.expand_dims(features, 0)
features = faster_whisper.transcribe.get_ctranslate2_storage(features)
features = get_ctranslate2_storage(features)
return self.model.encode(features, to_cpu=to_cpu)
@ -94,17 +104,17 @@ class FasterWhisperPipeline(Pipeline):
# - add support for custom inference kwargs
def __init__(
self,
model,
vad,
vad_params: dict,
options : NamedTuple,
tokenizer=None,
device: Union[int, str, "torch.device"] = -1,
framework = "pt",
language : Optional[str] = None,
suppress_numerals: bool = False,
**kwargs
self,
model: WhisperModel,
vad: VoiceActivitySegmentation,
vad_params: dict,
options: TranscriptionOptions,
tokenizer: Optional[Tokenizer] = None,
device: Union[int, str, "torch.device"] = -1,
framework="pt",
language: Optional[str] = None,
suppress_numerals: bool = False,
**kwargs,
):
self.model = model
self.tokenizer = tokenizer
@ -156,7 +166,13 @@ class FasterWhisperPipeline(Pipeline):
return model_outputs
def get_iterator(
self, inputs, num_workers: int, batch_size: int, preprocess_params, forward_params, postprocess_params
self,
inputs,
num_workers: int,
batch_size: int,
preprocess_params: dict,
forward_params: dict,
postprocess_params: dict,
):
dataset = PipelineIterator(inputs, self.preprocess, preprocess_params)
if "TOKENIZERS_PARALLELISM" not in os.environ:
@ -171,7 +187,16 @@ class FasterWhisperPipeline(Pipeline):
return final_iterator
def transcribe(
self, audio: Union[str, np.ndarray], batch_size=None, num_workers=0, language=None, task=None, chunk_size=30, print_progress = False, combined_progress=False
self,
audio: Union[str, np.ndarray],
batch_size: Optional[int] = None,
num_workers=0,
language: Optional[str] = None,
task: Optional[str] = None,
chunk_size=30,
print_progress=False,
combined_progress=False,
verbose=False,
) -> TranscriptionResult:
if isinstance(audio, str):
audio = load_audio(audio)
@ -193,24 +218,30 @@ class FasterWhisperPipeline(Pipeline):
if self.tokenizer is None:
language = language or self.detect_language(audio)
task = task or "transcribe"
self.tokenizer = faster_whisper.tokenizer.Tokenizer(self.model.hf_tokenizer,
self.model.model.is_multilingual, task=task,
language=language)
self.tokenizer = Tokenizer(
self.model.hf_tokenizer,
self.model.model.is_multilingual,
task=task,
language=language,
)
else:
language = language or self.tokenizer.language_code
task = task or self.tokenizer.task
if task != self.tokenizer.task or language != self.tokenizer.language_code:
self.tokenizer = faster_whisper.tokenizer.Tokenizer(self.model.hf_tokenizer,
self.model.model.is_multilingual, task=task,
language=language)
self.tokenizer = Tokenizer(
self.model.hf_tokenizer,
self.model.model.is_multilingual,
task=task,
language=language,
)
if self.suppress_numerals:
previous_suppress_tokens = self.options.suppress_tokens
numeral_symbol_tokens = find_numeral_symbol_tokens(self.tokenizer)
print(f"Suppressing numeral and symbol tokens")
new_suppressed_tokens = numeral_symbol_tokens + self.options.suppress_tokens
new_suppressed_tokens = list(set(new_suppressed_tokens))
self.options = self.options._replace(suppress_tokens=new_suppressed_tokens)
self.options = replace(self.options, suppress_tokens=new_suppressed_tokens)
segments: List[SingleSegment] = []
batch_size = batch_size or self._batch_size
@ -223,6 +254,8 @@ class FasterWhisperPipeline(Pipeline):
text = out['text']
if batch_size in [0, 1, None]:
text = text[0]
if verbose:
print(f"Transcript: [{round(vad_segments[idx]['start'], 3)} --> {round(vad_segments[idx]['end'], 3)}] {text}")
segments.append(
{
"text": text,
@ -237,12 +270,11 @@ class FasterWhisperPipeline(Pipeline):
# revert suppressed tokens if suppress_numerals is enabled
if self.suppress_numerals:
self.options = self.options._replace(suppress_tokens=previous_suppress_tokens)
self.options = replace(self.options, suppress_tokens=previous_suppress_tokens)
return {"segments": segments, "language": language}
def detect_language(self, audio: np.ndarray):
def detect_language(self, audio: np.ndarray) -> str:
if audio.shape[0] < N_SAMPLES:
print("Warning: audio is shorter than 30s, language detection may be inaccurate.")
model_n_mels = self.model.feat_kwargs.get("feature_size")
@ -256,31 +288,36 @@ class FasterWhisperPipeline(Pipeline):
print(f"Detected language: {language} ({language_probability:.2f}) in first 30s of audio...")
return language
def load_model(whisper_arch,
device,
device_index=0,
compute_type="float16",
asr_options=None,
language : Optional[str] = None,
vad_model=None,
vad_options=None,
model : Optional[WhisperModel] = None,
task="transcribe",
download_root=None,
threads=4):
'''Load a Whisper model for inference.
def load_model(
whisper_arch: str,
device: str,
device_index=0,
compute_type="float16",
asr_options: Optional[dict] = None,
language: Optional[str] = None,
vad_model: Optional[VoiceActivitySegmentation] = None,
vad_options: Optional[dict] = None,
model: Optional[WhisperModel] = None,
task="transcribe",
download_root: Optional[str] = None,
local_files_only=False,
threads=4,
) -> FasterWhisperPipeline:
"""Load a Whisper model for inference.
Args:
whisper_arch: str - The name of the Whisper model to load.
device: str - The device to load the model on.
compute_type: str - The compute type to use for the model.
options: dict - A dictionary of options to use for the model.
language: str - The language of the model. (use English for now)
model: Optional[WhisperModel] - The WhisperModel instance to use.
download_root: Optional[str] - The root directory to download the model to.
threads: int - The number of cpu threads to use per worker, e.g. will be multiplied by num workers.
whisper_arch - The name of the Whisper model to load.
device - The device to load the model on.
compute_type - The compute type to use for the model.
options - A dictionary of options to use for the model.
language - The language of the model. (use English for now)
model - The WhisperModel instance to use.
download_root - The root directory to download the model to.
local_files_only - If `True`, avoid downloading the file and return the path to the local cached file if it exists.
threads - The number of cpu threads to use per worker, e.g. will be multiplied by num workers.
Returns:
A Whisper pipeline.
'''
"""
if whisper_arch.endswith(".en"):
language = "en"
@ -290,9 +327,10 @@ def load_model(whisper_arch,
device_index=device_index,
compute_type=compute_type,
download_root=download_root,
local_files_only=local_files_only,
cpu_threads=threads)
if language is not None:
tokenizer = faster_whisper.tokenizer.Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task=task, language=language)
tokenizer = Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task=task, language=language)
else:
print("No language specified, language will be first be detected for each audio file (increases inference time).")
tokenizer = None
@ -319,10 +357,12 @@ def load_model(whisper_arch,
"word_timestamps": False,
"prepend_punctuations": "\"'“¿([{-",
"append_punctuations": "\"'.。,!?::”)]}、",
"multilingual": model.model.is_multilingual,
"suppress_numerals": False,
"max_new_tokens": None,
"clip_timestamps": None,
"hallucination_silence_threshold": None,
"hotwords": None,
}
if asr_options is not None:
@ -331,7 +371,7 @@ def load_model(whisper_arch,
suppress_numerals = default_asr_options["suppress_numerals"]
del default_asr_options["suppress_numerals"]
default_asr_options = faster_whisper.transcribe.TranscriptionOptions(**default_asr_options)
default_asr_options = TranscriptionOptions(**default_asr_options)
default_vad_options = {
"vad_onset": 0.500,
@ -354,4 +394,4 @@ def load_model(whisper_arch,
language=language,
suppress_numerals=suppress_numerals,
vad_params=default_vad_options,
)
)

View File

@ -22,7 +22,7 @@ FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame
TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token
def load_audio(file: str, sr: int = SAMPLE_RATE):
def load_audio(file: str, sr: int = SAMPLE_RATE) -> np.ndarray:
"""
Open an audio file and read as mono waveform, resampling as necessary

View File

@ -1,5 +1,8 @@
# conjunctions.py
from typing import Set
conjunctions_by_language = {
'en': {'and', 'whether', 'or', 'as', 'but', 'so', 'for', 'nor', 'which', 'yet', 'although', 'since', 'unless', 'when', 'while', 'because', 'if', 'how', 'that', 'than', 'who', 'where', 'what', 'near', 'before', 'after', 'across', 'through', 'until', 'once', 'whereas', 'even', 'both', 'either', 'neither', 'though'},
'fr': {'et', 'ou', 'mais', 'parce', 'bien', 'pendant', 'quand', '', 'comme', 'si', 'que', 'avant', 'après', 'aussitôt', 'jusquà', 'à', 'malgré', 'donc', 'tant', 'puisque', 'ni', 'soit', 'bien', 'encore', 'dès', 'lorsque'},
@ -36,8 +39,9 @@ commas_by_language = {
'ur': '،'
}
def get_conjunctions(lang_code):
def get_conjunctions(lang_code: str) -> Set[str]:
return conjunctions_by_language.get(lang_code, set())
def get_comma(lang_code):
return commas_by_language.get(lang_code, ',')
def get_comma(lang_code: str) -> str:
return commas_by_language.get(lang_code, ",")

View File

@ -5,6 +5,7 @@ from typing import Optional, Union
import torch
from .audio import load_audio, SAMPLE_RATE
from .types import TranscriptionResult, AlignedTranscriptionResult
class DiarizationPipeline:
@ -18,7 +19,13 @@ class DiarizationPipeline:
device = torch.device(device)
self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token).to(device)
def __call__(self, audio: Union[str, np.ndarray], num_speakers=None, min_speakers=None, max_speakers=None):
def __call__(
self,
audio: Union[str, np.ndarray],
num_speakers: Optional[int] = None,
min_speakers: Optional[int] = None,
max_speakers: Optional[int] = None,
):
if isinstance(audio, str):
audio = load_audio(audio)
audio_data = {
@ -32,7 +39,11 @@ class DiarizationPipeline:
return diarize_df
def assign_word_speakers(diarize_df, transcript_result, fill_nearest=False):
def assign_word_speakers(
diarize_df: pd.DataFrame,
transcript_result: Union[AlignedTranscriptionResult, TranscriptionResult],
fill_nearest=False,
) -> dict:
transcript_segments = transcript_result["segments"]
for seg in transcript_segments:
# assign speaker to segment (if any)

View File

@ -10,8 +10,15 @@ from .alignment import align, load_align_model
from .asr import load_model
from .audio import load_audio
from .diarize import DiarizationPipeline, assign_word_speakers
from .utils import (LANGUAGES, TO_LANGUAGE_CODE, get_writer, optional_float,
optional_int, str2bool)
from .types import AlignedTranscriptionResult, TranscriptionResult
from .utils import (
LANGUAGES,
TO_LANGUAGE_CODE,
get_writer,
optional_float,
optional_int,
str2bool,
)
def cli():
@ -87,6 +94,7 @@ def cli():
device: str = args.pop("device")
device_index: int = args.pop("device_index")
compute_type: str = args.pop("compute_type")
verbose: bool = args.pop("verbose")
# model_flush: bool = args.pop("model_flush")
os.makedirs(output_dir, exist_ok=True)
@ -94,7 +102,7 @@ def cli():
align_model: str = args.pop("align_model")
interpolate_method: str = args.pop("interpolate_method")
no_align: bool = args.pop("no_align")
task : str = args.pop("task")
task: str = args.pop("task")
if task == "translate":
# translation cannot be aligned
no_align = True
@ -173,7 +181,13 @@ def cli():
audio = load_audio(audio_path)
# >> VAD & ASR
print(">>Performing transcription...")
result = model.transcribe(audio, batch_size=batch_size, chunk_size=chunk_size, print_progress=print_progress)
result: TranscriptionResult = model.transcribe(
audio,
batch_size=batch_size,
chunk_size=chunk_size,
print_progress=print_progress,
verbose=verbose,
)
results.append((result, audio_path))
# Unload Whisper and VAD
@ -200,7 +214,16 @@ def cli():
print(f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language...")
align_model, align_metadata = load_align_model(result["language"], device)
print(">>Performing alignment...")
result = align(result["segments"], align_model, align_metadata, input_audio, device, interpolate_method=interpolate_method, return_char_alignments=return_char_alignments, print_progress=print_progress)
result: AlignedTranscriptionResult = align(
result["segments"],
align_model,
align_metadata,
input_audio,
device,
interpolate_method=interpolate_method,
return_char_alignments=return_char_alignments,
print_progress=print_progress,
)
results.append((result, audio_path))

View File

@ -214,7 +214,12 @@ class WriteTXT(ResultWriter):
def write_result(self, result: dict, file: TextIO, options: dict):
for segment in result["segments"]:
print(segment["text"].strip(), file=file, flush=True)
speaker = segment.get("speaker")
text = segment["text"].strip()
if speaker is not None:
print(f"[{speaker}]: {text}", file=file, flush=True)
else:
print(text, file=file, flush=True)
class SubtitlesWriter(ResultWriter):