Merge branch 'main' into main

This commit is contained in:
Max Bain
2025-01-13 10:09:20 +00:00
committed by GitHub
6 changed files with 52 additions and 31 deletions

View File

@ -1,6 +1,5 @@
import os
import warnings
from typing import List, NamedTuple, Optional, Union
from typing import List, Optional, Union
from dataclasses import replace
import ctranslate2
@ -13,8 +12,8 @@ from transformers import Pipeline
from transformers.pipelines.pt_utils import PipelineIterator
from .audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram
import whisperx.vads
from .types import SingleSegment, TranscriptionResult
from .vads import Vad, Silero, Pyannote
def find_numeral_symbol_tokens(tokenizer):
numeral_symbol_tokens = []
@ -209,12 +208,12 @@ class FasterWhisperPipeline(Pipeline):
# Pre-process audio and merge chunks as defined by the respective VAD child class
# In case vad_model is manually assigned (see 'load_model') follow the functionality of pyannote toolkit
if issubclass(type(self.vad_model), whisperx.vads.Vad):
if issubclass(type(self.vad_model), Vad):
waveform = self.vad_model.preprocess_audio(audio)
merge_chunks = self.vad_model.merge_chunks
else:
waveform = whisperx.vads.Pyannote.preprocess_audio(audio)
merge_chunks = whisperx.vads.Pyannote.merge_chunks
waveform = Pyannote.preprocess_audio(audio)
merge_chunks = Pyannote.merge_chunks
vad_segments = self.vad_model({"waveform": waveform, "sample_rate": SAMPLE_RATE})
vad_segments = merge_chunks(
@ -304,8 +303,8 @@ def load_model(
compute_type="float16",
asr_options: Optional[dict] = None,
language: Optional[str] = None,
vad_model = None,
vad_method: str = "pyannote",
vad_model: Optional[Vad]= None,
vad_method: Optional[str] = "pyannote",
vad_options: Optional[dict] = None,
model: Optional[WhisperModel] = None,
task="transcribe",
@ -318,7 +317,7 @@ def load_model(
whisper_arch - The name of the Whisper model to load.
device - The device to load the model on.
compute_type - The compute type to use for the model.
vad_method: str - The vad method to use. vad_model has higher priority if is not None.
vad_method - The vad method to use. vad_model has higher priority if is not None.
options - A dictionary of options to use for the model.
language - The language of the model. (use English for now)
model - The WhisperModel instance to use.
@ -398,9 +397,9 @@ def load_model(
vad_model = vad_model
else:
if vad_method == "silero":
vad_model = whisperx.vads.Silero(**default_vad_options)
vad_model = Silero(**default_vad_options)
elif vad_method == "pyannote":
vad_model = whisperx.vads.Pyannote(torch.device(device), use_auth_token=None, **default_vad_options)
vad_model = Pyannote(torch.device(device), use_auth_token=None, **default_vad_options)
else:
raise ValueError(f"Invalid vad_method: {vad_method}")