Pass device to pyannote.audio.Inference

This commit is contained in:
smly
2023-02-22 03:48:20 +09:00
parent f7093e60d3
commit 57f5957e0e

View File

@ -645,9 +645,12 @@ def cli():
if hf_token is None: if hf_token is None:
print("Warning, no huggingface token used, needs to be saved in environment variable, otherwise will throw error loading VAD model...") print("Warning, no huggingface token used, needs to be saved in environment variable, otherwise will throw error loading VAD model...")
from pyannote.audio import Inference from pyannote.audio import Inference
vad_pipeline = Inference("pyannote/segmentation", vad_pipeline = Inference(
pre_aggregation_hook=lambda segmentation: segmentation, "pyannote/segmentation",
use_auth_token=hf_token) pre_aggregation_hook=lambda segmentation: segmentation,
use_auth_token=hf_token,
device=torch.device(device),
)
diarize_pipeline = None diarize_pipeline = None
if diarize: if diarize: