feat: enhance diarization with optional output of speaker embeddings

- Updated DiarizationPipeline to include a return_embeddings parameter for optional speaker embeddings.
- Modified assign_word_speakers to accept and process speaker embeddings.
- Updated CLI to support --speaker_embeddings flag for JSON output.
- Ensured backward compatibility for existing functionality.
This commit is contained in:
Radu-Sebastian Amarie
2025-03-21 13:57:47 +00:00
committed by Barabazs
parent d700b56c9c
commit 1631c3040f
3 changed files with 79 additions and 11 deletions

View File

@ -44,6 +44,7 @@ def cli():
parser.add_argument("--min_speakers", default=None, type=int, help="Minimum number of speakers to in audio file")
parser.add_argument("--max_speakers", default=None, type=int, help="Maximum number of speakers to in audio file")
parser.add_argument("--diarize_model", default="pyannote/speaker-diarization-3.1", type=str, help="Name of the speaker diarization model to use")
parser.add_argument("--speaker_embeddings", action="store_true", help="Include speaker embeddings in JSON output (only works with --diarize)")
parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")

View File

@ -1,7 +1,7 @@
import numpy as np
import pandas as pd
from pyannote.audio import Pipeline
from typing import Optional, Union
from typing import Optional, Union, Tuple, Dict, List, Any
import torch
from whisperx.audio import load_audio, SAMPLE_RATE
@ -26,25 +26,81 @@ class DiarizationPipeline:
num_speakers: Optional[int] = None,
min_speakers: Optional[int] = None,
max_speakers: Optional[int] = None,
):
return_embeddings: bool = False,
) -> Union[Tuple[pd.DataFrame, Optional[Dict[str, List[float]]]], pd.DataFrame]:
"""
Perform speaker diarization on audio.
Args:
audio: Path to audio file or audio array
num_speakers: Exact number of speakers (if known)
min_speakers: Minimum number of speakers to detect
max_speakers: Maximum number of speakers to detect
return_embeddings: Whether to return speaker embeddings
Returns:
If return_embeddings is True:
Tuple of (diarization dataframe, speaker embeddings dictionary)
Otherwise:
Just the diarization dataframe
"""
if isinstance(audio, str):
audio = load_audio(audio)
audio_data = {
'waveform': torch.from_numpy(audio[None, :]),
'sample_rate': SAMPLE_RATE
}
segments = self.model(audio_data, num_speakers = num_speakers, min_speakers=min_speakers, max_speakers=max_speakers)
diarize_df = pd.DataFrame(segments.itertracks(yield_label=True), columns=['segment', 'label', 'speaker'])
if return_embeddings:
diarization, embeddings = self.model(
audio_data,
num_speakers=num_speakers,
min_speakers=min_speakers,
max_speakers=max_speakers,
return_embeddings=True
)
else:
diarization = self.model(
audio_data,
num_speakers=num_speakers,
min_speakers=min_speakers,
max_speakers=max_speakers
)
embeddings = None
diarize_df = pd.DataFrame(diarization.itertracks(yield_label=True), columns=['segment', 'label', 'speaker'])
diarize_df['start'] = diarize_df['segment'].apply(lambda x: x.start)
diarize_df['end'] = diarize_df['segment'].apply(lambda x: x.end)
return diarize_df
if return_embeddings and embeddings is not None:
speaker_embeddings = {speaker: embeddings[s].tolist() for s, speaker in enumerate(diarization.labels())}
return diarize_df, speaker_embeddings
# For backwards compatibility
if return_embeddings:
return diarize_df, None
else:
return diarize_df
def assign_word_speakers(
diarize_df: pd.DataFrame,
transcript_result: Union[AlignedTranscriptionResult, TranscriptionResult],
fill_nearest=False,
) -> dict:
speaker_embeddings: Optional[Dict[str, List[float]]] = None,
fill_nearest: bool = False,
) -> Union[AlignedTranscriptionResult, TranscriptionResult]:
"""
Assign speakers to words and segments in the transcript.
Args:
diarize_df: Diarization dataframe from DiarizationPipeline
transcript_result: Transcription result to augment with speaker labels
speaker_embeddings: Optional dictionary mapping speaker IDs to embedding vectors
fill_nearest: If True, assign speakers even when there's no direct time overlap
Returns:
Updated transcript_result with speaker assignments and optionally embeddings
"""
transcript_segments = transcript_result["segments"]
for seg in transcript_segments:
# assign speaker to segment (if any)
@ -75,7 +131,11 @@ def assign_word_speakers(
# sum over speakers
speaker = dia_tmp.groupby("speaker")["intersection"].sum().sort_values(ascending=False).index[0]
word["speaker"] = speaker
# Add speaker embeddings to the result if provided
if speaker_embeddings is not None:
transcript_result["speaker_embeddings"] = speaker_embeddings
return transcript_result

View File

@ -59,6 +59,10 @@ def transcribe_task(args: dict, parser: argparse.ArgumentParser):
max_speakers: int = args.pop("max_speakers")
diarize_model_name: str = args.pop("diarize_model")
print_progress: bool = args.pop("print_progress")
return_speaker_embeddings: bool = args.pop("speaker_embeddings")
if return_speaker_embeddings and not diarize:
warnings.warn("--speaker_embeddings has no effect without --diarize")
if args["language"] is not None:
args["language"] = args["language"].lower()
@ -209,10 +213,13 @@ def transcribe_task(args: dict, parser: argparse.ArgumentParser):
results = []
diarize_model = DiarizationPipeline(model_name=diarize_model_name, use_auth_token=hf_token, device=device)
for result, input_audio_path in tmp_results:
diarize_segments = diarize_model(
input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers
diarize_segments, speaker_embeddings = diarize_model(
input_audio_path,
min_speakers=min_speakers,
max_speakers=max_speakers,
return_embeddings=return_speaker_embeddings
)
result = assign_word_speakers(diarize_segments, result)
result = assign_word_speakers(diarize_segments, result, speaker_embeddings)
results.append((result, input_audio_path))
# >> Write
for result, audio_path in results: