mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
skeleton v2
This commit is contained in:
128
whisperx/vad.py
128
whisperx/vad.py
@ -1,10 +1,32 @@
|
||||
import os
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from pyannote.core import Annotation, Segment, SlidingWindowFeature, Timeline
|
||||
import torch
|
||||
from typing import Optional, Callable, Union, Text
|
||||
from pyannote.audio.core.io import AudioFile
|
||||
from pyannote.core import Annotation, Segment, SlidingWindowFeature
|
||||
from pyannote.audio.pipelines.utils import PipelineModel
|
||||
from pyannote.audio import Model, Pipeline
|
||||
from pyannote.audio.pipelines import VoiceActivityDetection
|
||||
from .diarize import Segment as SegmentX
|
||||
|
||||
from typing import List, Tuple, Optional
|
||||
|
||||
def load_vad_model(device, vad_onset, vad_offset, use_auth_token=None):
|
||||
vad_model = Model.from_pretrained("pyannote/segmentation", use_auth_token=use_auth_token)
|
||||
hyperparameters = {"onset": vad_onset,
|
||||
"offset": vad_offset,
|
||||
"min_duration_on": 0.1,
|
||||
"min_duration_off": 0.1}
|
||||
vad_pipeline = VoiceActivitySegmentation(segmentation=vad_model, device=torch.device(device))
|
||||
vad_pipeline.instantiate(hyperparameters)
|
||||
|
||||
return vad_pipeline
|
||||
|
||||
class Binarize:
|
||||
"""Binarize detection scores using hysteresis thresholding
|
||||
"""Binarize detection scores using hysteresis thresholding, with min-cut operation
|
||||
to ensure not segments are longer than max_duration.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
onset : float, optional
|
||||
@ -28,6 +50,9 @@ class Binarize:
|
||||
Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of
|
||||
RNN-based Voice Activity Detection", InterSpeech 2015.
|
||||
|
||||
Modified by Max Bain to include WhisperX's min-cut operation
|
||||
https://arxiv.org/abs/2303.00747
|
||||
|
||||
Pyannote-audio
|
||||
"""
|
||||
|
||||
@ -136,6 +161,51 @@ class Binarize:
|
||||
return active
|
||||
|
||||
|
||||
class VoiceActivitySegmentation(VoiceActivityDetection):
|
||||
def __init__(
|
||||
self,
|
||||
segmentation: PipelineModel = "pyannote/segmentation",
|
||||
fscore: bool = False,
|
||||
use_auth_token: Union[Text, None] = None,
|
||||
**inference_kwargs,
|
||||
):
|
||||
|
||||
super().__init__(segmentation=segmentation, fscore=fscore, use_auth_token=use_auth_token, **inference_kwargs)
|
||||
|
||||
def apply(self, file: AudioFile, hook: Optional[Callable] = None) -> Annotation:
|
||||
"""Apply voice activity detection
|
||||
|
||||
Parameters
|
||||
----------
|
||||
file : AudioFile
|
||||
Processed file.
|
||||
hook : callable, optional
|
||||
Hook called after each major step of the pipeline with the following
|
||||
signature: hook("step_name", step_artefact, file=file)
|
||||
|
||||
Returns
|
||||
-------
|
||||
speech : Annotation
|
||||
Speech regions.
|
||||
"""
|
||||
|
||||
# setup hook (e.g. for debugging purposes)
|
||||
hook = self.setup_hook(file, hook=hook)
|
||||
|
||||
# apply segmentation model (only if needed)
|
||||
# output shape is (num_chunks, num_frames, 1)
|
||||
if self.training:
|
||||
if self.CACHED_SEGMENTATION in file:
|
||||
segmentations = file[self.CACHED_SEGMENTATION]
|
||||
else:
|
||||
segmentations = self._segmentation(file)
|
||||
file[self.CACHED_SEGMENTATION] = segmentations
|
||||
else:
|
||||
segmentations: SlidingWindowFeature = self._segmentation(file)
|
||||
|
||||
return segmentations
|
||||
|
||||
|
||||
def merge_vad(vad_arr, pad_onset=0.0, pad_offset=0.0, min_duration_off=0.0, min_duration_on=0.0):
|
||||
|
||||
active = Annotation()
|
||||
@ -157,21 +227,49 @@ def merge_vad(vad_arr, pad_onset=0.0, pad_offset=0.0, min_duration_off=0.0, min_
|
||||
active_segs = pd.DataFrame([x['segment'] for x in active['content']])
|
||||
return active_segs
|
||||
|
||||
def merge_chunks(segments, chunk_size):
|
||||
"""
|
||||
Merge operation described in paper
|
||||
"""
|
||||
curr_end = 0
|
||||
merged_segments = []
|
||||
seg_idxs = []
|
||||
speaker_idxs = []
|
||||
|
||||
assert chunk_size > 0
|
||||
binarize = Binarize(max_duration=chunk_size)
|
||||
segments = binarize(segments)
|
||||
segments_list = []
|
||||
for speech_turn in segments.get_timeline():
|
||||
segments_list.append(SegmentX(speech_turn.start, speech_turn.end, "UNKNOWN"))
|
||||
|
||||
assert segments_list, "segments_list is empty."
|
||||
# Make sur the starting point is the start of the segment.
|
||||
curr_start = segments_list[0].start
|
||||
|
||||
for seg in segments_list:
|
||||
if seg.end - curr_start > chunk_size and curr_end-curr_start > 0:
|
||||
merged_segments.append({
|
||||
"start": curr_start,
|
||||
"end": curr_end,
|
||||
"segments": seg_idxs,
|
||||
})
|
||||
curr_start = seg.start
|
||||
seg_idxs = []
|
||||
speaker_idxs = []
|
||||
curr_end = seg.end
|
||||
seg_idxs.append((seg.start, seg.end))
|
||||
speaker_idxs.append(seg.speaker)
|
||||
# add final
|
||||
merged_segments.append({
|
||||
"start": curr_start,
|
||||
"end": curr_end,
|
||||
"segments": seg_idxs,
|
||||
})
|
||||
return merged_segments
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# from pyannote.audio import Inference
|
||||
# hook = lambda segmentation: segmentation
|
||||
# inference = Inference("pyannote/segmentation", pre_aggregation_hook=hook)
|
||||
# audio = "/tmp/11962.wav"
|
||||
# scores = inference(audio)
|
||||
# binarize = Binarize(max_duration=15)
|
||||
# anno = binarize(scores)
|
||||
# res = []
|
||||
# for ann in anno.get_timeline():
|
||||
# res.append((ann.start, ann.end))
|
||||
|
||||
# res = pd.DataFrame(res)
|
||||
# res[2] = res[1] - res[0]
|
||||
import pandas as pd
|
||||
input_fp = "tt298650_sync.wav"
|
||||
df = pd.read_csv(f"/work/maxbain/tmp/{input_fp}.sad", sep=" ", header=None)
|
||||
|
Reference in New Issue
Block a user