Merge pull request #303 from m-bain/v3

Suppress numerals
This commit is contained in:
Max Bain
2023-06-05 15:46:26 +01:00
committed by GitHub
2 changed files with 30 additions and 9 deletions

View File

@ -13,6 +13,15 @@ from .audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram
from .vad import load_vad_model, merge_chunks
from .types import TranscriptionResult, SingleSegment
def find_numeral_symbol_tokens(tokenizer):
numeral_symbol_tokens = []
for i in range(tokenizer.eot):
token = tokenizer.decode([i]).removeprefix(" ")
has_numeral_symbol = any(c in "0123456789%" for c in token)
if has_numeral_symbol:
numeral_symbol_tokens.append(i)
return numeral_symbol_tokens
def load_model(whisper_arch,
device,
device_index=0,
@ -67,11 +76,22 @@ def load_model(whisper_arch,
"max_initial_timestamp": 0.0,
"word_timestamps": False,
"prepend_punctuations": "\"'“¿([{-",
"append_punctuations": "\"'.。,!?::”)]}、"
"append_punctuations": "\"'.。,!?::”)]}、",
"suppress_numerals": False,
}
if asr_options is not None:
default_asr_options.update(asr_options)
if default_asr_options["suppress_numerals"]:
if tokenizer is None:
tokenizer = faster_whisper.tokenizer.Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task=task, language="en")
numeral_symbol_tokens = find_numeral_symbol_tokens(tokenizer)
print(f"Suppressing numeral and symbol tokens: {numeral_symbol_tokens}")
default_asr_options["suppress_tokens"] += numeral_symbol_tokens
default_asr_options["suppress_tokens"] = list(set(default_asr_options["suppress_tokens"]))
del default_asr_options["suppress_numerals"]
default_asr_options = faster_whisper.transcribe.TranscriptionOptions(**default_asr_options)
default_vad_options = {
@ -119,13 +139,10 @@ class WhisperModel(faster_whisper.WhisperModel):
result = self.model.generate(
encoder_output,
[prompt] * batch_size,
# length_penalty=options.length_penalty,
# max_length=self.max_length,
# return_scores=True,
# return_no_speech_prob=True,
# suppress_blank=options.suppress_blank,
# suppress_tokens=options.suppress_tokens,
# max_initial_timestamp_index=max_initial_timestamp_index,
length_penalty=options.length_penalty,
max_length=self.max_length,
suppress_blank=options.suppress_blank,
suppress_tokens=options.suppress_tokens,
)
tokens_batch = [x.sequences_ids[0] for x in result]

View File

@ -51,9 +51,11 @@ def cli():
parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero")
parser.add_argument("--patience", type=float, default=None, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
parser.add_argument("--length_penalty", type=float, default=None, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default")
parser.add_argument("--length_penalty", type=float, default=1.0, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default")
parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
parser.add_argument("--suppress_numerals", action="store_true", help="whether to suppress numeric symbols and currency symbols during sampling, since wav2vec2 cannot align them correctly")
parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
parser.add_argument("--condition_on_previous_text", type=str2bool, default=False, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
@ -130,6 +132,8 @@ def cli():
"no_speech_threshold": args.pop("no_speech_threshold"),
"condition_on_previous_text": False,
"initial_prompt": args.pop("initial_prompt"),
"suppress_tokens": [int(x) for x in args.pop("suppress_tokens").split(",")],
"suppress_numerals": args.pop("suppress_numerals"),
}
writer = get_writer(output_format, output_dir)