From a323cff654cf39e190043a8642c35205d1af02e8 Mon Sep 17 00:00:00 2001 From: Max Bain Date: Mon, 5 Jun 2023 15:27:42 +0100 Subject: [PATCH 1/3] --suppress_numerals option, ensures non-numerical words, for wav2vec2 alignment --- whisperx/asr.py | 35 +++++++++++++++++++++++++++-------- whisperx/transcribe.py | 6 +++++- 2 files changed, 32 insertions(+), 9 deletions(-) diff --git a/whisperx/asr.py b/whisperx/asr.py index 66b58ad..fbb5331 100644 --- a/whisperx/asr.py +++ b/whisperx/asr.py @@ -12,6 +12,14 @@ from transformers.pipelines.pt_utils import PipelineIterator from .audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram from .vad import load_vad_model, merge_chunks +def find_numeral_symbol_tokens(tokenizer): + numeral_symbol_tokens = [] + for i in range(tokenizer.eot): + token = tokenizer.decode([i]).removeprefix(" ") + if all(c in "0123456789@#%&*+=_$:-.,?!" for c in token): + numeral_symbol_tokens.append(i) + return numeral_symbol_tokens + def load_model(whisper_arch, device, compute_type="float16", asr_options=None, language=None, vad_options=None, model=None, task="transcribe"): @@ -54,13 +62,27 @@ def load_model(whisper_arch, device, compute_type="float16", asr_options=None, l "max_initial_timestamp": 0.0, "word_timestamps": False, "prepend_punctuations": "\"'“¿([{-", - "append_punctuations": "\"'.。,,!!??::”)]}、" + "append_punctuations": "\"'.。,,!!??::”)]}、", + "suppress_numerals": False, } if asr_options is not None: default_asr_options.update(asr_options) + + if default_asr_options["suppress_numerals"]: + if tokenizer is None: + tokenizer = faster_whisper.tokenizer.Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task=task, language="en") + numeral_symbol_tokens = find_numeral_symbol_tokens(tokenizer) + print(f"Suppressing numeral and symbol tokens: {numeral_symbol_tokens}") + default_asr_options["suppress_tokens"] += numeral_symbol_tokens + default_asr_options["suppress_tokens"] = list(set(default_asr_options["suppress_tokens"])) + del default_asr_options["suppress_numerals"] + default_asr_options = faster_whisper.transcribe.TranscriptionOptions(**default_asr_options) + + + default_vad_options = { "vad_onset": 0.500, "vad_offset": 0.363 @@ -106,13 +128,10 @@ class WhisperModel(faster_whisper.WhisperModel): result = self.model.generate( encoder_output, [prompt] * batch_size, - # length_penalty=options.length_penalty, - # max_length=self.max_length, - # return_scores=True, - # return_no_speech_prob=True, - # suppress_blank=options.suppress_blank, - # suppress_tokens=options.suppress_tokens, - # max_initial_timestamp_index=max_initial_timestamp_index, + length_penalty=options.length_penalty, + max_length=self.max_length, + suppress_blank=options.suppress_blank, + suppress_tokens=options.suppress_tokens, ) tokens_batch = [x.sequences_ids[0] for x in result] diff --git a/whisperx/transcribe.py b/whisperx/transcribe.py index 3edc746..e9b6fc6 100644 --- a/whisperx/transcribe.py +++ b/whisperx/transcribe.py @@ -50,9 +50,11 @@ def cli(): parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature") parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero") parser.add_argument("--patience", type=float, default=None, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search") - parser.add_argument("--length_penalty", type=float, default=None, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default") + parser.add_argument("--length_penalty", type=float, default=1.0, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default") parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations") + parser.add_argument("--suppress_numerals", action="store_true", help="whether to suppress numeric symbols and currency symbols during sampling, since wav2vec2 cannot align them correctly") + parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.") parser.add_argument("--condition_on_previous_text", type=str2bool, default=False, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop") parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default") @@ -128,6 +130,8 @@ def cli(): "no_speech_threshold": args.pop("no_speech_threshold"), "condition_on_previous_text": False, "initial_prompt": args.pop("initial_prompt"), + "suppress_tokens": [int(x) for x in args.pop("suppress_tokens").split(",")], + "suppress_numerals": args.pop("suppress_numerals"), } writer = get_writer(output_format, output_dir) From 74a00eecd7e0f90766f9fcd709f3e287ad6d97b3 Mon Sep 17 00:00:00 2001 From: Max Bain Date: Mon, 5 Jun 2023 15:33:04 +0100 Subject: [PATCH 2/3] suppress numerals fix --- whisperx/asr.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/whisperx/asr.py b/whisperx/asr.py index f2a7203..501b21d 100644 --- a/whisperx/asr.py +++ b/whisperx/asr.py @@ -11,7 +11,7 @@ from transformers.pipelines.pt_utils import PipelineIterator from .audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram from .vad import load_vad_model, merge_chunks -<<<<<<< HEAD +from .types import TranscriptionResult, SingleSegment def find_numeral_symbol_tokens(tokenizer): numeral_symbol_tokens = [] @@ -21,12 +21,6 @@ def find_numeral_symbol_tokens(tokenizer): numeral_symbol_tokens.append(i) return numeral_symbol_tokens - -def load_model(whisper_arch, device, compute_type="float16", asr_options=None, language=None, - vad_options=None, model=None, task="transcribe"): -======= -from .types import TranscriptionResult, SingleSegment - def load_model(whisper_arch, device, device_index=0, @@ -37,7 +31,6 @@ def load_model(whisper_arch, model=None, task="transcribe", download_root=None): ->>>>>>> ec6a110cdf2616919cfd0a616f9ae2fbdd44903f '''Load a Whisper model for inference. Args: whisper_arch: str - The name of the Whisper model to load. @@ -100,9 +93,6 @@ def load_model(whisper_arch, default_asr_options = faster_whisper.transcribe.TranscriptionOptions(**default_asr_options) - - - default_vad_options = { "vad_onset": 0.500, "vad_offset": 0.363 From d7f1d16f1927bfaf6b2b62de1d68daae470f2721 Mon Sep 17 00:00:00 2001 From: Max Bain Date: Mon, 5 Jun 2023 15:44:17 +0100 Subject: [PATCH 3/3] suppress numerals change logic --- whisperx/asr.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/whisperx/asr.py b/whisperx/asr.py index 501b21d..09454c9 100644 --- a/whisperx/asr.py +++ b/whisperx/asr.py @@ -17,7 +17,8 @@ def find_numeral_symbol_tokens(tokenizer): numeral_symbol_tokens = [] for i in range(tokenizer.eot): token = tokenizer.decode([i]).removeprefix(" ") - if all(c in "0123456789@#%&*+=_$:-.,?!" for c in token): + has_numeral_symbol = any(c in "0123456789%$£" for c in token) + if has_numeral_symbol: numeral_symbol_tokens.append(i) return numeral_symbol_tokens