From b3432412530ecb0cc5ac923f161da281e41d23d2 Mon Sep 17 00:00:00 2001 From: bog Date: Sat, 31 May 2025 13:32:31 +0200 Subject: [PATCH] feat: add diarize_model arg to CLI (#1101) --- whisperx/__main__.py | 1 + whisperx/diarize.py | 5 +++-- whisperx/transcribe.py | 4 +++- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/whisperx/__main__.py b/whisperx/__main__.py index 6da7b87..e7f80be 100644 --- a/whisperx/__main__.py +++ b/whisperx/__main__.py @@ -43,6 +43,7 @@ def cli(): parser.add_argument("--diarize", action="store_true", help="Apply diarization to assign speaker labels to each segment/word") parser.add_argument("--min_speakers", default=None, type=int, help="Minimum number of speakers to in audio file") parser.add_argument("--max_speakers", default=None, type=int, help="Maximum number of speakers to in audio file") + parser.add_argument("--diarize_model", default="pyannote/speaker-diarization-3.1", type=str, help="Name of the speaker diarization model to use") parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling") parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature") diff --git a/whisperx/diarize.py b/whisperx/diarize.py index 97a7813..26f33e4 100644 --- a/whisperx/diarize.py +++ b/whisperx/diarize.py @@ -11,13 +11,14 @@ from whisperx.types import TranscriptionResult, AlignedTranscriptionResult class DiarizationPipeline: def __init__( self, - model_name="pyannote/speaker-diarization-3.1", + model_name=None, use_auth_token=None, device: Optional[Union[str, torch.device]] = "cpu", ): if isinstance(device, str): device = torch.device(device) - self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token).to(device) + model_config = model_name or "pyannote/speaker-diarization-3.1" + self.model = Pipeline.from_pretrained(model_config, use_auth_token=use_auth_token).to(device) def __call__( self, diff --git a/whisperx/transcribe.py b/whisperx/transcribe.py index f567824..867b378 100644 --- a/whisperx/transcribe.py +++ b/whisperx/transcribe.py @@ -57,6 +57,7 @@ def transcribe_task(args: dict, parser: argparse.ArgumentParser): diarize: bool = args.pop("diarize") min_speakers: int = args.pop("min_speakers") max_speakers: int = args.pop("max_speakers") + diarize_model_name: str = args.pop("diarize_model") print_progress: bool = args.pop("print_progress") if args["language"] is not None: @@ -204,8 +205,9 @@ def transcribe_task(args: dict, parser: argparse.ArgumentParser): ) tmp_results = results print(">>Performing diarization...") + print(">>Using model:", diarize_model_name) results = [] - diarize_model = DiarizationPipeline(use_auth_token=hf_token, device=device) + diarize_model = DiarizationPipeline(model_name=diarize_model_name, use_auth_token=hf_token, device=device) for result, input_audio_path in tmp_results: diarize_segments = diarize_model( input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers