mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
.
This commit is contained in:
@ -207,10 +207,7 @@ def align(
|
|||||||
|
|
||||||
# TODO: Probably can get some speedup gain with batched inference here
|
# TODO: Probably can get some speedup gain with batched inference here
|
||||||
waveform_segment = audio[:, f1:f2]
|
waveform_segment = audio[:, f1:f2]
|
||||||
|
# Handle the minimum input length for wav2vec2 models
|
||||||
with torch.inference_mode():
|
|
||||||
if model_type == "torchaudio":
|
|
||||||
# Handle the minimum input length for torchaudio wav2vec2 models
|
|
||||||
if waveform_segment.shape[-1] < 400:
|
if waveform_segment.shape[-1] < 400:
|
||||||
lengths = torch.as_tensor([waveform_segment.shape[-1]]).to(device)
|
lengths = torch.as_tensor([waveform_segment.shape[-1]]).to(device)
|
||||||
waveform_segment = torch.nn.functional.pad(
|
waveform_segment = torch.nn.functional.pad(
|
||||||
@ -218,6 +215,9 @@ def align(
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
lengths = None
|
lengths = None
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
if model_type == "torchaudio":
|
||||||
emissions, _ = model(waveform_segment.to(device), lengths=lengths)
|
emissions, _ = model(waveform_segment.to(device), lengths=lengths)
|
||||||
elif model_type == "huggingface":
|
elif model_type == "huggingface":
|
||||||
emissions = model(waveform_segment.to(device)).logits
|
emissions = model(waveform_segment.to(device)).logits
|
||||||
|
Reference in New Issue
Block a user