mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
--suppress_numerals option, ensures non-numerical words, for wav2vec2 alignment
This commit is contained in:
@ -12,6 +12,14 @@ from transformers.pipelines.pt_utils import PipelineIterator
|
||||
from .audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram
|
||||
from .vad import load_vad_model, merge_chunks
|
||||
|
||||
def find_numeral_symbol_tokens(tokenizer):
|
||||
numeral_symbol_tokens = []
|
||||
for i in range(tokenizer.eot):
|
||||
token = tokenizer.decode([i]).removeprefix(" ")
|
||||
if all(c in "0123456789@#%&*+=_$:-.,?!" for c in token):
|
||||
numeral_symbol_tokens.append(i)
|
||||
return numeral_symbol_tokens
|
||||
|
||||
|
||||
def load_model(whisper_arch, device, compute_type="float16", asr_options=None, language=None,
|
||||
vad_options=None, model=None, task="transcribe"):
|
||||
@ -54,13 +62,27 @@ def load_model(whisper_arch, device, compute_type="float16", asr_options=None, l
|
||||
"max_initial_timestamp": 0.0,
|
||||
"word_timestamps": False,
|
||||
"prepend_punctuations": "\"'“¿([{-",
|
||||
"append_punctuations": "\"'.。,,!!??::”)]}、"
|
||||
"append_punctuations": "\"'.。,,!!??::”)]}、",
|
||||
"suppress_numerals": False,
|
||||
}
|
||||
|
||||
if asr_options is not None:
|
||||
default_asr_options.update(asr_options)
|
||||
|
||||
if default_asr_options["suppress_numerals"]:
|
||||
if tokenizer is None:
|
||||
tokenizer = faster_whisper.tokenizer.Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task=task, language="en")
|
||||
numeral_symbol_tokens = find_numeral_symbol_tokens(tokenizer)
|
||||
print(f"Suppressing numeral and symbol tokens: {numeral_symbol_tokens}")
|
||||
default_asr_options["suppress_tokens"] += numeral_symbol_tokens
|
||||
default_asr_options["suppress_tokens"] = list(set(default_asr_options["suppress_tokens"]))
|
||||
del default_asr_options["suppress_numerals"]
|
||||
|
||||
default_asr_options = faster_whisper.transcribe.TranscriptionOptions(**default_asr_options)
|
||||
|
||||
|
||||
|
||||
|
||||
default_vad_options = {
|
||||
"vad_onset": 0.500,
|
||||
"vad_offset": 0.363
|
||||
@ -106,13 +128,10 @@ class WhisperModel(faster_whisper.WhisperModel):
|
||||
result = self.model.generate(
|
||||
encoder_output,
|
||||
[prompt] * batch_size,
|
||||
# length_penalty=options.length_penalty,
|
||||
# max_length=self.max_length,
|
||||
# return_scores=True,
|
||||
# return_no_speech_prob=True,
|
||||
# suppress_blank=options.suppress_blank,
|
||||
# suppress_tokens=options.suppress_tokens,
|
||||
# max_initial_timestamp_index=max_initial_timestamp_index,
|
||||
length_penalty=options.length_penalty,
|
||||
max_length=self.max_length,
|
||||
suppress_blank=options.suppress_blank,
|
||||
suppress_tokens=options.suppress_tokens,
|
||||
)
|
||||
|
||||
tokens_batch = [x.sequences_ids[0] for x in result]
|
||||
|
Reference in New Issue
Block a user