mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
more
This commit is contained in:
@ -4,7 +4,7 @@ from pyannote.audio import Pipeline
|
|||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .audio import SAMPLE_RATE
|
from .audio import load_audio, SAMPLE_RATE
|
||||||
|
|
||||||
class DiarizationPipeline:
|
class DiarizationPipeline:
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -18,6 +18,8 @@ class DiarizationPipeline:
|
|||||||
self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token).to(device)
|
self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token).to(device)
|
||||||
|
|
||||||
def __call__(self, audio, min_speakers=None, max_speakers=None):
|
def __call__(self, audio, min_speakers=None, max_speakers=None):
|
||||||
|
if isinstance(audio, str):
|
||||||
|
audio = load_audio(audio)
|
||||||
audio_data = {
|
audio_data = {
|
||||||
'waveform': torch.from_numpy(audio[None, :]),
|
'waveform': torch.from_numpy(audio[None, :]),
|
||||||
'sample_rate': SAMPLE_RATE
|
'sample_rate': SAMPLE_RATE
|
||||||
|
Reference in New Issue
Block a user