diff --git a/whisperx/asr.py b/whisperx/asr.py index 43976f2..52e7972 100644 --- a/whisperx/asr.py +++ b/whisperx/asr.py @@ -13,8 +13,8 @@ from transformers import Pipeline from transformers.pipelines.pt_utils import PipelineIterator from .audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram -import whisperx.vads from .types import SingleSegment, TranscriptionResult +from .vads import Vad, Silero, Pyannote def find_numeral_symbol_tokens(tokenizer): numeral_symbol_tokens = [] @@ -209,12 +209,12 @@ class FasterWhisperPipeline(Pipeline): # Pre-process audio and merge chunks as defined by the respective VAD child class # In case vad_model is manually assigned (see 'load_model') follow the functionality of pyannote toolkit - if issubclass(type(self.vad_model), whisperx.vads.Vad): + if issubclass(type(self.vad_model), Vad): waveform = self.vad_model.preprocess_audio(audio) merge_chunks = self.vad_model.merge_chunks else: - waveform = whisperx.vads.Pyannote.preprocess_audio(audio) - merge_chunks = whisperx.vads.Pyannote.merge_chunks + waveform = Pyannote.preprocess_audio(audio) + merge_chunks = Pyannote.merge_chunks vad_segments = self.vad_model({"waveform": waveform, "sample_rate": SAMPLE_RATE}) vad_segments = merge_chunks( @@ -398,9 +398,9 @@ def load_model( vad_model = vad_model else: if vad_method == "silero": - vad_model = whisperx.vads.Silero(**default_vad_options) + vad_model = Silero(**default_vad_options) elif vad_method == "pyannote": - vad_model = whisperx.vads.Pyannote(torch.device(device), use_auth_token=None, **default_vad_options) + vad_model = Pyannote(torch.device(device), use_auth_token=None, **default_vad_options) else: raise ValueError(f"Invalid vad_method: {vad_method}")