diff --git a/whisperx/diarize.py b/whisperx/diarize.py index 6f8c257..93ff41d 100644 --- a/whisperx/diarize.py +++ b/whisperx/diarize.py @@ -11,6 +11,7 @@ class DiarizationPipeline: use_auth_token=None, device: Optional[Union[str, torch.device]] = "cpu", ): + self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token) if isinstance(device, str): device = torch.device(device) self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token).to(device) diff --git a/whisperx/transcribe.py b/whisperx/transcribe.py index 4b5a664..f3f63fe 100644 --- a/whisperx/transcribe.py +++ b/whisperx/transcribe.py @@ -190,7 +190,7 @@ def cli(): tmp_results = results print(">>Performing diarization...") results = [] - diarize_model = DiarizationPipeline(use_auth_token=hf_token) + diarize_model = DiarizationPipeline(use_auth_token=hf_token, device=device) for result, input_audio_path in tmp_results: diarize_segments = diarize_model(input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers) results_segments, word_segments = assign_word_speakers(diarize_segments, result["segments"])