mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
67 lines
2.9 KiB
Python
67 lines
2.9 KiB
Python
from io import IOBase
|
|
from pathlib import Path
|
|
from typing import Mapping, Text
|
|
from typing import Optional
|
|
from typing import Union
|
|
|
|
import torch
|
|
|
|
from whisperx.diarize import Segment as SegmentX
|
|
from whisperx.vads.vad import Vad
|
|
|
|
AudioFile = Union[Text, Path, IOBase, Mapping]
|
|
|
|
|
|
class Silero(Vad):
|
|
# check again default values
|
|
def __init__(self, **kwargs):
|
|
print(">>Performing voice activity detection using Silero...")
|
|
super().__init__(kwargs['vad_onset'])
|
|
|
|
self.vad_onset = kwargs['vad_onset']
|
|
self.chunk_size = kwargs['chunk_size']
|
|
self.vad_pipeline, vad_utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',
|
|
model='silero_vad',
|
|
force_reload=False,
|
|
onnx=False,
|
|
trust_repo=True)
|
|
(self.get_speech_timestamps, _, self.read_audio, _, _) = vad_utils
|
|
|
|
def __call__(self, audio: AudioFile, **kwargs):
|
|
"""use silero to get segments of speech"""
|
|
# Only accept 16000 Hz for now.
|
|
# Note: Silero models support both 8000 and 16000 Hz. Although other values are not directly supported,
|
|
# multiples of 16000 (e.g. 32000 or 48000) are cast to 16000 inside of the JIT model!
|
|
sample_rate = audio["sample_rate"]
|
|
if sample_rate != 16000:
|
|
raise ValueError("Only 16000Hz sample rate is allowed")
|
|
|
|
timestamps = self.get_speech_timestamps(audio["waveform"],
|
|
model=self.vad_pipeline,
|
|
sampling_rate=sample_rate,
|
|
max_speech_duration_s=self.chunk_size,
|
|
threshold=self.vad_onset
|
|
# min_silence_duration_ms = self.min_duration_off/1000
|
|
# min_speech_duration_ms = self.min_duration_on/1000
|
|
# ...
|
|
# See silero documentation for full option list
|
|
)
|
|
return [SegmentX(i['start'] / sample_rate, i['end'] / sample_rate, "UNKNOWN") for i in timestamps]
|
|
|
|
@staticmethod
|
|
def preprocess_audio(audio):
|
|
return audio
|
|
|
|
@staticmethod
|
|
def merge_chunks(segments_list,
|
|
chunk_size,
|
|
onset: float = 0.5,
|
|
offset: Optional[float] = None,
|
|
):
|
|
assert chunk_size > 0
|
|
if len(segments_list) == 0:
|
|
print("No active speech found in audio")
|
|
return []
|
|
assert segments_list, "segments_list is empty."
|
|
return Vad.merge_chunks(segments_list, chunk_size, onset, offset)
|