9 Commits

10 changed files with 181 additions and 71 deletions

35
.github/workflows/tmp.yml vendored Normal file
View File

@ -0,0 +1,35 @@
name: Python Compatibility Test (PyPi)
on:
push:
branches: [main]
pull_request:
branches: [main]
workflow_dispatch: # Allows manual triggering from GitHub UI
jobs:
test:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.9", "3.10", "3.11", "3.12"]
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install package
run: |
pip install whisperx
- name: Print packages
run: |
pip list
- name: Test import
run: |
python -c "import whisperx; print('Successfully imported whisperx')"

View File

@ -9,7 +9,7 @@ with open("README.md", "r", encoding="utf-8") as f:
setup( setup(
name="whisperx", name="whisperx",
py_modules=["whisperx"], py_modules=["whisperx"],
version="3.3.0", version="3.3.1",
description="Time-Accurate Automatic Speech Recognition using Whisper.", description="Time-Accurate Automatic Speech Recognition using Whisper.",
long_description=long_description, long_description=long_description,
long_description_content_type="text/markdown", long_description_content_type="text/markdown",

View File

@ -1,5 +1,5 @@
import math import math
from conjunctions import get_conjunctions, get_comma from .conjunctions import get_conjunctions, get_comma
from typing import TextIO from typing import TextIO
def normal_round(n): def normal_round(n):

View File

@ -3,7 +3,7 @@ Forced Alignment with Whisper
C. Max Bain C. Max Bain
""" """
from dataclasses import dataclass from dataclasses import dataclass
from typing import Iterable, Union, List from typing import Iterable, Optional, Union, List
import numpy as np import numpy as np
import pandas as pd import pandas as pd
@ -65,7 +65,7 @@ DEFAULT_ALIGN_MODELS_HF = {
} }
def load_align_model(language_code, device, model_name=None, model_dir=None): def load_align_model(language_code: str, device: str, model_name: Optional[str] = None, model_dir=None):
if model_name is None: if model_name is None:
# use default model # use default model
if language_code in DEFAULT_ALIGN_MODELS_TORCH: if language_code in DEFAULT_ALIGN_MODELS_TORCH:

View File

@ -1,17 +1,21 @@
import os import os
import warnings import warnings
from typing import List, Union, Optional, NamedTuple from typing import List, NamedTuple, Optional, Union
from dataclasses import replace
import ctranslate2 import ctranslate2
import faster_whisper import faster_whisper
import numpy as np import numpy as np
import torch import torch
from faster_whisper.tokenizer import Tokenizer
from faster_whisper.transcribe import TranscriptionOptions, get_ctranslate2_storage
from transformers import Pipeline from transformers import Pipeline
from transformers.pipelines.pt_utils import PipelineIterator from transformers.pipelines.pt_utils import PipelineIterator
from .audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram from .audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram
from .vad import load_vad_model, merge_chunks from .types import SingleSegment, TranscriptionResult
from .types import TranscriptionResult, SingleSegment from .vad import VoiceActivitySegmentation, load_vad_model, merge_chunks
def find_numeral_symbol_tokens(tokenizer): def find_numeral_symbol_tokens(tokenizer):
numeral_symbol_tokens = [] numeral_symbol_tokens = []
@ -28,7 +32,13 @@ class WhisperModel(faster_whisper.WhisperModel):
Currently only works in non-timestamp mode and fixed prompt for all samples in batch. Currently only works in non-timestamp mode and fixed prompt for all samples in batch.
''' '''
def generate_segment_batched(self, features: np.ndarray, tokenizer: faster_whisper.tokenizer.Tokenizer, options: faster_whisper.transcribe.TranscriptionOptions, encoder_output = None): def generate_segment_batched(
self,
features: np.ndarray,
tokenizer: Tokenizer,
options: TranscriptionOptions,
encoder_output=None,
):
batch_size = features.shape[0] batch_size = features.shape[0]
all_tokens = [] all_tokens = []
prompt_reset_since = 0 prompt_reset_since = 0
@ -81,7 +91,7 @@ class WhisperModel(faster_whisper.WhisperModel):
# unsqueeze if batch size = 1 # unsqueeze if batch size = 1
if len(features.shape) == 2: if len(features.shape) == 2:
features = np.expand_dims(features, 0) features = np.expand_dims(features, 0)
features = faster_whisper.transcribe.get_ctranslate2_storage(features) features = get_ctranslate2_storage(features)
return self.model.encode(features, to_cpu=to_cpu) return self.model.encode(features, to_cpu=to_cpu)
@ -95,16 +105,16 @@ class FasterWhisperPipeline(Pipeline):
def __init__( def __init__(
self, self,
model, model: WhisperModel,
vad, vad: VoiceActivitySegmentation,
vad_params: dict, vad_params: dict,
options : NamedTuple, options: TranscriptionOptions,
tokenizer=None, tokenizer: Optional[Tokenizer] = None,
device: Union[int, str, "torch.device"] = -1, device: Union[int, str, "torch.device"] = -1,
framework = "pt", framework="pt",
language : Optional[str] = None, language: Optional[str] = None,
suppress_numerals: bool = False, suppress_numerals: bool = False,
**kwargs **kwargs,
): ):
self.model = model self.model = model
self.tokenizer = tokenizer self.tokenizer = tokenizer
@ -156,7 +166,13 @@ class FasterWhisperPipeline(Pipeline):
return model_outputs return model_outputs
def get_iterator( def get_iterator(
self, inputs, num_workers: int, batch_size: int, preprocess_params, forward_params, postprocess_params self,
inputs,
num_workers: int,
batch_size: int,
preprocess_params: dict,
forward_params: dict,
postprocess_params: dict,
): ):
dataset = PipelineIterator(inputs, self.preprocess, preprocess_params) dataset = PipelineIterator(inputs, self.preprocess, preprocess_params)
if "TOKENIZERS_PARALLELISM" not in os.environ: if "TOKENIZERS_PARALLELISM" not in os.environ:
@ -171,7 +187,16 @@ class FasterWhisperPipeline(Pipeline):
return final_iterator return final_iterator
def transcribe( def transcribe(
self, audio: Union[str, np.ndarray], batch_size=None, num_workers=0, language=None, task=None, chunk_size=30, print_progress = False, combined_progress=False, verbose=False self,
audio: Union[str, np.ndarray],
batch_size: Optional[int] = None,
num_workers=0,
language: Optional[str] = None,
task: Optional[str] = None,
chunk_size=30,
print_progress=False,
combined_progress=False,
verbose=False,
) -> TranscriptionResult: ) -> TranscriptionResult:
if isinstance(audio, str): if isinstance(audio, str):
audio = load_audio(audio) audio = load_audio(audio)
@ -193,16 +218,22 @@ class FasterWhisperPipeline(Pipeline):
if self.tokenizer is None: if self.tokenizer is None:
language = language or self.detect_language(audio) language = language or self.detect_language(audio)
task = task or "transcribe" task = task or "transcribe"
self.tokenizer = faster_whisper.tokenizer.Tokenizer(self.model.hf_tokenizer, self.tokenizer = Tokenizer(
self.model.model.is_multilingual, task=task, self.model.hf_tokenizer,
language=language) self.model.model.is_multilingual,
task=task,
language=language,
)
else: else:
language = language or self.tokenizer.language_code language = language or self.tokenizer.language_code
task = task or self.tokenizer.task task = task or self.tokenizer.task
if task != self.tokenizer.task or language != self.tokenizer.language_code: if task != self.tokenizer.task or language != self.tokenizer.language_code:
self.tokenizer = faster_whisper.tokenizer.Tokenizer(self.model.hf_tokenizer, self.tokenizer = Tokenizer(
self.model.model.is_multilingual, task=task, self.model.hf_tokenizer,
language=language) self.model.model.is_multilingual,
task=task,
language=language,
)
if self.suppress_numerals: if self.suppress_numerals:
previous_suppress_tokens = self.options.suppress_tokens previous_suppress_tokens = self.options.suppress_tokens
@ -210,7 +241,7 @@ class FasterWhisperPipeline(Pipeline):
print(f"Suppressing numeral and symbol tokens") print(f"Suppressing numeral and symbol tokens")
new_suppressed_tokens = numeral_symbol_tokens + self.options.suppress_tokens new_suppressed_tokens = numeral_symbol_tokens + self.options.suppress_tokens
new_suppressed_tokens = list(set(new_suppressed_tokens)) new_suppressed_tokens = list(set(new_suppressed_tokens))
self.options = self.options._replace(suppress_tokens=new_suppressed_tokens) self.options = replace(self.options, suppress_tokens=new_suppressed_tokens)
segments: List[SingleSegment] = [] segments: List[SingleSegment] = []
batch_size = batch_size or self._batch_size batch_size = batch_size or self._batch_size
@ -239,12 +270,11 @@ class FasterWhisperPipeline(Pipeline):
# revert suppressed tokens if suppress_numerals is enabled # revert suppressed tokens if suppress_numerals is enabled
if self.suppress_numerals: if self.suppress_numerals:
self.options = self.options._replace(suppress_tokens=previous_suppress_tokens) self.options = replace(self.options, suppress_tokens=previous_suppress_tokens)
return {"segments": segments, "language": language} return {"segments": segments, "language": language}
def detect_language(self, audio: np.ndarray) -> str:
def detect_language(self, audio: np.ndarray):
if audio.shape[0] < N_SAMPLES: if audio.shape[0] < N_SAMPLES:
print("Warning: audio is shorter than 30s, language detection may be inaccurate.") print("Warning: audio is shorter than 30s, language detection may be inaccurate.")
model_n_mels = self.model.feat_kwargs.get("feature_size") model_n_mels = self.model.feat_kwargs.get("feature_size")
@ -258,33 +288,36 @@ class FasterWhisperPipeline(Pipeline):
print(f"Detected language: {language} ({language_probability:.2f}) in first 30s of audio...") print(f"Detected language: {language} ({language_probability:.2f}) in first 30s of audio...")
return language return language
def load_model(whisper_arch,
device, def load_model(
whisper_arch: str,
device: str,
device_index=0, device_index=0,
compute_type="float16", compute_type="float16",
asr_options=None, asr_options: Optional[dict] = None,
language : Optional[str] = None, language: Optional[str] = None,
vad_model=None, vad_model: Optional[VoiceActivitySegmentation] = None,
vad_options=None, vad_options: Optional[dict] = None,
model : Optional[WhisperModel] = None, model: Optional[WhisperModel] = None,
task="transcribe", task="transcribe",
download_root=None, download_root: Optional[str] = None,
local_files_only=False, local_files_only=False,
threads=4): threads=4,
'''Load a Whisper model for inference. ) -> FasterWhisperPipeline:
"""Load a Whisper model for inference.
Args: Args:
whisper_arch: str - The name of the Whisper model to load. whisper_arch - The name of the Whisper model to load.
device: str - The device to load the model on. device - The device to load the model on.
compute_type: str - The compute type to use for the model. compute_type - The compute type to use for the model.
options: dict - A dictionary of options to use for the model. options - A dictionary of options to use for the model.
language: str - The language of the model. (use English for now) language - The language of the model. (use English for now)
model: Optional[WhisperModel] - The WhisperModel instance to use. model - The WhisperModel instance to use.
download_root: Optional[str] - The root directory to download the model to. download_root - The root directory to download the model to.
local_files_only: bool - If `True`, avoid downloading the file and return the path to the local cached file if it exists. local_files_only - If `True`, avoid downloading the file and return the path to the local cached file if it exists.
threads: int - The number of cpu threads to use per worker, e.g. will be multiplied by num workers. threads - The number of cpu threads to use per worker, e.g. will be multiplied by num workers.
Returns: Returns:
A Whisper pipeline. A Whisper pipeline.
''' """
if whisper_arch.endswith(".en"): if whisper_arch.endswith(".en"):
language = "en" language = "en"
@ -297,7 +330,7 @@ def load_model(whisper_arch,
local_files_only=local_files_only, local_files_only=local_files_only,
cpu_threads=threads) cpu_threads=threads)
if language is not None: if language is not None:
tokenizer = faster_whisper.tokenizer.Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task=task, language=language) tokenizer = Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task=task, language=language)
else: else:
print("No language specified, language will be first be detected for each audio file (increases inference time).") print("No language specified, language will be first be detected for each audio file (increases inference time).")
tokenizer = None tokenizer = None
@ -338,7 +371,7 @@ def load_model(whisper_arch,
suppress_numerals = default_asr_options["suppress_numerals"] suppress_numerals = default_asr_options["suppress_numerals"]
del default_asr_options["suppress_numerals"] del default_asr_options["suppress_numerals"]
default_asr_options = faster_whisper.transcribe.TranscriptionOptions(**default_asr_options) default_asr_options = TranscriptionOptions(**default_asr_options)
default_vad_options = { default_vad_options = {
"vad_onset": 0.500, "vad_onset": 0.500,

View File

@ -22,7 +22,7 @@ FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame
TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token
def load_audio(file: str, sr: int = SAMPLE_RATE): def load_audio(file: str, sr: int = SAMPLE_RATE) -> np.ndarray:
""" """
Open an audio file and read as mono waveform, resampling as necessary Open an audio file and read as mono waveform, resampling as necessary

View File

@ -1,5 +1,8 @@
# conjunctions.py # conjunctions.py
from typing import Set
conjunctions_by_language = { conjunctions_by_language = {
'en': {'and', 'whether', 'or', 'as', 'but', 'so', 'for', 'nor', 'which', 'yet', 'although', 'since', 'unless', 'when', 'while', 'because', 'if', 'how', 'that', 'than', 'who', 'where', 'what', 'near', 'before', 'after', 'across', 'through', 'until', 'once', 'whereas', 'even', 'both', 'either', 'neither', 'though'}, 'en': {'and', 'whether', 'or', 'as', 'but', 'so', 'for', 'nor', 'which', 'yet', 'although', 'since', 'unless', 'when', 'while', 'because', 'if', 'how', 'that', 'than', 'who', 'where', 'what', 'near', 'before', 'after', 'across', 'through', 'until', 'once', 'whereas', 'even', 'both', 'either', 'neither', 'though'},
'fr': {'et', 'ou', 'mais', 'parce', 'bien', 'pendant', 'quand', '', 'comme', 'si', 'que', 'avant', 'après', 'aussitôt', 'jusquà', 'à', 'malgré', 'donc', 'tant', 'puisque', 'ni', 'soit', 'bien', 'encore', 'dès', 'lorsque'}, 'fr': {'et', 'ou', 'mais', 'parce', 'bien', 'pendant', 'quand', '', 'comme', 'si', 'que', 'avant', 'après', 'aussitôt', 'jusquà', 'à', 'malgré', 'donc', 'tant', 'puisque', 'ni', 'soit', 'bien', 'encore', 'dès', 'lorsque'},
@ -36,8 +39,9 @@ commas_by_language = {
'ur': '،' 'ur': '،'
} }
def get_conjunctions(lang_code): def get_conjunctions(lang_code: str) -> Set[str]:
return conjunctions_by_language.get(lang_code, set()) return conjunctions_by_language.get(lang_code, set())
def get_comma(lang_code):
return commas_by_language.get(lang_code, ',') def get_comma(lang_code: str) -> str:
return commas_by_language.get(lang_code, ",")

View File

@ -5,6 +5,7 @@ from typing import Optional, Union
import torch import torch
from .audio import load_audio, SAMPLE_RATE from .audio import load_audio, SAMPLE_RATE
from .types import TranscriptionResult, AlignedTranscriptionResult
class DiarizationPipeline: class DiarizationPipeline:
@ -18,7 +19,13 @@ class DiarizationPipeline:
device = torch.device(device) device = torch.device(device)
self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token).to(device) self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token).to(device)
def __call__(self, audio: Union[str, np.ndarray], num_speakers=None, min_speakers=None, max_speakers=None): def __call__(
self,
audio: Union[str, np.ndarray],
num_speakers: Optional[int] = None,
min_speakers: Optional[int] = None,
max_speakers: Optional[int] = None,
):
if isinstance(audio, str): if isinstance(audio, str):
audio = load_audio(audio) audio = load_audio(audio)
audio_data = { audio_data = {
@ -32,7 +39,11 @@ class DiarizationPipeline:
return diarize_df return diarize_df
def assign_word_speakers(diarize_df, transcript_result, fill_nearest=False): def assign_word_speakers(
diarize_df: pd.DataFrame,
transcript_result: Union[AlignedTranscriptionResult, TranscriptionResult],
fill_nearest=False,
) -> dict:
transcript_segments = transcript_result["segments"] transcript_segments = transcript_result["segments"]
for seg in transcript_segments: for seg in transcript_segments:
# assign speaker to segment (if any) # assign speaker to segment (if any)

View File

@ -10,8 +10,15 @@ from .alignment import align, load_align_model
from .asr import load_model from .asr import load_model
from .audio import load_audio from .audio import load_audio
from .diarize import DiarizationPipeline, assign_word_speakers from .diarize import DiarizationPipeline, assign_word_speakers
from .utils import (LANGUAGES, TO_LANGUAGE_CODE, get_writer, optional_float, from .types import AlignedTranscriptionResult, TranscriptionResult
optional_int, str2bool) from .utils import (
LANGUAGES,
TO_LANGUAGE_CODE,
get_writer,
optional_float,
optional_int,
str2bool,
)
def cli(): def cli():
@ -95,7 +102,7 @@ def cli():
align_model: str = args.pop("align_model") align_model: str = args.pop("align_model")
interpolate_method: str = args.pop("interpolate_method") interpolate_method: str = args.pop("interpolate_method")
no_align: bool = args.pop("no_align") no_align: bool = args.pop("no_align")
task : str = args.pop("task") task: str = args.pop("task")
if task == "translate": if task == "translate":
# translation cannot be aligned # translation cannot be aligned
no_align = True no_align = True
@ -174,7 +181,13 @@ def cli():
audio = load_audio(audio_path) audio = load_audio(audio_path)
# >> VAD & ASR # >> VAD & ASR
print(">>Performing transcription...") print(">>Performing transcription...")
result = model.transcribe(audio, batch_size=batch_size, chunk_size=chunk_size, print_progress=print_progress, verbose=verbose) result: TranscriptionResult = model.transcribe(
audio,
batch_size=batch_size,
chunk_size=chunk_size,
print_progress=print_progress,
verbose=verbose,
)
results.append((result, audio_path)) results.append((result, audio_path))
# Unload Whisper and VAD # Unload Whisper and VAD
@ -201,7 +214,16 @@ def cli():
print(f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language...") print(f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language...")
align_model, align_metadata = load_align_model(result["language"], device) align_model, align_metadata = load_align_model(result["language"], device)
print(">>Performing alignment...") print(">>Performing alignment...")
result = align(result["segments"], align_model, align_metadata, input_audio, device, interpolate_method=interpolate_method, return_char_alignments=return_char_alignments, print_progress=print_progress) result: AlignedTranscriptionResult = align(
result["segments"],
align_model,
align_metadata,
input_audio,
device,
interpolate_method=interpolate_method,
return_char_alignments=return_char_alignments,
print_progress=print_progress,
)
results.append((result, audio_path)) results.append((result, audio_path))

View File

@ -214,7 +214,12 @@ class WriteTXT(ResultWriter):
def write_result(self, result: dict, file: TextIO, options: dict): def write_result(self, result: dict, file: TextIO, options: dict):
for segment in result["segments"]: for segment in result["segments"]:
print(segment["text"].strip(), file=file, flush=True) speaker = segment.get("speaker")
text = segment["text"].strip()
if speaker is not None:
print(f"[{speaker}]: {text}", file=file, flush=True)
else:
print(text, file=file, flush=True)
class SubtitlesWriter(ResultWriter): class SubtitlesWriter(ResultWriter):