Merge pull request #210 from sorgfresser/v3

Update pyannote and torch version
This commit is contained in:
Max Bain
2023-05-04 20:32:06 +01:00
committed by GitHub
6 changed files with 30 additions and 36 deletions

View File

@ -1,6 +1,5 @@
torch==1.11.0 torch==2.0.0
torchaudio==0.11.0 torchaudio==2.0.1
pyannote.audio
faster-whisper faster-whisper
transformers transformers
ffmpeg-python==0.2.0 ffmpeg-python==0.2.0

View File

@ -19,7 +19,7 @@ setup(
for r in pkg_resources.parse_requirements( for r in pkg_resources.parse_requirements(
open(os.path.join(os.path.dirname(__file__), "requirements.txt")) open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
) )
], ] + ["pyannote.audio @ git+https://github.com/pyannote/pyannote-audio@11b56a137a578db9335efc00298f6ec1932e6317"],
entry_points = { entry_points = {
'console_scripts': ['whisperx=whisperx.transcribe:cli'], 'console_scripts': ['whisperx=whisperx.transcribe:cli'],
}, },

View File

@ -181,6 +181,9 @@ class FasterWhisperPipeline(Pipeline):
def preprocess(self, audio): def preprocess(self, audio):
audio = audio['inputs'] audio = audio['inputs']
if isinstance(audio, np.ndarray):
audio = torch.from_numpy(audio)
features = log_mel_spectrogram(audio, padding=N_SAMPLES - audio.shape[0]) features = log_mel_spectrogram(audio, padding=N_SAMPLES - audio.shape[0])
return {'inputs': features} return {'inputs': features}
@ -253,7 +256,7 @@ class FasterWhisperPipeline(Pipeline):
def detect_language(self, audio: np.ndarray): def detect_language(self, audio: np.ndarray):
if audio.shape[0] < N_SAMPLES: if audio.shape[0] < N_SAMPLES:
print("Warning: audio is shorter than 30s, language detection may be inaccurate.") print("Warning: audio is shorter than 30s, language detection may be inaccurate.")
segment = log_mel_spectrogram(audio[: N_SAMPLES], segment = log_mel_spectrogram(torch.from_numpy(audio[:N_SAMPLES]),
padding=0 if audio.shape[0] >= N_SAMPLES else N_SAMPLES - audio.shape[0]) padding=0 if audio.shape[0] >= N_SAMPLES else N_SAMPLES - audio.shape[0])
encoder_output = self.model.encode(segment) encoder_output = self.model.encode(segment)
results = self.model.model.detect_language(encoder_output) results = self.model.model.detect_language(encoder_output)

View File

@ -22,6 +22,12 @@ N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2
FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame
TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token
with np.load(
os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
) as f:
MEL_FILTERS = torch.from_numpy(f[f"mel_{80}"])
def load_audio(file: str, sr: int = SAMPLE_RATE): def load_audio(file: str, sr: int = SAMPLE_RATE):
""" """
@ -79,27 +85,9 @@ def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
return array return array
@lru_cache(maxsize=None) @torch.compile(fullgraph=True)
def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor:
"""
load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
Allows decoupling librosa dependency; saved using:
np.savez_compressed(
"mel_filters.npz",
mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
)
"""
assert n_mels == 80, f"Unsupported n_mels: {n_mels}"
with np.load(
os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
) as f:
return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
def log_mel_spectrogram( def log_mel_spectrogram(
audio: Union[str, np.ndarray, torch.Tensor], audio: torch.Tensor,
n_mels: int = N_MELS,
padding: int = 0, padding: int = 0,
device: Optional[Union[str, torch.device]] = None, device: Optional[Union[str, torch.device]] = None,
): ):
@ -108,7 +96,7 @@ def log_mel_spectrogram(
Parameters Parameters
---------- ----------
audio: Union[str, np.ndarray, torch.Tensor], shape = (*) audio: torch.Tensor, shape = (*)
The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
n_mels: int n_mels: int
@ -125,21 +113,19 @@ def log_mel_spectrogram(
torch.Tensor, shape = (80, n_frames) torch.Tensor, shape = (80, n_frames)
A Tensor that contains the Mel spectrogram A Tensor that contains the Mel spectrogram
""" """
if not torch.is_tensor(audio): global MEL_FILTERS
if isinstance(audio, str):
audio = load_audio(audio)
audio = torch.from_numpy(audio)
if device is not None: if device is not None:
audio = audio.to(device) audio = audio.to(device)
if padding > 0: if padding > 0:
audio = F.pad(audio, (0, padding)) audio = F.pad(audio, (0, padding))
window = torch.hann_window(N_FFT).to(audio.device) window = torch.hann_window(N_FFT).to(audio.device)
stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True) stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=False)
magnitudes = stft[..., :-1].abs() ** 2 # Square the real and imaginary components and sum them together, similar to torch.abs() on complex tensors
magnitudes = (stft[:, :-1, :] ** 2).sum(dim=-1)
filters = mel_filters(audio.device, n_mels) MEL_FILTERS = MEL_FILTERS.to(audio.device)
mel_spec = filters @ magnitudes mel_spec = MEL_FILTERS @ magnitudes
log_spec = torch.clamp(mel_spec, min=1e-10).log10() log_spec = torch.clamp(mel_spec, min=1e-10).log10()
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)

View File

@ -1,14 +1,19 @@
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from pyannote.audio import Pipeline from pyannote.audio import Pipeline
from typing import Optional, Union
import torch
class DiarizationPipeline: class DiarizationPipeline:
def __init__( def __init__(
self, self,
model_name="pyannote/speaker-diarization@2.1", model_name="pyannote/speaker-diarization@2.1",
use_auth_token=None, use_auth_token=None,
device: Optional[Union[str, torch.device]] = "cpu",
): ):
self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token) if isinstance(device, str):
device = torch.device(device)
self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token).to(device)
def __call__(self, audio, min_speakers=None, max_speakers=None): def __call__(self, audio, min_speakers=None, max_speakers=None):
segments = self.model(audio, min_speakers=min_speakers, max_speakers=max_speakers) segments = self.model(audio, min_speakers=min_speakers, max_speakers=max_speakers)

View File

@ -193,8 +193,9 @@ def cli():
if hf_token is None: if hf_token is None:
print("Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model...") print("Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model...")
tmp_results = results tmp_results = results
print(">>Performing diarization...")
results = [] results = []
diarize_model = DiarizationPipeline(use_auth_token=hf_token) diarize_model = DiarizationPipeline(use_auth_token=hf_token, device=device)
for result, input_audio_path in tmp_results: for result, input_audio_path in tmp_results:
diarize_segments = diarize_model(input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers) diarize_segments = diarize_model(input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers)
results_segments, word_segments = assign_word_speakers(diarize_segments, result["segments"]) results_segments, word_segments = assign_word_speakers(diarize_segments, result["segments"])