diff --git a/whisperx/alignment.py b/whisperx/alignment.py index 874502b..68465f9 100644 --- a/whisperx/alignment.py +++ b/whisperx/alignment.py @@ -207,17 +207,17 @@ def align( # TODO: Probably can get some speedup gain with batched inference here waveform_segment = audio[:, f1:f2] - + # Handle the minimum input length for wav2vec2 models + if waveform_segment.shape[-1] < 400: + lengths = torch.as_tensor([waveform_segment.shape[-1]]).to(device) + waveform_segment = torch.nn.functional.pad( + waveform_segment, (0, 400 - waveform_segment.shape[-1]) + ) + else: + lengths = None + with torch.inference_mode(): if model_type == "torchaudio": - # Handle the minimum input length for torchaudio wav2vec2 models - if waveform_segment.shape[-1] < 400: - lengths = torch.as_tensor([waveform_segment.shape[-1]]).to(device) - waveform_segment = torch.nn.functional.pad( - waveform_segment, (0, 400 - waveform_segment.shape[-1]) - ) - else: - lengths = None emissions, _ = model(waveform_segment.to(device), lengths=lengths) elif model_type == "huggingface": emissions = model(waveform_segment.to(device)).logits