Files
whisperX/whisperx/vads/pyannote.py

264 lines
9.7 KiB
Python
Raw Normal View History

2023-03-30 05:31:57 +01:00
import os
from typing import Callable, Text, Union
from typing import Optional
2023-04-24 21:08:43 +01:00
import numpy as np
2023-03-30 05:31:57 +01:00
import torch
2023-03-30 18:57:55 +01:00
from pyannote.audio import Model
2023-04-24 21:08:43 +01:00
from pyannote.audio.core.io import AudioFile
2023-03-30 05:31:57 +01:00
from pyannote.audio.pipelines import VoiceActivityDetection
2023-04-24 21:08:43 +01:00
from pyannote.audio.pipelines.utils import PipelineModel
from pyannote.core import Annotation, SlidingWindowFeature
from pyannote.core import Segment
2023-04-24 21:08:43 +01:00
from whisperx.diarize import Segment as SegmentX
from whisperx.vads.vad import Vad
2023-04-24 21:08:43 +01:00
def load_vad_model(device, vad_onset=0.500, vad_offset=0.363, use_auth_token=None, model_fp=None):
2023-03-30 18:57:55 +01:00
model_dir = torch.hub._get_torch_home()
2024-12-14 22:16:43 -06:00
2025-01-12 12:50:15 +00:00
main_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
2024-12-14 22:16:43 -06:00
os.makedirs(model_dir, exist_ok = True)
2023-04-14 15:02:58 +01:00
if model_fp is None:
2024-12-14 22:16:43 -06:00
# Dynamically resolve the path to the model file
2025-01-12 12:50:15 +00:00
model_fp = os.path.join(main_dir, "assets", "pytorch_model.bin")
2024-12-14 22:16:43 -06:00
model_fp = os.path.abspath(model_fp) # Ensure the path is absolute
else:
model_fp = os.path.abspath(model_fp) # Ensure any provided path is absolute
2024-12-14 22:16:43 -06:00
# Check if the resolved model file exists
if not os.path.exists(model_fp):
raise FileNotFoundError(f"Model file not found at {model_fp}")
2023-03-30 18:57:55 +01:00
if os.path.exists(model_fp) and not os.path.isfile(model_fp):
raise RuntimeError(f"{model_fp} exists and is not a regular file")
model_bytes = open(model_fp, "rb").read()
vad_model = Model.from_pretrained(model_fp, use_auth_token=use_auth_token)
hyperparameters = {"onset": vad_onset,
2023-03-30 05:31:57 +01:00
"offset": vad_offset,
"min_duration_on": 0.1,
"min_duration_off": 0.1}
vad_pipeline = VoiceActivitySegmentation(segmentation=vad_model, device=torch.device(device))
vad_pipeline.instantiate(hyperparameters)
return vad_pipeline
class Binarize:
2023-03-30 05:31:57 +01:00
"""Binarize detection scores using hysteresis thresholding, with min-cut operation
to ensure not segments are longer than max_duration.
Parameters
----------
onset : float, optional
Onset threshold. Defaults to 0.5.
offset : float, optional
Offset threshold. Defaults to `onset`.
min_duration_on : float, optional
Remove active regions shorter than that many seconds. Defaults to 0s.
min_duration_off : float, optional
Fill inactive regions shorter than that many seconds. Defaults to 0s.
pad_onset : float, optional
Extend active regions by moving their start time by that many seconds.
Defaults to 0s.
pad_offset : float, optional
Extend active regions by moving their end time by that many seconds.
Defaults to 0s.
max_duration: float
The maximum length of an active segment, divides segment at timestamp with lowest score.
Reference
---------
Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of
RNN-based Voice Activity Detection", InterSpeech 2015.
Modified by Max Bain to include WhisperX's min-cut operation
2023-03-30 05:31:57 +01:00
https://arxiv.org/abs/2303.00747
Pyannote-audio
"""
def __init__(
self,
onset: float = 0.5,
offset: Optional[float] = None,
min_duration_on: float = 0.0,
min_duration_off: float = 0.0,
pad_onset: float = 0.0,
pad_offset: float = 0.0,
max_duration: float = float('inf')
):
super().__init__()
self.onset = onset
self.offset = offset or onset
self.pad_onset = pad_onset
self.pad_offset = pad_offset
self.min_duration_on = min_duration_on
self.min_duration_off = min_duration_off
self.max_duration = max_duration
def __call__(self, scores: SlidingWindowFeature) -> Annotation:
"""Binarize detection scores
Parameters
----------
scores : SlidingWindowFeature
Detection scores.
Returns
-------
active : Annotation
Binarized scores.
"""
num_frames, num_classes = scores.data.shape
frames = scores.sliding_window
timestamps = [frames[i].middle for i in range(num_frames)]
# annotation meant to store 'active' regions
active = Annotation()
for k, k_scores in enumerate(scores.data.T):
label = k if scores.labels is None else scores.labels[k]
# initial state
start = timestamps[0]
is_active = k_scores[0] > self.onset
curr_scores = [k_scores[0]]
curr_timestamps = [start]
2023-11-03 18:50:15 +01:00
t = start
for t, y in zip(timestamps[1:], k_scores[1:]):
# currently active
if is_active:
curr_duration = t - start
if curr_duration > self.max_duration:
search_after = len(curr_scores) // 2
# divide segment
min_score_div_idx = search_after + np.argmin(curr_scores[search_after:])
min_score_t = curr_timestamps[min_score_div_idx]
region = Segment(start - self.pad_onset, min_score_t + self.pad_offset)
active[region, k] = label
start = curr_timestamps[min_score_div_idx]
curr_scores = curr_scores[min_score_div_idx + 1:]
curr_timestamps = curr_timestamps[min_score_div_idx + 1:]
# switching from active to inactive
elif y < self.offset:
region = Segment(start - self.pad_onset, t + self.pad_offset)
active[region, k] = label
start = t
is_active = False
curr_scores = []
curr_timestamps = []
2023-05-26 20:36:03 +01:00
curr_scores.append(y)
curr_timestamps.append(t)
# currently inactive
else:
# switching from inactive to active
if y > self.onset:
start = t
is_active = True
# if active at the end, add final region
if is_active:
region = Segment(start - self.pad_onset, t + self.pad_offset)
active[region, k] = label
# because of padding, some active regions might be overlapping: merge them.
# also: fill same speaker gaps shorter than min_duration_off
if self.pad_offset > 0.0 or self.pad_onset > 0.0 or self.min_duration_off > 0.0:
if self.max_duration < float("inf"):
raise NotImplementedError(f"This would break current max_duration param")
active = active.support(collar=self.min_duration_off)
# remove tracks shorter than min_duration_on
if self.min_duration_on > 0:
for segment, track in list(active.itertracks()):
if segment.duration < self.min_duration_on:
del active[segment, track]
return active
2023-03-30 05:31:57 +01:00
class VoiceActivitySegmentation(VoiceActivityDetection):
def __init__(
self,
segmentation: PipelineModel = "pyannote/segmentation",
fscore: bool = False,
use_auth_token: Union[Text, None] = None,
**inference_kwargs,
2023-03-30 05:31:57 +01:00
):
super().__init__(segmentation=segmentation, fscore=fscore, use_auth_token=use_auth_token, **inference_kwargs)
def apply(self, file: AudioFile, hook: Optional[Callable] = None) -> Annotation:
"""Apply voice activity detection
Parameters
----------
file : AudioFile
Processed file.
hook : callable, optional
Hook called after each major step of the pipeline with the following
signature: hook("step_name", step_artefact, file=file)
Returns
-------
speech : Annotation
Speech regions.
"""
# setup hook (e.g. for debugging purposes)
hook = self.setup_hook(file, hook=hook)
# apply segmentation model (only if needed)
# output shape is (num_chunks, num_frames, 1)
if self.training:
if self.CACHED_SEGMENTATION in file:
segmentations = file[self.CACHED_SEGMENTATION]
else:
segmentations = self._segmentation(file)
file[self.CACHED_SEGMENTATION] = segmentations
else:
segmentations: SlidingWindowFeature = self._segmentation(file)
return segmentations
class Pyannote(Vad):
def __init__(self, device, use_auth_token=None, model_fp=None, **kwargs):
print(">>Performing voice activity detection using Pyannote...")
super().__init__(kwargs['vad_onset'])
2025-01-12 12:50:15 +00:00
self.vad_pipeline = load_vad_model(device, use_auth_token=use_auth_token, model_fp=model_fp)
def __call__(self, audio: AudioFile, **kwargs):
return self.vad_pipeline(audio)
@staticmethod
def preprocess_audio(audio):
return torch.from_numpy(audio).unsqueeze(0)
@staticmethod
def merge_chunks(segments,
chunk_size,
onset: float = 0.5,
offset: Optional[float] = None,
):
assert chunk_size > 0
binarize = Binarize(max_duration=chunk_size, onset=onset, offset=offset)
segments = binarize(segments)
segments_list = []
for speech_turn in segments.get_timeline():
segments_list.append(SegmentX(speech_turn.start, speech_turn.end, "UNKNOWN"))
if len(segments_list) == 0:
print("No active speech found in audio")
return []
assert segments_list, "segments_list is empty."
return Vad.merge_chunks(segments_list, chunk_size, onset, offset)