2 Commits

Author SHA1 Message Date
1843f3553a Merge c72c627d10 into d700b56c9c 2025-06-17 19:02:36 +02:00
c72c627d10 add on_progress callback 2025-01-25 22:29:55 -03:00
7 changed files with 1591 additions and 1640 deletions

View File

@ -2,7 +2,7 @@
urls = { repository = "https://github.com/m-bain/whisperx" }
authors = [{ name = "Max Bain" }]
name = "whisperx"
version = "3.4.2"
version = "3.3.4"
description = "Time-Accurate Automatic Speech Recognition using Whisper."
readme = "README.md"
requires-python = ">=3.9, <3.13"

3095
uv.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -44,7 +44,6 @@ def cli():
parser.add_argument("--min_speakers", default=None, type=int, help="Minimum number of speakers to in audio file")
parser.add_argument("--max_speakers", default=None, type=int, help="Maximum number of speakers to in audio file")
parser.add_argument("--diarize_model", default="pyannote/speaker-diarization-3.1", type=str, help="Name of the speaker diarization model to use")
parser.add_argument("--speaker_embeddings", action="store_true", help="Include speaker embeddings in JSON output (only works with --diarize)")
parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")

View File

@ -5,7 +5,7 @@ C. Max Bain
import math
from dataclasses import dataclass
from typing import Iterable, Optional, Union, List
from typing import Iterable, Union, List, Callable, Optional
import numpy as np
import pandas as pd
@ -120,6 +120,7 @@ def align(
return_char_alignments: bool = False,
print_progress: bool = False,
combined_progress: bool = False,
on_progress: Callable[[int, int], None] = None
) -> AlignedTranscriptionResult:
"""
Align phoneme recognition predictions to known transcription.
@ -149,6 +150,9 @@ def align(
percent_complete = (50 + base_progress / 2) if combined_progress else base_progress
print(f"Progress: {percent_complete:.2f}%...")
if on_progress:
on_progress(sdx + 1, total_segments)
num_leading = len(segment["text"]) - len(segment["text"].lstrip())
num_trailing = len(segment["text"]) - len(segment["text"].rstrip())
text = segment["text"]
@ -424,7 +428,7 @@ def get_wildcard_emission(frame_emission, tokens, blank_id):
wildcard_mask = (tokens == -1)
# Get scores for non-wildcard positions
regular_scores = frame_emission[tokens.clamp(min=0).long()] # clamp to avoid -1 index
regular_scores = frame_emission[tokens.clamp(min=0)] # clamp to avoid -1 index
# Create a mask and compute the maximum value without modifying frame_emission
max_valid_score = frame_emission.clone() # Create a copy

View File

@ -1,6 +1,8 @@
import os
from typing import List, Optional, Union
from dataclasses import replace
import warnings
from typing import List, Union, Optional, NamedTuple, Callable
from enum import Enum
import ctranslate2
import faster_whisper
@ -103,6 +105,12 @@ class FasterWhisperPipeline(Pipeline):
# - add support for timestamp mode
# - add support for custom inference kwargs
class TranscriptionState(Enum):
LOADING_AUDIO = "loading_audio"
GENERATING_VAD_SEGMENTS = "generating_vad_segments"
TRANSCRIBING = "transcribing"
FINISHED = "finished"
def __init__(
self,
model: WhisperModel,
@ -197,8 +205,12 @@ class FasterWhisperPipeline(Pipeline):
print_progress=False,
combined_progress=False,
verbose=False,
on_progress: Callable[[TranscriptionState, Optional[int], Optional[int]], None] = None,
) -> TranscriptionResult:
if isinstance(audio, str):
if on_progress:
on_progress(self.__class__.TranscriptionState.LOADING_AUDIO)
audio = load_audio(audio)
def data(audio, segments):
@ -216,6 +228,8 @@ class FasterWhisperPipeline(Pipeline):
else:
waveform = Pyannote.preprocess_audio(audio)
merge_chunks = Pyannote.merge_chunks
if on_progress:
on_progress(self.__class__.TranscriptionState.GENERATING_VAD_SEGMENTS)
vad_segments = self.vad_model({"waveform": waveform, "sample_rate": SAMPLE_RATE})
vad_segments = merge_chunks(
@ -255,16 +269,22 @@ class FasterWhisperPipeline(Pipeline):
segments: List[SingleSegment] = []
batch_size = batch_size or self._batch_size
total_segments = len(vad_segments)
if on_progress:
on_progress(self.__class__.TranscriptionState.TRANSCRIBING, 0, total_segments)
for idx, out in enumerate(self.__call__(data(audio, vad_segments), batch_size=batch_size, num_workers=num_workers)):
if print_progress:
base_progress = ((idx + 1) / total_segments) * 100
percent_complete = base_progress / 2 if combined_progress else base_progress
print(f"Progress: {percent_complete:.2f}%...")
if on_progress:
on_progress(self.__class__.TranscriptionState.TRANSCRIBING, idx + 1, total_segments)
text = out['text']
if batch_size in [0, 1, None]:
text = text[0]
if verbose:
print(f"Transcript: [{round(vad_segments[idx]['start'], 3)} --> {round(vad_segments[idx]['end'], 3)}] {text}")
segments.append(
{
"text": text,
@ -273,6 +293,9 @@ class FasterWhisperPipeline(Pipeline):
}
)
if on_progress:
on_progress(self.__class__.TranscriptionState.FINISHED)
# revert the tokenizer if multilingual inference is enabled
if self.preset_language is None:
self.tokenizer = None

View File

@ -26,81 +26,25 @@ class DiarizationPipeline:
num_speakers: Optional[int] = None,
min_speakers: Optional[int] = None,
max_speakers: Optional[int] = None,
return_embeddings: bool = False,
) -> Union[tuple[pd.DataFrame, Optional[dict[str, list[float]]]], pd.DataFrame]:
"""
Perform speaker diarization on audio.
Args:
audio: Path to audio file or audio array
num_speakers: Exact number of speakers (if known)
min_speakers: Minimum number of speakers to detect
max_speakers: Maximum number of speakers to detect
return_embeddings: Whether to return speaker embeddings
Returns:
If return_embeddings is True:
Tuple of (diarization dataframe, speaker embeddings dictionary)
Otherwise:
Just the diarization dataframe
"""
):
if isinstance(audio, str):
audio = load_audio(audio)
audio_data = {
'waveform': torch.from_numpy(audio[None, :]),
'sample_rate': SAMPLE_RATE
}
if return_embeddings:
diarization, embeddings = self.model(
audio_data,
num_speakers=num_speakers,
min_speakers=min_speakers,
max_speakers=max_speakers,
return_embeddings=True,
)
else:
diarization = self.model(
audio_data,
num_speakers=num_speakers,
min_speakers=min_speakers,
max_speakers=max_speakers,
)
embeddings = None
diarize_df = pd.DataFrame(diarization.itertracks(yield_label=True), columns=['segment', 'label', 'speaker'])
segments = self.model(audio_data, num_speakers = num_speakers, min_speakers=min_speakers, max_speakers=max_speakers)
diarize_df = pd.DataFrame(segments.itertracks(yield_label=True), columns=['segment', 'label', 'speaker'])
diarize_df['start'] = diarize_df['segment'].apply(lambda x: x.start)
diarize_df['end'] = diarize_df['segment'].apply(lambda x: x.end)
if return_embeddings and embeddings is not None:
speaker_embeddings = {speaker: embeddings[s].tolist() for s, speaker in enumerate(diarization.labels())}
return diarize_df, speaker_embeddings
# For backwards compatibility
if return_embeddings:
return diarize_df, None
else:
return diarize_df
return diarize_df
def assign_word_speakers(
diarize_df: pd.DataFrame,
transcript_result: Union[AlignedTranscriptionResult, TranscriptionResult],
speaker_embeddings: Optional[dict[str, list[float]]] = None,
fill_nearest: bool = False,
) -> Union[AlignedTranscriptionResult, TranscriptionResult]:
"""
Assign speakers to words and segments in the transcript.
Args:
diarize_df: Diarization dataframe from DiarizationPipeline
transcript_result: Transcription result to augment with speaker labels
speaker_embeddings: Optional dictionary mapping speaker IDs to embedding vectors
fill_nearest: If True, assign speakers even when there's no direct time overlap
Returns:
Updated transcript_result with speaker assignments and optionally embeddings
"""
fill_nearest=False,
) -> dict:
transcript_segments = transcript_result["segments"]
for seg in transcript_segments:
# assign speaker to segment (if any)
@ -132,10 +76,6 @@ def assign_word_speakers(
speaker = dia_tmp.groupby("speaker")["intersection"].sum().sort_values(ascending=False).index[0]
word["speaker"] = speaker
# Add speaker embeddings to the result if provided
if speaker_embeddings is not None:
transcript_result["speaker_embeddings"] = speaker_embeddings
return transcript_result

View File

@ -59,10 +59,6 @@ def transcribe_task(args: dict, parser: argparse.ArgumentParser):
max_speakers: int = args.pop("max_speakers")
diarize_model_name: str = args.pop("diarize_model")
print_progress: bool = args.pop("print_progress")
return_speaker_embeddings: bool = args.pop("speaker_embeddings")
if return_speaker_embeddings and not diarize:
warnings.warn("--speaker_embeddings has no effect without --diarize")
if args["language"] is not None:
args["language"] = args["language"].lower()
@ -213,20 +209,10 @@ def transcribe_task(args: dict, parser: argparse.ArgumentParser):
results = []
diarize_model = DiarizationPipeline(model_name=diarize_model_name, use_auth_token=hf_token, device=device)
for result, input_audio_path in tmp_results:
diarize_result = diarize_model(
input_audio_path,
min_speakers=min_speakers,
max_speakers=max_speakers,
return_embeddings=return_speaker_embeddings
diarize_segments = diarize_model(
input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers
)
if return_speaker_embeddings:
diarize_segments, speaker_embeddings = diarize_result
else:
diarize_segments = diarize_result
speaker_embeddings = None
result = assign_word_speakers(diarize_segments, result, speaker_embeddings)
result = assign_word_speakers(diarize_segments, result)
results.append((result, input_audio_path))
# >> Write
for result, audio_path in results: