diff --git a/README.md b/README.md index b52401b..3be1a3c 100644 --- a/README.md +++ b/README.md @@ -177,8 +177,8 @@ print(result["segments"]) # after alignment diarize_model = whisperx.DiarizationPipeline(use_auth_token=YOUR_HF_TOKEN, device=device) # add min/max number of speakers if known -diarize_segments = diarize_model(audio_file) -# diarize_model(audio_file, min_speakers=min_speakers, max_speakers=max_speakers) +diarize_segments = diarize_model(audio) +# diarize_model(audio, min_speakers=min_speakers, max_speakers=max_speakers) result = whisperx.assign_word_speakers(diarize_segments, result) print(diarize_segments) diff --git a/whisperx/diarize.py b/whisperx/diarize.py index 320d2a4..e50dc0f 100644 --- a/whisperx/diarize.py +++ b/whisperx/diarize.py @@ -4,6 +4,8 @@ from pyannote.audio import Pipeline from typing import Optional, Union import torch +from .audio import load_audio, SAMPLE_RATE + class DiarizationPipeline: def __init__( self, @@ -15,8 +17,14 @@ class DiarizationPipeline: device = torch.device(device) self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token).to(device) - def __call__(self, audio, min_speakers=None, max_speakers=None): - segments = self.model(audio, min_speakers=min_speakers, max_speakers=max_speakers) + def __call__(self, audio: Union[str, np.ndarray], min_speakers=None, max_speakers=None): + if isinstance(audio, str): + audio = load_audio(audio) + audio_data = { + 'waveform': torch.from_numpy(audio[None, :]), + 'sample_rate': SAMPLE_RATE + } + segments = self.model(audio_data, min_speakers=min_speakers, max_speakers=max_speakers) diarize_df = pd.DataFrame(segments.itertracks(yield_label=True)) diarize_df['start'] = diarize_df[0].apply(lambda x: x.start) diarize_df['end'] = diarize_df[0].apply(lambda x: x.end)