From 8081ef2dcd5f57eab5f71ba838db01a29a8f1476 Mon Sep 17 00:00:00 2001 From: Max Bain Date: Sat, 28 Jan 2023 00:22:33 +0000 Subject: [PATCH] add custom vad binarization for vad cut --- whisperx/vad.py | 176 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 176 insertions(+) create mode 100644 whisperx/vad.py diff --git a/whisperx/vad.py b/whisperx/vad.py new file mode 100644 index 0000000..eb8bd2c --- /dev/null +++ b/whisperx/vad.py @@ -0,0 +1,176 @@ +import pandas as pd +import numpy as np +from pyannote.core import Annotation, Segment, SlidingWindowFeature, Timeline +from typing import List, Tuple, Optional + +class Binarize: + """Binarize detection scores using hysteresis thresholding + Parameters + ---------- + onset : float, optional + Onset threshold. Defaults to 0.5. + offset : float, optional + Offset threshold. Defaults to `onset`. + min_duration_on : float, optional + Remove active regions shorter than that many seconds. Defaults to 0s. + min_duration_off : float, optional + Fill inactive regions shorter than that many seconds. Defaults to 0s. + pad_onset : float, optional + Extend active regions by moving their start time by that many seconds. + Defaults to 0s. + pad_offset : float, optional + Extend active regions by moving their end time by that many seconds. + Defaults to 0s. + max_duration: float + The maximum length of an active segment, divides segment at timestamp with lowest score. + Reference + --------- + Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of + RNN-based Voice Activity Detection", InterSpeech 2015. + + Pyannote-audio + """ + + def __init__( + self, + onset: float = 0.5, + offset: Optional[float] = None, + min_duration_on: float = 0.0, + min_duration_off: float = 0.0, + pad_onset: float = 0.0, + pad_offset: float = 0.0, + max_duration: float = float('inf') + ): + + super().__init__() + + self.onset = onset + self.offset = offset or onset + + self.pad_onset = pad_onset + self.pad_offset = pad_offset + + self.min_duration_on = min_duration_on + self.min_duration_off = min_duration_off + + self.max_duration = max_duration + + def __call__(self, scores: SlidingWindowFeature) -> Annotation: + """Binarize detection scores + Parameters + ---------- + scores : SlidingWindowFeature + Detection scores. + Returns + ------- + active : Annotation + Binarized scores. + """ + + num_frames, num_classes = scores.data.shape + frames = scores.sliding_window + timestamps = [frames[i].middle for i in range(num_frames)] + + # annotation meant to store 'active' regions + active = Annotation() + for k, k_scores in enumerate(scores.data.T): + + label = k if scores.labels is None else scores.labels[k] + + # initial state + start = timestamps[0] + is_active = k_scores[0] > self.onset + curr_scores = [k_scores[0]] + curr_timestamps = [start] + for t, y in zip(timestamps[1:], k_scores[1:]): + # currently active + if is_active: + curr_duration = t - start + if curr_duration > self.max_duration: + # if curr_duration > 15: + # import pdb; pdb.set_trace() + search_after = len(curr_scores) // 2 + # divide segment + min_score_div_idx = search_after + np.argmin(curr_scores[search_after:]) + min_score_t = curr_timestamps[min_score_div_idx] + region = Segment(start - self.pad_onset, min_score_t + self.pad_offset) + active[region, k] = label + start = curr_timestamps[min_score_div_idx] + curr_scores = curr_scores[min_score_div_idx+1:] + curr_timestamps = curr_timestamps[min_score_div_idx+1:] + # switching from active to inactive + elif y < self.offset: + region = Segment(start - self.pad_onset, t + self.pad_offset) + active[region, k] = label + start = t + is_active = False + curr_scores = [] + curr_timestamps = [] + # currently inactive + else: + # switching from inactive to active + if y > self.onset: + start = t + is_active = True + curr_scores.append(y) + curr_timestamps.append(t) + + # if active at the end, add final region + if is_active: + region = Segment(start - self.pad_onset, t + self.pad_offset) + active[region, k] = label + + # because of padding, some active regions might be overlapping: merge them. + # also: fill same speaker gaps shorter than min_duration_off + if self.pad_offset > 0.0 or self.pad_onset > 0.0 or self.min_duration_off > 0.0: + if self.max_duration < float("inf"): + raise NotImplementedError(f"This would break current max_duration param") + active = active.support(collar=self.min_duration_off) + + # remove tracks shorter than min_duration_on + if self.min_duration_on > 0: + for segment, track in list(active.itertracks()): + if segment.duration < self.min_duration_on: + del active[segment, track] + + return active + + +def merge_vad(vad_arr, pad_onset=0.0, pad_offset=0.0, min_duration_off=0.0, min_duration_on=0.0): + # because of padding, some active regions might be overlapping: merge them. + # also: fill same speaker gaps shorter than min_duration_off + + active = Annotation() + for k, vad_t in enumerate(vad_arr): + region = Segment(vad_t[0] - pad_onset, vad_t[1] + pad_offset) + active[region, k] = 1 + + + if pad_offset > 0.0 or pad_onset > 0.0 or min_duration_off > 0.0: + active = active.support(collar=min_duration_off) + + # remove tracks shorter than min_duration_on + if min_duration_on > 0: + for segment, track in list(active.itertracks()): + if segment.duration < min_duration_on: + del active[segment, track] + + active = active.for_json() + active_segs = pd.DataFrame([x['segment'] for x in active['content']]) + return active_segs + + +if __name__ == "__main__": + from pyannote.audio import Inference + hook = lambda segmentation: segmentation + inference = Inference("pyannote/segmentation", pre_aggregation_hook=hook) + audio = "/tmp/11962.wav" + scores = inference(audio) + binarize = Binarize(max_duration=15) + anno = binarize(scores) + res = [] + for ann in anno.get_timeline(): + res.append((ann.start, ann.end)) + + res = pd.DataFrame(res) + res[2] = res[1] - res[0] \ No newline at end of file