mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
add device to dia pipeline @sorgfresser
This commit is contained in:
@ -11,6 +11,7 @@ class DiarizationPipeline:
|
|||||||
use_auth_token=None,
|
use_auth_token=None,
|
||||||
device: Optional[Union[str, torch.device]] = "cpu",
|
device: Optional[Union[str, torch.device]] = "cpu",
|
||||||
):
|
):
|
||||||
|
self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token)
|
||||||
if isinstance(device, str):
|
if isinstance(device, str):
|
||||||
device = torch.device(device)
|
device = torch.device(device)
|
||||||
self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token).to(device)
|
self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token).to(device)
|
||||||
|
@ -190,7 +190,7 @@ def cli():
|
|||||||
tmp_results = results
|
tmp_results = results
|
||||||
print(">>Performing diarization...")
|
print(">>Performing diarization...")
|
||||||
results = []
|
results = []
|
||||||
diarize_model = DiarizationPipeline(use_auth_token=hf_token)
|
diarize_model = DiarizationPipeline(use_auth_token=hf_token, device=device)
|
||||||
for result, input_audio_path in tmp_results:
|
for result, input_audio_path in tmp_results:
|
||||||
diarize_segments = diarize_model(input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers)
|
diarize_segments = diarize_model(input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers)
|
||||||
results_segments, word_segments = assign_word_speakers(diarize_segments, result["segments"])
|
results_segments, word_segments = assign_word_speakers(diarize_segments, result["segments"])
|
||||||
|
Reference in New Issue
Block a user