support for large-v3

This commit is contained in:
MahmoudAshraf97
2023-11-25 12:09:00 +00:00
parent d97cdb7bcf
commit 71a5281bde
3 changed files with 11 additions and 5 deletions

View File

@ -140,7 +140,12 @@ class FasterWhisperPipeline(Pipeline):
def preprocess(self, audio):
audio = audio['inputs']
features = log_mel_spectrogram(audio, padding=N_SAMPLES - audio.shape[0])
model_n_mels = self.model.feat_kwargs.get("feature_size")
features = log_mel_spectrogram(
audio,
n_mels=model_n_mels if model_n_mels is not None else 80,
padding=N_SAMPLES - audio.shape[0],
)
return {'inputs': features}
def _forward(self, model_inputs):
@ -240,7 +245,9 @@ class FasterWhisperPipeline(Pipeline):
def detect_language(self, audio: np.ndarray):
if audio.shape[0] < N_SAMPLES:
print("Warning: audio is shorter than 30s, language detection may be inaccurate.")
model_n_mels = self.model.feat_kwargs.get("feature_size")
segment = log_mel_spectrogram(audio[: N_SAMPLES],
n_mels=model_n_mels if model_n_mels is not None else 80,
padding=0 if audio.shape[0] >= N_SAMPLES else N_SAMPLES - audio.shape[0])
encoder_output = self.model.encode(segment)
results = self.model.model.detect_language(encoder_output)