refactor: simplify imports for better type inference

This commit is contained in:
Barabazs
2025-01-05 10:58:59 +01:00
parent c60594fa3b
commit 0f7f9f9f83

View File

@ -6,6 +6,9 @@ import ctranslate2
import faster_whisper import faster_whisper
import numpy as np import numpy as np
import torch import torch
from faster_whisper.tokenizer import Tokenizer
from faster_whisper.transcribe import (TranscriptionOptions,
get_ctranslate2_storage)
from transformers import Pipeline from transformers import Pipeline
from transformers.pipelines.pt_utils import PipelineIterator from transformers.pipelines.pt_utils import PipelineIterator
@ -28,7 +31,13 @@ class WhisperModel(faster_whisper.WhisperModel):
Currently only works in non-timestamp mode and fixed prompt for all samples in batch. Currently only works in non-timestamp mode and fixed prompt for all samples in batch.
''' '''
def generate_segment_batched(self, features: np.ndarray, tokenizer: faster_whisper.tokenizer.Tokenizer, options: faster_whisper.transcribe.TranscriptionOptions, encoder_output = None): def generate_segment_batched(
self,
features: np.ndarray,
tokenizer: Tokenizer,
options: TranscriptionOptions,
encoder_output=None,
):
batch_size = features.shape[0] batch_size = features.shape[0]
all_tokens = [] all_tokens = []
prompt_reset_since = 0 prompt_reset_since = 0
@ -81,7 +90,7 @@ class WhisperModel(faster_whisper.WhisperModel):
# unsqueeze if batch size = 1 # unsqueeze if batch size = 1
if len(features.shape) == 2: if len(features.shape) == 2:
features = np.expand_dims(features, 0) features = np.expand_dims(features, 0)
features = faster_whisper.transcribe.get_ctranslate2_storage(features) features = get_ctranslate2_storage(features)
return self.model.encode(features, to_cpu=to_cpu) return self.model.encode(features, to_cpu=to_cpu)
@ -193,17 +202,23 @@ class FasterWhisperPipeline(Pipeline):
if self.tokenizer is None: if self.tokenizer is None:
language = language or self.detect_language(audio) language = language or self.detect_language(audio)
task = task or "transcribe" task = task or "transcribe"
self.tokenizer = faster_whisper.tokenizer.Tokenizer(self.model.hf_tokenizer, self.tokenizer = Tokenizer(
self.model.model.is_multilingual, task=task, self.model.hf_tokenizer,
language=language) self.model.model.is_multilingual,
task=task,
language=language,
)
else: else:
language = language or self.tokenizer.language_code language = language or self.tokenizer.language_code
task = task or self.tokenizer.task task = task or self.tokenizer.task
if task != self.tokenizer.task or language != self.tokenizer.language_code: if task != self.tokenizer.task or language != self.tokenizer.language_code:
self.tokenizer = faster_whisper.tokenizer.Tokenizer(self.model.hf_tokenizer, self.tokenizer = Tokenizer(
self.model.model.is_multilingual, task=task, self.model.hf_tokenizer,
language=language) self.model.model.is_multilingual,
task=task,
language=language,
)
if self.suppress_numerals: if self.suppress_numerals:
previous_suppress_tokens = self.options.suppress_tokens previous_suppress_tokens = self.options.suppress_tokens
numeral_symbol_tokens = find_numeral_symbol_tokens(self.tokenizer) numeral_symbol_tokens = find_numeral_symbol_tokens(self.tokenizer)
@ -297,7 +312,7 @@ def load_model(whisper_arch,
local_files_only=local_files_only, local_files_only=local_files_only,
cpu_threads=threads) cpu_threads=threads)
if language is not None: if language is not None:
tokenizer = faster_whisper.tokenizer.Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task=task, language=language) tokenizer = Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task=task, language=language)
else: else:
print("No language specified, language will be first be detected for each audio file (increases inference time).") print("No language specified, language will be first be detected for each audio file (increases inference time).")
tokenizer = None tokenizer = None
@ -338,7 +353,7 @@ def load_model(whisper_arch,
suppress_numerals = default_asr_options["suppress_numerals"] suppress_numerals = default_asr_options["suppress_numerals"]
del default_asr_options["suppress_numerals"] del default_asr_options["suppress_numerals"]
default_asr_options = faster_whisper.transcribe.TranscriptionOptions(**default_asr_options) default_asr_options = TranscriptionOptions(**default_asr_options)
default_vad_options = { default_vad_options = {
"vad_onset": 0.500, "vad_onset": 0.500,