refactor: add type hints

This commit is contained in:
Barabazs
2025-01-05 11:26:18 +01:00
parent 0f7f9f9f83
commit 9a8967f27e
6 changed files with 111 additions and 57 deletions

View File

@ -1,20 +1,20 @@
import os
import warnings
from typing import List, Union, Optional, NamedTuple
from typing import List, NamedTuple, Optional, Union
import ctranslate2
import faster_whisper
import numpy as np
import torch
from faster_whisper.tokenizer import Tokenizer
from faster_whisper.transcribe import (TranscriptionOptions,
get_ctranslate2_storage)
from faster_whisper.transcribe import TranscriptionOptions, get_ctranslate2_storage
from transformers import Pipeline
from transformers.pipelines.pt_utils import PipelineIterator
from .audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram
from .vad import load_vad_model, merge_chunks
from .types import TranscriptionResult, SingleSegment
from .types import SingleSegment, TranscriptionResult
from .vad import VoiceActivitySegmentation, load_vad_model, merge_chunks
def find_numeral_symbol_tokens(tokenizer):
numeral_symbol_tokens = []
@ -103,17 +103,17 @@ class FasterWhisperPipeline(Pipeline):
# - add support for custom inference kwargs
def __init__(
self,
model,
vad,
vad_params: dict,
options : NamedTuple,
tokenizer=None,
device: Union[int, str, "torch.device"] = -1,
framework = "pt",
language : Optional[str] = None,
suppress_numerals: bool = False,
**kwargs
self,
model: WhisperModel,
vad: VoiceActivitySegmentation,
vad_params: dict,
options: NamedTuple,
tokenizer: Optional[Tokenizer] = None,
device: Union[int, str, "torch.device"] = -1,
framework="pt",
language: Optional[str] = None,
suppress_numerals: bool = False,
**kwargs,
):
self.model = model
self.tokenizer = tokenizer
@ -165,7 +165,13 @@ class FasterWhisperPipeline(Pipeline):
return model_outputs
def get_iterator(
self, inputs, num_workers: int, batch_size: int, preprocess_params, forward_params, postprocess_params
self,
inputs,
num_workers: int,
batch_size: int,
preprocess_params: dict,
forward_params: dict,
postprocess_params: dict,
):
dataset = PipelineIterator(inputs, self.preprocess, preprocess_params)
if "TOKENIZERS_PARALLELISM" not in os.environ:
@ -180,7 +186,16 @@ class FasterWhisperPipeline(Pipeline):
return final_iterator
def transcribe(
self, audio: Union[str, np.ndarray], batch_size=None, num_workers=0, language=None, task=None, chunk_size=30, print_progress = False, combined_progress=False, verbose=False
self,
audio: Union[str, np.ndarray],
batch_size: Optional[int] = None,
num_workers=0,
language: Optional[str] = None,
task: Optional[str] = None,
chunk_size=30,
print_progress=False,
combined_progress=False,
verbose=False,
) -> TranscriptionResult:
if isinstance(audio, str):
audio = load_audio(audio)
@ -258,8 +273,7 @@ class FasterWhisperPipeline(Pipeline):
return {"segments": segments, "language": language}
def detect_language(self, audio: np.ndarray):
def detect_language(self, audio: np.ndarray) -> str:
if audio.shape[0] < N_SAMPLES:
print("Warning: audio is shorter than 30s, language detection may be inaccurate.")
model_n_mels = self.model.feat_kwargs.get("feature_size")
@ -273,33 +287,36 @@ class FasterWhisperPipeline(Pipeline):
print(f"Detected language: {language} ({language_probability:.2f}) in first 30s of audio...")
return language
def load_model(whisper_arch,
device,
device_index=0,
compute_type="float16",
asr_options=None,
language : Optional[str] = None,
vad_model=None,
vad_options=None,
model : Optional[WhisperModel] = None,
task="transcribe",
download_root=None,
local_files_only=False,
threads=4):
'''Load a Whisper model for inference.
def load_model(
whisper_arch: str,
device: str,
device_index=0,
compute_type="float16",
asr_options: Optional[dict] = None,
language: Optional[str] = None,
vad_model: Optional[VoiceActivitySegmentation] = None,
vad_options: Optional[dict] = None,
model: Optional[WhisperModel] = None,
task="transcribe",
download_root: Optional[str] = None,
local_files_only=False,
threads=4,
) -> FasterWhisperPipeline:
"""Load a Whisper model for inference.
Args:
whisper_arch: str - The name of the Whisper model to load.
device: str - The device to load the model on.
compute_type: str - The compute type to use for the model.
options: dict - A dictionary of options to use for the model.
language: str - The language of the model. (use English for now)
model: Optional[WhisperModel] - The WhisperModel instance to use.
download_root: Optional[str] - The root directory to download the model to.
local_files_only: bool - If `True`, avoid downloading the file and return the path to the local cached file if it exists.
threads: int - The number of cpu threads to use per worker, e.g. will be multiplied by num workers.
whisper_arch - The name of the Whisper model to load.
device - The device to load the model on.
compute_type - The compute type to use for the model.
options - A dictionary of options to use for the model.
language - The language of the model. (use English for now)
model - The WhisperModel instance to use.
download_root - The root directory to download the model to.
local_files_only - If `True`, avoid downloading the file and return the path to the local cached file if it exists.
threads - The number of cpu threads to use per worker, e.g. will be multiplied by num workers.
Returns:
A Whisper pipeline.
'''
"""
if whisper_arch.endswith(".en"):
language = "en"