mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
Add torch compile to log mel spectrogram
This commit is contained in:
@ -181,6 +181,9 @@ class FasterWhisperPipeline(Pipeline):
|
||||
|
||||
def preprocess(self, audio):
|
||||
audio = audio['inputs']
|
||||
if isinstance(audio, np.ndarray):
|
||||
audio = torch.from_numpy(audio)
|
||||
|
||||
features = log_mel_spectrogram(audio, padding=N_SAMPLES - audio.shape[0])
|
||||
return {'inputs': features}
|
||||
|
||||
@ -253,7 +256,7 @@ class FasterWhisperPipeline(Pipeline):
|
||||
def detect_language(self, audio: np.ndarray):
|
||||
if audio.shape[0] < N_SAMPLES:
|
||||
print("Warning: audio is shorter than 30s, language detection may be inaccurate.")
|
||||
segment = log_mel_spectrogram(audio[: N_SAMPLES],
|
||||
segment = log_mel_spectrogram(torch.from_numpy(audio[:N_SAMPLES]),
|
||||
padding=0 if audio.shape[0] >= N_SAMPLES else N_SAMPLES - audio.shape[0])
|
||||
encoder_output = self.model.encode(segment)
|
||||
results = self.model.model.detect_language(encoder_output)
|
||||
|
Reference in New Issue
Block a user