Merge branch 'main' into add-merge-chunk-size-as-argument

This commit is contained in:
Max Bain
2023-08-29 10:05:05 -06:00
committed by GitHub
5 changed files with 30 additions and 5 deletions

View File

@ -74,6 +74,8 @@ def cli():
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face Access Token to access PyAnnote gated models")
parser.add_argument("--print_progress", type=str2bool, default = False, help = "if True, progress will be printed in transcribe() and align() methods.")
# fmt: on
args = parser.parse_args().__dict__
@ -107,6 +109,7 @@ def cli():
diarize: bool = args.pop("diarize")
min_speakers: int = args.pop("min_speakers")
max_speakers: int = args.pop("max_speakers")
print_progress: bool = args.pop("print_progress")
if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
@ -159,7 +162,7 @@ def cli():
audio = load_audio(audio_path)
# >> VAD & ASR
print(">>Performing transcription...")
result = model.transcribe(audio, batch_size=batch_size, chunk_size=chunk_size)
result = model.transcribe(audio, batch_size=batch_size, chunk_size=chunk_size, print_progress=print_progress)
results.append((result, audio_path))
# Unload Whisper and VAD
@ -187,7 +190,7 @@ def cli():
print(f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language...")
align_model, align_metadata = load_align_model(result["language"], device)
print(">>Performing alignment...")
result = align(result["segments"], align_model, align_metadata, input_audio, device, interpolate_method=interpolate_method, return_char_alignments=return_char_alignments)
result = align(result["segments"], align_model, align_metadata, input_audio, device, interpolate_method=interpolate_method, return_char_alignments=return_char_alignments, print_progress=print_progress)
results.append((result, audio_path))