mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
refactor: remove namespace for consistency
This commit is contained in:
@ -13,8 +13,8 @@ from transformers import Pipeline
|
|||||||
from transformers.pipelines.pt_utils import PipelineIterator
|
from transformers.pipelines.pt_utils import PipelineIterator
|
||||||
|
|
||||||
from .audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram
|
from .audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram
|
||||||
import whisperx.vads
|
|
||||||
from .types import SingleSegment, TranscriptionResult
|
from .types import SingleSegment, TranscriptionResult
|
||||||
|
from .vads import Vad, Silero, Pyannote
|
||||||
|
|
||||||
def find_numeral_symbol_tokens(tokenizer):
|
def find_numeral_symbol_tokens(tokenizer):
|
||||||
numeral_symbol_tokens = []
|
numeral_symbol_tokens = []
|
||||||
@ -209,12 +209,12 @@ class FasterWhisperPipeline(Pipeline):
|
|||||||
|
|
||||||
# Pre-process audio and merge chunks as defined by the respective VAD child class
|
# Pre-process audio and merge chunks as defined by the respective VAD child class
|
||||||
# In case vad_model is manually assigned (see 'load_model') follow the functionality of pyannote toolkit
|
# In case vad_model is manually assigned (see 'load_model') follow the functionality of pyannote toolkit
|
||||||
if issubclass(type(self.vad_model), whisperx.vads.Vad):
|
if issubclass(type(self.vad_model), Vad):
|
||||||
waveform = self.vad_model.preprocess_audio(audio)
|
waveform = self.vad_model.preprocess_audio(audio)
|
||||||
merge_chunks = self.vad_model.merge_chunks
|
merge_chunks = self.vad_model.merge_chunks
|
||||||
else:
|
else:
|
||||||
waveform = whisperx.vads.Pyannote.preprocess_audio(audio)
|
waveform = Pyannote.preprocess_audio(audio)
|
||||||
merge_chunks = whisperx.vads.Pyannote.merge_chunks
|
merge_chunks = Pyannote.merge_chunks
|
||||||
|
|
||||||
vad_segments = self.vad_model({"waveform": waveform, "sample_rate": SAMPLE_RATE})
|
vad_segments = self.vad_model({"waveform": waveform, "sample_rate": SAMPLE_RATE})
|
||||||
vad_segments = merge_chunks(
|
vad_segments = merge_chunks(
|
||||||
@ -398,9 +398,9 @@ def load_model(
|
|||||||
vad_model = vad_model
|
vad_model = vad_model
|
||||||
else:
|
else:
|
||||||
if vad_method == "silero":
|
if vad_method == "silero":
|
||||||
vad_model = whisperx.vads.Silero(**default_vad_options)
|
vad_model = Silero(**default_vad_options)
|
||||||
elif vad_method == "pyannote":
|
elif vad_method == "pyannote":
|
||||||
vad_model = whisperx.vads.Pyannote(torch.device(device), use_auth_token=None, **default_vad_options)
|
vad_model = Pyannote(torch.device(device), use_auth_token=None, **default_vad_options)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid vad_method: {vad_method}")
|
raise ValueError(f"Invalid vad_method: {vad_method}")
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user