9 Commits

10 changed files with 181 additions and 71 deletions

35
.github/workflows/tmp.yml vendored Normal file
View File

@ -0,0 +1,35 @@
name: Python Compatibility Test (PyPi)
on:
push:
branches: [main]
pull_request:
branches: [main]
workflow_dispatch: # Allows manual triggering from GitHub UI
jobs:
test:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.9", "3.10", "3.11", "3.12"]
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install package
run: |
pip install whisperx
- name: Print packages
run: |
pip list
- name: Test import
run: |
python -c "import whisperx; print('Successfully imported whisperx')"

View File

@ -9,7 +9,7 @@ with open("README.md", "r", encoding="utf-8") as f:
setup(
name="whisperx",
py_modules=["whisperx"],
version="3.3.0",
version="3.3.1",
description="Time-Accurate Automatic Speech Recognition using Whisper.",
long_description=long_description,
long_description_content_type="text/markdown",

View File

@ -1,5 +1,5 @@
import math
from conjunctions import get_conjunctions, get_comma
from .conjunctions import get_conjunctions, get_comma
from typing import TextIO
def normal_round(n):

View File

@ -3,7 +3,7 @@ Forced Alignment with Whisper
C. Max Bain
"""
from dataclasses import dataclass
from typing import Iterable, Union, List
from typing import Iterable, Optional, Union, List
import numpy as np
import pandas as pd
@ -65,7 +65,7 @@ DEFAULT_ALIGN_MODELS_HF = {
}
def load_align_model(language_code, device, model_name=None, model_dir=None):
def load_align_model(language_code: str, device: str, model_name: Optional[str] = None, model_dir=None):
if model_name is None:
# use default model
if language_code in DEFAULT_ALIGN_MODELS_TORCH:

View File

@ -1,17 +1,21 @@
import os
import warnings
from typing import List, Union, Optional, NamedTuple
from typing import List, NamedTuple, Optional, Union
from dataclasses import replace
import ctranslate2
import faster_whisper
import numpy as np
import torch
from faster_whisper.tokenizer import Tokenizer
from faster_whisper.transcribe import TranscriptionOptions, get_ctranslate2_storage
from transformers import Pipeline
from transformers.pipelines.pt_utils import PipelineIterator
from .audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram
from .vad import load_vad_model, merge_chunks
from .types import TranscriptionResult, SingleSegment
from .types import SingleSegment, TranscriptionResult
from .vad import VoiceActivitySegmentation, load_vad_model, merge_chunks
def find_numeral_symbol_tokens(tokenizer):
numeral_symbol_tokens = []
@ -28,7 +32,13 @@ class WhisperModel(faster_whisper.WhisperModel):
Currently only works in non-timestamp mode and fixed prompt for all samples in batch.
'''
def generate_segment_batched(self, features: np.ndarray, tokenizer: faster_whisper.tokenizer.Tokenizer, options: faster_whisper.transcribe.TranscriptionOptions, encoder_output = None):
def generate_segment_batched(
self,
features: np.ndarray,
tokenizer: Tokenizer,
options: TranscriptionOptions,
encoder_output=None,
):
batch_size = features.shape[0]
all_tokens = []
prompt_reset_since = 0
@ -81,7 +91,7 @@ class WhisperModel(faster_whisper.WhisperModel):
# unsqueeze if batch size = 1
if len(features.shape) == 2:
features = np.expand_dims(features, 0)
features = faster_whisper.transcribe.get_ctranslate2_storage(features)
features = get_ctranslate2_storage(features)
return self.model.encode(features, to_cpu=to_cpu)
@ -94,17 +104,17 @@ class FasterWhisperPipeline(Pipeline):
# - add support for custom inference kwargs
def __init__(
self,
model,
vad,
vad_params: dict,
options : NamedTuple,
tokenizer=None,
device: Union[int, str, "torch.device"] = -1,
framework = "pt",
language : Optional[str] = None,
suppress_numerals: bool = False,
**kwargs
self,
model: WhisperModel,
vad: VoiceActivitySegmentation,
vad_params: dict,
options: TranscriptionOptions,
tokenizer: Optional[Tokenizer] = None,
device: Union[int, str, "torch.device"] = -1,
framework="pt",
language: Optional[str] = None,
suppress_numerals: bool = False,
**kwargs,
):
self.model = model
self.tokenizer = tokenizer
@ -156,7 +166,13 @@ class FasterWhisperPipeline(Pipeline):
return model_outputs
def get_iterator(
self, inputs, num_workers: int, batch_size: int, preprocess_params, forward_params, postprocess_params
self,
inputs,
num_workers: int,
batch_size: int,
preprocess_params: dict,
forward_params: dict,
postprocess_params: dict,
):
dataset = PipelineIterator(inputs, self.preprocess, preprocess_params)
if "TOKENIZERS_PARALLELISM" not in os.environ:
@ -171,7 +187,16 @@ class FasterWhisperPipeline(Pipeline):
return final_iterator
def transcribe(
self, audio: Union[str, np.ndarray], batch_size=None, num_workers=0, language=None, task=None, chunk_size=30, print_progress = False, combined_progress=False, verbose=False
self,
audio: Union[str, np.ndarray],
batch_size: Optional[int] = None,
num_workers=0,
language: Optional[str] = None,
task: Optional[str] = None,
chunk_size=30,
print_progress=False,
combined_progress=False,
verbose=False,
) -> TranscriptionResult:
if isinstance(audio, str):
audio = load_audio(audio)
@ -193,24 +218,30 @@ class FasterWhisperPipeline(Pipeline):
if self.tokenizer is None:
language = language or self.detect_language(audio)
task = task or "transcribe"
self.tokenizer = faster_whisper.tokenizer.Tokenizer(self.model.hf_tokenizer,
self.model.model.is_multilingual, task=task,
language=language)
self.tokenizer = Tokenizer(
self.model.hf_tokenizer,
self.model.model.is_multilingual,
task=task,
language=language,
)
else:
language = language or self.tokenizer.language_code
task = task or self.tokenizer.task
if task != self.tokenizer.task or language != self.tokenizer.language_code:
self.tokenizer = faster_whisper.tokenizer.Tokenizer(self.model.hf_tokenizer,
self.model.model.is_multilingual, task=task,
language=language)
self.tokenizer = Tokenizer(
self.model.hf_tokenizer,
self.model.model.is_multilingual,
task=task,
language=language,
)
if self.suppress_numerals:
previous_suppress_tokens = self.options.suppress_tokens
numeral_symbol_tokens = find_numeral_symbol_tokens(self.tokenizer)
print(f"Suppressing numeral and symbol tokens")
new_suppressed_tokens = numeral_symbol_tokens + self.options.suppress_tokens
new_suppressed_tokens = list(set(new_suppressed_tokens))
self.options = self.options._replace(suppress_tokens=new_suppressed_tokens)
self.options = replace(self.options, suppress_tokens=new_suppressed_tokens)
segments: List[SingleSegment] = []
batch_size = batch_size or self._batch_size
@ -239,12 +270,11 @@ class FasterWhisperPipeline(Pipeline):
# revert suppressed tokens if suppress_numerals is enabled
if self.suppress_numerals:
self.options = self.options._replace(suppress_tokens=previous_suppress_tokens)
self.options = replace(self.options, suppress_tokens=previous_suppress_tokens)
return {"segments": segments, "language": language}
def detect_language(self, audio: np.ndarray):
def detect_language(self, audio: np.ndarray) -> str:
if audio.shape[0] < N_SAMPLES:
print("Warning: audio is shorter than 30s, language detection may be inaccurate.")
model_n_mels = self.model.feat_kwargs.get("feature_size")
@ -258,33 +288,36 @@ class FasterWhisperPipeline(Pipeline):
print(f"Detected language: {language} ({language_probability:.2f}) in first 30s of audio...")
return language
def load_model(whisper_arch,
device,
device_index=0,
compute_type="float16",
asr_options=None,
language : Optional[str] = None,
vad_model=None,
vad_options=None,
model : Optional[WhisperModel] = None,
task="transcribe",
download_root=None,
local_files_only=False,
threads=4):
'''Load a Whisper model for inference.
def load_model(
whisper_arch: str,
device: str,
device_index=0,
compute_type="float16",
asr_options: Optional[dict] = None,
language: Optional[str] = None,
vad_model: Optional[VoiceActivitySegmentation] = None,
vad_options: Optional[dict] = None,
model: Optional[WhisperModel] = None,
task="transcribe",
download_root: Optional[str] = None,
local_files_only=False,
threads=4,
) -> FasterWhisperPipeline:
"""Load a Whisper model for inference.
Args:
whisper_arch: str - The name of the Whisper model to load.
device: str - The device to load the model on.
compute_type: str - The compute type to use for the model.
options: dict - A dictionary of options to use for the model.
language: str - The language of the model. (use English for now)
model: Optional[WhisperModel] - The WhisperModel instance to use.
download_root: Optional[str] - The root directory to download the model to.
local_files_only: bool - If `True`, avoid downloading the file and return the path to the local cached file if it exists.
threads: int - The number of cpu threads to use per worker, e.g. will be multiplied by num workers.
whisper_arch - The name of the Whisper model to load.
device - The device to load the model on.
compute_type - The compute type to use for the model.
options - A dictionary of options to use for the model.
language - The language of the model. (use English for now)
model - The WhisperModel instance to use.
download_root - The root directory to download the model to.
local_files_only - If `True`, avoid downloading the file and return the path to the local cached file if it exists.
threads - The number of cpu threads to use per worker, e.g. will be multiplied by num workers.
Returns:
A Whisper pipeline.
'''
"""
if whisper_arch.endswith(".en"):
language = "en"
@ -297,7 +330,7 @@ def load_model(whisper_arch,
local_files_only=local_files_only,
cpu_threads=threads)
if language is not None:
tokenizer = faster_whisper.tokenizer.Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task=task, language=language)
tokenizer = Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task=task, language=language)
else:
print("No language specified, language will be first be detected for each audio file (increases inference time).")
tokenizer = None
@ -338,7 +371,7 @@ def load_model(whisper_arch,
suppress_numerals = default_asr_options["suppress_numerals"]
del default_asr_options["suppress_numerals"]
default_asr_options = faster_whisper.transcribe.TranscriptionOptions(**default_asr_options)
default_asr_options = TranscriptionOptions(**default_asr_options)
default_vad_options = {
"vad_onset": 0.500,

View File

@ -22,7 +22,7 @@ FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame
TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token
def load_audio(file: str, sr: int = SAMPLE_RATE):
def load_audio(file: str, sr: int = SAMPLE_RATE) -> np.ndarray:
"""
Open an audio file and read as mono waveform, resampling as necessary

View File

@ -1,5 +1,8 @@
# conjunctions.py
from typing import Set
conjunctions_by_language = {
'en': {'and', 'whether', 'or', 'as', 'but', 'so', 'for', 'nor', 'which', 'yet', 'although', 'since', 'unless', 'when', 'while', 'because', 'if', 'how', 'that', 'than', 'who', 'where', 'what', 'near', 'before', 'after', 'across', 'through', 'until', 'once', 'whereas', 'even', 'both', 'either', 'neither', 'though'},
'fr': {'et', 'ou', 'mais', 'parce', 'bien', 'pendant', 'quand', '', 'comme', 'si', 'que', 'avant', 'après', 'aussitôt', 'jusquà', 'à', 'malgré', 'donc', 'tant', 'puisque', 'ni', 'soit', 'bien', 'encore', 'dès', 'lorsque'},
@ -36,8 +39,9 @@ commas_by_language = {
'ur': '،'
}
def get_conjunctions(lang_code):
def get_conjunctions(lang_code: str) -> Set[str]:
return conjunctions_by_language.get(lang_code, set())
def get_comma(lang_code):
return commas_by_language.get(lang_code, ',')
def get_comma(lang_code: str) -> str:
return commas_by_language.get(lang_code, ",")

View File

@ -5,6 +5,7 @@ from typing import Optional, Union
import torch
from .audio import load_audio, SAMPLE_RATE
from .types import TranscriptionResult, AlignedTranscriptionResult
class DiarizationPipeline:
@ -18,7 +19,13 @@ class DiarizationPipeline:
device = torch.device(device)
self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token).to(device)
def __call__(self, audio: Union[str, np.ndarray], num_speakers=None, min_speakers=None, max_speakers=None):
def __call__(
self,
audio: Union[str, np.ndarray],
num_speakers: Optional[int] = None,
min_speakers: Optional[int] = None,
max_speakers: Optional[int] = None,
):
if isinstance(audio, str):
audio = load_audio(audio)
audio_data = {
@ -32,7 +39,11 @@ class DiarizationPipeline:
return diarize_df
def assign_word_speakers(diarize_df, transcript_result, fill_nearest=False):
def assign_word_speakers(
diarize_df: pd.DataFrame,
transcript_result: Union[AlignedTranscriptionResult, TranscriptionResult],
fill_nearest=False,
) -> dict:
transcript_segments = transcript_result["segments"]
for seg in transcript_segments:
# assign speaker to segment (if any)

View File

@ -10,8 +10,15 @@ from .alignment import align, load_align_model
from .asr import load_model
from .audio import load_audio
from .diarize import DiarizationPipeline, assign_word_speakers
from .utils import (LANGUAGES, TO_LANGUAGE_CODE, get_writer, optional_float,
optional_int, str2bool)
from .types import AlignedTranscriptionResult, TranscriptionResult
from .utils import (
LANGUAGES,
TO_LANGUAGE_CODE,
get_writer,
optional_float,
optional_int,
str2bool,
)
def cli():
@ -95,7 +102,7 @@ def cli():
align_model: str = args.pop("align_model")
interpolate_method: str = args.pop("interpolate_method")
no_align: bool = args.pop("no_align")
task : str = args.pop("task")
task: str = args.pop("task")
if task == "translate":
# translation cannot be aligned
no_align = True
@ -174,7 +181,13 @@ def cli():
audio = load_audio(audio_path)
# >> VAD & ASR
print(">>Performing transcription...")
result = model.transcribe(audio, batch_size=batch_size, chunk_size=chunk_size, print_progress=print_progress, verbose=verbose)
result: TranscriptionResult = model.transcribe(
audio,
batch_size=batch_size,
chunk_size=chunk_size,
print_progress=print_progress,
verbose=verbose,
)
results.append((result, audio_path))
# Unload Whisper and VAD
@ -201,7 +214,16 @@ def cli():
print(f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language...")
align_model, align_metadata = load_align_model(result["language"], device)
print(">>Performing alignment...")
result = align(result["segments"], align_model, align_metadata, input_audio, device, interpolate_method=interpolate_method, return_char_alignments=return_char_alignments, print_progress=print_progress)
result: AlignedTranscriptionResult = align(
result["segments"],
align_model,
align_metadata,
input_audio,
device,
interpolate_method=interpolate_method,
return_char_alignments=return_char_alignments,
print_progress=print_progress,
)
results.append((result, audio_path))

View File

@ -214,7 +214,12 @@ class WriteTXT(ResultWriter):
def write_result(self, result: dict, file: TextIO, options: dict):
for segment in result["segments"]:
print(segment["text"].strip(), file=file, flush=True)
speaker = segment.get("speaker")
text = segment["text"].strip()
if speaker is not None:
print(f"[{speaker}]: {text}", file=file, flush=True)
else:
print(text, file=file, flush=True)
class SubtitlesWriter(ResultWriter):