mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
Accept alternative VAD methods. Extend to use Silero VAD.
This commit is contained in:
@ -1,6 +1,5 @@
|
||||
import os
|
||||
import warnings
|
||||
from typing import List, NamedTuple, Optional, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import ctranslate2
|
||||
import faster_whisper
|
||||
@ -12,9 +11,8 @@ from transformers import Pipeline
|
||||
from transformers.pipelines.pt_utils import PipelineIterator
|
||||
|
||||
from .audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram
|
||||
import whisperx.vads
|
||||
from .types import SingleSegment, TranscriptionResult
|
||||
from .vad import VoiceActivitySegmentation, load_vad_model, merge_chunks
|
||||
|
||||
|
||||
def find_numeral_symbol_tokens(tokenizer):
|
||||
numeral_symbol_tokens = []
|
||||
@ -105,7 +103,7 @@ class FasterWhisperPipeline(Pipeline):
|
||||
def __init__(
|
||||
self,
|
||||
model: WhisperModel,
|
||||
vad: VoiceActivitySegmentation,
|
||||
vad,
|
||||
vad_params: dict,
|
||||
options: TranscriptionOptions,
|
||||
tokenizer: Optional[Tokenizer] = None,
|
||||
@ -207,7 +205,16 @@ class FasterWhisperPipeline(Pipeline):
|
||||
# print(f2-f1)
|
||||
yield {'inputs': audio[f1:f2]}
|
||||
|
||||
vad_segments = self.vad_model({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": SAMPLE_RATE})
|
||||
# Pre-process audio and merge chunks as defined by the respective VAD child class
|
||||
# In case vad_model is manually assigned (see 'load_model') follow the functionality of pyannote toolkit
|
||||
if issubclass(type(self.vad_model), whisperx.vads.Vad):
|
||||
waveform = self.vad_model.preprocess_audio(audio)
|
||||
merge_chunks = self.vad_model.merge_chunks
|
||||
else:
|
||||
waveform = whisperx.vads.Pyannote.preprocess_audio(audio)
|
||||
merge_chunks = whisperx.vads.Pyannote.merge_chunks
|
||||
|
||||
vad_segments = self.vad_model({"waveform": waveform, "sample_rate": SAMPLE_RATE})
|
||||
vad_segments = merge_chunks(
|
||||
vad_segments,
|
||||
chunk_size,
|
||||
@ -295,7 +302,8 @@ def load_model(
|
||||
compute_type="float16",
|
||||
asr_options: Optional[dict] = None,
|
||||
language: Optional[str] = None,
|
||||
vad_model: Optional[VoiceActivitySegmentation] = None,
|
||||
vad_model = None,
|
||||
vad_method = None,
|
||||
vad_options: Optional[dict] = None,
|
||||
model: Optional[WhisperModel] = None,
|
||||
task="transcribe",
|
||||
@ -308,6 +316,7 @@ def load_model(
|
||||
whisper_arch - The name of the Whisper model to load.
|
||||
device - The device to load the model on.
|
||||
compute_type - The compute type to use for the model.
|
||||
vad_method: str - The vad method to use. vad_model has higher priority if is not None.
|
||||
options - A dictionary of options to use for the model.
|
||||
language - The language of the model. (use English for now)
|
||||
model - The WhisperModel instance to use.
|
||||
@ -373,6 +382,7 @@ def load_model(
|
||||
default_asr_options = TranscriptionOptions(**default_asr_options)
|
||||
|
||||
default_vad_options = {
|
||||
"chunk_size": 30, # needed by silero since binarization happens before merge_chunks
|
||||
"vad_onset": 0.500,
|
||||
"vad_offset": 0.363
|
||||
}
|
||||
@ -380,10 +390,16 @@ def load_model(
|
||||
if vad_options is not None:
|
||||
default_vad_options.update(vad_options)
|
||||
|
||||
# Note: manually assigned vad_model has higher priority than vad_method!
|
||||
if vad_model is not None:
|
||||
print("Use manually assigned vad_model. vad_method is ignored.")
|
||||
vad_model = vad_model
|
||||
else:
|
||||
vad_model = load_vad_model(torch.device(device), use_auth_token=None, **default_vad_options)
|
||||
match vad_method:
|
||||
case "silero":
|
||||
vad_model = whisperx.vads.Silero(**default_vad_options)
|
||||
case "pyannote" | _:
|
||||
vad_model = whisperx.vads.Pyannote(torch.device(device), use_auth_token=None, **default_vad_options)
|
||||
|
||||
return FasterWhisperPipeline(
|
||||
model=model,
|
||||
@ -393,4 +409,4 @@ def load_model(
|
||||
language=language,
|
||||
suppress_numerals=suppress_numerals,
|
||||
vad_params=default_vad_options,
|
||||
)
|
||||
)
|
||||
|
Reference in New Issue
Block a user