From cb3ed4ab9d92937703993e2a653d70dfa420c73a Mon Sep 17 00:00:00 2001 From: awerks Date: Wed, 16 Aug 2023 16:22:29 +0200 Subject: [PATCH] Update transcribe.py --- whisperx/transcribe.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/whisperx/transcribe.py b/whisperx/transcribe.py index 1cc144e..49788bd 100644 --- a/whisperx/transcribe.py +++ b/whisperx/transcribe.py @@ -73,6 +73,8 @@ def cli(): parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS") parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face Access Token to access PyAnnote gated models") + + parser.add_argument("--print_progress", type=str2bool, default = False, help = "if True, progress will be printed in transcribe() and align() methods.") # fmt: on args = parser.parse_args().__dict__ @@ -104,6 +106,7 @@ def cli(): diarize: bool = args.pop("diarize") min_speakers: int = args.pop("min_speakers") max_speakers: int = args.pop("max_speakers") + print_progress: bool = args.pop("print_progress") if model_name.endswith(".en") and args["language"] not in {"en", "English"}: @@ -156,7 +159,7 @@ def cli(): audio = load_audio(audio_path) # >> VAD & ASR print(">>Performing transcription...") - result = model.transcribe(audio, batch_size=batch_size) + result = model.transcribe(audio, batch_size=batch_size, print_progress=print_progress) results.append((result, audio_path)) # Unload Whisper and VAD @@ -184,7 +187,7 @@ def cli(): print(f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language...") align_model, align_metadata = load_align_model(result["language"], device) print(">>Performing alignment...") - result = align(result["segments"], align_model, align_metadata, input_audio, device, interpolate_method=interpolate_method, return_char_alignments=return_char_alignments) + result = align(result["segments"], align_model, align_metadata, input_audio, device, interpolate_method=interpolate_method, return_char_alignments=return_char_alignments, print_progress=print_progress) results.append((result, audio_path))