mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
Set diarization device manually
This commit is contained in:
@ -1,14 +1,19 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from pyannote.audio import Pipeline
|
from pyannote.audio import Pipeline
|
||||||
|
from typing import Optional, Union
|
||||||
|
import torch
|
||||||
|
|
||||||
class DiarizationPipeline:
|
class DiarizationPipeline:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_name="pyannote/speaker-diarization@2.1",
|
model_name="pyannote/speaker-diarization@2.1",
|
||||||
use_auth_token=None,
|
use_auth_token=None,
|
||||||
|
device: Optional[Union[str, torch.device]] = "cpu",
|
||||||
):
|
):
|
||||||
self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token)
|
if isinstance(device, str):
|
||||||
|
device = torch.device(device)
|
||||||
|
self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token).to(device)
|
||||||
|
|
||||||
def __call__(self, audio, min_speakers=None, max_speakers=None):
|
def __call__(self, audio, min_speakers=None, max_speakers=None):
|
||||||
segments = self.model(audio, min_speakers=min_speakers, max_speakers=max_speakers)
|
segments = self.model(audio, min_speakers=min_speakers, max_speakers=max_speakers)
|
||||||
|
@ -193,8 +193,9 @@ def cli():
|
|||||||
if hf_token is None:
|
if hf_token is None:
|
||||||
print("Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model...")
|
print("Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model...")
|
||||||
tmp_results = results
|
tmp_results = results
|
||||||
|
print(">>Performing diarization...")
|
||||||
results = []
|
results = []
|
||||||
diarize_model = DiarizationPipeline(use_auth_token=hf_token)
|
diarize_model = DiarizationPipeline(use_auth_token=hf_token, device=device)
|
||||||
for result, input_audio_path in tmp_results:
|
for result, input_audio_path in tmp_results:
|
||||||
diarize_segments = diarize_model(input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers)
|
diarize_segments = diarize_model(input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers)
|
||||||
results_segments, word_segments = assign_word_speakers(diarize_segments, result["segments"])
|
results_segments, word_segments = assign_word_speakers(diarize_segments, result["segments"])
|
||||||
|
Reference in New Issue
Block a user