Merge pull request #646 from santialferez/diarize-patch-1

Update pyannote to v3.1.1 to fix a diarization problem (and diarize.py)
This commit is contained in:
Max Bain
2024-01-03 02:35:53 +00:00
committed by GitHub
15 changed files with 2366 additions and 3 deletions

View File

@ -18,14 +18,14 @@ class DiarizationPipeline:
device = torch.device(device)
self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token).to(device)
def __call__(self, audio: Union[str, np.ndarray], min_speakers=None, max_speakers=None):
def __call__(self, audio: Union[str, np.ndarray], num_speakers=None, min_speakers=None, max_speakers=None):
if isinstance(audio, str):
audio = load_audio(audio)
audio_data = {
'waveform': torch.from_numpy(audio[None, :]),
'sample_rate': SAMPLE_RATE
}
segments = self.model(audio_data, min_speakers=min_speakers, max_speakers=max_speakers)
segments = self.model(audio_data, num_speakers = num_speakers, min_speakers=min_speakers, max_speakers=max_speakers)
diarize_df = pd.DataFrame(segments.itertracks(yield_label=True), columns=['segment', 'label', 'speaker'])
diarize_df['start'] = diarize_df['segment'].apply(lambda x: x.start)
diarize_df['end'] = diarize_df['segment'].apply(lambda x: x.end)