diff --git a/whisperx/diarize.py b/whisperx/diarize.py index 2a9bd69..eae6a19 100644 --- a/whisperx/diarize.py +++ b/whisperx/diarize.py @@ -4,7 +4,7 @@ from pyannote.audio import Pipeline from typing import Optional, Union import torch -from .audio import SAMPLE_RATE +from .audio import load_audio, SAMPLE_RATE class DiarizationPipeline: def __init__( @@ -18,6 +18,8 @@ class DiarizationPipeline: self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token).to(device) def __call__(self, audio, min_speakers=None, max_speakers=None): + if isinstance(audio, str): + audio = load_audio(audio) audio_data = { 'waveform': torch.from_numpy(audio[None, :]), 'sample_rate': SAMPLE_RATE