fix minimum input length for torch wav2vec2 models

This commit is contained in:
Mahmoud Ashraf
2023-10-06 00:41:23 +03:00
committed by GitHub
parent d077abdbdf
commit 8049dba2f7

View File

@ -210,7 +210,15 @@ def align(
with torch.inference_mode():
if model_type == "torchaudio":
emissions, _ = model(waveform_segment.to(device))
# Handle the minimum input length for torchaudio wav2vec2 models
if waveform_segment.shape[-1] < 400:
lengths = torch.as_tensor([waveform_segment.shape[-1]]).to(device)
waveform_segment = torch.nn.functional.pad(
waveform_segment, (0, 400 - waveform_segment.shape[-1])
)
else:
lengths = None
emissions, _ = model(waveform_segment.to(device), lengths=lengths)
elif model_type == "huggingface":
emissions = model(waveform_segment.to(device)).logits
else: