refactor: simplify imports for better type inference

This commit is contained in:
Barabazs
2025-01-05 10:58:59 +01:00
parent c60594fa3b
commit 0f7f9f9f83

View File

@ -6,6 +6,9 @@ import ctranslate2
import faster_whisper
import numpy as np
import torch
from faster_whisper.tokenizer import Tokenizer
from faster_whisper.transcribe import (TranscriptionOptions,
get_ctranslate2_storage)
from transformers import Pipeline
from transformers.pipelines.pt_utils import PipelineIterator
@ -28,7 +31,13 @@ class WhisperModel(faster_whisper.WhisperModel):
Currently only works in non-timestamp mode and fixed prompt for all samples in batch.
'''
def generate_segment_batched(self, features: np.ndarray, tokenizer: faster_whisper.tokenizer.Tokenizer, options: faster_whisper.transcribe.TranscriptionOptions, encoder_output = None):
def generate_segment_batched(
self,
features: np.ndarray,
tokenizer: Tokenizer,
options: TranscriptionOptions,
encoder_output=None,
):
batch_size = features.shape[0]
all_tokens = []
prompt_reset_since = 0
@ -81,7 +90,7 @@ class WhisperModel(faster_whisper.WhisperModel):
# unsqueeze if batch size = 1
if len(features.shape) == 2:
features = np.expand_dims(features, 0)
features = faster_whisper.transcribe.get_ctranslate2_storage(features)
features = get_ctranslate2_storage(features)
return self.model.encode(features, to_cpu=to_cpu)
@ -193,17 +202,23 @@ class FasterWhisperPipeline(Pipeline):
if self.tokenizer is None:
language = language or self.detect_language(audio)
task = task or "transcribe"
self.tokenizer = faster_whisper.tokenizer.Tokenizer(self.model.hf_tokenizer,
self.model.model.is_multilingual, task=task,
language=language)
self.tokenizer = Tokenizer(
self.model.hf_tokenizer,
self.model.model.is_multilingual,
task=task,
language=language,
)
else:
language = language or self.tokenizer.language_code
task = task or self.tokenizer.task
if task != self.tokenizer.task or language != self.tokenizer.language_code:
self.tokenizer = faster_whisper.tokenizer.Tokenizer(self.model.hf_tokenizer,
self.model.model.is_multilingual, task=task,
language=language)
self.tokenizer = Tokenizer(
self.model.hf_tokenizer,
self.model.model.is_multilingual,
task=task,
language=language,
)
if self.suppress_numerals:
previous_suppress_tokens = self.options.suppress_tokens
numeral_symbol_tokens = find_numeral_symbol_tokens(self.tokenizer)
@ -297,7 +312,7 @@ def load_model(whisper_arch,
local_files_only=local_files_only,
cpu_threads=threads)
if language is not None:
tokenizer = faster_whisper.tokenizer.Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task=task, language=language)
tokenizer = Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task=task, language=language)
else:
print("No language specified, language will be first be detected for each audio file (increases inference time).")
tokenizer = None
@ -338,7 +353,7 @@ def load_model(whisper_arch,
suppress_numerals = default_asr_options["suppress_numerals"]
del default_asr_options["suppress_numerals"]
default_asr_options = faster_whisper.transcribe.TranscriptionOptions(**default_asr_options)
default_asr_options = TranscriptionOptions(**default_asr_options)
default_vad_options = {
"vad_onset": 0.500,