torch2.0, remove compile for now, round to times to 3 decimal

This commit is contained in:
Max Bain
2023-05-04 20:38:13 +01:00
parent d2116b98ca
commit 4e2ac4e4e9
6 changed files with 40 additions and 34 deletions

View File

@ -181,9 +181,6 @@ class FasterWhisperPipeline(Pipeline):
def preprocess(self, audio):
audio = audio['inputs']
if isinstance(audio, np.ndarray):
audio = torch.from_numpy(audio)
features = log_mel_spectrogram(audio, padding=N_SAMPLES - audio.shape[0])
return {'inputs': features}
@ -256,7 +253,7 @@ class FasterWhisperPipeline(Pipeline):
def detect_language(self, audio: np.ndarray):
if audio.shape[0] < N_SAMPLES:
print("Warning: audio is shorter than 30s, language detection may be inaccurate.")
segment = log_mel_spectrogram(torch.from_numpy(audio[:N_SAMPLES]),
segment = log_mel_spectrogram(audio[: N_SAMPLES],
padding=0 if audio.shape[0] >= N_SAMPLES else N_SAMPLES - audio.shape[0])
encoder_output = self.model.encode(segment)
results = self.model.model.detect_language(encoder_output)