torch2.0, remove compile for now, round to times to 3 decimal

This commit is contained in:
Max Bain
2023-05-04 20:38:13 +01:00
parent d2116b98ca
commit 4e2ac4e4e9
6 changed files with 40 additions and 34 deletions

View File

@ -61,23 +61,23 @@ This repository refines the timestamps of openAI's Whisper model via forced alig
<h2 align="left" id="setup">Setup ⚙️</h2> <h2 align="left" id="setup">Setup ⚙️</h2>
Tested for PyTorch 0.11, Python 3.8 (use other versions at your own risk!) Tested for PyTorch 2.0, Python 3.10 (use other versions at your own risk!)
GPU execution requires the NVIDIA libraries cuBLAS 11.x and cuDNN 8.x to be installed on the system. Please refer to the [CTranslate2 documentation](https://opennmt.net/CTranslate2/installation.html). GPU execution requires the NVIDIA libraries cuBLAS 11.x and cuDNN 8.x to be installed on the system. Please refer to the [CTranslate2 documentation](https://opennmt.net/CTranslate2/installation.html).
### 1. Create Python3.8 environment ### 1. Create Python3.10 environment
`conda create --name whisperx python=3.8` `conda create --name whisperx python=3.10`
`conda activate whisperx` `conda activate whisperx`
### 2. Install PyTorch 0.11.0, e.g. for Linux and Windows: ### 2. Install PyTorch2.0, e.g. for Linux and Windows CUDA11.7:
`pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113` `pip3 install torch torchvision torchaudio`
See other methods [here.](https://pytorch.org/get-started/previous-versions/#wheel-4) See other methods [here.](https://pytorch.org/get-started/locally/)
### 3. Install this repo ### 3. Install this repo

View File

@ -6,7 +6,7 @@ from setuptools import setup, find_packages
setup( setup(
name="whisperx", name="whisperx",
py_modules=["whisperx"], py_modules=["whisperx"],
version="3.0.0", version="3.0.2",
description="Time-Accurate Automatic Speech Recognition using Whisper.", description="Time-Accurate Automatic Speech Recognition using Whisper.",
readme="README.md", readme="README.md",
python_requires=">=3.8", python_requires=">=3.8",

View File

@ -268,8 +268,8 @@ def align(
start, end, score = None, None, None start, end, score = None, None, None
if cdx in clean_cdx: if cdx in clean_cdx:
char_seg = char_segments[clean_cdx.index(cdx)] char_seg = char_segments[clean_cdx.index(cdx)]
start = char_seg.start * ratio + t1 start = round(char_seg.start * ratio + t1, 3)
end = char_seg.end * ratio + t1 end = round(char_seg.end * ratio + t1, 3)
score = char_seg.score score = char_seg.score
char_segments_arr["char"].append(char) char_segments_arr["char"].append(char)

View File

@ -181,9 +181,6 @@ class FasterWhisperPipeline(Pipeline):
def preprocess(self, audio): def preprocess(self, audio):
audio = audio['inputs'] audio = audio['inputs']
if isinstance(audio, np.ndarray):
audio = torch.from_numpy(audio)
features = log_mel_spectrogram(audio, padding=N_SAMPLES - audio.shape[0]) features = log_mel_spectrogram(audio, padding=N_SAMPLES - audio.shape[0])
return {'inputs': features} return {'inputs': features}
@ -256,7 +253,7 @@ class FasterWhisperPipeline(Pipeline):
def detect_language(self, audio: np.ndarray): def detect_language(self, audio: np.ndarray):
if audio.shape[0] < N_SAMPLES: if audio.shape[0] < N_SAMPLES:
print("Warning: audio is shorter than 30s, language detection may be inaccurate.") print("Warning: audio is shorter than 30s, language detection may be inaccurate.")
segment = log_mel_spectrogram(torch.from_numpy(audio[:N_SAMPLES]), segment = log_mel_spectrogram(audio[: N_SAMPLES],
padding=0 if audio.shape[0] >= N_SAMPLES else N_SAMPLES - audio.shape[0]) padding=0 if audio.shape[0] >= N_SAMPLES else N_SAMPLES - audio.shape[0])
encoder_output = self.model.encode(segment) encoder_output = self.model.encode(segment)
results = self.model.model.detect_language(encoder_output) results = self.model.model.detect_language(encoder_output)

View File

@ -22,12 +22,6 @@ N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2
FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame
TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token
with np.load(
os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
) as f:
MEL_FILTERS = torch.from_numpy(f[f"mel_{80}"])
def load_audio(file: str, sr: int = SAMPLE_RATE): def load_audio(file: str, sr: int = SAMPLE_RATE):
""" """
@ -85,9 +79,27 @@ def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
return array return array
@torch.compile(fullgraph=True) @lru_cache(maxsize=None)
def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor:
"""
load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
Allows decoupling librosa dependency; saved using:
np.savez_compressed(
"mel_filters.npz",
mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
)
"""
assert n_mels == 80, f"Unsupported n_mels: {n_mels}"
with np.load(
os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
) as f:
return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
def log_mel_spectrogram( def log_mel_spectrogram(
audio: torch.Tensor, audio: Union[str, np.ndarray, torch.Tensor],
n_mels: int = N_MELS,
padding: int = 0, padding: int = 0,
device: Optional[Union[str, torch.device]] = None, device: Optional[Union[str, torch.device]] = None,
): ):
@ -96,7 +108,7 @@ def log_mel_spectrogram(
Parameters Parameters
---------- ----------
audio: torch.Tensor, shape = (*) audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
n_mels: int n_mels: int
@ -113,19 +125,21 @@ def log_mel_spectrogram(
torch.Tensor, shape = (80, n_frames) torch.Tensor, shape = (80, n_frames)
A Tensor that contains the Mel spectrogram A Tensor that contains the Mel spectrogram
""" """
global MEL_FILTERS if not torch.is_tensor(audio):
if isinstance(audio, str):
audio = load_audio(audio)
audio = torch.from_numpy(audio)
if device is not None: if device is not None:
audio = audio.to(device) audio = audio.to(device)
if padding > 0: if padding > 0:
audio = F.pad(audio, (0, padding)) audio = F.pad(audio, (0, padding))
window = torch.hann_window(N_FFT).to(audio.device) window = torch.hann_window(N_FFT).to(audio.device)
stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=False) stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
# Square the real and imaginary components and sum them together, similar to torch.abs() on complex tensors magnitudes = stft[..., :-1].abs() ** 2
magnitudes = (stft[:, :-1, :] ** 2).sum(dim=-1)
MEL_FILTERS = MEL_FILTERS.to(audio.device) filters = mel_filters(audio.device, n_mels)
mel_spec = MEL_FILTERS @ magnitudes mel_spec = filters @ magnitudes
log_spec = torch.clamp(mel_spec, min=1e-10).log10() log_spec = torch.clamp(mel_spec, min=1e-10).log10()
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)

View File

@ -72,7 +72,6 @@ def cli():
parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face Access Token to access PyAnnote gated models") parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face Access Token to access PyAnnote gated models")
# parser.add_argument("--model_flush", action="store_true", help="Flush memory from each model after use, reduces GPU requirement but slower processing >1 audio file.") # parser.add_argument("--model_flush", action="store_true", help="Flush memory from each model after use, reduces GPU requirement but slower processing >1 audio file.")
parser.add_argument("--tmp_dir", default=None, help="Temporary directory to write audio file if input if not .wav format (only for VAD).")
# fmt: on # fmt: on
args = parser.parse_args().__dict__ args = parser.parse_args().__dict__
@ -86,10 +85,6 @@ def cli():
# model_flush: bool = args.pop("model_flush") # model_flush: bool = args.pop("model_flush")
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
tmp_dir: str = args.pop("tmp_dir")
if tmp_dir is not None:
os.makedirs(tmp_dir, exist_ok=True)
align_model: str = args.pop("align_model") align_model: str = args.pop("align_model")
interpolate_method: str = args.pop("interpolate_method") interpolate_method: str = args.pop("interpolate_method")
no_align: bool = args.pop("no_align") no_align: bool = args.pop("no_align")
@ -195,7 +190,7 @@ def cli():
tmp_results = results tmp_results = results
print(">>Performing diarization...") print(">>Performing diarization...")
results = [] results = []
diarize_model = DiarizationPipeline(use_auth_token=hf_token, device=device) diarize_model = DiarizationPipeline(use_auth_token=hf_token)
for result, input_audio_path in tmp_results: for result, input_audio_path in tmp_results:
diarize_segments = diarize_model(input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers) diarize_segments = diarize_model(input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers)
results_segments, word_segments = assign_word_speakers(diarize_segments, result["segments"]) results_segments, word_segments = assign_word_speakers(diarize_segments, result["segments"])