From 07361ba1d7e10c218ef30dd465b92e89ddebb5c5 Mon Sep 17 00:00:00 2001 From: Max Bain Date: Fri, 5 May 2023 11:53:51 +0100 Subject: [PATCH] add device to dia pipeline @sorgfresser --- whisperx/diarize.py | 1 + whisperx/transcribe.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/whisperx/diarize.py b/whisperx/diarize.py index 6f8c257..93ff41d 100644 --- a/whisperx/diarize.py +++ b/whisperx/diarize.py @@ -11,6 +11,7 @@ class DiarizationPipeline: use_auth_token=None, device: Optional[Union[str, torch.device]] = "cpu", ): + self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token) if isinstance(device, str): device = torch.device(device) self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token).to(device) diff --git a/whisperx/transcribe.py b/whisperx/transcribe.py index 4b5a664..f3f63fe 100644 --- a/whisperx/transcribe.py +++ b/whisperx/transcribe.py @@ -190,7 +190,7 @@ def cli(): tmp_results = results print(">>Performing diarization...") results = [] - diarize_model = DiarizationPipeline(use_auth_token=hf_token) + diarize_model = DiarizationPipeline(use_auth_token=hf_token, device=device) for result, input_audio_path in tmp_results: diarize_segments = diarize_model(input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers) results_segments, word_segments = assign_word_speakers(diarize_segments, result["segments"])