From a6eb33778b12273ee11572ca72190dd361ac2b7f Mon Sep 17 00:00:00 2001 From: Max Bain Date: Sun, 8 Jan 2023 12:24:35 +0000 Subject: [PATCH] additional waveform segment check --- whisperx/transcribe.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/whisperx/transcribe.py b/whisperx/transcribe.py index ed0c2e9..d2303aa 100644 --- a/whisperx/transcribe.py +++ b/whisperx/transcribe.py @@ -275,7 +275,6 @@ def align( start_from_previous: bool = True, drop_non_aligned_words: bool = False, ): - print("Performing alignment...") if not torch.is_tensor(audio): if isinstance(audio, str): audio = load_audio(audio) @@ -308,7 +307,12 @@ def align( f1 = int(t1 * SAMPLE_RATE) f2 = int(t2 * SAMPLE_RATE) + waveform_segment = audio[:, f1:f2] + + if waveform_segment.shape[1] < 10: + print("Failed to align segment: too short in duration, %.3f" % waveform_segment.shape[1]/SAMPLE_RATE) + continue with torch.inference_mode(): if model_type == "torchaudio": emissions, _ = model(waveform_segment.to(device)) @@ -507,6 +511,7 @@ def cli(): align_model, align_metadata = load_align_model(align_language, device, model_name=align_model) for audio_path in args.pop("audio"): + print("Performing transcription...") result = transcribe(model, audio_path, temperature=temperature, **args) if result["language"] != align_metadata["language"]: @@ -538,8 +543,8 @@ def cli(): write_srt(result_aligned["word_segments"], file=srt) # save ASS - with open(os.path.join(output_dir, audio_basename + ".ass"), "w", encoding="utf-8") as ass: - write_ass(result_aligned["segments"], file=ass) + # with open(os.path.join(output_dir, audio_basename + ".ass"), "w", encoding="utf-8") as ass: + # write_ass(result_aligned["segments"], file=ass) if __name__ == '__main__':