mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
Accept alternative VAD methods. Extend to use Silero VAD.
This commit is contained in:
@ -278,7 +278,7 @@ Bug finding and pull requests are also highly appreciated to keep this project g
|
|||||||
|
|
||||||
* [ ] Add benchmarking code (TEDLIUM for spd/WER & word segmentation)
|
* [ ] Add benchmarking code (TEDLIUM for spd/WER & word segmentation)
|
||||||
|
|
||||||
* [ ] Allow silero-vad as alternative VAD option
|
* [x] Allow silero-vad as alternative VAD option
|
||||||
|
|
||||||
* [ ] Improve diarization (word level). *Harder than first thought...*
|
* [ ] Improve diarization (word level). *Harder than first thought...*
|
||||||
|
|
||||||
@ -300,7 +300,9 @@ Borrows important alignment code from [PyTorch tutorial on forced alignment](htt
|
|||||||
And uses the wonderful pyannote VAD / Diarization https://github.com/pyannote/pyannote-audio
|
And uses the wonderful pyannote VAD / Diarization https://github.com/pyannote/pyannote-audio
|
||||||
|
|
||||||
|
|
||||||
Valuable VAD & Diarization Models from [pyannote audio](https://github.com/pyannote/pyannote-audio)
|
Valuable VAD & Diarization Models from:
|
||||||
|
- [pyannote audio][https://github.com/pyannote/pyannote-audio]
|
||||||
|
- [silero vad][https://github.com/snakers4/silero-vad]
|
||||||
|
|
||||||
Great backend from [faster-whisper](https://github.com/guillaumekln/faster-whisper) and [CTranslate2](https://github.com/OpenNMT/CTranslate2)
|
Great backend from [faster-whisper](https://github.com/guillaumekln/faster-whisper) and [CTranslate2](https://github.com/OpenNMT/CTranslate2)
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from .transcribe import load_model
|
|
||||||
from .alignment import load_align_model, align
|
from .alignment import load_align_model, align
|
||||||
from .audio import load_audio
|
from .audio import load_audio
|
||||||
from .diarize import assign_word_speakers, DiarizationPipeline
|
from .diarize import assign_word_speakers, DiarizationPipeline
|
||||||
|
from .asr import load_model
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
import warnings
|
from typing import List, Optional, Union
|
||||||
from typing import List, NamedTuple, Optional, Union
|
|
||||||
|
|
||||||
import ctranslate2
|
import ctranslate2
|
||||||
import faster_whisper
|
import faster_whisper
|
||||||
@ -12,9 +11,8 @@ from transformers import Pipeline
|
|||||||
from transformers.pipelines.pt_utils import PipelineIterator
|
from transformers.pipelines.pt_utils import PipelineIterator
|
||||||
|
|
||||||
from .audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram
|
from .audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram
|
||||||
|
import whisperx.vads
|
||||||
from .types import SingleSegment, TranscriptionResult
|
from .types import SingleSegment, TranscriptionResult
|
||||||
from .vad import VoiceActivitySegmentation, load_vad_model, merge_chunks
|
|
||||||
|
|
||||||
|
|
||||||
def find_numeral_symbol_tokens(tokenizer):
|
def find_numeral_symbol_tokens(tokenizer):
|
||||||
numeral_symbol_tokens = []
|
numeral_symbol_tokens = []
|
||||||
@ -105,7 +103,7 @@ class FasterWhisperPipeline(Pipeline):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: WhisperModel,
|
model: WhisperModel,
|
||||||
vad: VoiceActivitySegmentation,
|
vad,
|
||||||
vad_params: dict,
|
vad_params: dict,
|
||||||
options: TranscriptionOptions,
|
options: TranscriptionOptions,
|
||||||
tokenizer: Optional[Tokenizer] = None,
|
tokenizer: Optional[Tokenizer] = None,
|
||||||
@ -207,7 +205,16 @@ class FasterWhisperPipeline(Pipeline):
|
|||||||
# print(f2-f1)
|
# print(f2-f1)
|
||||||
yield {'inputs': audio[f1:f2]}
|
yield {'inputs': audio[f1:f2]}
|
||||||
|
|
||||||
vad_segments = self.vad_model({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": SAMPLE_RATE})
|
# Pre-process audio and merge chunks as defined by the respective VAD child class
|
||||||
|
# In case vad_model is manually assigned (see 'load_model') follow the functionality of pyannote toolkit
|
||||||
|
if issubclass(type(self.vad_model), whisperx.vads.Vad):
|
||||||
|
waveform = self.vad_model.preprocess_audio(audio)
|
||||||
|
merge_chunks = self.vad_model.merge_chunks
|
||||||
|
else:
|
||||||
|
waveform = whisperx.vads.Pyannote.preprocess_audio(audio)
|
||||||
|
merge_chunks = whisperx.vads.Pyannote.merge_chunks
|
||||||
|
|
||||||
|
vad_segments = self.vad_model({"waveform": waveform, "sample_rate": SAMPLE_RATE})
|
||||||
vad_segments = merge_chunks(
|
vad_segments = merge_chunks(
|
||||||
vad_segments,
|
vad_segments,
|
||||||
chunk_size,
|
chunk_size,
|
||||||
@ -295,7 +302,8 @@ def load_model(
|
|||||||
compute_type="float16",
|
compute_type="float16",
|
||||||
asr_options: Optional[dict] = None,
|
asr_options: Optional[dict] = None,
|
||||||
language: Optional[str] = None,
|
language: Optional[str] = None,
|
||||||
vad_model: Optional[VoiceActivitySegmentation] = None,
|
vad_model = None,
|
||||||
|
vad_method = None,
|
||||||
vad_options: Optional[dict] = None,
|
vad_options: Optional[dict] = None,
|
||||||
model: Optional[WhisperModel] = None,
|
model: Optional[WhisperModel] = None,
|
||||||
task="transcribe",
|
task="transcribe",
|
||||||
@ -308,6 +316,7 @@ def load_model(
|
|||||||
whisper_arch - The name of the Whisper model to load.
|
whisper_arch - The name of the Whisper model to load.
|
||||||
device - The device to load the model on.
|
device - The device to load the model on.
|
||||||
compute_type - The compute type to use for the model.
|
compute_type - The compute type to use for the model.
|
||||||
|
vad_method: str - The vad method to use. vad_model has higher priority if is not None.
|
||||||
options - A dictionary of options to use for the model.
|
options - A dictionary of options to use for the model.
|
||||||
language - The language of the model. (use English for now)
|
language - The language of the model. (use English for now)
|
||||||
model - The WhisperModel instance to use.
|
model - The WhisperModel instance to use.
|
||||||
@ -373,6 +382,7 @@ def load_model(
|
|||||||
default_asr_options = TranscriptionOptions(**default_asr_options)
|
default_asr_options = TranscriptionOptions(**default_asr_options)
|
||||||
|
|
||||||
default_vad_options = {
|
default_vad_options = {
|
||||||
|
"chunk_size": 30, # needed by silero since binarization happens before merge_chunks
|
||||||
"vad_onset": 0.500,
|
"vad_onset": 0.500,
|
||||||
"vad_offset": 0.363
|
"vad_offset": 0.363
|
||||||
}
|
}
|
||||||
@ -380,10 +390,16 @@ def load_model(
|
|||||||
if vad_options is not None:
|
if vad_options is not None:
|
||||||
default_vad_options.update(vad_options)
|
default_vad_options.update(vad_options)
|
||||||
|
|
||||||
|
# Note: manually assigned vad_model has higher priority than vad_method!
|
||||||
if vad_model is not None:
|
if vad_model is not None:
|
||||||
|
print("Use manually assigned vad_model. vad_method is ignored.")
|
||||||
vad_model = vad_model
|
vad_model = vad_model
|
||||||
else:
|
else:
|
||||||
vad_model = load_vad_model(torch.device(device), use_auth_token=None, **default_vad_options)
|
match vad_method:
|
||||||
|
case "silero":
|
||||||
|
vad_model = whisperx.vads.Silero(**default_vad_options)
|
||||||
|
case "pyannote" | _:
|
||||||
|
vad_model = whisperx.vads.Pyannote(torch.device(device), use_auth_token=None, **default_vad_options)
|
||||||
|
|
||||||
return FasterWhisperPipeline(
|
return FasterWhisperPipeline(
|
||||||
model=model,
|
model=model,
|
||||||
@ -393,4 +409,4 @@ def load_model(
|
|||||||
language=language,
|
language=language,
|
||||||
suppress_numerals=suppress_numerals,
|
suppress_numerals=suppress_numerals,
|
||||||
vad_params=default_vad_options,
|
vad_params=default_vad_options,
|
||||||
)
|
)
|
||||||
|
@ -46,6 +46,7 @@ def cli():
|
|||||||
parser.add_argument("--return_char_alignments", action='store_true', help="Return character-level alignments in the output json file")
|
parser.add_argument("--return_char_alignments", action='store_true', help="Return character-level alignments in the output json file")
|
||||||
|
|
||||||
# vad params
|
# vad params
|
||||||
|
parser.add_argument("--vad_method", type=str, default="pyannote", choices=["pyannote", "silero"], help="VAD method to be used")
|
||||||
parser.add_argument("--vad_onset", type=float, default=0.500, help="Onset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected")
|
parser.add_argument("--vad_onset", type=float, default=0.500, help="Onset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected")
|
||||||
parser.add_argument("--vad_offset", type=float, default=0.363, help="Offset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected.")
|
parser.add_argument("--vad_offset", type=float, default=0.363, help="Offset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected.")
|
||||||
parser.add_argument("--chunk_size", type=int, default=30, help="Chunk size for merging VAD segments. Default is 30, reduce this if the chunk is too long.")
|
parser.add_argument("--chunk_size", type=int, default=30, help="Chunk size for merging VAD segments. Default is 30, reduce this if the chunk is too long.")
|
||||||
@ -110,6 +111,7 @@ def cli():
|
|||||||
return_char_alignments: bool = args.pop("return_char_alignments")
|
return_char_alignments: bool = args.pop("return_char_alignments")
|
||||||
|
|
||||||
hf_token: str = args.pop("hf_token")
|
hf_token: str = args.pop("hf_token")
|
||||||
|
vad_method: str = args.pop("vad_method")
|
||||||
vad_onset: float = args.pop("vad_onset")
|
vad_onset: float = args.pop("vad_onset")
|
||||||
vad_offset: float = args.pop("vad_offset")
|
vad_offset: float = args.pop("vad_offset")
|
||||||
|
|
||||||
@ -175,7 +177,7 @@ def cli():
|
|||||||
results = []
|
results = []
|
||||||
tmp_results = []
|
tmp_results = []
|
||||||
# model = load_model(model_name, device=device, download_root=model_dir)
|
# model = load_model(model_name, device=device, download_root=model_dir)
|
||||||
model = load_model(model_name, device=device, device_index=device_index, download_root=model_dir, compute_type=compute_type, language=args['language'], asr_options=asr_options, vad_options={"vad_onset": vad_onset, "vad_offset": vad_offset}, task=task, threads=faster_whisper_threads)
|
model = load_model(model_name, device=device, device_index=device_index, download_root=model_dir, compute_type=compute_type, language=args['language'], asr_options=asr_options, vad_method=vad_method, vad_options={"chunk_size":chunk_size, "vad_onset": vad_onset, "vad_offset": vad_offset}, task=task, threads=faster_whisper_threads)
|
||||||
|
|
||||||
for audio_path in args.pop("audio"):
|
for audio_path in args.pop("audio"):
|
||||||
audio = load_audio(audio_path)
|
audio = load_audio(audio_path)
|
||||||
|
3
whisperx/vads/__init__.py
Normal file
3
whisperx/vads/__init__.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
from whisperx.vads.pyannote import Pyannote
|
||||||
|
from whisperx.vads.silero import Silero
|
||||||
|
from whisperx.vads.vad import Vad
|
@ -1,19 +1,21 @@
|
|||||||
import hashlib
|
import hashlib
|
||||||
import os
|
import os
|
||||||
import urllib
|
import urllib
|
||||||
from typing import Callable, Optional, Text, Union
|
from typing import Callable, Text, Union
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
|
||||||
import torch
|
import torch
|
||||||
from pyannote.audio import Model
|
from pyannote.audio import Model
|
||||||
from pyannote.audio.core.io import AudioFile
|
from pyannote.audio.core.io import AudioFile
|
||||||
from pyannote.audio.pipelines import VoiceActivityDetection
|
from pyannote.audio.pipelines import VoiceActivityDetection
|
||||||
from pyannote.audio.pipelines.utils import PipelineModel
|
from pyannote.audio.pipelines.utils import PipelineModel
|
||||||
from pyannote.core import Annotation, Segment, SlidingWindowFeature
|
from pyannote.core import Annotation, SlidingWindowFeature
|
||||||
|
from pyannote.core import Segment
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from .diarize import Segment as SegmentX
|
from whisperx.diarize import Segment as SegmentX
|
||||||
|
from whisperx.vads.vad import Vad
|
||||||
|
|
||||||
# deprecated
|
# deprecated
|
||||||
VAD_SEGMENTATION_URL = "https://whisperx.s3.eu-west-2.amazonaws.com/model_weights/segmentation/0b5b3216d60a2d32fc086b47ea8c67589aaeb26b7e07fcbe620d6d0b83e209ea/pytorch_model.bin"
|
VAD_SEGMENTATION_URL = "https://whisperx.s3.eu-west-2.amazonaws.com/model_weights/segmentation/0b5b3216d60a2d32fc086b47ea8c67589aaeb26b7e07fcbe620d6d0b83e209ea/pytorch_model.bin"
|
||||||
@ -30,11 +32,11 @@ def load_vad_model(device, vad_onset=0.500, vad_offset=0.363, use_auth_token=Non
|
|||||||
model_fp = os.path.abspath(model_fp) # Ensure the path is absolute
|
model_fp = os.path.abspath(model_fp) # Ensure the path is absolute
|
||||||
else:
|
else:
|
||||||
model_fp = os.path.abspath(model_fp) # Ensure any provided path is absolute
|
model_fp = os.path.abspath(model_fp) # Ensure any provided path is absolute
|
||||||
|
|
||||||
# Check if the resolved model file exists
|
# Check if the resolved model file exists
|
||||||
if not os.path.exists(model_fp):
|
if not os.path.exists(model_fp):
|
||||||
raise FileNotFoundError(f"Model file not found at {model_fp}")
|
raise FileNotFoundError(f"Model file not found at {model_fp}")
|
||||||
|
|
||||||
if os.path.exists(model_fp) and not os.path.isfile(model_fp):
|
if os.path.exists(model_fp) and not os.path.isfile(model_fp):
|
||||||
raise RuntimeError(f"{model_fp} exists and is not a regular file")
|
raise RuntimeError(f"{model_fp} exists and is not a regular file")
|
||||||
|
|
||||||
@ -45,7 +47,7 @@ def load_vad_model(device, vad_onset=0.500, vad_offset=0.363, use_auth_token=Non
|
|||||||
)
|
)
|
||||||
|
|
||||||
vad_model = Model.from_pretrained(model_fp, use_auth_token=use_auth_token)
|
vad_model = Model.from_pretrained(model_fp, use_auth_token=use_auth_token)
|
||||||
hyperparameters = {"onset": vad_onset,
|
hyperparameters = {"onset": vad_onset,
|
||||||
"offset": vad_offset,
|
"offset": vad_offset,
|
||||||
"min_duration_on": 0.1,
|
"min_duration_on": 0.1,
|
||||||
"min_duration_off": 0.1}
|
"min_duration_off": 0.1}
|
||||||
@ -81,21 +83,21 @@ class Binarize:
|
|||||||
Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of
|
Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of
|
||||||
RNN-based Voice Activity Detection", InterSpeech 2015.
|
RNN-based Voice Activity Detection", InterSpeech 2015.
|
||||||
|
|
||||||
Modified by Max Bain to include WhisperX's min-cut operation
|
Modified by Max Bain to include WhisperX's min-cut operation
|
||||||
https://arxiv.org/abs/2303.00747
|
https://arxiv.org/abs/2303.00747
|
||||||
|
|
||||||
Pyannote-audio
|
Pyannote-audio
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
onset: float = 0.5,
|
onset: float = 0.5,
|
||||||
offset: Optional[float] = None,
|
offset: Optional[float] = None,
|
||||||
min_duration_on: float = 0.0,
|
min_duration_on: float = 0.0,
|
||||||
min_duration_off: float = 0.0,
|
min_duration_off: float = 0.0,
|
||||||
pad_onset: float = 0.0,
|
pad_onset: float = 0.0,
|
||||||
pad_offset: float = 0.0,
|
pad_offset: float = 0.0,
|
||||||
max_duration: float = float('inf')
|
max_duration: float = float('inf')
|
||||||
):
|
):
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -141,7 +143,7 @@ class Binarize:
|
|||||||
t = start
|
t = start
|
||||||
for t, y in zip(timestamps[1:], k_scores[1:]):
|
for t, y in zip(timestamps[1:], k_scores[1:]):
|
||||||
# currently active
|
# currently active
|
||||||
if is_active:
|
if is_active:
|
||||||
curr_duration = t - start
|
curr_duration = t - start
|
||||||
if curr_duration > self.max_duration:
|
if curr_duration > self.max_duration:
|
||||||
search_after = len(curr_scores) // 2
|
search_after = len(curr_scores) // 2
|
||||||
@ -151,8 +153,8 @@ class Binarize:
|
|||||||
region = Segment(start - self.pad_onset, min_score_t + self.pad_offset)
|
region = Segment(start - self.pad_onset, min_score_t + self.pad_offset)
|
||||||
active[region, k] = label
|
active[region, k] = label
|
||||||
start = curr_timestamps[min_score_div_idx]
|
start = curr_timestamps[min_score_div_idx]
|
||||||
curr_scores = curr_scores[min_score_div_idx+1:]
|
curr_scores = curr_scores[min_score_div_idx + 1:]
|
||||||
curr_timestamps = curr_timestamps[min_score_div_idx+1:]
|
curr_timestamps = curr_timestamps[min_score_div_idx + 1:]
|
||||||
# switching from active to inactive
|
# switching from active to inactive
|
||||||
elif y < self.offset:
|
elif y < self.offset:
|
||||||
region = Segment(start - self.pad_onset, t + self.pad_offset)
|
region = Segment(start - self.pad_onset, t + self.pad_offset)
|
||||||
@ -193,11 +195,11 @@ class Binarize:
|
|||||||
|
|
||||||
class VoiceActivitySegmentation(VoiceActivityDetection):
|
class VoiceActivitySegmentation(VoiceActivityDetection):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
segmentation: PipelineModel = "pyannote/segmentation",
|
segmentation: PipelineModel = "pyannote/segmentation",
|
||||||
fscore: bool = False,
|
fscore: bool = False,
|
||||||
use_auth_token: Union[Text, None] = None,
|
use_auth_token: Union[Text, None] = None,
|
||||||
**inference_kwargs,
|
**inference_kwargs,
|
||||||
):
|
):
|
||||||
|
|
||||||
super().__init__(segmentation=segmentation, fscore=fscore, use_auth_token=use_auth_token, **inference_kwargs)
|
super().__init__(segmentation=segmentation, fscore=fscore, use_auth_token=use_auth_token, **inference_kwargs)
|
||||||
@ -236,72 +238,72 @@ class VoiceActivitySegmentation(VoiceActivityDetection):
|
|||||||
return segmentations
|
return segmentations
|
||||||
|
|
||||||
|
|
||||||
def merge_vad(vad_arr, pad_onset=0.0, pad_offset=0.0, min_duration_off=0.0, min_duration_on=0.0):
|
class Pyannote(Vad):
|
||||||
|
|
||||||
active = Annotation()
|
def __init__(self, device, use_auth_token=None, model_fp=None, **kwargs):
|
||||||
for k, vad_t in enumerate(vad_arr):
|
print(">>Performing voice activity detection using Pyannote...")
|
||||||
region = Segment(vad_t[0] - pad_onset, vad_t[1] + pad_offset)
|
super().__init__(kwargs['vad_onset'])
|
||||||
active[region, k] = 1
|
|
||||||
|
|
||||||
|
model_dir = torch.hub._get_torch_home()
|
||||||
|
os.makedirs(model_dir, exist_ok=True)
|
||||||
|
if model_fp is None:
|
||||||
|
model_fp = os.path.join(model_dir, "whisperx-vad-segmentation.bin")
|
||||||
|
if os.path.exists(model_fp) and not os.path.isfile(model_fp):
|
||||||
|
raise RuntimeError(f"{model_fp} exists and is not a regular file")
|
||||||
|
|
||||||
if pad_offset > 0.0 or pad_onset > 0.0 or min_duration_off > 0.0:
|
if not os.path.isfile(model_fp):
|
||||||
active = active.support(collar=min_duration_off)
|
with urllib.request.urlopen(VAD_SEGMENTATION_URL) as source, open(model_fp, "wb") as output:
|
||||||
|
with tqdm(
|
||||||
# remove tracks shorter than min_duration_on
|
total=int(source.info().get("Content-Length")),
|
||||||
if min_duration_on > 0:
|
ncols=80,
|
||||||
for segment, track in list(active.itertracks()):
|
unit="iB",
|
||||||
if segment.duration < min_duration_on:
|
unit_scale=True,
|
||||||
del active[segment, track]
|
unit_divisor=1024,
|
||||||
|
) as loop:
|
||||||
active = active.for_json()
|
while True:
|
||||||
active_segs = pd.DataFrame([x['segment'] for x in active['content']])
|
buffer = source.read(8192)
|
||||||
return active_segs
|
if not buffer:
|
||||||
|
break
|
||||||
|
|
||||||
def merge_chunks(
|
output.write(buffer)
|
||||||
segments,
|
loop.update(len(buffer))
|
||||||
chunk_size,
|
|
||||||
onset: float = 0.5,
|
|
||||||
offset: Optional[float] = None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Merge operation described in paper
|
|
||||||
"""
|
|
||||||
curr_end = 0
|
|
||||||
merged_segments = []
|
|
||||||
seg_idxs = []
|
|
||||||
speaker_idxs = []
|
|
||||||
|
|
||||||
assert chunk_size > 0
|
model_bytes = open(model_fp, "rb").read()
|
||||||
binarize = Binarize(max_duration=chunk_size, onset=onset, offset=offset)
|
if hashlib.sha256(model_bytes).hexdigest() != VAD_SEGMENTATION_URL.split('/')[-2]:
|
||||||
segments = binarize(segments)
|
raise RuntimeError(
|
||||||
segments_list = []
|
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
|
||||||
for speech_turn in segments.get_timeline():
|
)
|
||||||
segments_list.append(SegmentX(speech_turn.start, speech_turn.end, "UNKNOWN"))
|
|
||||||
|
|
||||||
if len(segments_list) == 0:
|
vad_model = Model.from_pretrained(model_fp, use_auth_token=use_auth_token)
|
||||||
print("No active speech found in audio")
|
hyperparameters = {"onset": kwargs['vad_onset'],
|
||||||
return []
|
"offset": kwargs['vad_offset'],
|
||||||
# assert segments_list, "segments_list is empty."
|
"min_duration_on": 0.1,
|
||||||
# Make sur the starting point is the start of the segment.
|
"min_duration_off": 0.1}
|
||||||
curr_start = segments_list[0].start
|
self.vad_pipeline = VoiceActivitySegmentation(segmentation=vad_model, device=torch.device(device))
|
||||||
|
self.vad_pipeline.instantiate(hyperparameters)
|
||||||
|
|
||||||
for seg in segments_list:
|
def __call__(self, audio: AudioFile, **kwargs):
|
||||||
if seg.end - curr_start > chunk_size and curr_end-curr_start > 0:
|
return self.vad_pipeline(audio)
|
||||||
merged_segments.append({
|
|
||||||
"start": curr_start,
|
@staticmethod
|
||||||
"end": curr_end,
|
def preprocess_audio(audio):
|
||||||
"segments": seg_idxs,
|
return torch.from_numpy(audio).unsqueeze(0)
|
||||||
})
|
|
||||||
curr_start = seg.start
|
@staticmethod
|
||||||
seg_idxs = []
|
def merge_chunks(segments,
|
||||||
speaker_idxs = []
|
chunk_size,
|
||||||
curr_end = seg.end
|
onset: float = 0.5,
|
||||||
seg_idxs.append((seg.start, seg.end))
|
offset: Optional[float] = None,
|
||||||
speaker_idxs.append(seg.speaker)
|
):
|
||||||
# add final
|
assert chunk_size > 0
|
||||||
merged_segments.append({
|
binarize = Binarize(max_duration=chunk_size, onset=onset, offset=offset)
|
||||||
"start": curr_start,
|
segments = binarize(segments)
|
||||||
"end": curr_end,
|
segments_list = []
|
||||||
"segments": seg_idxs,
|
for speech_turn in segments.get_timeline():
|
||||||
})
|
segments_list.append(SegmentX(speech_turn.start, speech_turn.end, "UNKNOWN"))
|
||||||
return merged_segments
|
|
||||||
|
if len(segments_list) == 0:
|
||||||
|
print("No active speech found in audio")
|
||||||
|
return []
|
||||||
|
assert segments_list, "segments_list is empty."
|
||||||
|
return Vad.merge_chunks(segments_list, chunk_size, onset, offset)
|
62
whisperx/vads/silero.py
Normal file
62
whisperx/vads/silero.py
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
from io import IOBase
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Mapping, Text
|
||||||
|
from typing import Optional
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from whisperx.diarize import Segment as SegmentX
|
||||||
|
from whisperx.vads.vad import Vad
|
||||||
|
|
||||||
|
AudioFile = Union[Text, Path, IOBase, Mapping]
|
||||||
|
|
||||||
|
|
||||||
|
class Silero(Vad):
|
||||||
|
# check again default values
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
print(">>Performing voice activity detection using Silero...")
|
||||||
|
super().__init__(kwargs['vad_onset'])
|
||||||
|
|
||||||
|
self.vad_onset = kwargs['vad_onset']
|
||||||
|
self.chunk_size = kwargs['chunk_size']
|
||||||
|
self.vad_pipeline, vad_utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',
|
||||||
|
model='silero_vad',
|
||||||
|
force_reload=False,
|
||||||
|
onnx=False,
|
||||||
|
trust_repo=True)
|
||||||
|
(self.get_speech_timestamps, _, self.read_audio, _, _) = vad_utils
|
||||||
|
|
||||||
|
def __call__(self, audio: AudioFile, **kwargs):
|
||||||
|
"""use silero to get segments of speech"""
|
||||||
|
# Only accept 16000 Hz for now.
|
||||||
|
# Note: Silero models support both 8000 and 16000 Hz. Although other values are not directly supported,
|
||||||
|
# multiples of 16000 (e.g. 32000 or 48000) are cast to 16000 inside of the JIT model!
|
||||||
|
sample_rate = audio["sample_rate"]
|
||||||
|
if sample_rate != 16000:
|
||||||
|
raise ValueError("Only 16000Hz sample rate is allowed")
|
||||||
|
|
||||||
|
timestamps = self.get_speech_timestamps(audio["waveform"],
|
||||||
|
model=self.vad_pipeline,
|
||||||
|
sampling_rate=sample_rate,
|
||||||
|
max_speech_duration_s=self.chunk_size,
|
||||||
|
threshold=self.vad_onset
|
||||||
|
# min_silence_duration_ms = self.min_duration_off/1000
|
||||||
|
# min_speech_duration_ms = self.min_duration_on/1000
|
||||||
|
# ...
|
||||||
|
# See silero documentation for full option list
|
||||||
|
)
|
||||||
|
return [SegmentX(i['start'] / sample_rate, i['end'] / sample_rate, "UNKNOWN") for i in timestamps]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def preprocess_audio(audio):
|
||||||
|
return audio
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def merge_chunks(segments,
|
||||||
|
chunk_size,
|
||||||
|
onset: float = 0.5,
|
||||||
|
offset: Optional[float] = None,
|
||||||
|
):
|
||||||
|
assert chunk_size > 0
|
||||||
|
return Vad.merge_chunks(segments, chunk_size, onset, offset)
|
74
whisperx/vads/vad.py
Normal file
74
whisperx/vads/vad.py
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
from pyannote.core import Annotation, Segment
|
||||||
|
|
||||||
|
|
||||||
|
class Vad:
|
||||||
|
def __init__(self, vad_onset):
|
||||||
|
if not (0 < vad_onset < 1):
|
||||||
|
raise ValueError(
|
||||||
|
"vad_onset is a decimal value between 0 and 1."
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def preprocess_audio(audio):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# keep merge_chunks as static so it can be also used by manually assigned vad_model (see 'load_model')
|
||||||
|
@staticmethod
|
||||||
|
def merge_chunks(segments,
|
||||||
|
chunk_size,
|
||||||
|
onset: float,
|
||||||
|
offset: Optional[float]):
|
||||||
|
"""
|
||||||
|
Merge operation described in paper
|
||||||
|
"""
|
||||||
|
curr_end = 0
|
||||||
|
merged_segments = []
|
||||||
|
seg_idxs = []
|
||||||
|
speaker_idxs = []
|
||||||
|
|
||||||
|
curr_start = segments[0].start
|
||||||
|
for seg in segments:
|
||||||
|
if seg.end - curr_start > chunk_size and curr_end - curr_start > 0:
|
||||||
|
merged_segments.append({
|
||||||
|
"start": curr_start,
|
||||||
|
"end": curr_end,
|
||||||
|
"segments": seg_idxs,
|
||||||
|
})
|
||||||
|
curr_start = seg.start
|
||||||
|
seg_idxs = []
|
||||||
|
speaker_idxs = []
|
||||||
|
curr_end = seg.end
|
||||||
|
seg_idxs.append((seg.start, seg.end))
|
||||||
|
speaker_idxs.append(seg.speaker)
|
||||||
|
# add final
|
||||||
|
merged_segments.append({
|
||||||
|
"start": curr_start,
|
||||||
|
"end": curr_end,
|
||||||
|
"segments": seg_idxs,
|
||||||
|
})
|
||||||
|
|
||||||
|
return merged_segments
|
||||||
|
|
||||||
|
# Unused function
|
||||||
|
@staticmethod
|
||||||
|
def merge_vad(vad_arr, pad_onset=0.0, pad_offset=0.0, min_duration_off=0.0, min_duration_on=0.0):
|
||||||
|
active = Annotation()
|
||||||
|
for k, vad_t in enumerate(vad_arr):
|
||||||
|
region = Segment(vad_t[0] - pad_onset, vad_t[1] + pad_offset)
|
||||||
|
active[region, k] = 1
|
||||||
|
|
||||||
|
if pad_offset > 0.0 or pad_onset > 0.0 or min_duration_off > 0.0:
|
||||||
|
active = active.support(collar=min_duration_off)
|
||||||
|
|
||||||
|
# remove tracks shorter than min_duration_on
|
||||||
|
if min_duration_on > 0:
|
||||||
|
for segment, track in list(active.itertracks()):
|
||||||
|
if segment.duration < min_duration_on:
|
||||||
|
del active[segment, track]
|
||||||
|
|
||||||
|
active = active.for_json()
|
||||||
|
active_segs = pd.DataFrame([x['segment'] for x in active['content']])
|
||||||
|
return active_segs
|
Reference in New Issue
Block a user