mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
Merge pull request #529 from MahmoudAshraf97/main
This commit is contained in:
@ -194,8 +194,8 @@ def align(
|
||||
aligned_segments.append(aligned_seg)
|
||||
continue
|
||||
|
||||
if t1 >= MAX_DURATION or t2 - t1 < 0.02:
|
||||
print("Failed to align segment: original start time longer than audio duration, skipping...")
|
||||
if t1 >= MAX_DURATION:
|
||||
print(f'Failed to align segment ("{segment["text"]}"): original start time longer than audio duration, skipping...')
|
||||
aligned_segments.append(aligned_seg)
|
||||
continue
|
||||
|
||||
@ -207,17 +207,17 @@ def align(
|
||||
|
||||
# TODO: Probably can get some speedup gain with batched inference here
|
||||
waveform_segment = audio[:, f1:f2]
|
||||
|
||||
# Handle the minimum input length for wav2vec2 models
|
||||
if waveform_segment.shape[-1] < 400:
|
||||
lengths = torch.as_tensor([waveform_segment.shape[-1]]).to(device)
|
||||
waveform_segment = torch.nn.functional.pad(
|
||||
waveform_segment, (0, 400 - waveform_segment.shape[-1])
|
||||
)
|
||||
else:
|
||||
lengths = None
|
||||
|
||||
with torch.inference_mode():
|
||||
if model_type == "torchaudio":
|
||||
# Handle the minimum input length for torchaudio wav2vec2 models
|
||||
if waveform_segment.shape[-1] < 400:
|
||||
lengths = torch.as_tensor([waveform_segment.shape[-1]]).to(device)
|
||||
waveform_segment = torch.nn.functional.pad(
|
||||
waveform_segment, (0, 400 - waveform_segment.shape[-1])
|
||||
)
|
||||
else:
|
||||
lengths = None
|
||||
emissions, _ = model(waveform_segment.to(device), lengths=lengths)
|
||||
elif model_type == "huggingface":
|
||||
emissions = model(waveform_segment.to(device)).logits
|
||||
|
Reference in New Issue
Block a user