mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
refactor: simplify imports for better type inference
This commit is contained in:
@ -6,6 +6,9 @@ import ctranslate2
|
||||
import faster_whisper
|
||||
import numpy as np
|
||||
import torch
|
||||
from faster_whisper.tokenizer import Tokenizer
|
||||
from faster_whisper.transcribe import (TranscriptionOptions,
|
||||
get_ctranslate2_storage)
|
||||
from transformers import Pipeline
|
||||
from transformers.pipelines.pt_utils import PipelineIterator
|
||||
|
||||
@ -28,7 +31,13 @@ class WhisperModel(faster_whisper.WhisperModel):
|
||||
Currently only works in non-timestamp mode and fixed prompt for all samples in batch.
|
||||
'''
|
||||
|
||||
def generate_segment_batched(self, features: np.ndarray, tokenizer: faster_whisper.tokenizer.Tokenizer, options: faster_whisper.transcribe.TranscriptionOptions, encoder_output = None):
|
||||
def generate_segment_batched(
|
||||
self,
|
||||
features: np.ndarray,
|
||||
tokenizer: Tokenizer,
|
||||
options: TranscriptionOptions,
|
||||
encoder_output=None,
|
||||
):
|
||||
batch_size = features.shape[0]
|
||||
all_tokens = []
|
||||
prompt_reset_since = 0
|
||||
@ -81,7 +90,7 @@ class WhisperModel(faster_whisper.WhisperModel):
|
||||
# unsqueeze if batch size = 1
|
||||
if len(features.shape) == 2:
|
||||
features = np.expand_dims(features, 0)
|
||||
features = faster_whisper.transcribe.get_ctranslate2_storage(features)
|
||||
features = get_ctranslate2_storage(features)
|
||||
|
||||
return self.model.encode(features, to_cpu=to_cpu)
|
||||
|
||||
@ -193,17 +202,23 @@ class FasterWhisperPipeline(Pipeline):
|
||||
if self.tokenizer is None:
|
||||
language = language or self.detect_language(audio)
|
||||
task = task or "transcribe"
|
||||
self.tokenizer = faster_whisper.tokenizer.Tokenizer(self.model.hf_tokenizer,
|
||||
self.model.model.is_multilingual, task=task,
|
||||
language=language)
|
||||
self.tokenizer = Tokenizer(
|
||||
self.model.hf_tokenizer,
|
||||
self.model.model.is_multilingual,
|
||||
task=task,
|
||||
language=language,
|
||||
)
|
||||
else:
|
||||
language = language or self.tokenizer.language_code
|
||||
task = task or self.tokenizer.task
|
||||
if task != self.tokenizer.task or language != self.tokenizer.language_code:
|
||||
self.tokenizer = faster_whisper.tokenizer.Tokenizer(self.model.hf_tokenizer,
|
||||
self.model.model.is_multilingual, task=task,
|
||||
language=language)
|
||||
|
||||
self.tokenizer = Tokenizer(
|
||||
self.model.hf_tokenizer,
|
||||
self.model.model.is_multilingual,
|
||||
task=task,
|
||||
language=language,
|
||||
)
|
||||
|
||||
if self.suppress_numerals:
|
||||
previous_suppress_tokens = self.options.suppress_tokens
|
||||
numeral_symbol_tokens = find_numeral_symbol_tokens(self.tokenizer)
|
||||
@ -297,7 +312,7 @@ def load_model(whisper_arch,
|
||||
local_files_only=local_files_only,
|
||||
cpu_threads=threads)
|
||||
if language is not None:
|
||||
tokenizer = faster_whisper.tokenizer.Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task=task, language=language)
|
||||
tokenizer = Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task=task, language=language)
|
||||
else:
|
||||
print("No language specified, language will be first be detected for each audio file (increases inference time).")
|
||||
tokenizer = None
|
||||
@ -338,7 +353,7 @@ def load_model(whisper_arch,
|
||||
suppress_numerals = default_asr_options["suppress_numerals"]
|
||||
del default_asr_options["suppress_numerals"]
|
||||
|
||||
default_asr_options = faster_whisper.transcribe.TranscriptionOptions(**default_asr_options)
|
||||
default_asr_options = TranscriptionOptions(**default_asr_options)
|
||||
|
||||
default_vad_options = {
|
||||
"vad_onset": 0.500,
|
||||
|
Reference in New Issue
Block a user