diff --git a/whisperx/diarize.py b/whisperx/diarize.py index 34dfc63..6f8c257 100644 --- a/whisperx/diarize.py +++ b/whisperx/diarize.py @@ -1,14 +1,19 @@ import numpy as np import pandas as pd from pyannote.audio import Pipeline +from typing import Optional, Union +import torch class DiarizationPipeline: def __init__( self, model_name="pyannote/speaker-diarization@2.1", use_auth_token=None, + device: Optional[Union[str, torch.device]] = "cpu", ): - self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token) + if isinstance(device, str): + device = torch.device(device) + self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token).to(device) def __call__(self, audio, min_speakers=None, max_speakers=None): segments = self.model(audio, min_speakers=min_speakers, max_speakers=max_speakers) diff --git a/whisperx/transcribe.py b/whisperx/transcribe.py index dab9e12..e284e83 100644 --- a/whisperx/transcribe.py +++ b/whisperx/transcribe.py @@ -193,8 +193,9 @@ def cli(): if hf_token is None: print("Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model...") tmp_results = results + print(">>Performing diarization...") results = [] - diarize_model = DiarizationPipeline(use_auth_token=hf_token) + diarize_model = DiarizationPipeline(use_auth_token=hf_token, device=device) for result, input_audio_path in tmp_results: diarize_segments = diarize_model(input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers) results_segments, word_segments = assign_word_speakers(diarize_segments, result["segments"])