diff --git a/whisperx/asr.py b/whisperx/asr.py index d0e6962..09454c9 100644 --- a/whisperx/asr.py +++ b/whisperx/asr.py @@ -13,6 +13,15 @@ from .audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram from .vad import load_vad_model, merge_chunks from .types import TranscriptionResult, SingleSegment +def find_numeral_symbol_tokens(tokenizer): + numeral_symbol_tokens = [] + for i in range(tokenizer.eot): + token = tokenizer.decode([i]).removeprefix(" ") + has_numeral_symbol = any(c in "0123456789%$£" for c in token) + if has_numeral_symbol: + numeral_symbol_tokens.append(i) + return numeral_symbol_tokens + def load_model(whisper_arch, device, device_index=0, @@ -67,11 +76,22 @@ def load_model(whisper_arch, "max_initial_timestamp": 0.0, "word_timestamps": False, "prepend_punctuations": "\"'“¿([{-", - "append_punctuations": "\"'.。,,!!??::”)]}、" + "append_punctuations": "\"'.。,,!!??::”)]}、", + "suppress_numerals": False, } if asr_options is not None: default_asr_options.update(asr_options) + + if default_asr_options["suppress_numerals"]: + if tokenizer is None: + tokenizer = faster_whisper.tokenizer.Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task=task, language="en") + numeral_symbol_tokens = find_numeral_symbol_tokens(tokenizer) + print(f"Suppressing numeral and symbol tokens: {numeral_symbol_tokens}") + default_asr_options["suppress_tokens"] += numeral_symbol_tokens + default_asr_options["suppress_tokens"] = list(set(default_asr_options["suppress_tokens"])) + del default_asr_options["suppress_numerals"] + default_asr_options = faster_whisper.transcribe.TranscriptionOptions(**default_asr_options) default_vad_options = { @@ -119,13 +139,10 @@ class WhisperModel(faster_whisper.WhisperModel): result = self.model.generate( encoder_output, [prompt] * batch_size, - # length_penalty=options.length_penalty, - # max_length=self.max_length, - # return_scores=True, - # return_no_speech_prob=True, - # suppress_blank=options.suppress_blank, - # suppress_tokens=options.suppress_tokens, - # max_initial_timestamp_index=max_initial_timestamp_index, + length_penalty=options.length_penalty, + max_length=self.max_length, + suppress_blank=options.suppress_blank, + suppress_tokens=options.suppress_tokens, ) tokens_batch = [x.sequences_ids[0] for x in result] diff --git a/whisperx/transcribe.py b/whisperx/transcribe.py index 691e3f9..3bb1a36 100644 --- a/whisperx/transcribe.py +++ b/whisperx/transcribe.py @@ -51,9 +51,11 @@ def cli(): parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature") parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero") parser.add_argument("--patience", type=float, default=None, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search") - parser.add_argument("--length_penalty", type=float, default=None, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default") + parser.add_argument("--length_penalty", type=float, default=1.0, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default") parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations") + parser.add_argument("--suppress_numerals", action="store_true", help="whether to suppress numeric symbols and currency symbols during sampling, since wav2vec2 cannot align them correctly") + parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.") parser.add_argument("--condition_on_previous_text", type=str2bool, default=False, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop") parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default") @@ -130,6 +132,8 @@ def cli(): "no_speech_threshold": args.pop("no_speech_threshold"), "condition_on_previous_text": False, "initial_prompt": args.pop("initial_prompt"), + "suppress_tokens": [int(x) for x in args.pop("suppress_tokens").split(",")], + "suppress_numerals": args.pop("suppress_numerals"), } writer = get_writer(output_format, output_dir)