mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
Set diarization device manually
This commit is contained in:
@ -1,14 +1,19 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from pyannote.audio import Pipeline
|
||||
from typing import Optional, Union
|
||||
import torch
|
||||
|
||||
class DiarizationPipeline:
|
||||
def __init__(
|
||||
self,
|
||||
model_name="pyannote/speaker-diarization@2.1",
|
||||
use_auth_token=None,
|
||||
device: Optional[Union[str, torch.device]] = "cpu",
|
||||
):
|
||||
self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token)
|
||||
if isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token).to(device)
|
||||
|
||||
def __call__(self, audio, min_speakers=None, max_speakers=None):
|
||||
segments = self.model(audio, min_speakers=min_speakers, max_speakers=max_speakers)
|
||||
|
@ -193,8 +193,9 @@ def cli():
|
||||
if hf_token is None:
|
||||
print("Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model...")
|
||||
tmp_results = results
|
||||
print(">>Performing diarization...")
|
||||
results = []
|
||||
diarize_model = DiarizationPipeline(use_auth_token=hf_token)
|
||||
diarize_model = DiarizationPipeline(use_auth_token=hf_token, device=device)
|
||||
for result, input_audio_path in tmp_results:
|
||||
diarize_segments = diarize_model(input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers)
|
||||
results_segments, word_segments = assign_word_speakers(diarize_segments, result["segments"])
|
||||
|
Reference in New Issue
Block a user