Set diarization device manually

This commit is contained in:
Simon
2023-05-04 16:25:34 +02:00
parent 2d59eb9726
commit d8f0ef4a19
2 changed files with 8 additions and 2 deletions

View File

@ -1,14 +1,19 @@
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from pyannote.audio import Pipeline from pyannote.audio import Pipeline
from typing import Optional, Union
import torch
class DiarizationPipeline: class DiarizationPipeline:
def __init__( def __init__(
self, self,
model_name="pyannote/speaker-diarization@2.1", model_name="pyannote/speaker-diarization@2.1",
use_auth_token=None, use_auth_token=None,
device: Optional[Union[str, torch.device]] = "cpu",
): ):
self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token) if isinstance(device, str):
device = torch.device(device)
self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token).to(device)
def __call__(self, audio, min_speakers=None, max_speakers=None): def __call__(self, audio, min_speakers=None, max_speakers=None):
segments = self.model(audio, min_speakers=min_speakers, max_speakers=max_speakers) segments = self.model(audio, min_speakers=min_speakers, max_speakers=max_speakers)

View File

@ -193,8 +193,9 @@ def cli():
if hf_token is None: if hf_token is None:
print("Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model...") print("Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model...")
tmp_results = results tmp_results = results
print(">>Performing diarization...")
results = [] results = []
diarize_model = DiarizationPipeline(use_auth_token=hf_token) diarize_model = DiarizationPipeline(use_auth_token=hf_token, device=device)
for result, input_audio_path in tmp_results: for result, input_audio_path in tmp_results:
diarize_segments = diarize_model(input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers) diarize_segments = diarize_model(input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers)
results_segments, word_segments = assign_word_speakers(diarize_segments, result["segments"]) results_segments, word_segments = assign_word_speakers(diarize_segments, result["segments"])