diff --git a/README.md b/README.md index 28345f1..b52401b 100644 --- a/README.md +++ b/README.md @@ -54,6 +54,7 @@ This repository provides fast automatic speech recognition (70x realtime with la

New🚨

+- _WhisperX_ accepted at INTERSPEECH 2023 - v3 transcript segment-per-sentence: using nltk sent_tokenize for better subtitlting & better diarization - v3 released, 70x speed-up open-sourced. Using batched whisper with [faster-whisper](https://github.com/guillaumekln/faster-whisper) backend! - v2 released, code cleanup, imports whisper library VAD filtering is now turned on by default, as in the paper. @@ -74,7 +75,7 @@ GPU execution requires the NVIDIA libraries cuBLAS 11.x and cuDNN 8.x to be inst ### 2. Install PyTorch2.0, e.g. for Linux and Windows CUDA11.7: -`conda install pytorch==2.0.0 torchvision==0.15.0 torchaudio==2.0.0 pytorch-cuda=11.8 -c pytorch -c nvidia` +`conda install pytorch==2.0.0 torchaudio==2.0.0 pytorch-cuda=11.8 -c pytorch -c nvidia` See other methods [here.](https://pytorch.org/get-started/previous-versions/#v200) @@ -184,6 +185,11 @@ print(diarize_segments) print(result["segments"]) # segments are now assigned speaker IDs ``` +## Demos 🚀 + +[![Replicate](https://replicate.com/daanelson/whisperx/badge)](https://replicate.com/daanelson/whisperx) + +If you don't have access to your own GPUs, use the link above to try out WhisperX.

Technical Details 👷‍♂️

@@ -276,7 +282,7 @@ If you use this in your research, please cite the paper: @article{bain2022whisperx, title={WhisperX: Time-Accurate Speech Transcription of Long-Form Audio}, author={Bain, Max and Huh, Jaesung and Han, Tengda and Zisserman, Andrew}, - journal={arXiv preprint, arXiv:2303.00747}, + journal={INTERSPEECH 2023}, year={2023} } ``` diff --git a/requirements.txt b/requirements.txt index ec90a07..ddfa28a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,8 @@ -torch==2.0.0 -torchaudio==2.0.1 +torch>=2 +torchaudio>=2 faster-whisper transformers -ffmpeg-python==0.2.0 +ffmpeg-python>=0.2 pandas -setuptools==65.6.3 -nltk \ No newline at end of file +setuptools>=65 +nltk diff --git a/whisperx/alignment.py b/whisperx/alignment.py index 13dfddc..2717bc4 100644 --- a/whisperx/alignment.py +++ b/whisperx/alignment.py @@ -15,6 +15,9 @@ from .audio import SAMPLE_RATE, load_audio from .utils import interpolate_nans from .types import AlignedTranscriptionResult, SingleSegment, SingleAlignedSegment, SingleWordSegment import nltk +from nltk.tokenize.punkt import PunktSentenceTokenizer, PunktParameters + +PUNKT_ABBREVIATIONS = ['dr', 'vs', 'mr', 'mrs', 'prof'] LANGUAGES_WITHOUT_SPACES = ["ja", "zh"] @@ -33,6 +36,7 @@ DEFAULT_ALIGN_MODELS_HF = { "uk": "Yehor/wav2vec2-xls-r-300m-uk-with-small-lm", "pt": "jonatasgrosman/wav2vec2-large-xlsr-53-portuguese", "ar": "jonatasgrosman/wav2vec2-large-xlsr-53-arabic", + "cs": "comodoro/wav2vec2-xls-r-300m-cs-250", "ru": "jonatasgrosman/wav2vec2-large-xlsr-53-russian", "pl": "jonatasgrosman/wav2vec2-large-xlsr-53-polish", "hu": "jonatasgrosman/wav2vec2-large-xlsr-53-hungarian", @@ -42,6 +46,9 @@ DEFAULT_ALIGN_MODELS_HF = { "tr": "mpoyraz/wav2vec2-xls-r-300m-cv7-turkish", "da": "saattrupdan/wav2vec2-xls-r-300m-ftspeech", "he": "imvladikon/wav2vec2-xls-r-300m-hebrew", + "vi": 'nguyenvulebinh/wav2vec2-base-vi', + "ko": "kresnik/wav2vec2-large-xlsr-korean", + "ur": "kingabzpro/wav2vec2-large-xls-r-300m-Urdu", } @@ -141,7 +148,11 @@ def align( if any([c in model_dictionary.keys() for c in wrd]): clean_wdx.append(wdx) - sentence_spans = list(nltk.tokenize.punkt.PunktSentenceTokenizer().span_tokenize(text)) + + punkt_param = PunktParameters() + punkt_param.abbrev_types = set(PUNKT_ABBREVIATIONS) + sentence_splitter = PunktSentenceTokenizer(punkt_param) + sentence_spans = list(sentence_splitter.span_tokenize(text)) segment["clean_char"] = clean_char segment["clean_cdx"] = clean_cdx @@ -300,6 +311,8 @@ def align( aligned_subsegments["end"] = interpolate_nans(aligned_subsegments["end"], method=interpolate_method) # concatenate sentences with same timestamps agg_dict = {"text": " ".join, "words": "sum"} + if model_lang in LANGUAGES_WITHOUT_SPACES: + agg_dict["text"] = "".join if return_char_alignments: agg_dict["chars"] = "sum" aligned_subsegments= aligned_subsegments.groupby(["start", "end"], as_index=False).agg(agg_dict) diff --git a/whisperx/asr.py b/whisperx/asr.py index 88d5bf6..09454c9 100644 --- a/whisperx/asr.py +++ b/whisperx/asr.py @@ -13,8 +13,25 @@ from .audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram from .vad import load_vad_model, merge_chunks from .types import TranscriptionResult, SingleSegment -def load_model(whisper_arch, device, compute_type="float16", asr_options=None, language=None, - vad_options=None, model=None, task="transcribe"): +def find_numeral_symbol_tokens(tokenizer): + numeral_symbol_tokens = [] + for i in range(tokenizer.eot): + token = tokenizer.decode([i]).removeprefix(" ") + has_numeral_symbol = any(c in "0123456789%$£" for c in token) + if has_numeral_symbol: + numeral_symbol_tokens.append(i) + return numeral_symbol_tokens + +def load_model(whisper_arch, + device, + device_index=0, + compute_type="float16", + asr_options=None, + language=None, + vad_options=None, + model=None, + task="transcribe", + download_root=None): '''Load a Whisper model for inference. Args: whisper_arch: str - The name of the Whisper model to load. @@ -22,14 +39,19 @@ def load_model(whisper_arch, device, compute_type="float16", asr_options=None, l compute_type: str - The compute type to use for the model. options: dict - A dictionary of options to use for the model. language: str - The language of the model. (use English for now) + download_root: Optional[str] - The root directory to download the model to. Returns: A Whisper pipeline. - ''' + ''' if whisper_arch.endswith(".en"): language = "en" - model = WhisperModel(whisper_arch, device=device, compute_type=compute_type) + model = WhisperModel(whisper_arch, + device=device, + device_index=device_index, + compute_type=compute_type, + download_root=download_root) if language is not None: tokenizer = faster_whisper.tokenizer.Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task=task, language=language) else: @@ -54,11 +76,22 @@ def load_model(whisper_arch, device, compute_type="float16", asr_options=None, l "max_initial_timestamp": 0.0, "word_timestamps": False, "prepend_punctuations": "\"'“¿([{-", - "append_punctuations": "\"'.。,,!!??::”)]}、" + "append_punctuations": "\"'.。,,!!??::”)]}、", + "suppress_numerals": False, } if asr_options is not None: default_asr_options.update(asr_options) + + if default_asr_options["suppress_numerals"]: + if tokenizer is None: + tokenizer = faster_whisper.tokenizer.Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task=task, language="en") + numeral_symbol_tokens = find_numeral_symbol_tokens(tokenizer) + print(f"Suppressing numeral and symbol tokens: {numeral_symbol_tokens}") + default_asr_options["suppress_tokens"] += numeral_symbol_tokens + default_asr_options["suppress_tokens"] = list(set(default_asr_options["suppress_tokens"])) + del default_asr_options["suppress_numerals"] + default_asr_options = faster_whisper.transcribe.TranscriptionOptions(**default_asr_options) default_vad_options = { @@ -106,15 +139,12 @@ class WhisperModel(faster_whisper.WhisperModel): result = self.model.generate( encoder_output, [prompt] * batch_size, - # length_penalty=options.length_penalty, - # max_length=self.max_length, - # return_scores=True, - # return_no_speech_prob=True, - # suppress_blank=options.suppress_blank, - # suppress_tokens=options.suppress_tokens, - # max_initial_timestamp_index=max_initial_timestamp_index, + length_penalty=options.length_penalty, + max_length=self.max_length, + suppress_blank=options.suppress_blank, + suppress_tokens=options.suppress_tokens, ) - + tokens_batch = [x.sequences_ids[0] for x in result] def decode_batch(tokens: List[List[int]]) -> str: @@ -127,7 +157,7 @@ class WhisperModel(faster_whisper.WhisperModel): text = decode_batch(tokens_batch) return text - + def encode(self, features: np.ndarray) -> ctranslate2.StorageView: # When the model is running on multiple GPUs, the encoder output should be moved # to the CPU since we don't know which GPU will handle the next job. @@ -136,9 +166,9 @@ class WhisperModel(faster_whisper.WhisperModel): if len(features.shape) == 2: features = np.expand_dims(features, 0) features = faster_whisper.transcribe.get_ctranslate2_storage(features) - + return self.model.encode(features, to_cpu=to_cpu) - + class FasterWhisperPipeline(Pipeline): """ Huggingface Pipeline wrapper for FasterWhisperModel. @@ -176,7 +206,7 @@ class FasterWhisperPipeline(Pipeline): self.device = torch.device(f"cuda:{device}") else: self.device = device - + super(Pipeline, self).__init__() self.vad_model = vad @@ -194,7 +224,7 @@ class FasterWhisperPipeline(Pipeline): def _forward(self, model_inputs): outputs = self.model.generate_segment_batched(model_inputs['inputs'], self.tokenizer, self.options) return {'text': outputs} - + def postprocess(self, model_outputs): return model_outputs @@ -214,11 +244,11 @@ class FasterWhisperPipeline(Pipeline): return final_iterator def transcribe( - self, audio: Union[str, np.ndarray], batch_size=None, num_workers=0 + self, audio: Union[str, np.ndarray], batch_size=None, num_workers=0, language=None, task=None ) -> TranscriptionResult: if isinstance(audio, str): audio = load_audio(audio) - + def data(audio, segments): for seg in segments: f1 = int(seg['start'] * SAMPLE_RATE) @@ -228,14 +258,19 @@ class FasterWhisperPipeline(Pipeline): vad_segments = self.vad_model({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": SAMPLE_RATE}) vad_segments = merge_chunks(vad_segments, 30) - - del_tokenizer = False if self.tokenizer is None: - language = self.detect_language(audio) - self.tokenizer = faster_whisper.tokenizer.Tokenizer(self.model.hf_tokenizer, self.model.model.is_multilingual, task="transcribe", language=language) - del_tokenizer = True + language = language or self.detect_language(audio) + task = task or "transcribe" + self.tokenizer = faster_whisper.tokenizer.Tokenizer(self.model.hf_tokenizer, + self.model.model.is_multilingual, task=task, + language=language) else: - language = self.tokenizer.language_code + language = language or self.tokenizer.language_code + task = task or self.tokenizer.task + if task != self.tokenizer.task or language != self.tokenizer.language_code: + self.tokenizer = faster_whisper.tokenizer.Tokenizer(self.model.hf_tokenizer, + self.model.model.is_multilingual, task=task, + language=language) segments: List[SingleSegment] = [] batch_size = batch_size or self._batch_size @@ -250,9 +285,6 @@ class FasterWhisperPipeline(Pipeline): "end": round(vad_segments[idx]['end'], 3) } ) - - if del_tokenizer: - self.tokenizer = None return {"segments": segments, "language": language} diff --git a/whisperx/transcribe.py b/whisperx/transcribe.py index 3edc746..1cc144e 100644 --- a/whisperx/transcribe.py +++ b/whisperx/transcribe.py @@ -21,11 +21,12 @@ def cli(): parser.add_argument("--model", default="small", help="name of the Whisper model to use") parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default") parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference") - parser.add_argument("--batch_size", default=8, type=int, help="device to use for PyTorch inference") + parser.add_argument("--device_index", default=0, type=int, help="device index to use for FasterWhisper inference") + parser.add_argument("--batch_size", default=8, type=int, help="the preferred batch size for inference") parser.add_argument("--compute_type", default="float16", type=str, choices=["float16", "float32", "int8"], help="compute type for computation") parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs") - parser.add_argument("--output_format", "-f", type=str, default="all", choices=["all", "srt", "vtt", "txt", "tsv", "json"], help="format of the output file; if not specified, all available formats will be produced") + parser.add_argument("--output_format", "-f", type=str, default="all", choices=["all", "srt", "vtt", "txt", "tsv", "json", "aud"], help="format of the output file; if not specified, all available formats will be produced") parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages") parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')") @@ -50,9 +51,11 @@ def cli(): parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature") parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero") parser.add_argument("--patience", type=float, default=None, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search") - parser.add_argument("--length_penalty", type=float, default=None, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default") + parser.add_argument("--length_penalty", type=float, default=1.0, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default") parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations") + parser.add_argument("--suppress_numerals", action="store_true", help="whether to suppress numeric symbols and currency symbols during sampling, since wav2vec2 cannot align them correctly") + parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.") parser.add_argument("--condition_on_previous_text", type=str2bool, default=False, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop") parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default") @@ -78,6 +81,7 @@ def cli(): output_dir: str = args.pop("output_dir") output_format: str = args.pop("output_format") device: str = args.pop("device") + device_index: int = args.pop("device_index") compute_type: str = args.pop("compute_type") # model_flush: bool = args.pop("model_flush") @@ -128,6 +132,8 @@ def cli(): "no_speech_threshold": args.pop("no_speech_threshold"), "condition_on_previous_text": False, "initial_prompt": args.pop("initial_prompt"), + "suppress_tokens": [int(x) for x in args.pop("suppress_tokens").split(",")], + "suppress_numerals": args.pop("suppress_numerals"), } writer = get_writer(output_format, output_dir) @@ -144,7 +150,7 @@ def cli(): results = [] tmp_results = [] # model = load_model(model_name, device=device, download_root=model_dir) - model = load_model(model_name, device=device, compute_type=compute_type, language=args['language'], asr_options=asr_options, vad_options={"vad_onset": vad_onset, "vad_offset": vad_offset}, task=task) + model = load_model(model_name, device=device, device_index=device_index, compute_type=compute_type, language=args['language'], asr_options=asr_options, vad_options={"vad_onset": vad_onset, "vad_offset": vad_offset}, task=task) for audio_path in args.pop("audio"): audio = load_audio(audio_path) @@ -204,4 +210,4 @@ def cli(): writer(result, audio_path, writer_args) if __name__ == "__main__": - cli() \ No newline at end of file + cli() diff --git a/whisperx/types.py b/whisperx/types.py index 75d4485..68f2d78 100644 --- a/whisperx/types.py +++ b/whisperx/types.py @@ -1,4 +1,4 @@ -from typing import TypedDict, Optional +from typing import TypedDict, Optional, List class SingleWordSegment(TypedDict): @@ -38,15 +38,15 @@ class SingleAlignedSegment(TypedDict): start: float end: float text: str - words: list[SingleWordSegment] - chars: Optional[list[SingleCharSegment]] + words: List[SingleWordSegment] + chars: Optional[List[SingleCharSegment]] class TranscriptionResult(TypedDict): """ A list of segments and word segments of a speech. """ - segments: list[SingleSegment] + segments: List[SingleSegment] language: str @@ -54,5 +54,5 @@ class AlignedTranscriptionResult(TypedDict): """ A list of segments and word segments of a speech. """ - segments: list[SingleAlignedSegment] - word_segments: list[SingleWordSegment] + segments: List[SingleAlignedSegment] + word_segments: List[SingleWordSegment] diff --git a/whisperx/utils.py b/whisperx/utils.py index d042bb7..36c7543 100644 --- a/whisperx/utils.py +++ b/whisperx/utils.py @@ -365,6 +365,28 @@ class WriteTSV(ResultWriter): print(round(1000 * segment["end"]), file=file, end="\t") print(segment["text"].strip().replace("\t", " "), file=file, flush=True) +class WriteAudacity(ResultWriter): + """ + Write a transcript to a text file that audacity can import as labels. + The extension used is "aud" to distinguish it from the txt file produced by WriteTXT. + Yet this is not an audacity project but only a label file! + + Please note : Audacity uses seconds in timestamps not ms! + Also there is no header expected. + + If speaker is provided it is prepended to the text between double square brackets [[]]. + """ + + extension: str = "aud" + + def write_result(self, result: dict, file: TextIO, options: dict): + ARROW = " " + for segment in result["segments"]: + print(segment["start"], file=file, end=ARROW) + print(segment["end"], file=file, end=ARROW) + print( ( ("[[" + segment["speaker"] + "]]") if "speaker" in segment else "") + segment["text"].strip().replace("\t", " "), file=file, flush=True) + + class WriteJSON(ResultWriter): extension: str = "json" @@ -383,6 +405,9 @@ def get_writer( "tsv": WriteTSV, "json": WriteJSON, } + optional_writers = { + "aud": WriteAudacity, + } if output_format == "all": all_writers = [writer(output_dir) for writer in writers.values()] @@ -393,10 +418,12 @@ def get_writer( return write_all + if output_format in optional_writers: + return optional_writers[output_format](output_dir) return writers[output_format](output_dir) def interpolate_nans(x, method='nearest'): if x.notnull().sum() > 1: return x.interpolate(method=method).ffill().bfill() else: - return x.ffill().bfill() \ No newline at end of file + return x.ffill().bfill() diff --git a/whisperx/vad.py b/whisperx/vad.py index 42b0bfb..15a9e5e 100644 --- a/whisperx/vad.py +++ b/whisperx/vad.py @@ -147,8 +147,6 @@ class Binarize: if is_active: curr_duration = t - start if curr_duration > self.max_duration: - # if curr_duration > 15: - # import pdb; pdb.set_trace() search_after = len(curr_scores) // 2 # divide segment min_score_div_idx = search_after + np.argmin(curr_scores[search_after:]) @@ -166,14 +164,14 @@ class Binarize: is_active = False curr_scores = [] curr_timestamps = [] + curr_scores.append(y) + curr_timestamps.append(t) # currently inactive else: # switching from inactive to active if y > self.onset: start = t is_active = True - curr_scores.append(y) - curr_timestamps.append(t) # if active at the end, add final region if is_active: