mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
@ -13,6 +13,15 @@ from .audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram
|
|||||||
from .vad import load_vad_model, merge_chunks
|
from .vad import load_vad_model, merge_chunks
|
||||||
from .types import TranscriptionResult, SingleSegment
|
from .types import TranscriptionResult, SingleSegment
|
||||||
|
|
||||||
|
def find_numeral_symbol_tokens(tokenizer):
|
||||||
|
numeral_symbol_tokens = []
|
||||||
|
for i in range(tokenizer.eot):
|
||||||
|
token = tokenizer.decode([i]).removeprefix(" ")
|
||||||
|
has_numeral_symbol = any(c in "0123456789%$£" for c in token)
|
||||||
|
if has_numeral_symbol:
|
||||||
|
numeral_symbol_tokens.append(i)
|
||||||
|
return numeral_symbol_tokens
|
||||||
|
|
||||||
def load_model(whisper_arch,
|
def load_model(whisper_arch,
|
||||||
device,
|
device,
|
||||||
device_index=0,
|
device_index=0,
|
||||||
@ -67,11 +76,22 @@ def load_model(whisper_arch,
|
|||||||
"max_initial_timestamp": 0.0,
|
"max_initial_timestamp": 0.0,
|
||||||
"word_timestamps": False,
|
"word_timestamps": False,
|
||||||
"prepend_punctuations": "\"'“¿([{-",
|
"prepend_punctuations": "\"'“¿([{-",
|
||||||
"append_punctuations": "\"'.。,,!!??::”)]}、"
|
"append_punctuations": "\"'.。,,!!??::”)]}、",
|
||||||
|
"suppress_numerals": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
if asr_options is not None:
|
if asr_options is not None:
|
||||||
default_asr_options.update(asr_options)
|
default_asr_options.update(asr_options)
|
||||||
|
|
||||||
|
if default_asr_options["suppress_numerals"]:
|
||||||
|
if tokenizer is None:
|
||||||
|
tokenizer = faster_whisper.tokenizer.Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task=task, language="en")
|
||||||
|
numeral_symbol_tokens = find_numeral_symbol_tokens(tokenizer)
|
||||||
|
print(f"Suppressing numeral and symbol tokens: {numeral_symbol_tokens}")
|
||||||
|
default_asr_options["suppress_tokens"] += numeral_symbol_tokens
|
||||||
|
default_asr_options["suppress_tokens"] = list(set(default_asr_options["suppress_tokens"]))
|
||||||
|
del default_asr_options["suppress_numerals"]
|
||||||
|
|
||||||
default_asr_options = faster_whisper.transcribe.TranscriptionOptions(**default_asr_options)
|
default_asr_options = faster_whisper.transcribe.TranscriptionOptions(**default_asr_options)
|
||||||
|
|
||||||
default_vad_options = {
|
default_vad_options = {
|
||||||
@ -119,13 +139,10 @@ class WhisperModel(faster_whisper.WhisperModel):
|
|||||||
result = self.model.generate(
|
result = self.model.generate(
|
||||||
encoder_output,
|
encoder_output,
|
||||||
[prompt] * batch_size,
|
[prompt] * batch_size,
|
||||||
# length_penalty=options.length_penalty,
|
length_penalty=options.length_penalty,
|
||||||
# max_length=self.max_length,
|
max_length=self.max_length,
|
||||||
# return_scores=True,
|
suppress_blank=options.suppress_blank,
|
||||||
# return_no_speech_prob=True,
|
suppress_tokens=options.suppress_tokens,
|
||||||
# suppress_blank=options.suppress_blank,
|
|
||||||
# suppress_tokens=options.suppress_tokens,
|
|
||||||
# max_initial_timestamp_index=max_initial_timestamp_index,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
tokens_batch = [x.sequences_ids[0] for x in result]
|
tokens_batch = [x.sequences_ids[0] for x in result]
|
||||||
|
@ -51,9 +51,11 @@ def cli():
|
|||||||
parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
|
parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
|
||||||
parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero")
|
parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero")
|
||||||
parser.add_argument("--patience", type=float, default=None, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
|
parser.add_argument("--patience", type=float, default=None, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
|
||||||
parser.add_argument("--length_penalty", type=float, default=None, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default")
|
parser.add_argument("--length_penalty", type=float, default=1.0, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default")
|
||||||
|
|
||||||
parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
|
parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
|
||||||
|
parser.add_argument("--suppress_numerals", action="store_true", help="whether to suppress numeric symbols and currency symbols during sampling, since wav2vec2 cannot align them correctly")
|
||||||
|
|
||||||
parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
|
parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
|
||||||
parser.add_argument("--condition_on_previous_text", type=str2bool, default=False, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
|
parser.add_argument("--condition_on_previous_text", type=str2bool, default=False, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
|
||||||
parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
|
parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
|
||||||
@ -130,6 +132,8 @@ def cli():
|
|||||||
"no_speech_threshold": args.pop("no_speech_threshold"),
|
"no_speech_threshold": args.pop("no_speech_threshold"),
|
||||||
"condition_on_previous_text": False,
|
"condition_on_previous_text": False,
|
||||||
"initial_prompt": args.pop("initial_prompt"),
|
"initial_prompt": args.pop("initial_prompt"),
|
||||||
|
"suppress_tokens": [int(x) for x in args.pop("suppress_tokens").split(",")],
|
||||||
|
"suppress_numerals": args.pop("suppress_numerals"),
|
||||||
}
|
}
|
||||||
|
|
||||||
writer = get_writer(output_format, output_dir)
|
writer = get_writer(output_format, output_dir)
|
||||||
|
Reference in New Issue
Block a user