diff --git a/requirements.txt b/requirements.txt index bc0455a..a5c1d73 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ torch>=2 torchaudio>=2 -faster-whisper>=0.8 +git+https://github.com/SYSTRAN/faster-whisper.git@0.10.0 transformers pandas setuptools>=65 diff --git a/whisperx/asr.py b/whisperx/asr.py index 94e0311..dba8271 100644 --- a/whisperx/asr.py +++ b/whisperx/asr.py @@ -140,7 +140,12 @@ class FasterWhisperPipeline(Pipeline): def preprocess(self, audio): audio = audio['inputs'] - features = log_mel_spectrogram(audio, padding=N_SAMPLES - audio.shape[0]) + model_n_mels = self.model.feat_kwargs.get("feature_size") + features = log_mel_spectrogram( + audio, + n_mels=model_n_mels if model_n_mels is not None else 80, + padding=N_SAMPLES - audio.shape[0], + ) return {'inputs': features} def _forward(self, model_inputs): @@ -240,7 +245,9 @@ class FasterWhisperPipeline(Pipeline): def detect_language(self, audio: np.ndarray): if audio.shape[0] < N_SAMPLES: print("Warning: audio is shorter than 30s, language detection may be inaccurate.") + model_n_mels = self.model.feat_kwargs.get("feature_size") segment = log_mel_spectrogram(audio[: N_SAMPLES], + n_mels=model_n_mels if model_n_mels is not None else 80, padding=0 if audio.shape[0] >= N_SAMPLES else N_SAMPLES - audio.shape[0]) encoder_output = self.model.encode(segment) results = self.model.model.detect_language(encoder_output) diff --git a/whisperx/assets/mel_filters.npz b/whisperx/assets/mel_filters.npz index 1a78392..28ea269 100644 Binary files a/whisperx/assets/mel_filters.npz and b/whisperx/assets/mel_filters.npz differ diff --git a/whisperx/audio.py b/whisperx/audio.py index 48fe3c1..db210fb 100644 --- a/whisperx/audio.py +++ b/whisperx/audio.py @@ -12,7 +12,6 @@ from .utils import exact_div # hard-coded audio hyperparameters SAMPLE_RATE = 16000 N_FFT = 400 -N_MELS = 80 HOP_LENGTH = 160 CHUNK_LENGTH = 30 N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk @@ -93,7 +92,7 @@ def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1): @lru_cache(maxsize=None) -def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor: +def mel_filters(device, n_mels: int) -> torch.Tensor: """ load the mel filterbank matrix for projecting STFT into a Mel spectrogram. Allows decoupling librosa dependency; saved using: @@ -103,7 +102,7 @@ def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor: mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80), ) """ - assert n_mels == 80, f"Unsupported n_mels: {n_mels}" + assert n_mels in [80, 128], f"Unsupported n_mels: {n_mels}" with np.load( os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz") ) as f: @@ -112,7 +111,7 @@ def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor: def log_mel_spectrogram( audio: Union[str, np.ndarray, torch.Tensor], - n_mels: int = N_MELS, + n_mels: int, padding: int = 0, device: Optional[Union[str, torch.device]] = None, ):