Merge pull request #888 from 3manifold/silero-vad

Silero VAD support
This commit is contained in:
Max Bain
2025-01-11 17:15:27 +00:00
committed by GitHub
8 changed files with 262 additions and 99 deletions

View File

@ -278,7 +278,7 @@ Bug finding and pull requests are also highly appreciated to keep this project g
* [ ] Add benchmarking code (TEDLIUM for spd/WER & word segmentation)
* [ ] Allow silero-vad as alternative VAD option
* [x] Allow silero-vad as alternative VAD option
* [ ] Improve diarization (word level). *Harder than first thought...*
@ -300,7 +300,9 @@ Borrows important alignment code from [PyTorch tutorial on forced alignment](htt
And uses the wonderful pyannote VAD / Diarization https://github.com/pyannote/pyannote-audio
Valuable VAD & Diarization Models from [pyannote audio](https://github.com/pyannote/pyannote-audio)
Valuable VAD & Diarization Models from:
- [pyannote audio][https://github.com/pyannote/pyannote-audio]
- [silero vad][https://github.com/snakers4/silero-vad]
Great backend from [faster-whisper](https://github.com/guillaumekln/faster-whisper) and [CTranslate2](https://github.com/OpenNMT/CTranslate2)

View File

@ -1,4 +1,4 @@
from .transcribe import load_model
from .alignment import load_align_model, align
from .audio import load_audio
from .diarize import assign_word_speakers, DiarizationPipeline
from .asr import load_model

View File

@ -13,9 +13,8 @@ from transformers import Pipeline
from transformers.pipelines.pt_utils import PipelineIterator
from .audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram
import whisperx.vads
from .types import SingleSegment, TranscriptionResult
from .vad import VoiceActivitySegmentation, load_vad_model, merge_chunks
def find_numeral_symbol_tokens(tokenizer):
numeral_symbol_tokens = []
@ -106,7 +105,7 @@ class FasterWhisperPipeline(Pipeline):
def __init__(
self,
model: WhisperModel,
vad: VoiceActivitySegmentation,
vad,
vad_params: dict,
options: TranscriptionOptions,
tokenizer: Optional[Tokenizer] = None,
@ -208,7 +207,16 @@ class FasterWhisperPipeline(Pipeline):
# print(f2-f1)
yield {'inputs': audio[f1:f2]}
vad_segments = self.vad_model({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": SAMPLE_RATE})
# Pre-process audio and merge chunks as defined by the respective VAD child class
# In case vad_model is manually assigned (see 'load_model') follow the functionality of pyannote toolkit
if issubclass(type(self.vad_model), whisperx.vads.Vad):
waveform = self.vad_model.preprocess_audio(audio)
merge_chunks = self.vad_model.merge_chunks
else:
waveform = whisperx.vads.Pyannote.preprocess_audio(audio)
merge_chunks = whisperx.vads.Pyannote.merge_chunks
vad_segments = self.vad_model({"waveform": waveform, "sample_rate": SAMPLE_RATE})
vad_segments = merge_chunks(
vad_segments,
chunk_size,
@ -296,7 +304,8 @@ def load_model(
compute_type="float16",
asr_options: Optional[dict] = None,
language: Optional[str] = None,
vad_model: Optional[VoiceActivitySegmentation] = None,
vad_model = None,
vad_method = None,
vad_options: Optional[dict] = None,
model: Optional[WhisperModel] = None,
task="transcribe",
@ -309,6 +318,7 @@ def load_model(
whisper_arch - The name of the Whisper model to load.
device - The device to load the model on.
compute_type - The compute type to use for the model.
vad_method: str - The vad method to use. vad_model has higher priority if is not None.
options - A dictionary of options to use for the model.
language - The language of the model. (use English for now)
model - The WhisperModel instance to use.
@ -374,6 +384,7 @@ def load_model(
default_asr_options = TranscriptionOptions(**default_asr_options)
default_vad_options = {
"chunk_size": 30, # needed by silero since binarization happens before merge_chunks
"vad_onset": 0.500,
"vad_offset": 0.363
}
@ -381,10 +392,17 @@ def load_model(
if vad_options is not None:
default_vad_options.update(vad_options)
# Note: manually assigned vad_model has higher priority than vad_method!
if vad_model is not None:
print("Use manually assigned vad_model. vad_method is ignored.")
vad_model = vad_model
else:
vad_model = load_vad_model(torch.device(device), use_auth_token=None, **default_vad_options)
if vad_method == "silero":
vad_model = whisperx.vads.Silero(**default_vad_options)
elif vad_method == "pyannote":
vad_model = whisperx.vads.Pyannote(torch.device(device), use_auth_token=None, **default_vad_options)
else:
raise ValueError(f"Invalid vad_method: {vad_method}")
return FasterWhisperPipeline(
model=model,

View File

@ -46,6 +46,7 @@ def cli():
parser.add_argument("--return_char_alignments", action='store_true', help="Return character-level alignments in the output json file")
# vad params
parser.add_argument("--vad_method", type=str, default="pyannote", choices=["pyannote", "silero"], help="VAD method to be used")
parser.add_argument("--vad_onset", type=float, default=0.500, help="Onset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected")
parser.add_argument("--vad_offset", type=float, default=0.363, help="Offset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected.")
parser.add_argument("--chunk_size", type=int, default=30, help="Chunk size for merging VAD segments. Default is 30, reduce this if the chunk is too long.")
@ -110,6 +111,7 @@ def cli():
return_char_alignments: bool = args.pop("return_char_alignments")
hf_token: str = args.pop("hf_token")
vad_method: str = args.pop("vad_method")
vad_onset: float = args.pop("vad_onset")
vad_offset: float = args.pop("vad_offset")
@ -175,7 +177,7 @@ def cli():
results = []
tmp_results = []
# model = load_model(model_name, device=device, download_root=model_dir)
model = load_model(model_name, device=device, device_index=device_index, download_root=model_dir, compute_type=compute_type, language=args['language'], asr_options=asr_options, vad_options={"vad_onset": vad_onset, "vad_offset": vad_offset}, task=task, threads=faster_whisper_threads)
model = load_model(model_name, device=device, device_index=device_index, download_root=model_dir, compute_type=compute_type, language=args['language'], asr_options=asr_options, vad_method=vad_method, vad_options={"chunk_size":chunk_size, "vad_onset": vad_onset, "vad_offset": vad_offset}, task=task, threads=faster_whisper_threads)
for audio_path in args.pop("audio"):
audio = load_audio(audio_path)

View File

@ -0,0 +1,3 @@
from whisperx.vads.pyannote import Pyannote
from whisperx.vads.silero import Silero
from whisperx.vads.vad import Vad

View File

@ -1,19 +1,21 @@
import hashlib
import os
import urllib
from typing import Callable, Optional, Text, Union
from typing import Callable, Text, Union
from typing import Optional
import numpy as np
import pandas as pd
import torch
from pyannote.audio import Model
from pyannote.audio.core.io import AudioFile
from pyannote.audio.pipelines import VoiceActivityDetection
from pyannote.audio.pipelines.utils import PipelineModel
from pyannote.core import Annotation, Segment, SlidingWindowFeature
from pyannote.core import Annotation, SlidingWindowFeature
from pyannote.core import Segment
from tqdm import tqdm
from .diarize import Segment as SegmentX
from whisperx.diarize import Segment as SegmentX
from whisperx.vads.vad import Vad
# deprecated
VAD_SEGMENTATION_URL = "https://whisperx.s3.eu-west-2.amazonaws.com/model_weights/segmentation/0b5b3216d60a2d32fc086b47ea8c67589aaeb26b7e07fcbe620d6d0b83e209ea/pytorch_model.bin"
@ -236,41 +238,63 @@ class VoiceActivitySegmentation(VoiceActivityDetection):
return segmentations
def merge_vad(vad_arr, pad_onset=0.0, pad_offset=0.0, min_duration_off=0.0, min_duration_on=0.0):
class Pyannote(Vad):
active = Annotation()
for k, vad_t in enumerate(vad_arr):
region = Segment(vad_t[0] - pad_onset, vad_t[1] + pad_offset)
active[region, k] = 1
def __init__(self, device, use_auth_token=None, model_fp=None, **kwargs):
print(">>Performing voice activity detection using Pyannote...")
super().__init__(kwargs['vad_onset'])
model_dir = torch.hub._get_torch_home()
os.makedirs(model_dir, exist_ok=True)
if model_fp is None:
model_fp = os.path.join(model_dir, "whisperx-vad-segmentation.bin")
if os.path.exists(model_fp) and not os.path.isfile(model_fp):
raise RuntimeError(f"{model_fp} exists and is not a regular file")
if pad_offset > 0.0 or pad_onset > 0.0 or min_duration_off > 0.0:
active = active.support(collar=min_duration_off)
if not os.path.isfile(model_fp):
with urllib.request.urlopen(VAD_SEGMENTATION_URL) as source, open(model_fp, "wb") as output:
with tqdm(
total=int(source.info().get("Content-Length")),
ncols=80,
unit="iB",
unit_scale=True,
unit_divisor=1024,
) as loop:
while True:
buffer = source.read(8192)
if not buffer:
break
# remove tracks shorter than min_duration_on
if min_duration_on > 0:
for segment, track in list(active.itertracks()):
if segment.duration < min_duration_on:
del active[segment, track]
output.write(buffer)
loop.update(len(buffer))
active = active.for_json()
active_segs = pd.DataFrame([x['segment'] for x in active['content']])
return active_segs
model_bytes = open(model_fp, "rb").read()
if hashlib.sha256(model_bytes).hexdigest() != VAD_SEGMENTATION_URL.split('/')[-2]:
raise RuntimeError(
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
)
def merge_chunks(
segments,
vad_model = Model.from_pretrained(model_fp, use_auth_token=use_auth_token)
hyperparameters = {"onset": kwargs['vad_onset'],
"offset": kwargs['vad_offset'],
"min_duration_on": 0.1,
"min_duration_off": 0.1}
self.vad_pipeline = VoiceActivitySegmentation(segmentation=vad_model, device=torch.device(device))
self.vad_pipeline.instantiate(hyperparameters)
def __call__(self, audio: AudioFile, **kwargs):
return self.vad_pipeline(audio)
@staticmethod
def preprocess_audio(audio):
return torch.from_numpy(audio).unsqueeze(0)
@staticmethod
def merge_chunks(segments,
chunk_size,
onset: float = 0.5,
offset: Optional[float] = None,
):
"""
Merge operation described in paper
"""
curr_end = 0
merged_segments = []
seg_idxs = []
speaker_idxs = []
assert chunk_size > 0
binarize = Binarize(max_duration=chunk_size, onset=onset, offset=offset)
segments = binarize(segments)
@ -281,27 +305,5 @@ def merge_chunks(
if len(segments_list) == 0:
print("No active speech found in audio")
return []
# assert segments_list, "segments_list is empty."
# Make sur the starting point is the start of the segment.
curr_start = segments_list[0].start
for seg in segments_list:
if seg.end - curr_start > chunk_size and curr_end-curr_start > 0:
merged_segments.append({
"start": curr_start,
"end": curr_end,
"segments": seg_idxs,
})
curr_start = seg.start
seg_idxs = []
speaker_idxs = []
curr_end = seg.end
seg_idxs.append((seg.start, seg.end))
speaker_idxs.append(seg.speaker)
# add final
merged_segments.append({
"start": curr_start,
"end": curr_end,
"segments": seg_idxs,
})
return merged_segments
assert segments_list, "segments_list is empty."
return Vad.merge_chunks(segments_list, chunk_size, onset, offset)

62
whisperx/vads/silero.py Normal file
View File

@ -0,0 +1,62 @@
from io import IOBase
from pathlib import Path
from typing import Mapping, Text
from typing import Optional
from typing import Union
import torch
from whisperx.diarize import Segment as SegmentX
from whisperx.vads.vad import Vad
AudioFile = Union[Text, Path, IOBase, Mapping]
class Silero(Vad):
# check again default values
def __init__(self, **kwargs):
print(">>Performing voice activity detection using Silero...")
super().__init__(kwargs['vad_onset'])
self.vad_onset = kwargs['vad_onset']
self.chunk_size = kwargs['chunk_size']
self.vad_pipeline, vad_utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',
model='silero_vad',
force_reload=False,
onnx=False,
trust_repo=True)
(self.get_speech_timestamps, _, self.read_audio, _, _) = vad_utils
def __call__(self, audio: AudioFile, **kwargs):
"""use silero to get segments of speech"""
# Only accept 16000 Hz for now.
# Note: Silero models support both 8000 and 16000 Hz. Although other values are not directly supported,
# multiples of 16000 (e.g. 32000 or 48000) are cast to 16000 inside of the JIT model!
sample_rate = audio["sample_rate"]
if sample_rate != 16000:
raise ValueError("Only 16000Hz sample rate is allowed")
timestamps = self.get_speech_timestamps(audio["waveform"],
model=self.vad_pipeline,
sampling_rate=sample_rate,
max_speech_duration_s=self.chunk_size,
threshold=self.vad_onset
# min_silence_duration_ms = self.min_duration_off/1000
# min_speech_duration_ms = self.min_duration_on/1000
# ...
# See silero documentation for full option list
)
return [SegmentX(i['start'] / sample_rate, i['end'] / sample_rate, "UNKNOWN") for i in timestamps]
@staticmethod
def preprocess_audio(audio):
return audio
@staticmethod
def merge_chunks(segments,
chunk_size,
onset: float = 0.5,
offset: Optional[float] = None,
):
assert chunk_size > 0
return Vad.merge_chunks(segments, chunk_size, onset, offset)

74
whisperx/vads/vad.py Normal file
View File

@ -0,0 +1,74 @@
from typing import Optional
import pandas as pd
from pyannote.core import Annotation, Segment
class Vad:
def __init__(self, vad_onset):
if not (0 < vad_onset < 1):
raise ValueError(
"vad_onset is a decimal value between 0 and 1."
)
@staticmethod
def preprocess_audio(audio):
pass
# keep merge_chunks as static so it can be also used by manually assigned vad_model (see 'load_model')
@staticmethod
def merge_chunks(segments,
chunk_size,
onset: float,
offset: Optional[float]):
"""
Merge operation described in paper
"""
curr_end = 0
merged_segments = []
seg_idxs = []
speaker_idxs = []
curr_start = segments[0].start
for seg in segments:
if seg.end - curr_start > chunk_size and curr_end - curr_start > 0:
merged_segments.append({
"start": curr_start,
"end": curr_end,
"segments": seg_idxs,
})
curr_start = seg.start
seg_idxs = []
speaker_idxs = []
curr_end = seg.end
seg_idxs.append((seg.start, seg.end))
speaker_idxs.append(seg.speaker)
# add final
merged_segments.append({
"start": curr_start,
"end": curr_end,
"segments": seg_idxs,
})
return merged_segments
# Unused function
@staticmethod
def merge_vad(vad_arr, pad_onset=0.0, pad_offset=0.0, min_duration_off=0.0, min_duration_on=0.0):
active = Annotation()
for k, vad_t in enumerate(vad_arr):
region = Segment(vad_t[0] - pad_onset, vad_t[1] + pad_offset)
active[region, k] = 1
if pad_offset > 0.0 or pad_onset > 0.0 or min_duration_off > 0.0:
active = active.support(collar=min_duration_off)
# remove tracks shorter than min_duration_on
if min_duration_on > 0:
for segment, track in list(active.itertracks()):
if segment.duration < min_duration_on:
del active[segment, track]
active = active.for_json()
active_segs = pd.DataFrame([x['segment'] for x in active['content']])
return active_segs