diff --git a/whisperx/alignment.py b/whisperx/alignment.py index 132e359..7f3d586 100644 --- a/whisperx/alignment.py +++ b/whisperx/alignment.py @@ -210,7 +210,15 @@ def align( with torch.inference_mode(): if model_type == "torchaudio": - emissions, _ = model(waveform_segment.to(device)) + # Handle the minimum input length for torchaudio wav2vec2 models + if waveform_segment.shape[-1] < 400: + lengths = torch.as_tensor([waveform_segment.shape[-1]]).to(device) + waveform_segment = torch.nn.functional.pad( + waveform_segment, (0, 400 - waveform_segment.shape[-1]) + ) + else: + lengths = None + emissions, _ = model(waveform_segment.to(device), lengths=lengths) elif model_type == "huggingface": emissions = model(waveform_segment.to(device)).logits else: