feat: enhance diarization with optional output of speaker embeddings

- Updated DiarizationPipeline to include a return_embeddings parameter for optional speaker embeddings.
- Modified assign_word_speakers to accept and process speaker embeddings.
- Updated CLI to support --speaker_embeddings flag for JSON output.
- Ensured backward compatibility for existing functionality.
This commit is contained in:
Radu-Sebastian Amarie
2025-03-21 13:57:47 +00:00
committed by Barabazs
parent d700b56c9c
commit 1631c3040f
3 changed files with 79 additions and 11 deletions

View File

@ -59,6 +59,10 @@ def transcribe_task(args: dict, parser: argparse.ArgumentParser):
max_speakers: int = args.pop("max_speakers")
diarize_model_name: str = args.pop("diarize_model")
print_progress: bool = args.pop("print_progress")
return_speaker_embeddings: bool = args.pop("speaker_embeddings")
if return_speaker_embeddings and not diarize:
warnings.warn("--speaker_embeddings has no effect without --diarize")
if args["language"] is not None:
args["language"] = args["language"].lower()
@ -209,10 +213,13 @@ def transcribe_task(args: dict, parser: argparse.ArgumentParser):
results = []
diarize_model = DiarizationPipeline(model_name=diarize_model_name, use_auth_token=hf_token, device=device)
for result, input_audio_path in tmp_results:
diarize_segments = diarize_model(
input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers
diarize_segments, speaker_embeddings = diarize_model(
input_audio_path,
min_speakers=min_speakers,
max_speakers=max_speakers,
return_embeddings=return_speaker_embeddings
)
result = assign_word_speakers(diarize_segments, result)
result = assign_word_speakers(diarize_segments, result, speaker_embeddings)
results.append((result, input_audio_path))
# >> Write
for result, audio_path in results: