mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
feat: add diarize_model arg to CLI (#1101)
This commit is contained in:
@ -57,6 +57,7 @@ def transcribe_task(args: dict, parser: argparse.ArgumentParser):
|
||||
diarize: bool = args.pop("diarize")
|
||||
min_speakers: int = args.pop("min_speakers")
|
||||
max_speakers: int = args.pop("max_speakers")
|
||||
diarize_model_name: str = args.pop("diarize_model")
|
||||
print_progress: bool = args.pop("print_progress")
|
||||
|
||||
if args["language"] is not None:
|
||||
@ -204,8 +205,9 @@ def transcribe_task(args: dict, parser: argparse.ArgumentParser):
|
||||
)
|
||||
tmp_results = results
|
||||
print(">>Performing diarization...")
|
||||
print(">>Using model:", diarize_model_name)
|
||||
results = []
|
||||
diarize_model = DiarizationPipeline(use_auth_token=hf_token, device=device)
|
||||
diarize_model = DiarizationPipeline(model_name=diarize_model_name, use_auth_token=hf_token, device=device)
|
||||
for result, input_audio_path in tmp_results:
|
||||
diarize_segments = diarize_model(
|
||||
input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers
|
||||
|
Reference in New Issue
Block a user