5 Commits

Author SHA1 Message Date
ffedc5cdf0 fix: speaker embedding bug (#1178)
* fix: improve handling of speaker embeddings in transcribe_task

* chore: bump version to 3.4.1
2025-06-25 13:55:20 +02:00
b93e9b6f57 chore: bump version to 3.4.0 2025-06-24 16:21:23 +02:00
844736e4e4 style: minor code formatting 2025-06-24 15:01:09 +02:00
220fec9aea refactor: update type hints in diarization module (PEP 585) 2025-06-24 15:01:09 +02:00
1631c3040f feat: enhance diarization with optional output of speaker embeddings
- Updated DiarizationPipeline to include a return_embeddings parameter for optional speaker embeddings.
- Modified assign_word_speakers to accept and process speaker embeddings.
- Updated CLI to support --speaker_embeddings flag for JSON output.
- Ensured backward compatibility for existing functionality.
2025-06-24 15:01:09 +02:00
7 changed files with 1639 additions and 1590 deletions

View File

@ -2,7 +2,7 @@
urls = { repository = "https://github.com/m-bain/whisperx" }
authors = [{ name = "Max Bain" }]
name = "whisperx"
version = "3.3.4"
version = "3.4.1"
description = "Time-Accurate Automatic Speech Recognition using Whisper."
readme = "README.md"
requires-python = ">=3.9, <3.13"

3095
uv.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -44,6 +44,7 @@ def cli():
parser.add_argument("--min_speakers", default=None, type=int, help="Minimum number of speakers to in audio file")
parser.add_argument("--max_speakers", default=None, type=int, help="Maximum number of speakers to in audio file")
parser.add_argument("--diarize_model", default="pyannote/speaker-diarization-3.1", type=str, help="Name of the speaker diarization model to use")
parser.add_argument("--speaker_embeddings", action="store_true", help="Include speaker embeddings in JSON output (only works with --diarize)")
parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")

View File

@ -5,7 +5,7 @@ C. Max Bain
import math
from dataclasses import dataclass
from typing import Iterable, Union, List, Callable, Optional
from typing import Iterable, Optional, Union, List
import numpy as np
import pandas as pd
@ -120,7 +120,6 @@ def align(
return_char_alignments: bool = False,
print_progress: bool = False,
combined_progress: bool = False,
on_progress: Callable[[int, int], None] = None
) -> AlignedTranscriptionResult:
"""
Align phoneme recognition predictions to known transcription.
@ -150,9 +149,6 @@ def align(
percent_complete = (50 + base_progress / 2) if combined_progress else base_progress
print(f"Progress: {percent_complete:.2f}%...")
if on_progress:
on_progress(sdx + 1, total_segments)
num_leading = len(segment["text"]) - len(segment["text"].lstrip())
num_trailing = len(segment["text"]) - len(segment["text"].rstrip())
text = segment["text"]

View File

@ -1,8 +1,6 @@
import os
from typing import List, Optional, Union
from dataclasses import replace
import warnings
from typing import List, Union, Optional, NamedTuple, Callable
from enum import Enum
import ctranslate2
import faster_whisper
@ -105,12 +103,6 @@ class FasterWhisperPipeline(Pipeline):
# - add support for timestamp mode
# - add support for custom inference kwargs
class TranscriptionState(Enum):
LOADING_AUDIO = "loading_audio"
GENERATING_VAD_SEGMENTS = "generating_vad_segments"
TRANSCRIBING = "transcribing"
FINISHED = "finished"
def __init__(
self,
model: WhisperModel,
@ -205,12 +197,8 @@ class FasterWhisperPipeline(Pipeline):
print_progress=False,
combined_progress=False,
verbose=False,
on_progress: Callable[[TranscriptionState, Optional[int], Optional[int]], None] = None,
) -> TranscriptionResult:
if isinstance(audio, str):
if on_progress:
on_progress(self.__class__.TranscriptionState.LOADING_AUDIO)
audio = load_audio(audio)
def data(audio, segments):
@ -228,8 +216,6 @@ class FasterWhisperPipeline(Pipeline):
else:
waveform = Pyannote.preprocess_audio(audio)
merge_chunks = Pyannote.merge_chunks
if on_progress:
on_progress(self.__class__.TranscriptionState.GENERATING_VAD_SEGMENTS)
vad_segments = self.vad_model({"waveform": waveform, "sample_rate": SAMPLE_RATE})
vad_segments = merge_chunks(
@ -269,22 +255,16 @@ class FasterWhisperPipeline(Pipeline):
segments: List[SingleSegment] = []
batch_size = batch_size or self._batch_size
total_segments = len(vad_segments)
if on_progress:
on_progress(self.__class__.TranscriptionState.TRANSCRIBING, 0, total_segments)
for idx, out in enumerate(self.__call__(data(audio, vad_segments), batch_size=batch_size, num_workers=num_workers)):
if print_progress:
base_progress = ((idx + 1) / total_segments) * 100
percent_complete = base_progress / 2 if combined_progress else base_progress
print(f"Progress: {percent_complete:.2f}%...")
if on_progress:
on_progress(self.__class__.TranscriptionState.TRANSCRIBING, idx + 1, total_segments)
text = out['text']
if batch_size in [0, 1, None]:
text = text[0]
if verbose:
print(f"Transcript: [{round(vad_segments[idx]['start'], 3)} --> {round(vad_segments[idx]['end'], 3)}] {text}")
segments.append(
{
"text": text,
@ -293,9 +273,6 @@ class FasterWhisperPipeline(Pipeline):
}
)
if on_progress:
on_progress(self.__class__.TranscriptionState.FINISHED)
# revert the tokenizer if multilingual inference is enabled
if self.preset_language is None:
self.tokenizer = None

View File

@ -26,25 +26,81 @@ class DiarizationPipeline:
num_speakers: Optional[int] = None,
min_speakers: Optional[int] = None,
max_speakers: Optional[int] = None,
):
return_embeddings: bool = False,
) -> Union[tuple[pd.DataFrame, Optional[dict[str, list[float]]]], pd.DataFrame]:
"""
Perform speaker diarization on audio.
Args:
audio: Path to audio file or audio array
num_speakers: Exact number of speakers (if known)
min_speakers: Minimum number of speakers to detect
max_speakers: Maximum number of speakers to detect
return_embeddings: Whether to return speaker embeddings
Returns:
If return_embeddings is True:
Tuple of (diarization dataframe, speaker embeddings dictionary)
Otherwise:
Just the diarization dataframe
"""
if isinstance(audio, str):
audio = load_audio(audio)
audio_data = {
'waveform': torch.from_numpy(audio[None, :]),
'sample_rate': SAMPLE_RATE
}
segments = self.model(audio_data, num_speakers = num_speakers, min_speakers=min_speakers, max_speakers=max_speakers)
diarize_df = pd.DataFrame(segments.itertracks(yield_label=True), columns=['segment', 'label', 'speaker'])
if return_embeddings:
diarization, embeddings = self.model(
audio_data,
num_speakers=num_speakers,
min_speakers=min_speakers,
max_speakers=max_speakers,
return_embeddings=True,
)
else:
diarization = self.model(
audio_data,
num_speakers=num_speakers,
min_speakers=min_speakers,
max_speakers=max_speakers,
)
embeddings = None
diarize_df = pd.DataFrame(diarization.itertracks(yield_label=True), columns=['segment', 'label', 'speaker'])
diarize_df['start'] = diarize_df['segment'].apply(lambda x: x.start)
diarize_df['end'] = diarize_df['segment'].apply(lambda x: x.end)
return diarize_df
if return_embeddings and embeddings is not None:
speaker_embeddings = {speaker: embeddings[s].tolist() for s, speaker in enumerate(diarization.labels())}
return diarize_df, speaker_embeddings
# For backwards compatibility
if return_embeddings:
return diarize_df, None
else:
return diarize_df
def assign_word_speakers(
diarize_df: pd.DataFrame,
transcript_result: Union[AlignedTranscriptionResult, TranscriptionResult],
fill_nearest=False,
) -> dict:
speaker_embeddings: Optional[dict[str, list[float]]] = None,
fill_nearest: bool = False,
) -> Union[AlignedTranscriptionResult, TranscriptionResult]:
"""
Assign speakers to words and segments in the transcript.
Args:
diarize_df: Diarization dataframe from DiarizationPipeline
transcript_result: Transcription result to augment with speaker labels
speaker_embeddings: Optional dictionary mapping speaker IDs to embedding vectors
fill_nearest: If True, assign speakers even when there's no direct time overlap
Returns:
Updated transcript_result with speaker assignments and optionally embeddings
"""
transcript_segments = transcript_result["segments"]
for seg in transcript_segments:
# assign speaker to segment (if any)
@ -76,6 +132,10 @@ def assign_word_speakers(
speaker = dia_tmp.groupby("speaker")["intersection"].sum().sort_values(ascending=False).index[0]
word["speaker"] = speaker
# Add speaker embeddings to the result if provided
if speaker_embeddings is not None:
transcript_result["speaker_embeddings"] = speaker_embeddings
return transcript_result

View File

@ -59,6 +59,10 @@ def transcribe_task(args: dict, parser: argparse.ArgumentParser):
max_speakers: int = args.pop("max_speakers")
diarize_model_name: str = args.pop("diarize_model")
print_progress: bool = args.pop("print_progress")
return_speaker_embeddings: bool = args.pop("speaker_embeddings")
if return_speaker_embeddings and not diarize:
warnings.warn("--speaker_embeddings has no effect without --diarize")
if args["language"] is not None:
args["language"] = args["language"].lower()
@ -209,10 +213,20 @@ def transcribe_task(args: dict, parser: argparse.ArgumentParser):
results = []
diarize_model = DiarizationPipeline(model_name=diarize_model_name, use_auth_token=hf_token, device=device)
for result, input_audio_path in tmp_results:
diarize_segments = diarize_model(
input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers
diarize_result = diarize_model(
input_audio_path,
min_speakers=min_speakers,
max_speakers=max_speakers,
return_embeddings=return_speaker_embeddings
)
result = assign_word_speakers(diarize_segments, result)
if return_speaker_embeddings:
diarize_segments, speaker_embeddings = diarize_result
else:
diarize_segments = diarize_result
speaker_embeddings = None
result = assign_word_speakers(diarize_segments, result, speaker_embeddings)
results.append((result, input_audio_path))
# >> Write
for result, audio_path in results: