This commit is contained in:
Mahmoud Ashraf
2023-10-16 20:43:37 +03:00
committed by GitHub
parent 02c0323777
commit b69956d725

View File

@ -207,10 +207,7 @@ def align(
# TODO: Probably can get some speedup gain with batched inference here # TODO: Probably can get some speedup gain with batched inference here
waveform_segment = audio[:, f1:f2] waveform_segment = audio[:, f1:f2]
# Handle the minimum input length for wav2vec2 models
with torch.inference_mode():
if model_type == "torchaudio":
# Handle the minimum input length for torchaudio wav2vec2 models
if waveform_segment.shape[-1] < 400: if waveform_segment.shape[-1] < 400:
lengths = torch.as_tensor([waveform_segment.shape[-1]]).to(device) lengths = torch.as_tensor([waveform_segment.shape[-1]]).to(device)
waveform_segment = torch.nn.functional.pad( waveform_segment = torch.nn.functional.pad(
@ -218,6 +215,9 @@ def align(
) )
else: else:
lengths = None lengths = None
with torch.inference_mode():
if model_type == "torchaudio":
emissions, _ = model(waveform_segment.to(device), lengths=lengths) emissions, _ = model(waveform_segment.to(device), lengths=lengths)
elif model_type == "huggingface": elif model_type == "huggingface":
emissions = model(waveform_segment.to(device)).logits emissions = model(waveform_segment.to(device)).logits