mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
Compare commits
1 Commits
65b981c025
...
improve-co
Author | SHA1 | Date | |
---|---|---|---|
88939b9e8a |
136
whisperx/asr.py
136
whisperx/asr.py
@ -16,6 +16,9 @@ from .types import SingleSegment, TranscriptionResult
|
||||
from .vads import Vad, Silero, Pyannote
|
||||
|
||||
def find_numeral_symbol_tokens(tokenizer):
|
||||
"""
|
||||
Finds tokens that represent numeral and symbols.
|
||||
"""
|
||||
numeral_symbol_tokens = []
|
||||
for i in range(tokenizer.eot):
|
||||
token = tokenizer.decode([i]).removeprefix(" ")
|
||||
@ -25,10 +28,10 @@ def find_numeral_symbol_tokens(tokenizer):
|
||||
return numeral_symbol_tokens
|
||||
|
||||
class WhisperModel(faster_whisper.WhisperModel):
|
||||
'''
|
||||
FasterWhisperModel provides batched inference for faster-whisper.
|
||||
Currently only works in non-timestamp mode and fixed prompt for all samples in batch.
|
||||
'''
|
||||
"""
|
||||
Wrapper around faster-whisper's WhisperModel to enable batched inference.
|
||||
Currently, it only supports non-timestamp mode and a fixed prompt for all samples in a batch.
|
||||
"""
|
||||
|
||||
def generate_segment_batched(
|
||||
self,
|
||||
@ -37,13 +40,28 @@ class WhisperModel(faster_whisper.WhisperModel):
|
||||
options: TranscriptionOptions,
|
||||
encoder_output=None,
|
||||
):
|
||||
"""
|
||||
Generates transcription for a batch of audio segments.
|
||||
|
||||
Args:
|
||||
features: The input audio features.
|
||||
tokenizer: The tokenizer used to decode the generated tokens.
|
||||
options: Transcription options.
|
||||
encoder_output: Output from the encoder model.
|
||||
|
||||
Returns:
|
||||
The decoded transcription text.
|
||||
"""
|
||||
batch_size = features.shape[0]
|
||||
# Initialize tokens and prompt for the generation process.
|
||||
all_tokens = []
|
||||
prompt_reset_since = 0
|
||||
# Check if an initial prompt is provided and handle it.
|
||||
if options.initial_prompt is not None:
|
||||
initial_prompt = " " + options.initial_prompt.strip()
|
||||
initial_prompt_tokens = tokenizer.encode(initial_prompt)
|
||||
all_tokens.extend(initial_prompt_tokens)
|
||||
# Prepare the prompt for the current batch.
|
||||
previous_tokens = all_tokens[prompt_reset_since:]
|
||||
prompt = self.get_prompt(
|
||||
tokenizer,
|
||||
@ -51,118 +69,58 @@ class WhisperModel(faster_whisper.WhisperModel):
|
||||
without_timestamps=options.without_timestamps,
|
||||
prefix=options.prefix,
|
||||
)
|
||||
|
||||
|
||||
# Encode the features to obtain the encoder output.
|
||||
encoder_output = self.encode(features)
|
||||
|
||||
# Determine the maximum initial timestamp index based on the options.
|
||||
max_initial_timestamp_index = int(
|
||||
round(options.max_initial_timestamp / self.time_precision)
|
||||
)
|
||||
|
||||
# Generate the transcription result for the batch.
|
||||
result = self.model.generate(
|
||||
encoder_output,
|
||||
[prompt] * batch_size,
|
||||
beam_size=options.beam_size,
|
||||
patience=options.patience,
|
||||
length_penalty=options.length_penalty,
|
||||
max_length=self.max_length,
|
||||
suppress_blank=options.suppress_blank,
|
||||
suppress_tokens=options.suppress_tokens,
|
||||
)
|
||||
encoder_output,
|
||||
[prompt] * batch_size,
|
||||
beam_size=options.beam_size,
|
||||
patience=options.patience,
|
||||
length_penalty=options.length_penalty,
|
||||
max_length=self.max_length,
|
||||
suppress_blank=options.suppress_blank,
|
||||
suppress_tokens=options.suppress_tokens,
|
||||
)
|
||||
|
||||
# Extract the token sequences from the result.
|
||||
tokens_batch = [x.sequences_ids[0] for x in result]
|
||||
|
||||
# Define an inner function to decode the tokens for each batch.
|
||||
def decode_batch(tokens: List[List[int]]) -> str:
|
||||
res = []
|
||||
for tk in tokens:
|
||||
res.append([token for token in tk if token < tokenizer.eot])
|
||||
# text_tokens = [token for token in tokens if token < self.eot]
|
||||
return tokenizer.tokenizer.decode_batch(res)
|
||||
|
||||
# Decode the tokens to get the transcription text.
|
||||
text = decode_batch(tokens_batch)
|
||||
|
||||
return text
|
||||
|
||||
def encode(self, features: np.ndarray) -> ctranslate2.StorageView:
|
||||
# When the model is running on multiple GPUs, the encoder output should be moved
|
||||
# to the CPU since we don't know which GPU will handle the next job.
|
||||
"""
|
||||
Encodes the audio features using the CTranslate2 storage.
|
||||
|
||||
When the model is running on multiple GPUs, the encoder output should be moved
|
||||
to the CPU since we don't know which GPU will handle the next job.
|
||||
"""
|
||||
# When the model is running on multiple GPUs, the encoder output should be moved to the CPU.
|
||||
to_cpu = self.model.device == "cuda" and len(self.model.device_index) > 1
|
||||
# unsqueeze if batch size = 1
|
||||
# If the batch size is 1, unsqueeze the features to ensure it is a 3D array.
|
||||
if len(features.shape) == 2:
|
||||
features = np.expand_dims(features, 0)
|
||||
features = get_ctranslate2_storage(features)
|
||||
|
||||
# call the model
|
||||
return self.model.encode(features, to_cpu=to_cpu)
|
||||
|
||||
class FasterWhisperPipeline(Pipeline):
|
||||
"""
|
||||
Huggingface Pipeline wrapper for FasterWhisperModel.
|
||||
"""
|
||||
# TODO:
|
||||
# - add support for timestamp mode
|
||||
# - add support for custom inference kwargs
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: WhisperModel,
|
||||
vad,
|
||||
vad_params: dict,
|
||||
options: TranscriptionOptions,
|
||||
tokenizer: Optional[Tokenizer] = None,
|
||||
device: Union[int, str, "torch.device"] = -1,
|
||||
framework="pt",
|
||||
language: Optional[str] = None,
|
||||
suppress_numerals: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.options = options
|
||||
self.preset_language = language
|
||||
self.suppress_numerals = suppress_numerals
|
||||
self._batch_size = kwargs.pop("batch_size", None)
|
||||
self._num_workers = 1
|
||||
self._preprocess_params, self._forward_params, self._postprocess_params = self._sanitize_parameters(**kwargs)
|
||||
self.call_count = 0
|
||||
self.framework = framework
|
||||
if self.framework == "pt":
|
||||
if isinstance(device, torch.device):
|
||||
self.device = device
|
||||
elif isinstance(device, str):
|
||||
self.device = torch.device(device)
|
||||
elif device < 0:
|
||||
self.device = torch.device("cpu")
|
||||
else:
|
||||
self.device = torch.device(f"cuda:{device}")
|
||||
else:
|
||||
self.device = device
|
||||
|
||||
super(Pipeline, self).__init__()
|
||||
self.vad_model = vad
|
||||
self._vad_params = vad_params
|
||||
|
||||
def _sanitize_parameters(self, **kwargs):
|
||||
preprocess_kwargs = {}
|
||||
if "tokenizer" in kwargs:
|
||||
preprocess_kwargs["maybe_arg"] = kwargs["maybe_arg"]
|
||||
return preprocess_kwargs, {}, {}
|
||||
|
||||
def preprocess(self, audio):
|
||||
audio = audio['inputs']
|
||||
model_n_mels = self.model.feat_kwargs.get("feature_size")
|
||||
features = log_mel_spectrogram(
|
||||
audio,
|
||||
n_mels=model_n_mels if model_n_mels is not None else 80,
|
||||
padding=N_SAMPLES - audio.shape[0],
|
||||
)
|
||||
return {'inputs': features}
|
||||
|
||||
def _forward(self, model_inputs):
|
||||
outputs = self.model.generate_segment_batched(model_inputs['inputs'], self.tokenizer, self.options)
|
||||
return {'text': outputs}
|
||||
|
||||
def postprocess(self, model_outputs):
|
||||
return model_outputs
|
||||
|
||||
def get_iterator(
|
||||
self,
|
||||
inputs,
|
||||
|
Reference in New Issue
Block a user