mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
Accept alternative VAD methods. Extend to use Silero VAD.
This commit is contained in:
3
whisperx/vads/__init__.py
Normal file
3
whisperx/vads/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from whisperx.vads.pyannote import Pyannote
|
||||
from whisperx.vads.silero import Silero
|
||||
from whisperx.vads.vad import Vad
|
309
whisperx/vads/pyannote.py
Normal file
309
whisperx/vads/pyannote.py
Normal file
@ -0,0 +1,309 @@
|
||||
import hashlib
|
||||
import os
|
||||
import urllib
|
||||
from typing import Callable, Text, Union
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from pyannote.audio import Model
|
||||
from pyannote.audio.core.io import AudioFile
|
||||
from pyannote.audio.pipelines import VoiceActivityDetection
|
||||
from pyannote.audio.pipelines.utils import PipelineModel
|
||||
from pyannote.core import Annotation, SlidingWindowFeature
|
||||
from pyannote.core import Segment
|
||||
from tqdm import tqdm
|
||||
|
||||
from whisperx.diarize import Segment as SegmentX
|
||||
from whisperx.vads.vad import Vad
|
||||
|
||||
# deprecated
|
||||
VAD_SEGMENTATION_URL = "https://whisperx.s3.eu-west-2.amazonaws.com/model_weights/segmentation/0b5b3216d60a2d32fc086b47ea8c67589aaeb26b7e07fcbe620d6d0b83e209ea/pytorch_model.bin"
|
||||
|
||||
def load_vad_model(device, vad_onset=0.500, vad_offset=0.363, use_auth_token=None, model_fp=None):
|
||||
model_dir = torch.hub._get_torch_home()
|
||||
|
||||
vad_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
os.makedirs(model_dir, exist_ok = True)
|
||||
if model_fp is None:
|
||||
# Dynamically resolve the path to the model file
|
||||
model_fp = os.path.join(vad_dir, "assets", "pytorch_model.bin")
|
||||
model_fp = os.path.abspath(model_fp) # Ensure the path is absolute
|
||||
else:
|
||||
model_fp = os.path.abspath(model_fp) # Ensure any provided path is absolute
|
||||
|
||||
# Check if the resolved model file exists
|
||||
if not os.path.exists(model_fp):
|
||||
raise FileNotFoundError(f"Model file not found at {model_fp}")
|
||||
|
||||
if os.path.exists(model_fp) and not os.path.isfile(model_fp):
|
||||
raise RuntimeError(f"{model_fp} exists and is not a regular file")
|
||||
|
||||
model_bytes = open(model_fp, "rb").read()
|
||||
if hashlib.sha256(model_bytes).hexdigest() != VAD_SEGMENTATION_URL.split('/')[-2]:
|
||||
raise RuntimeError(
|
||||
"Model has been downloaded but the SHA256 checksum does not match. Please retry loading the model."
|
||||
)
|
||||
|
||||
vad_model = Model.from_pretrained(model_fp, use_auth_token=use_auth_token)
|
||||
hyperparameters = {"onset": vad_onset,
|
||||
"offset": vad_offset,
|
||||
"min_duration_on": 0.1,
|
||||
"min_duration_off": 0.1}
|
||||
vad_pipeline = VoiceActivitySegmentation(segmentation=vad_model, device=torch.device(device))
|
||||
vad_pipeline.instantiate(hyperparameters)
|
||||
|
||||
return vad_pipeline
|
||||
|
||||
class Binarize:
|
||||
"""Binarize detection scores using hysteresis thresholding, with min-cut operation
|
||||
to ensure not segments are longer than max_duration.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
onset : float, optional
|
||||
Onset threshold. Defaults to 0.5.
|
||||
offset : float, optional
|
||||
Offset threshold. Defaults to `onset`.
|
||||
min_duration_on : float, optional
|
||||
Remove active regions shorter than that many seconds. Defaults to 0s.
|
||||
min_duration_off : float, optional
|
||||
Fill inactive regions shorter than that many seconds. Defaults to 0s.
|
||||
pad_onset : float, optional
|
||||
Extend active regions by moving their start time by that many seconds.
|
||||
Defaults to 0s.
|
||||
pad_offset : float, optional
|
||||
Extend active regions by moving their end time by that many seconds.
|
||||
Defaults to 0s.
|
||||
max_duration: float
|
||||
The maximum length of an active segment, divides segment at timestamp with lowest score.
|
||||
Reference
|
||||
---------
|
||||
Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of
|
||||
RNN-based Voice Activity Detection", InterSpeech 2015.
|
||||
|
||||
Modified by Max Bain to include WhisperX's min-cut operation
|
||||
https://arxiv.org/abs/2303.00747
|
||||
|
||||
Pyannote-audio
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
onset: float = 0.5,
|
||||
offset: Optional[float] = None,
|
||||
min_duration_on: float = 0.0,
|
||||
min_duration_off: float = 0.0,
|
||||
pad_onset: float = 0.0,
|
||||
pad_offset: float = 0.0,
|
||||
max_duration: float = float('inf')
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.onset = onset
|
||||
self.offset = offset or onset
|
||||
|
||||
self.pad_onset = pad_onset
|
||||
self.pad_offset = pad_offset
|
||||
|
||||
self.min_duration_on = min_duration_on
|
||||
self.min_duration_off = min_duration_off
|
||||
|
||||
self.max_duration = max_duration
|
||||
|
||||
def __call__(self, scores: SlidingWindowFeature) -> Annotation:
|
||||
"""Binarize detection scores
|
||||
Parameters
|
||||
----------
|
||||
scores : SlidingWindowFeature
|
||||
Detection scores.
|
||||
Returns
|
||||
-------
|
||||
active : Annotation
|
||||
Binarized scores.
|
||||
"""
|
||||
|
||||
num_frames, num_classes = scores.data.shape
|
||||
frames = scores.sliding_window
|
||||
timestamps = [frames[i].middle for i in range(num_frames)]
|
||||
|
||||
# annotation meant to store 'active' regions
|
||||
active = Annotation()
|
||||
for k, k_scores in enumerate(scores.data.T):
|
||||
|
||||
label = k if scores.labels is None else scores.labels[k]
|
||||
|
||||
# initial state
|
||||
start = timestamps[0]
|
||||
is_active = k_scores[0] > self.onset
|
||||
curr_scores = [k_scores[0]]
|
||||
curr_timestamps = [start]
|
||||
t = start
|
||||
for t, y in zip(timestamps[1:], k_scores[1:]):
|
||||
# currently active
|
||||
if is_active:
|
||||
curr_duration = t - start
|
||||
if curr_duration > self.max_duration:
|
||||
search_after = len(curr_scores) // 2
|
||||
# divide segment
|
||||
min_score_div_idx = search_after + np.argmin(curr_scores[search_after:])
|
||||
min_score_t = curr_timestamps[min_score_div_idx]
|
||||
region = Segment(start - self.pad_onset, min_score_t + self.pad_offset)
|
||||
active[region, k] = label
|
||||
start = curr_timestamps[min_score_div_idx]
|
||||
curr_scores = curr_scores[min_score_div_idx + 1:]
|
||||
curr_timestamps = curr_timestamps[min_score_div_idx + 1:]
|
||||
# switching from active to inactive
|
||||
elif y < self.offset:
|
||||
region = Segment(start - self.pad_onset, t + self.pad_offset)
|
||||
active[region, k] = label
|
||||
start = t
|
||||
is_active = False
|
||||
curr_scores = []
|
||||
curr_timestamps = []
|
||||
curr_scores.append(y)
|
||||
curr_timestamps.append(t)
|
||||
# currently inactive
|
||||
else:
|
||||
# switching from inactive to active
|
||||
if y > self.onset:
|
||||
start = t
|
||||
is_active = True
|
||||
|
||||
# if active at the end, add final region
|
||||
if is_active:
|
||||
region = Segment(start - self.pad_onset, t + self.pad_offset)
|
||||
active[region, k] = label
|
||||
|
||||
# because of padding, some active regions might be overlapping: merge them.
|
||||
# also: fill same speaker gaps shorter than min_duration_off
|
||||
if self.pad_offset > 0.0 or self.pad_onset > 0.0 or self.min_duration_off > 0.0:
|
||||
if self.max_duration < float("inf"):
|
||||
raise NotImplementedError(f"This would break current max_duration param")
|
||||
active = active.support(collar=self.min_duration_off)
|
||||
|
||||
# remove tracks shorter than min_duration_on
|
||||
if self.min_duration_on > 0:
|
||||
for segment, track in list(active.itertracks()):
|
||||
if segment.duration < self.min_duration_on:
|
||||
del active[segment, track]
|
||||
|
||||
return active
|
||||
|
||||
|
||||
class VoiceActivitySegmentation(VoiceActivityDetection):
|
||||
def __init__(
|
||||
self,
|
||||
segmentation: PipelineModel = "pyannote/segmentation",
|
||||
fscore: bool = False,
|
||||
use_auth_token: Union[Text, None] = None,
|
||||
**inference_kwargs,
|
||||
):
|
||||
|
||||
super().__init__(segmentation=segmentation, fscore=fscore, use_auth_token=use_auth_token, **inference_kwargs)
|
||||
|
||||
def apply(self, file: AudioFile, hook: Optional[Callable] = None) -> Annotation:
|
||||
"""Apply voice activity detection
|
||||
|
||||
Parameters
|
||||
----------
|
||||
file : AudioFile
|
||||
Processed file.
|
||||
hook : callable, optional
|
||||
Hook called after each major step of the pipeline with the following
|
||||
signature: hook("step_name", step_artefact, file=file)
|
||||
|
||||
Returns
|
||||
-------
|
||||
speech : Annotation
|
||||
Speech regions.
|
||||
"""
|
||||
|
||||
# setup hook (e.g. for debugging purposes)
|
||||
hook = self.setup_hook(file, hook=hook)
|
||||
|
||||
# apply segmentation model (only if needed)
|
||||
# output shape is (num_chunks, num_frames, 1)
|
||||
if self.training:
|
||||
if self.CACHED_SEGMENTATION in file:
|
||||
segmentations = file[self.CACHED_SEGMENTATION]
|
||||
else:
|
||||
segmentations = self._segmentation(file)
|
||||
file[self.CACHED_SEGMENTATION] = segmentations
|
||||
else:
|
||||
segmentations: SlidingWindowFeature = self._segmentation(file)
|
||||
|
||||
return segmentations
|
||||
|
||||
|
||||
class Pyannote(Vad):
|
||||
|
||||
def __init__(self, device, use_auth_token=None, model_fp=None, **kwargs):
|
||||
print(">>Performing voice activity detection using Pyannote...")
|
||||
super().__init__(kwargs['vad_onset'])
|
||||
|
||||
model_dir = torch.hub._get_torch_home()
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
if model_fp is None:
|
||||
model_fp = os.path.join(model_dir, "whisperx-vad-segmentation.bin")
|
||||
if os.path.exists(model_fp) and not os.path.isfile(model_fp):
|
||||
raise RuntimeError(f"{model_fp} exists and is not a regular file")
|
||||
|
||||
if not os.path.isfile(model_fp):
|
||||
with urllib.request.urlopen(VAD_SEGMENTATION_URL) as source, open(model_fp, "wb") as output:
|
||||
with tqdm(
|
||||
total=int(source.info().get("Content-Length")),
|
||||
ncols=80,
|
||||
unit="iB",
|
||||
unit_scale=True,
|
||||
unit_divisor=1024,
|
||||
) as loop:
|
||||
while True:
|
||||
buffer = source.read(8192)
|
||||
if not buffer:
|
||||
break
|
||||
|
||||
output.write(buffer)
|
||||
loop.update(len(buffer))
|
||||
|
||||
model_bytes = open(model_fp, "rb").read()
|
||||
if hashlib.sha256(model_bytes).hexdigest() != VAD_SEGMENTATION_URL.split('/')[-2]:
|
||||
raise RuntimeError(
|
||||
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
|
||||
)
|
||||
|
||||
vad_model = Model.from_pretrained(model_fp, use_auth_token=use_auth_token)
|
||||
hyperparameters = {"onset": kwargs['vad_onset'],
|
||||
"offset": kwargs['vad_offset'],
|
||||
"min_duration_on": 0.1,
|
||||
"min_duration_off": 0.1}
|
||||
self.vad_pipeline = VoiceActivitySegmentation(segmentation=vad_model, device=torch.device(device))
|
||||
self.vad_pipeline.instantiate(hyperparameters)
|
||||
|
||||
def __call__(self, audio: AudioFile, **kwargs):
|
||||
return self.vad_pipeline(audio)
|
||||
|
||||
@staticmethod
|
||||
def preprocess_audio(audio):
|
||||
return torch.from_numpy(audio).unsqueeze(0)
|
||||
|
||||
@staticmethod
|
||||
def merge_chunks(segments,
|
||||
chunk_size,
|
||||
onset: float = 0.5,
|
||||
offset: Optional[float] = None,
|
||||
):
|
||||
assert chunk_size > 0
|
||||
binarize = Binarize(max_duration=chunk_size, onset=onset, offset=offset)
|
||||
segments = binarize(segments)
|
||||
segments_list = []
|
||||
for speech_turn in segments.get_timeline():
|
||||
segments_list.append(SegmentX(speech_turn.start, speech_turn.end, "UNKNOWN"))
|
||||
|
||||
if len(segments_list) == 0:
|
||||
print("No active speech found in audio")
|
||||
return []
|
||||
assert segments_list, "segments_list is empty."
|
||||
return Vad.merge_chunks(segments_list, chunk_size, onset, offset)
|
62
whisperx/vads/silero.py
Normal file
62
whisperx/vads/silero.py
Normal file
@ -0,0 +1,62 @@
|
||||
from io import IOBase
|
||||
from pathlib import Path
|
||||
from typing import Mapping, Text
|
||||
from typing import Optional
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
|
||||
from whisperx.diarize import Segment as SegmentX
|
||||
from whisperx.vads.vad import Vad
|
||||
|
||||
AudioFile = Union[Text, Path, IOBase, Mapping]
|
||||
|
||||
|
||||
class Silero(Vad):
|
||||
# check again default values
|
||||
def __init__(self, **kwargs):
|
||||
print(">>Performing voice activity detection using Silero...")
|
||||
super().__init__(kwargs['vad_onset'])
|
||||
|
||||
self.vad_onset = kwargs['vad_onset']
|
||||
self.chunk_size = kwargs['chunk_size']
|
||||
self.vad_pipeline, vad_utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',
|
||||
model='silero_vad',
|
||||
force_reload=False,
|
||||
onnx=False,
|
||||
trust_repo=True)
|
||||
(self.get_speech_timestamps, _, self.read_audio, _, _) = vad_utils
|
||||
|
||||
def __call__(self, audio: AudioFile, **kwargs):
|
||||
"""use silero to get segments of speech"""
|
||||
# Only accept 16000 Hz for now.
|
||||
# Note: Silero models support both 8000 and 16000 Hz. Although other values are not directly supported,
|
||||
# multiples of 16000 (e.g. 32000 or 48000) are cast to 16000 inside of the JIT model!
|
||||
sample_rate = audio["sample_rate"]
|
||||
if sample_rate != 16000:
|
||||
raise ValueError("Only 16000Hz sample rate is allowed")
|
||||
|
||||
timestamps = self.get_speech_timestamps(audio["waveform"],
|
||||
model=self.vad_pipeline,
|
||||
sampling_rate=sample_rate,
|
||||
max_speech_duration_s=self.chunk_size,
|
||||
threshold=self.vad_onset
|
||||
# min_silence_duration_ms = self.min_duration_off/1000
|
||||
# min_speech_duration_ms = self.min_duration_on/1000
|
||||
# ...
|
||||
# See silero documentation for full option list
|
||||
)
|
||||
return [SegmentX(i['start'] / sample_rate, i['end'] / sample_rate, "UNKNOWN") for i in timestamps]
|
||||
|
||||
@staticmethod
|
||||
def preprocess_audio(audio):
|
||||
return audio
|
||||
|
||||
@staticmethod
|
||||
def merge_chunks(segments,
|
||||
chunk_size,
|
||||
onset: float = 0.5,
|
||||
offset: Optional[float] = None,
|
||||
):
|
||||
assert chunk_size > 0
|
||||
return Vad.merge_chunks(segments, chunk_size, onset, offset)
|
74
whisperx/vads/vad.py
Normal file
74
whisperx/vads/vad.py
Normal file
@ -0,0 +1,74 @@
|
||||
from typing import Optional
|
||||
|
||||
import pandas as pd
|
||||
from pyannote.core import Annotation, Segment
|
||||
|
||||
|
||||
class Vad:
|
||||
def __init__(self, vad_onset):
|
||||
if not (0 < vad_onset < 1):
|
||||
raise ValueError(
|
||||
"vad_onset is a decimal value between 0 and 1."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def preprocess_audio(audio):
|
||||
pass
|
||||
|
||||
# keep merge_chunks as static so it can be also used by manually assigned vad_model (see 'load_model')
|
||||
@staticmethod
|
||||
def merge_chunks(segments,
|
||||
chunk_size,
|
||||
onset: float,
|
||||
offset: Optional[float]):
|
||||
"""
|
||||
Merge operation described in paper
|
||||
"""
|
||||
curr_end = 0
|
||||
merged_segments = []
|
||||
seg_idxs = []
|
||||
speaker_idxs = []
|
||||
|
||||
curr_start = segments[0].start
|
||||
for seg in segments:
|
||||
if seg.end - curr_start > chunk_size and curr_end - curr_start > 0:
|
||||
merged_segments.append({
|
||||
"start": curr_start,
|
||||
"end": curr_end,
|
||||
"segments": seg_idxs,
|
||||
})
|
||||
curr_start = seg.start
|
||||
seg_idxs = []
|
||||
speaker_idxs = []
|
||||
curr_end = seg.end
|
||||
seg_idxs.append((seg.start, seg.end))
|
||||
speaker_idxs.append(seg.speaker)
|
||||
# add final
|
||||
merged_segments.append({
|
||||
"start": curr_start,
|
||||
"end": curr_end,
|
||||
"segments": seg_idxs,
|
||||
})
|
||||
|
||||
return merged_segments
|
||||
|
||||
# Unused function
|
||||
@staticmethod
|
||||
def merge_vad(vad_arr, pad_onset=0.0, pad_offset=0.0, min_duration_off=0.0, min_duration_on=0.0):
|
||||
active = Annotation()
|
||||
for k, vad_t in enumerate(vad_arr):
|
||||
region = Segment(vad_t[0] - pad_onset, vad_t[1] + pad_offset)
|
||||
active[region, k] = 1
|
||||
|
||||
if pad_offset > 0.0 or pad_onset > 0.0 or min_duration_off > 0.0:
|
||||
active = active.support(collar=min_duration_off)
|
||||
|
||||
# remove tracks shorter than min_duration_on
|
||||
if min_duration_on > 0:
|
||||
for segment, track in list(active.itertracks()):
|
||||
if segment.duration < min_duration_on:
|
||||
del active[segment, track]
|
||||
|
||||
active = active.for_json()
|
||||
active_segs = pd.DataFrame([x['segment'] for x in active['content']])
|
||||
return active_segs
|
Reference in New Issue
Block a user