--suppress_numerals option, ensures non-numerical words, for wav2vec2 alignment

This commit is contained in:
Max Bain
2023-06-05 15:27:42 +01:00
parent 42b4909bc0
commit a323cff654
2 changed files with 32 additions and 9 deletions

View File

@ -12,6 +12,14 @@ from transformers.pipelines.pt_utils import PipelineIterator
from .audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram
from .vad import load_vad_model, merge_chunks
def find_numeral_symbol_tokens(tokenizer):
numeral_symbol_tokens = []
for i in range(tokenizer.eot):
token = tokenizer.decode([i]).removeprefix(" ")
if all(c in "0123456789@#%&*+=_$:-.,?!" for c in token):
numeral_symbol_tokens.append(i)
return numeral_symbol_tokens
def load_model(whisper_arch, device, compute_type="float16", asr_options=None, language=None,
vad_options=None, model=None, task="transcribe"):
@ -54,13 +62,27 @@ def load_model(whisper_arch, device, compute_type="float16", asr_options=None, l
"max_initial_timestamp": 0.0,
"word_timestamps": False,
"prepend_punctuations": "\"'“¿([{-",
"append_punctuations": "\"'.。,!?::”)]}、"
"append_punctuations": "\"'.。,!?::”)]}、",
"suppress_numerals": False,
}
if asr_options is not None:
default_asr_options.update(asr_options)
if default_asr_options["suppress_numerals"]:
if tokenizer is None:
tokenizer = faster_whisper.tokenizer.Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task=task, language="en")
numeral_symbol_tokens = find_numeral_symbol_tokens(tokenizer)
print(f"Suppressing numeral and symbol tokens: {numeral_symbol_tokens}")
default_asr_options["suppress_tokens"] += numeral_symbol_tokens
default_asr_options["suppress_tokens"] = list(set(default_asr_options["suppress_tokens"]))
del default_asr_options["suppress_numerals"]
default_asr_options = faster_whisper.transcribe.TranscriptionOptions(**default_asr_options)
default_vad_options = {
"vad_onset": 0.500,
"vad_offset": 0.363
@ -106,13 +128,10 @@ class WhisperModel(faster_whisper.WhisperModel):
result = self.model.generate(
encoder_output,
[prompt] * batch_size,
# length_penalty=options.length_penalty,
# max_length=self.max_length,
# return_scores=True,
# return_no_speech_prob=True,
# suppress_blank=options.suppress_blank,
# suppress_tokens=options.suppress_tokens,
# max_initial_timestamp_index=max_initial_timestamp_index,
length_penalty=options.length_penalty,
max_length=self.max_length,
suppress_blank=options.suppress_blank,
suppress_tokens=options.suppress_tokens,
)
tokens_batch = [x.sequences_ids[0] for x in result]