handle negative / tiny duration segments, final

This commit is contained in:
Max Bain
2023-01-08 14:01:10 +00:00
parent a6eb33778b
commit 78c87d3bfd
2 changed files with 30 additions and 14 deletions

View File

@ -223,6 +223,10 @@ def transcribe(
end_timestamp_position = (
sliced_tokens[-1].item() - tokenizer.timestamp_begin
)
# clamp end-time to at least be 1 frame after start-time
end_timestamp_position = max(end_timestamp_position, start_timestamp_position + time_precision)
add_segment(
start=timestamp_offset + start_timestamp_position * time_precision,
end=timestamp_offset + end_timestamp_position * time_precision,
@ -291,28 +295,27 @@ def align(
prev_t2 = 0
word_segments_list = []
for idx, segment in enumerate(transcript):
if int(segment['start'] * SAMPLE_RATE) >= audio.shape[1]:
print("Failed to align segment: original start time longer than audio duration, skipping...")
continue
if int(segment['start']) >= int(segment['end']):
print("Failed to align segment: original end time is not after start time, skipping...")
continue
# first we pad
t1 = max(segment['start'] - extend_duration, 0)
t2 = min(segment['end'] + extend_duration, MAX_DURATION)
# use prev_t2 as current t1 if it's later
if start_from_previous and t1 < prev_t2:
t1 = prev_t2
# check if timestamp range is still valid
if t1 >= MAX_DURATION:
print("Failed to align segment: original start time longer than audio duration, skipping...")
continue
if t2 - t1 < 0.02:
print("Failed to align segment: duration smaller than 0.02s time precision")
continue
f1 = int(t1 * SAMPLE_RATE)
f2 = int(t2 * SAMPLE_RATE)
waveform_segment = audio[:, f1:f2]
if waveform_segment.shape[1] < 10:
print("Failed to align segment: too short in duration, %.3f" % waveform_segment.shape[1]/SAMPLE_RATE)
continue
with torch.inference_mode():
if model_type == "torchaudio":
emissions, _ = model(waveform_segment.to(device))
@ -321,6 +324,7 @@ def align(
else:
raise NotImplementedError(f"Align model of type {model_type} not supported.")
emissions = torch.log_softmax(emissions, dim=-1)
emission = emissions[0].cpu().detach()
transcription = segment['text'].strip()
if model_lang not in LANGUAGES_WITHOUT_SPACES:
@ -519,6 +523,7 @@ def cli():
print(f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language...")
align_model, align_metadata = load_align_model(result["language"], device)
print("Performing alignment...")
result_aligned = align(result["segments"], align_model, align_metadata, audio_path, device,
extend_duration=align_extend, start_from_previous=align_from_prev, drop_non_aligned_words=drop_non_aligned)
audio_basename = os.path.basename(audio_path)