Merge pull request #400 from davidas1/fast-diarize

make diarization faster
This commit is contained in:
Max Bain
2023-08-02 13:43:20 +01:00
committed by GitHub
2 changed files with 12 additions and 4 deletions

View File

@ -177,8 +177,8 @@ print(result["segments"]) # after alignment
diarize_model = whisperx.DiarizationPipeline(use_auth_token=YOUR_HF_TOKEN, device=device)
# add min/max number of speakers if known
diarize_segments = diarize_model(audio_file)
# diarize_model(audio_file, min_speakers=min_speakers, max_speakers=max_speakers)
diarize_segments = diarize_model(audio)
# diarize_model(audio, min_speakers=min_speakers, max_speakers=max_speakers)
result = whisperx.assign_word_speakers(diarize_segments, result)
print(diarize_segments)

View File

@ -4,6 +4,8 @@ from pyannote.audio import Pipeline
from typing import Optional, Union
import torch
from .audio import load_audio, SAMPLE_RATE
class DiarizationPipeline:
def __init__(
self,
@ -15,8 +17,14 @@ class DiarizationPipeline:
device = torch.device(device)
self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token).to(device)
def __call__(self, audio, min_speakers=None, max_speakers=None):
segments = self.model(audio, min_speakers=min_speakers, max_speakers=max_speakers)
def __call__(self, audio: Union[str, np.ndarray], min_speakers=None, max_speakers=None):
if isinstance(audio, str):
audio = load_audio(audio)
audio_data = {
'waveform': torch.from_numpy(audio[None, :]),
'sample_rate': SAMPLE_RATE
}
segments = self.model(audio_data, min_speakers=min_speakers, max_speakers=max_speakers)
diarize_df = pd.DataFrame(segments.itertracks(yield_label=True))
diarize_df['start'] = diarize_df[0].apply(lambda x: x.start)
diarize_df['end'] = diarize_df[0].apply(lambda x: x.end)