additional waveform segment check

This commit is contained in:
Max Bain
2023-01-08 12:24:35 +00:00
parent 857bcca238
commit a6eb33778b

View File

@ -275,7 +275,6 @@ def align(
start_from_previous: bool = True,
drop_non_aligned_words: bool = False,
):
print("Performing alignment...")
if not torch.is_tensor(audio):
if isinstance(audio, str):
audio = load_audio(audio)
@ -308,7 +307,12 @@ def align(
f1 = int(t1 * SAMPLE_RATE)
f2 = int(t2 * SAMPLE_RATE)
waveform_segment = audio[:, f1:f2]
if waveform_segment.shape[1] < 10:
print("Failed to align segment: too short in duration, %.3f" % waveform_segment.shape[1]/SAMPLE_RATE)
continue
with torch.inference_mode():
if model_type == "torchaudio":
emissions, _ = model(waveform_segment.to(device))
@ -507,6 +511,7 @@ def cli():
align_model, align_metadata = load_align_model(align_language, device, model_name=align_model)
for audio_path in args.pop("audio"):
print("Performing transcription...")
result = transcribe(model, audio_path, temperature=temperature, **args)
if result["language"] != align_metadata["language"]:
@ -538,8 +543,8 @@ def cli():
write_srt(result_aligned["word_segments"], file=srt)
# save ASS
with open(os.path.join(output_dir, audio_basename + ".ass"), "w", encoding="utf-8") as ass:
write_ass(result_aligned["segments"], file=ass)
# with open(os.path.join(output_dir, audio_basename + ".ass"), "w", encoding="utf-8") as ass:
# write_ass(result_aligned["segments"], file=ass)
if __name__ == '__main__':