Accept alternative VAD methods. Extend to use Silero VAD.

This commit is contained in:
3manifold
2024-09-26 10:28:52 +02:00
parent 10b05fc43f
commit 79eb8fa53d
8 changed files with 262 additions and 101 deletions

View File

@ -278,7 +278,7 @@ Bug finding and pull requests are also highly appreciated to keep this project g
* [ ] Add benchmarking code (TEDLIUM for spd/WER & word segmentation)
* [ ] Allow silero-vad as alternative VAD option
* [x] Allow silero-vad as alternative VAD option
* [ ] Improve diarization (word level). *Harder than first thought...*
@ -300,7 +300,9 @@ Borrows important alignment code from [PyTorch tutorial on forced alignment](htt
And uses the wonderful pyannote VAD / Diarization https://github.com/pyannote/pyannote-audio
Valuable VAD & Diarization Models from [pyannote audio](https://github.com/pyannote/pyannote-audio)
Valuable VAD & Diarization Models from:
- [pyannote audio][https://github.com/pyannote/pyannote-audio]
- [silero vad][https://github.com/snakers4/silero-vad]
Great backend from [faster-whisper](https://github.com/guillaumekln/faster-whisper) and [CTranslate2](https://github.com/OpenNMT/CTranslate2)

View File

@ -1,4 +1,4 @@
from .transcribe import load_model
from .alignment import load_align_model, align
from .audio import load_audio
from .diarize import assign_word_speakers, DiarizationPipeline
from .diarize import assign_word_speakers, DiarizationPipeline
from .asr import load_model

View File

@ -1,6 +1,5 @@
import os
import warnings
from typing import List, NamedTuple, Optional, Union
from typing import List, Optional, Union
import ctranslate2
import faster_whisper
@ -12,9 +11,8 @@ from transformers import Pipeline
from transformers.pipelines.pt_utils import PipelineIterator
from .audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram
import whisperx.vads
from .types import SingleSegment, TranscriptionResult
from .vad import VoiceActivitySegmentation, load_vad_model, merge_chunks
def find_numeral_symbol_tokens(tokenizer):
numeral_symbol_tokens = []
@ -105,7 +103,7 @@ class FasterWhisperPipeline(Pipeline):
def __init__(
self,
model: WhisperModel,
vad: VoiceActivitySegmentation,
vad,
vad_params: dict,
options: TranscriptionOptions,
tokenizer: Optional[Tokenizer] = None,
@ -207,7 +205,16 @@ class FasterWhisperPipeline(Pipeline):
# print(f2-f1)
yield {'inputs': audio[f1:f2]}
vad_segments = self.vad_model({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": SAMPLE_RATE})
# Pre-process audio and merge chunks as defined by the respective VAD child class
# In case vad_model is manually assigned (see 'load_model') follow the functionality of pyannote toolkit
if issubclass(type(self.vad_model), whisperx.vads.Vad):
waveform = self.vad_model.preprocess_audio(audio)
merge_chunks = self.vad_model.merge_chunks
else:
waveform = whisperx.vads.Pyannote.preprocess_audio(audio)
merge_chunks = whisperx.vads.Pyannote.merge_chunks
vad_segments = self.vad_model({"waveform": waveform, "sample_rate": SAMPLE_RATE})
vad_segments = merge_chunks(
vad_segments,
chunk_size,
@ -295,7 +302,8 @@ def load_model(
compute_type="float16",
asr_options: Optional[dict] = None,
language: Optional[str] = None,
vad_model: Optional[VoiceActivitySegmentation] = None,
vad_model = None,
vad_method = None,
vad_options: Optional[dict] = None,
model: Optional[WhisperModel] = None,
task="transcribe",
@ -308,6 +316,7 @@ def load_model(
whisper_arch - The name of the Whisper model to load.
device - The device to load the model on.
compute_type - The compute type to use for the model.
vad_method: str - The vad method to use. vad_model has higher priority if is not None.
options - A dictionary of options to use for the model.
language - The language of the model. (use English for now)
model - The WhisperModel instance to use.
@ -373,6 +382,7 @@ def load_model(
default_asr_options = TranscriptionOptions(**default_asr_options)
default_vad_options = {
"chunk_size": 30, # needed by silero since binarization happens before merge_chunks
"vad_onset": 0.500,
"vad_offset": 0.363
}
@ -380,10 +390,16 @@ def load_model(
if vad_options is not None:
default_vad_options.update(vad_options)
# Note: manually assigned vad_model has higher priority than vad_method!
if vad_model is not None:
print("Use manually assigned vad_model. vad_method is ignored.")
vad_model = vad_model
else:
vad_model = load_vad_model(torch.device(device), use_auth_token=None, **default_vad_options)
match vad_method:
case "silero":
vad_model = whisperx.vads.Silero(**default_vad_options)
case "pyannote" | _:
vad_model = whisperx.vads.Pyannote(torch.device(device), use_auth_token=None, **default_vad_options)
return FasterWhisperPipeline(
model=model,
@ -393,4 +409,4 @@ def load_model(
language=language,
suppress_numerals=suppress_numerals,
vad_params=default_vad_options,
)
)

View File

@ -46,6 +46,7 @@ def cli():
parser.add_argument("--return_char_alignments", action='store_true', help="Return character-level alignments in the output json file")
# vad params
parser.add_argument("--vad_method", type=str, default="pyannote", choices=["pyannote", "silero"], help="VAD method to be used")
parser.add_argument("--vad_onset", type=float, default=0.500, help="Onset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected")
parser.add_argument("--vad_offset", type=float, default=0.363, help="Offset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected.")
parser.add_argument("--chunk_size", type=int, default=30, help="Chunk size for merging VAD segments. Default is 30, reduce this if the chunk is too long.")
@ -110,6 +111,7 @@ def cli():
return_char_alignments: bool = args.pop("return_char_alignments")
hf_token: str = args.pop("hf_token")
vad_method: str = args.pop("vad_method")
vad_onset: float = args.pop("vad_onset")
vad_offset: float = args.pop("vad_offset")
@ -175,7 +177,7 @@ def cli():
results = []
tmp_results = []
# model = load_model(model_name, device=device, download_root=model_dir)
model = load_model(model_name, device=device, device_index=device_index, download_root=model_dir, compute_type=compute_type, language=args['language'], asr_options=asr_options, vad_options={"vad_onset": vad_onset, "vad_offset": vad_offset}, task=task, threads=faster_whisper_threads)
model = load_model(model_name, device=device, device_index=device_index, download_root=model_dir, compute_type=compute_type, language=args['language'], asr_options=asr_options, vad_method=vad_method, vad_options={"chunk_size":chunk_size, "vad_onset": vad_onset, "vad_offset": vad_offset}, task=task, threads=faster_whisper_threads)
for audio_path in args.pop("audio"):
audio = load_audio(audio_path)

View File

@ -0,0 +1,3 @@
from whisperx.vads.pyannote import Pyannote
from whisperx.vads.silero import Silero
from whisperx.vads.vad import Vad

View File

@ -1,19 +1,21 @@
import hashlib
import os
import urllib
from typing import Callable, Optional, Text, Union
from typing import Callable, Text, Union
from typing import Optional
import numpy as np
import pandas as pd
import torch
from pyannote.audio import Model
from pyannote.audio.core.io import AudioFile
from pyannote.audio.pipelines import VoiceActivityDetection
from pyannote.audio.pipelines.utils import PipelineModel
from pyannote.core import Annotation, Segment, SlidingWindowFeature
from pyannote.core import Annotation, SlidingWindowFeature
from pyannote.core import Segment
from tqdm import tqdm
from .diarize import Segment as SegmentX
from whisperx.diarize import Segment as SegmentX
from whisperx.vads.vad import Vad
# deprecated
VAD_SEGMENTATION_URL = "https://whisperx.s3.eu-west-2.amazonaws.com/model_weights/segmentation/0b5b3216d60a2d32fc086b47ea8c67589aaeb26b7e07fcbe620d6d0b83e209ea/pytorch_model.bin"
@ -30,11 +32,11 @@ def load_vad_model(device, vad_onset=0.500, vad_offset=0.363, use_auth_token=Non
model_fp = os.path.abspath(model_fp) # Ensure the path is absolute
else:
model_fp = os.path.abspath(model_fp) # Ensure any provided path is absolute
# Check if the resolved model file exists
if not os.path.exists(model_fp):
raise FileNotFoundError(f"Model file not found at {model_fp}")
if os.path.exists(model_fp) and not os.path.isfile(model_fp):
raise RuntimeError(f"{model_fp} exists and is not a regular file")
@ -45,7 +47,7 @@ def load_vad_model(device, vad_onset=0.500, vad_offset=0.363, use_auth_token=Non
)
vad_model = Model.from_pretrained(model_fp, use_auth_token=use_auth_token)
hyperparameters = {"onset": vad_onset,
hyperparameters = {"onset": vad_onset,
"offset": vad_offset,
"min_duration_on": 0.1,
"min_duration_off": 0.1}
@ -81,21 +83,21 @@ class Binarize:
Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of
RNN-based Voice Activity Detection", InterSpeech 2015.
Modified by Max Bain to include WhisperX's min-cut operation
Modified by Max Bain to include WhisperX's min-cut operation
https://arxiv.org/abs/2303.00747
Pyannote-audio
"""
def __init__(
self,
onset: float = 0.5,
offset: Optional[float] = None,
min_duration_on: float = 0.0,
min_duration_off: float = 0.0,
pad_onset: float = 0.0,
pad_offset: float = 0.0,
max_duration: float = float('inf')
self,
onset: float = 0.5,
offset: Optional[float] = None,
min_duration_on: float = 0.0,
min_duration_off: float = 0.0,
pad_onset: float = 0.0,
pad_offset: float = 0.0,
max_duration: float = float('inf')
):
super().__init__()
@ -141,7 +143,7 @@ class Binarize:
t = start
for t, y in zip(timestamps[1:], k_scores[1:]):
# currently active
if is_active:
if is_active:
curr_duration = t - start
if curr_duration > self.max_duration:
search_after = len(curr_scores) // 2
@ -151,8 +153,8 @@ class Binarize:
region = Segment(start - self.pad_onset, min_score_t + self.pad_offset)
active[region, k] = label
start = curr_timestamps[min_score_div_idx]
curr_scores = curr_scores[min_score_div_idx+1:]
curr_timestamps = curr_timestamps[min_score_div_idx+1:]
curr_scores = curr_scores[min_score_div_idx + 1:]
curr_timestamps = curr_timestamps[min_score_div_idx + 1:]
# switching from active to inactive
elif y < self.offset:
region = Segment(start - self.pad_onset, t + self.pad_offset)
@ -193,11 +195,11 @@ class Binarize:
class VoiceActivitySegmentation(VoiceActivityDetection):
def __init__(
self,
segmentation: PipelineModel = "pyannote/segmentation",
fscore: bool = False,
use_auth_token: Union[Text, None] = None,
**inference_kwargs,
self,
segmentation: PipelineModel = "pyannote/segmentation",
fscore: bool = False,
use_auth_token: Union[Text, None] = None,
**inference_kwargs,
):
super().__init__(segmentation=segmentation, fscore=fscore, use_auth_token=use_auth_token, **inference_kwargs)
@ -236,72 +238,72 @@ class VoiceActivitySegmentation(VoiceActivityDetection):
return segmentations
def merge_vad(vad_arr, pad_onset=0.0, pad_offset=0.0, min_duration_off=0.0, min_duration_on=0.0):
class Pyannote(Vad):
active = Annotation()
for k, vad_t in enumerate(vad_arr):
region = Segment(vad_t[0] - pad_onset, vad_t[1] + pad_offset)
active[region, k] = 1
def __init__(self, device, use_auth_token=None, model_fp=None, **kwargs):
print(">>Performing voice activity detection using Pyannote...")
super().__init__(kwargs['vad_onset'])
model_dir = torch.hub._get_torch_home()
os.makedirs(model_dir, exist_ok=True)
if model_fp is None:
model_fp = os.path.join(model_dir, "whisperx-vad-segmentation.bin")
if os.path.exists(model_fp) and not os.path.isfile(model_fp):
raise RuntimeError(f"{model_fp} exists and is not a regular file")
if pad_offset > 0.0 or pad_onset > 0.0 or min_duration_off > 0.0:
active = active.support(collar=min_duration_off)
# remove tracks shorter than min_duration_on
if min_duration_on > 0:
for segment, track in list(active.itertracks()):
if segment.duration < min_duration_on:
del active[segment, track]
active = active.for_json()
active_segs = pd.DataFrame([x['segment'] for x in active['content']])
return active_segs
if not os.path.isfile(model_fp):
with urllib.request.urlopen(VAD_SEGMENTATION_URL) as source, open(model_fp, "wb") as output:
with tqdm(
total=int(source.info().get("Content-Length")),
ncols=80,
unit="iB",
unit_scale=True,
unit_divisor=1024,
) as loop:
while True:
buffer = source.read(8192)
if not buffer:
break
def merge_chunks(
segments,
chunk_size,
onset: float = 0.5,
offset: Optional[float] = None,
):
"""
Merge operation described in paper
"""
curr_end = 0
merged_segments = []
seg_idxs = []
speaker_idxs = []
output.write(buffer)
loop.update(len(buffer))
assert chunk_size > 0
binarize = Binarize(max_duration=chunk_size, onset=onset, offset=offset)
segments = binarize(segments)
segments_list = []
for speech_turn in segments.get_timeline():
segments_list.append(SegmentX(speech_turn.start, speech_turn.end, "UNKNOWN"))
model_bytes = open(model_fp, "rb").read()
if hashlib.sha256(model_bytes).hexdigest() != VAD_SEGMENTATION_URL.split('/')[-2]:
raise RuntimeError(
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
)
if len(segments_list) == 0:
print("No active speech found in audio")
return []
# assert segments_list, "segments_list is empty."
# Make sur the starting point is the start of the segment.
curr_start = segments_list[0].start
vad_model = Model.from_pretrained(model_fp, use_auth_token=use_auth_token)
hyperparameters = {"onset": kwargs['vad_onset'],
"offset": kwargs['vad_offset'],
"min_duration_on": 0.1,
"min_duration_off": 0.1}
self.vad_pipeline = VoiceActivitySegmentation(segmentation=vad_model, device=torch.device(device))
self.vad_pipeline.instantiate(hyperparameters)
for seg in segments_list:
if seg.end - curr_start > chunk_size and curr_end-curr_start > 0:
merged_segments.append({
"start": curr_start,
"end": curr_end,
"segments": seg_idxs,
})
curr_start = seg.start
seg_idxs = []
speaker_idxs = []
curr_end = seg.end
seg_idxs.append((seg.start, seg.end))
speaker_idxs.append(seg.speaker)
# add final
merged_segments.append({
"start": curr_start,
"end": curr_end,
"segments": seg_idxs,
})
return merged_segments
def __call__(self, audio: AudioFile, **kwargs):
return self.vad_pipeline(audio)
@staticmethod
def preprocess_audio(audio):
return torch.from_numpy(audio).unsqueeze(0)
@staticmethod
def merge_chunks(segments,
chunk_size,
onset: float = 0.5,
offset: Optional[float] = None,
):
assert chunk_size > 0
binarize = Binarize(max_duration=chunk_size, onset=onset, offset=offset)
segments = binarize(segments)
segments_list = []
for speech_turn in segments.get_timeline():
segments_list.append(SegmentX(speech_turn.start, speech_turn.end, "UNKNOWN"))
if len(segments_list) == 0:
print("No active speech found in audio")
return []
assert segments_list, "segments_list is empty."
return Vad.merge_chunks(segments_list, chunk_size, onset, offset)

62
whisperx/vads/silero.py Normal file
View File

@ -0,0 +1,62 @@
from io import IOBase
from pathlib import Path
from typing import Mapping, Text
from typing import Optional
from typing import Union
import torch
from whisperx.diarize import Segment as SegmentX
from whisperx.vads.vad import Vad
AudioFile = Union[Text, Path, IOBase, Mapping]
class Silero(Vad):
# check again default values
def __init__(self, **kwargs):
print(">>Performing voice activity detection using Silero...")
super().__init__(kwargs['vad_onset'])
self.vad_onset = kwargs['vad_onset']
self.chunk_size = kwargs['chunk_size']
self.vad_pipeline, vad_utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',
model='silero_vad',
force_reload=False,
onnx=False,
trust_repo=True)
(self.get_speech_timestamps, _, self.read_audio, _, _) = vad_utils
def __call__(self, audio: AudioFile, **kwargs):
"""use silero to get segments of speech"""
# Only accept 16000 Hz for now.
# Note: Silero models support both 8000 and 16000 Hz. Although other values are not directly supported,
# multiples of 16000 (e.g. 32000 or 48000) are cast to 16000 inside of the JIT model!
sample_rate = audio["sample_rate"]
if sample_rate != 16000:
raise ValueError("Only 16000Hz sample rate is allowed")
timestamps = self.get_speech_timestamps(audio["waveform"],
model=self.vad_pipeline,
sampling_rate=sample_rate,
max_speech_duration_s=self.chunk_size,
threshold=self.vad_onset
# min_silence_duration_ms = self.min_duration_off/1000
# min_speech_duration_ms = self.min_duration_on/1000
# ...
# See silero documentation for full option list
)
return [SegmentX(i['start'] / sample_rate, i['end'] / sample_rate, "UNKNOWN") for i in timestamps]
@staticmethod
def preprocess_audio(audio):
return audio
@staticmethod
def merge_chunks(segments,
chunk_size,
onset: float = 0.5,
offset: Optional[float] = None,
):
assert chunk_size > 0
return Vad.merge_chunks(segments, chunk_size, onset, offset)

74
whisperx/vads/vad.py Normal file
View File

@ -0,0 +1,74 @@
from typing import Optional
import pandas as pd
from pyannote.core import Annotation, Segment
class Vad:
def __init__(self, vad_onset):
if not (0 < vad_onset < 1):
raise ValueError(
"vad_onset is a decimal value between 0 and 1."
)
@staticmethod
def preprocess_audio(audio):
pass
# keep merge_chunks as static so it can be also used by manually assigned vad_model (see 'load_model')
@staticmethod
def merge_chunks(segments,
chunk_size,
onset: float,
offset: Optional[float]):
"""
Merge operation described in paper
"""
curr_end = 0
merged_segments = []
seg_idxs = []
speaker_idxs = []
curr_start = segments[0].start
for seg in segments:
if seg.end - curr_start > chunk_size and curr_end - curr_start > 0:
merged_segments.append({
"start": curr_start,
"end": curr_end,
"segments": seg_idxs,
})
curr_start = seg.start
seg_idxs = []
speaker_idxs = []
curr_end = seg.end
seg_idxs.append((seg.start, seg.end))
speaker_idxs.append(seg.speaker)
# add final
merged_segments.append({
"start": curr_start,
"end": curr_end,
"segments": seg_idxs,
})
return merged_segments
# Unused function
@staticmethod
def merge_vad(vad_arr, pad_onset=0.0, pad_offset=0.0, min_duration_off=0.0, min_duration_on=0.0):
active = Annotation()
for k, vad_t in enumerate(vad_arr):
region = Segment(vad_t[0] - pad_onset, vad_t[1] + pad_offset)
active[region, k] = 1
if pad_offset > 0.0 or pad_onset > 0.0 or min_duration_off > 0.0:
active = active.support(collar=min_duration_off)
# remove tracks shorter than min_duration_on
if min_duration_on > 0:
for segment, track in list(active.itertracks()):
if segment.duration < min_duration_on:
del active[segment, track]
active = active.for_json()
active_segs = pd.DataFrame([x['segment'] for x in active['content']])
return active_segs