mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
176 lines
6.5 KiB
Python
176 lines
6.5 KiB
Python
![]() |
import pandas as pd
|
||
|
import numpy as np
|
||
|
from pyannote.core import Annotation, Segment, SlidingWindowFeature, Timeline
|
||
|
from typing import List, Tuple, Optional
|
||
|
|
||
|
class Binarize:
|
||
|
"""Binarize detection scores using hysteresis thresholding
|
||
|
Parameters
|
||
|
----------
|
||
|
onset : float, optional
|
||
|
Onset threshold. Defaults to 0.5.
|
||
|
offset : float, optional
|
||
|
Offset threshold. Defaults to `onset`.
|
||
|
min_duration_on : float, optional
|
||
|
Remove active regions shorter than that many seconds. Defaults to 0s.
|
||
|
min_duration_off : float, optional
|
||
|
Fill inactive regions shorter than that many seconds. Defaults to 0s.
|
||
|
pad_onset : float, optional
|
||
|
Extend active regions by moving their start time by that many seconds.
|
||
|
Defaults to 0s.
|
||
|
pad_offset : float, optional
|
||
|
Extend active regions by moving their end time by that many seconds.
|
||
|
Defaults to 0s.
|
||
|
max_duration: float
|
||
|
The maximum length of an active segment, divides segment at timestamp with lowest score.
|
||
|
Reference
|
||
|
---------
|
||
|
Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of
|
||
|
RNN-based Voice Activity Detection", InterSpeech 2015.
|
||
|
|
||
|
Pyannote-audio
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
onset: float = 0.5,
|
||
|
offset: Optional[float] = None,
|
||
|
min_duration_on: float = 0.0,
|
||
|
min_duration_off: float = 0.0,
|
||
|
pad_onset: float = 0.0,
|
||
|
pad_offset: float = 0.0,
|
||
|
max_duration: float = float('inf')
|
||
|
):
|
||
|
|
||
|
super().__init__()
|
||
|
|
||
|
self.onset = onset
|
||
|
self.offset = offset or onset
|
||
|
|
||
|
self.pad_onset = pad_onset
|
||
|
self.pad_offset = pad_offset
|
||
|
|
||
|
self.min_duration_on = min_duration_on
|
||
|
self.min_duration_off = min_duration_off
|
||
|
|
||
|
self.max_duration = max_duration
|
||
|
|
||
|
def __call__(self, scores: SlidingWindowFeature) -> Annotation:
|
||
|
"""Binarize detection scores
|
||
|
Parameters
|
||
|
----------
|
||
|
scores : SlidingWindowFeature
|
||
|
Detection scores.
|
||
|
Returns
|
||
|
-------
|
||
|
active : Annotation
|
||
|
Binarized scores.
|
||
|
"""
|
||
|
|
||
|
num_frames, num_classes = scores.data.shape
|
||
|
frames = scores.sliding_window
|
||
|
timestamps = [frames[i].middle for i in range(num_frames)]
|
||
|
|
||
|
# annotation meant to store 'active' regions
|
||
|
active = Annotation()
|
||
|
for k, k_scores in enumerate(scores.data.T):
|
||
|
|
||
|
label = k if scores.labels is None else scores.labels[k]
|
||
|
|
||
|
# initial state
|
||
|
start = timestamps[0]
|
||
|
is_active = k_scores[0] > self.onset
|
||
|
curr_scores = [k_scores[0]]
|
||
|
curr_timestamps = [start]
|
||
|
for t, y in zip(timestamps[1:], k_scores[1:]):
|
||
|
# currently active
|
||
|
if is_active:
|
||
|
curr_duration = t - start
|
||
|
if curr_duration > self.max_duration:
|
||
|
# if curr_duration > 15:
|
||
|
# import pdb; pdb.set_trace()
|
||
|
search_after = len(curr_scores) // 2
|
||
|
# divide segment
|
||
|
min_score_div_idx = search_after + np.argmin(curr_scores[search_after:])
|
||
|
min_score_t = curr_timestamps[min_score_div_idx]
|
||
|
region = Segment(start - self.pad_onset, min_score_t + self.pad_offset)
|
||
|
active[region, k] = label
|
||
|
start = curr_timestamps[min_score_div_idx]
|
||
|
curr_scores = curr_scores[min_score_div_idx+1:]
|
||
|
curr_timestamps = curr_timestamps[min_score_div_idx+1:]
|
||
|
# switching from active to inactive
|
||
|
elif y < self.offset:
|
||
|
region = Segment(start - self.pad_onset, t + self.pad_offset)
|
||
|
active[region, k] = label
|
||
|
start = t
|
||
|
is_active = False
|
||
|
curr_scores = []
|
||
|
curr_timestamps = []
|
||
|
# currently inactive
|
||
|
else:
|
||
|
# switching from inactive to active
|
||
|
if y > self.onset:
|
||
|
start = t
|
||
|
is_active = True
|
||
|
curr_scores.append(y)
|
||
|
curr_timestamps.append(t)
|
||
|
|
||
|
# if active at the end, add final region
|
||
|
if is_active:
|
||
|
region = Segment(start - self.pad_onset, t + self.pad_offset)
|
||
|
active[region, k] = label
|
||
|
|
||
|
# because of padding, some active regions might be overlapping: merge them.
|
||
|
# also: fill same speaker gaps shorter than min_duration_off
|
||
|
if self.pad_offset > 0.0 or self.pad_onset > 0.0 or self.min_duration_off > 0.0:
|
||
|
if self.max_duration < float("inf"):
|
||
|
raise NotImplementedError(f"This would break current max_duration param")
|
||
|
active = active.support(collar=self.min_duration_off)
|
||
|
|
||
|
# remove tracks shorter than min_duration_on
|
||
|
if self.min_duration_on > 0:
|
||
|
for segment, track in list(active.itertracks()):
|
||
|
if segment.duration < self.min_duration_on:
|
||
|
del active[segment, track]
|
||
|
|
||
|
return active
|
||
|
|
||
|
|
||
|
def merge_vad(vad_arr, pad_onset=0.0, pad_offset=0.0, min_duration_off=0.0, min_duration_on=0.0):
|
||
|
# because of padding, some active regions might be overlapping: merge them.
|
||
|
# also: fill same speaker gaps shorter than min_duration_off
|
||
|
|
||
|
active = Annotation()
|
||
|
for k, vad_t in enumerate(vad_arr):
|
||
|
region = Segment(vad_t[0] - pad_onset, vad_t[1] + pad_offset)
|
||
|
active[region, k] = 1
|
||
|
|
||
|
|
||
|
if pad_offset > 0.0 or pad_onset > 0.0 or min_duration_off > 0.0:
|
||
|
active = active.support(collar=min_duration_off)
|
||
|
|
||
|
# remove tracks shorter than min_duration_on
|
||
|
if min_duration_on > 0:
|
||
|
for segment, track in list(active.itertracks()):
|
||
|
if segment.duration < min_duration_on:
|
||
|
del active[segment, track]
|
||
|
|
||
|
active = active.for_json()
|
||
|
active_segs = pd.DataFrame([x['segment'] for x in active['content']])
|
||
|
return active_segs
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
from pyannote.audio import Inference
|
||
|
hook = lambda segmentation: segmentation
|
||
|
inference = Inference("pyannote/segmentation", pre_aggregation_hook=hook)
|
||
|
audio = "/tmp/11962.wav"
|
||
|
scores = inference(audio)
|
||
|
binarize = Binarize(max_duration=15)
|
||
|
anno = binarize(scores)
|
||
|
res = []
|
||
|
for ann in anno.get_timeline():
|
||
|
res.append((ann.start, ann.end))
|
||
|
|
||
|
res = pd.DataFrame(res)
|
||
|
res[2] = res[1] - res[0]
|