mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
Fix: Allow vad options to be configurable by correctly passing down to FasterWhisperPipeline.
This commit is contained in:
@ -111,6 +111,7 @@ def load_model(whisper_arch,
|
|||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
language=language,
|
language=language,
|
||||||
suppress_numerals=suppress_numerals,
|
suppress_numerals=suppress_numerals,
|
||||||
|
vad_params=default_vad_options,
|
||||||
)
|
)
|
||||||
|
|
||||||
class WhisperModel(faster_whisper.WhisperModel):
|
class WhisperModel(faster_whisper.WhisperModel):
|
||||||
@ -186,6 +187,7 @@ class FasterWhisperPipeline(Pipeline):
|
|||||||
self,
|
self,
|
||||||
model,
|
model,
|
||||||
vad,
|
vad,
|
||||||
|
vad_params: dict,
|
||||||
options : NamedTuple,
|
options : NamedTuple,
|
||||||
tokenizer=None,
|
tokenizer=None,
|
||||||
device: Union[int, str, "torch.device"] = -1,
|
device: Union[int, str, "torch.device"] = -1,
|
||||||
@ -218,6 +220,7 @@ class FasterWhisperPipeline(Pipeline):
|
|||||||
|
|
||||||
super(Pipeline, self).__init__()
|
super(Pipeline, self).__init__()
|
||||||
self.vad_model = vad
|
self.vad_model = vad
|
||||||
|
self._vad_params = vad_params
|
||||||
|
|
||||||
def _sanitize_parameters(self, **kwargs):
|
def _sanitize_parameters(self, **kwargs):
|
||||||
preprocess_kwargs = {}
|
preprocess_kwargs = {}
|
||||||
@ -266,7 +269,12 @@ class FasterWhisperPipeline(Pipeline):
|
|||||||
yield {'inputs': audio[f1:f2]}
|
yield {'inputs': audio[f1:f2]}
|
||||||
|
|
||||||
vad_segments = self.vad_model({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": SAMPLE_RATE})
|
vad_segments = self.vad_model({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": SAMPLE_RATE})
|
||||||
vad_segments = merge_chunks(vad_segments, chunk_size)
|
vad_segments = merge_chunks(
|
||||||
|
vad_segments,
|
||||||
|
chunk_size,
|
||||||
|
onset=self._vad_params["vad_onset"],
|
||||||
|
offset=self._vad_params["vad_offset"],
|
||||||
|
)
|
||||||
if self.tokenizer is None:
|
if self.tokenizer is None:
|
||||||
language = language or self.detect_language(audio)
|
language = language or self.detect_language(audio)
|
||||||
task = task or "transcribe"
|
task = task or "transcribe"
|
||||||
|
@ -260,7 +260,12 @@ def merge_vad(vad_arr, pad_onset=0.0, pad_offset=0.0, min_duration_off=0.0, min_
|
|||||||
active_segs = pd.DataFrame([x['segment'] for x in active['content']])
|
active_segs = pd.DataFrame([x['segment'] for x in active['content']])
|
||||||
return active_segs
|
return active_segs
|
||||||
|
|
||||||
def merge_chunks(segments, chunk_size):
|
def merge_chunks(
|
||||||
|
segments,
|
||||||
|
chunk_size,
|
||||||
|
onset: float = 0.5,
|
||||||
|
offset: Optional[float] = None,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Merge operation described in paper
|
Merge operation described in paper
|
||||||
"""
|
"""
|
||||||
@ -270,7 +275,7 @@ def merge_chunks(segments, chunk_size):
|
|||||||
speaker_idxs = []
|
speaker_idxs = []
|
||||||
|
|
||||||
assert chunk_size > 0
|
assert chunk_size > 0
|
||||||
binarize = Binarize(max_duration=chunk_size)
|
binarize = Binarize(max_duration=chunk_size, onset=onset, offset=offset)
|
||||||
segments = binarize(segments)
|
segments = binarize(segments)
|
||||||
segments_list = []
|
segments_list = []
|
||||||
for speech_turn in segments.get_timeline():
|
for speech_turn in segments.get_timeline():
|
||||||
|
Reference in New Issue
Block a user