diff --git a/whisperx/transcribe.py b/whisperx/transcribe.py index f18f2bc..3c5b0a7 100644 --- a/whisperx/transcribe.py +++ b/whisperx/transcribe.py @@ -282,8 +282,12 @@ def align( f2 = int(t2 * SAMPLE_RATE) waveform_segment = audio[:, f1:f2] + print(language) with torch.inference_mode(): - emissions = model(waveform_segment.to(device)).logits + if language != 'ja': + emissions, _ = model(waveform_segment.to(device)) + else: + emissions = model(waveform_segment.to(device)).logits emissions = torch.log_softmax(emissions, dim=-1) emission = emissions[0].cpu().detach() transcription = segment['text'].strip()