diff --git a/README.md b/README.md index c9951ce..1f41bb9 100644 --- a/README.md +++ b/README.md @@ -61,23 +61,23 @@ This repository refines the timestamps of openAI's Whisper model via forced alig

Setup ⚙️

-Tested for PyTorch 0.11, Python 3.8 (use other versions at your own risk!) +Tested for PyTorch 2.0, Python 3.10 (use other versions at your own risk!) GPU execution requires the NVIDIA libraries cuBLAS 11.x and cuDNN 8.x to be installed on the system. Please refer to the [CTranslate2 documentation](https://opennmt.net/CTranslate2/installation.html). -### 1. Create Python3.8 environment +### 1. Create Python3.10 environment -`conda create --name whisperx python=3.8` +`conda create --name whisperx python=3.10` `conda activate whisperx` -### 2. Install PyTorch 0.11.0, e.g. for Linux and Windows: +### 2. Install PyTorch2.0, e.g. for Linux and Windows CUDA11.7: -`pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113` +`pip3 install torch torchvision torchaudio` -See other methods [here.](https://pytorch.org/get-started/previous-versions/#wheel-4) +See other methods [here.](https://pytorch.org/get-started/locally/) ### 3. Install this repo diff --git a/setup.py b/setup.py index 859d171..66f22cd 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ from setuptools import setup, find_packages setup( name="whisperx", py_modules=["whisperx"], - version="3.0.0", + version="3.0.2", description="Time-Accurate Automatic Speech Recognition using Whisper.", readme="README.md", python_requires=">=3.8", diff --git a/whisperx/alignment.py b/whisperx/alignment.py index 2ae77f3..e63e6e5 100644 --- a/whisperx/alignment.py +++ b/whisperx/alignment.py @@ -268,8 +268,8 @@ def align( start, end, score = None, None, None if cdx in clean_cdx: char_seg = char_segments[clean_cdx.index(cdx)] - start = char_seg.start * ratio + t1 - end = char_seg.end * ratio + t1 + start = round(char_seg.start * ratio + t1, 3) + end = round(char_seg.end * ratio + t1, 3) score = char_seg.score char_segments_arr["char"].append(char) diff --git a/whisperx/asr.py b/whisperx/asr.py index 1ca12ce..ba6220b 100644 --- a/whisperx/asr.py +++ b/whisperx/asr.py @@ -181,9 +181,6 @@ class FasterWhisperPipeline(Pipeline): def preprocess(self, audio): audio = audio['inputs'] - if isinstance(audio, np.ndarray): - audio = torch.from_numpy(audio) - features = log_mel_spectrogram(audio, padding=N_SAMPLES - audio.shape[0]) return {'inputs': features} @@ -256,7 +253,7 @@ class FasterWhisperPipeline(Pipeline): def detect_language(self, audio: np.ndarray): if audio.shape[0] < N_SAMPLES: print("Warning: audio is shorter than 30s, language detection may be inaccurate.") - segment = log_mel_spectrogram(torch.from_numpy(audio[:N_SAMPLES]), + segment = log_mel_spectrogram(audio[: N_SAMPLES], padding=0 if audio.shape[0] >= N_SAMPLES else N_SAMPLES - audio.shape[0]) encoder_output = self.model.encode(segment) results = self.model.model.detect_language(encoder_output) diff --git a/whisperx/audio.py b/whisperx/audio.py index 8ac0674..513ab7c 100644 --- a/whisperx/audio.py +++ b/whisperx/audio.py @@ -22,12 +22,6 @@ N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2 FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token -with np.load( - os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz") -) as f: - MEL_FILTERS = torch.from_numpy(f[f"mel_{80}"]) - - def load_audio(file: str, sr: int = SAMPLE_RATE): """ @@ -85,9 +79,27 @@ def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1): return array -@torch.compile(fullgraph=True) +@lru_cache(maxsize=None) +def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor: + """ + load the mel filterbank matrix for projecting STFT into a Mel spectrogram. + Allows decoupling librosa dependency; saved using: + + np.savez_compressed( + "mel_filters.npz", + mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80), + ) + """ + assert n_mels == 80, f"Unsupported n_mels: {n_mels}" + with np.load( + os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz") + ) as f: + return torch.from_numpy(f[f"mel_{n_mels}"]).to(device) + + def log_mel_spectrogram( - audio: torch.Tensor, + audio: Union[str, np.ndarray, torch.Tensor], + n_mels: int = N_MELS, padding: int = 0, device: Optional[Union[str, torch.device]] = None, ): @@ -96,7 +108,7 @@ def log_mel_spectrogram( Parameters ---------- - audio: torch.Tensor, shape = (*) + audio: Union[str, np.ndarray, torch.Tensor], shape = (*) The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz n_mels: int @@ -113,19 +125,21 @@ def log_mel_spectrogram( torch.Tensor, shape = (80, n_frames) A Tensor that contains the Mel spectrogram """ - global MEL_FILTERS + if not torch.is_tensor(audio): + if isinstance(audio, str): + audio = load_audio(audio) + audio = torch.from_numpy(audio) if device is not None: audio = audio.to(device) if padding > 0: audio = F.pad(audio, (0, padding)) window = torch.hann_window(N_FFT).to(audio.device) - stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=False) - # Square the real and imaginary components and sum them together, similar to torch.abs() on complex tensors - magnitudes = (stft[:, :-1, :] ** 2).sum(dim=-1) + stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True) + magnitudes = stft[..., :-1].abs() ** 2 - MEL_FILTERS = MEL_FILTERS.to(audio.device) - mel_spec = MEL_FILTERS @ magnitudes + filters = mel_filters(audio.device, n_mels) + mel_spec = filters @ magnitudes log_spec = torch.clamp(mel_spec, min=1e-10).log10() log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) diff --git a/whisperx/transcribe.py b/whisperx/transcribe.py index e284e83..4b5a664 100644 --- a/whisperx/transcribe.py +++ b/whisperx/transcribe.py @@ -72,7 +72,6 @@ def cli(): parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face Access Token to access PyAnnote gated models") # parser.add_argument("--model_flush", action="store_true", help="Flush memory from each model after use, reduces GPU requirement but slower processing >1 audio file.") - parser.add_argument("--tmp_dir", default=None, help="Temporary directory to write audio file if input if not .wav format (only for VAD).") # fmt: on args = parser.parse_args().__dict__ @@ -86,10 +85,6 @@ def cli(): # model_flush: bool = args.pop("model_flush") os.makedirs(output_dir, exist_ok=True) - tmp_dir: str = args.pop("tmp_dir") - if tmp_dir is not None: - os.makedirs(tmp_dir, exist_ok=True) - align_model: str = args.pop("align_model") interpolate_method: str = args.pop("interpolate_method") no_align: bool = args.pop("no_align") @@ -195,7 +190,7 @@ def cli(): tmp_results = results print(">>Performing diarization...") results = [] - diarize_model = DiarizationPipeline(use_auth_token=hf_token, device=device) + diarize_model = DiarizationPipeline(use_auth_token=hf_token) for result, input_audio_path in tmp_results: diarize_segments = diarize_model(input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers) results_segments, word_segments = assign_word_speakers(diarize_segments, result["segments"])