mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
refactor: simplify imports for better type inference
This commit is contained in:
@ -6,6 +6,9 @@ import ctranslate2
|
|||||||
import faster_whisper
|
import faster_whisper
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from faster_whisper.tokenizer import Tokenizer
|
||||||
|
from faster_whisper.transcribe import (TranscriptionOptions,
|
||||||
|
get_ctranslate2_storage)
|
||||||
from transformers import Pipeline
|
from transformers import Pipeline
|
||||||
from transformers.pipelines.pt_utils import PipelineIterator
|
from transformers.pipelines.pt_utils import PipelineIterator
|
||||||
|
|
||||||
@ -28,7 +31,13 @@ class WhisperModel(faster_whisper.WhisperModel):
|
|||||||
Currently only works in non-timestamp mode and fixed prompt for all samples in batch.
|
Currently only works in non-timestamp mode and fixed prompt for all samples in batch.
|
||||||
'''
|
'''
|
||||||
|
|
||||||
def generate_segment_batched(self, features: np.ndarray, tokenizer: faster_whisper.tokenizer.Tokenizer, options: faster_whisper.transcribe.TranscriptionOptions, encoder_output = None):
|
def generate_segment_batched(
|
||||||
|
self,
|
||||||
|
features: np.ndarray,
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
options: TranscriptionOptions,
|
||||||
|
encoder_output=None,
|
||||||
|
):
|
||||||
batch_size = features.shape[0]
|
batch_size = features.shape[0]
|
||||||
all_tokens = []
|
all_tokens = []
|
||||||
prompt_reset_since = 0
|
prompt_reset_since = 0
|
||||||
@ -81,7 +90,7 @@ class WhisperModel(faster_whisper.WhisperModel):
|
|||||||
# unsqueeze if batch size = 1
|
# unsqueeze if batch size = 1
|
||||||
if len(features.shape) == 2:
|
if len(features.shape) == 2:
|
||||||
features = np.expand_dims(features, 0)
|
features = np.expand_dims(features, 0)
|
||||||
features = faster_whisper.transcribe.get_ctranslate2_storage(features)
|
features = get_ctranslate2_storage(features)
|
||||||
|
|
||||||
return self.model.encode(features, to_cpu=to_cpu)
|
return self.model.encode(features, to_cpu=to_cpu)
|
||||||
|
|
||||||
@ -193,16 +202,22 @@ class FasterWhisperPipeline(Pipeline):
|
|||||||
if self.tokenizer is None:
|
if self.tokenizer is None:
|
||||||
language = language or self.detect_language(audio)
|
language = language or self.detect_language(audio)
|
||||||
task = task or "transcribe"
|
task = task or "transcribe"
|
||||||
self.tokenizer = faster_whisper.tokenizer.Tokenizer(self.model.hf_tokenizer,
|
self.tokenizer = Tokenizer(
|
||||||
self.model.model.is_multilingual, task=task,
|
self.model.hf_tokenizer,
|
||||||
language=language)
|
self.model.model.is_multilingual,
|
||||||
|
task=task,
|
||||||
|
language=language,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
language = language or self.tokenizer.language_code
|
language = language or self.tokenizer.language_code
|
||||||
task = task or self.tokenizer.task
|
task = task or self.tokenizer.task
|
||||||
if task != self.tokenizer.task or language != self.tokenizer.language_code:
|
if task != self.tokenizer.task or language != self.tokenizer.language_code:
|
||||||
self.tokenizer = faster_whisper.tokenizer.Tokenizer(self.model.hf_tokenizer,
|
self.tokenizer = Tokenizer(
|
||||||
self.model.model.is_multilingual, task=task,
|
self.model.hf_tokenizer,
|
||||||
language=language)
|
self.model.model.is_multilingual,
|
||||||
|
task=task,
|
||||||
|
language=language,
|
||||||
|
)
|
||||||
|
|
||||||
if self.suppress_numerals:
|
if self.suppress_numerals:
|
||||||
previous_suppress_tokens = self.options.suppress_tokens
|
previous_suppress_tokens = self.options.suppress_tokens
|
||||||
@ -297,7 +312,7 @@ def load_model(whisper_arch,
|
|||||||
local_files_only=local_files_only,
|
local_files_only=local_files_only,
|
||||||
cpu_threads=threads)
|
cpu_threads=threads)
|
||||||
if language is not None:
|
if language is not None:
|
||||||
tokenizer = faster_whisper.tokenizer.Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task=task, language=language)
|
tokenizer = Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task=task, language=language)
|
||||||
else:
|
else:
|
||||||
print("No language specified, language will be first be detected for each audio file (increases inference time).")
|
print("No language specified, language will be first be detected for each audio file (increases inference time).")
|
||||||
tokenizer = None
|
tokenizer = None
|
||||||
@ -338,7 +353,7 @@ def load_model(whisper_arch,
|
|||||||
suppress_numerals = default_asr_options["suppress_numerals"]
|
suppress_numerals = default_asr_options["suppress_numerals"]
|
||||||
del default_asr_options["suppress_numerals"]
|
del default_asr_options["suppress_numerals"]
|
||||||
|
|
||||||
default_asr_options = faster_whisper.transcribe.TranscriptionOptions(**default_asr_options)
|
default_asr_options = TranscriptionOptions(**default_asr_options)
|
||||||
|
|
||||||
default_vad_options = {
|
default_vad_options = {
|
||||||
"vad_onset": 0.500,
|
"vad_onset": 0.500,
|
||||||
|
Reference in New Issue
Block a user