From 79801167aca51ea01a735b276dca145d029a6b1a Mon Sep 17 00:00:00 2001 From: Andrew Bettke Date: Thu, 5 Oct 2023 10:06:34 -0400 Subject: [PATCH] Fix: Allow vad options to be configurable by correctly passing down to FasterWhisperPipeline. --- whisperx/asr.py | 10 +++++++++- whisperx/vad.py | 9 +++++++-- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/whisperx/asr.py b/whisperx/asr.py index b0dc824..27de6db 100644 --- a/whisperx/asr.py +++ b/whisperx/asr.py @@ -111,6 +111,7 @@ def load_model(whisper_arch, tokenizer=tokenizer, language=language, suppress_numerals=suppress_numerals, + vad_params=default_vad_options, ) class WhisperModel(faster_whisper.WhisperModel): @@ -186,6 +187,7 @@ class FasterWhisperPipeline(Pipeline): self, model, vad, + vad_params: dict, options : NamedTuple, tokenizer=None, device: Union[int, str, "torch.device"] = -1, @@ -218,6 +220,7 @@ class FasterWhisperPipeline(Pipeline): super(Pipeline, self).__init__() self.vad_model = vad + self._vad_params = vad_params def _sanitize_parameters(self, **kwargs): preprocess_kwargs = {} @@ -266,7 +269,12 @@ class FasterWhisperPipeline(Pipeline): yield {'inputs': audio[f1:f2]} vad_segments = self.vad_model({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": SAMPLE_RATE}) - vad_segments = merge_chunks(vad_segments, chunk_size) + vad_segments = merge_chunks( + vad_segments, + chunk_size, + onset=self._vad_params["vad_onset"], + offset=self._vad_params["vad_offset"], + ) if self.tokenizer is None: language = language or self.detect_language(audio) task = task or "transcribe" diff --git a/whisperx/vad.py b/whisperx/vad.py index 15a9e5e..dac0365 100644 --- a/whisperx/vad.py +++ b/whisperx/vad.py @@ -260,7 +260,12 @@ def merge_vad(vad_arr, pad_onset=0.0, pad_offset=0.0, min_duration_off=0.0, min_ active_segs = pd.DataFrame([x['segment'] for x in active['content']]) return active_segs -def merge_chunks(segments, chunk_size): +def merge_chunks( + segments, + chunk_size, + onset: float = 0.5, + offset: Optional[float] = None, +): """ Merge operation described in paper """ @@ -270,7 +275,7 @@ def merge_chunks(segments, chunk_size): speaker_idxs = [] assert chunk_size > 0 - binarize = Binarize(max_duration=chunk_size) + binarize = Binarize(max_duration=chunk_size, onset=onset, offset=offset) segments = binarize(segments) segments_list = [] for speech_turn in segments.get_timeline():