diff --git a/README.md b/README.md index 8043e02..a660d2d 100644 --- a/README.md +++ b/README.md @@ -52,13 +52,6 @@ This repository provides fast automatic speech recognition (70x realtime with la **Speaker Diarization** is the process of partitioning an audio stream containing human speech into homogeneous segments according to the identity of each speaker. -- v3 pre-release [this branch](https://github.com/m-bain/whisperX/tree/v3) *70x speed-up open-sourced. Using batched whisper with faster-whisper backend*! -- v2 released, code cleanup, imports whisper library. VAD filtering is now turned on by default, as in the paper. -- Paper drop🎓👨‍🏫! Please see our [ArxiV preprint](https://arxiv.org/abs/2303.00747) for benchmarking and details of WhisperX. We also introduce more efficient batch inference resulting in large-v2 with *60-70x REAL TIME speed (not provided in this repo). -- VAD filtering: Voice Activity Detection (VAD) from [Pyannote.audio](https://huggingface.co/pyannote/voice-activity-detection) is used as a preprocessing step to remove reliance on whisper timestamps and only transcribe audio segments containing speech. add `--vad_filter True` flag, increases timestamp accuracy and robustness (requires more GPU mem due to 30s inputs in wav2vec2) -- Character level timestamps (see `*.char.ass` file output) -- Diarization (still in beta, add `--diarize`) -

New🚨

- v3 transcript segment-per-sentence: using nltk sent_tokenize for better subtitlting & better diarization @@ -87,11 +80,11 @@ See other methods [here.](https://pytorch.org/get-started/previous-versions/#v20 ### 3. Install this repo -`pip install git+https://github.com/m-bain/whisperx.git@v3` +`pip install git+https://github.com/m-bain/whisperx.git` If already installed, update package to most recent commit -`pip install git+https://github.com/m-bain/whisperx.git@v3 --upgrade` +`pip install git+https://github.com/m-bain/whisperx.git --upgrade` If wishing to modify this package, clone and install in editable mode: ``` @@ -183,10 +176,10 @@ print(result["segments"]) # after alignment diarize_model = whisperx.DiarizationPipeline(use_auth_token=YOUR_HF_TOKEN, device=device) # add min/max number of speakers if known -diarize_segments = diarize_model(input_audio_path) -# diarize_model(input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers) +diarize_segments = diarize_model(audio_file) +# diarize_model(audio_file, min_speakers=min_speakers, max_speakers=max_speakers) -result = assign_word_speakers(diarize_segments, result) +result = whisperx.assign_word_speakers(diarize_segments, result) print(diarize_segments) print(result["segments"]) # segments are now assigned speaker IDs ``` diff --git a/whisperx/alignment.py b/whisperx/alignment.py index aade4b4..8d088be 100644 --- a/whisperx/alignment.py +++ b/whisperx/alignment.py @@ -3,7 +3,7 @@ Forced Alignment with Whisper C. Max Bain """ from dataclasses import dataclass -from typing import Iterator, Union +from typing import Iterator, Union, List import numpy as np import pandas as pd @@ -13,7 +13,11 @@ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor from .audio import SAMPLE_RATE, load_audio from .utils import interpolate_nans +from .types import AlignedTranscriptionResult, SingleSegment, SingleAlignedSegment, SingleWordSegment import nltk +from nltk.tokenize.punkt import PunktSentenceTokenizer, PunktParameters + +PUNKT_ABBREVIATIONS = ['dr', 'vs', 'mr', 'mrs', 'prof'] LANGUAGES_WITHOUT_SPACES = ["ja", "zh"] @@ -32,6 +36,7 @@ DEFAULT_ALIGN_MODELS_HF = { "uk": "Yehor/wav2vec2-xls-r-300m-uk-with-small-lm", "pt": "jonatasgrosman/wav2vec2-large-xlsr-53-portuguese", "ar": "jonatasgrosman/wav2vec2-large-xlsr-53-arabic", + "cs": "comodoro/wav2vec2-xls-r-300m-cs-250", "ru": "jonatasgrosman/wav2vec2-large-xlsr-53-russian", "pl": "jonatasgrosman/wav2vec2-large-xlsr-53-polish", "hu": "jonatasgrosman/wav2vec2-large-xlsr-53-hungarian", @@ -39,7 +44,10 @@ DEFAULT_ALIGN_MODELS_HF = { "fa": "jonatasgrosman/wav2vec2-large-xlsr-53-persian", "el": "jonatasgrosman/wav2vec2-large-xlsr-53-greek", "tr": "mpoyraz/wav2vec2-xls-r-300m-cv7-turkish", + "da": "saattrupdan/wav2vec2-xls-r-300m-ftspeech", "he": "imvladikon/wav2vec2-xls-r-300m-hebrew", + "vi": 'nguyenvulebinh/wav2vec2-base-vi', + "ko": "kresnik/wav2vec2-large-xlsr-korean", } @@ -80,14 +88,14 @@ def load_align_model(language_code, device, model_name=None, model_dir=None): def align( - transcript: Iterator[dict], + transcript: Iterator[SingleSegment], model: torch.nn.Module, align_model_metadata: dict, audio: Union[str, np.ndarray, torch.Tensor], device: str, interpolate_method: str = "nearest", return_char_alignments: bool = False, -): +) -> AlignedTranscriptionResult: """ Align phoneme recognition predictions to known transcription. """ @@ -139,14 +147,18 @@ def align( if any([c in model_dictionary.keys() for c in wrd]): clean_wdx.append(wdx) - sentence_spans = list(nltk.tokenize.punkt.PunktSentenceTokenizer().span_tokenize(text)) + + punkt_param = PunktParameters() + punkt_param.abbrev_types = set(PUNKT_ABBREVIATIONS) + sentence_splitter = PunktSentenceTokenizer(punkt_param) + sentence_spans = list(sentence_splitter.span_tokenize(text)) segment["clean_char"] = clean_char segment["clean_cdx"] = clean_cdx segment["clean_wdx"] = clean_wdx segment["sentence_spans"] = sentence_spans - aligned_segments = [] + aligned_segments: List[SingleAlignedSegment] = [] # 2. Get prediction matrix from alignment model & align for sdx, segment in enumerate(transcript): @@ -154,7 +166,7 @@ def align( t2 = segment["end"] text = segment["text"] - aligned_seg = { + aligned_seg: SingleAlignedSegment = { "start": t1, "end": t2, "text": text, @@ -307,7 +319,7 @@ def align( aligned_segments += aligned_subsegments # create word_segments list - word_segments = [] + word_segments: List[SingleWordSegment] = [] for segment in aligned_segments: word_segments += segment["words"] diff --git a/whisperx/asr.py b/whisperx/asr.py index 66b58ad..d0e6962 100644 --- a/whisperx/asr.py +++ b/whisperx/asr.py @@ -11,10 +11,18 @@ from transformers.pipelines.pt_utils import PipelineIterator from .audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram from .vad import load_vad_model, merge_chunks +from .types import TranscriptionResult, SingleSegment - -def load_model(whisper_arch, device, compute_type="float16", asr_options=None, language=None, - vad_options=None, model=None, task="transcribe"): +def load_model(whisper_arch, + device, + device_index=0, + compute_type="float16", + asr_options=None, + language=None, + vad_options=None, + model=None, + task="transcribe", + download_root=None): '''Load a Whisper model for inference. Args: whisper_arch: str - The name of the Whisper model to load. @@ -22,14 +30,19 @@ def load_model(whisper_arch, device, compute_type="float16", asr_options=None, l compute_type: str - The compute type to use for the model. options: dict - A dictionary of options to use for the model. language: str - The language of the model. (use English for now) + download_root: Optional[str] - The root directory to download the model to. Returns: A Whisper pipeline. - ''' + ''' if whisper_arch.endswith(".en"): language = "en" - model = WhisperModel(whisper_arch, device=device, compute_type=compute_type) + model = WhisperModel(whisper_arch, + device=device, + device_index=device_index, + compute_type=compute_type, + download_root=download_root) if language is not None: tokenizer = faster_whisper.tokenizer.Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task=task, language=language) else: @@ -114,7 +127,7 @@ class WhisperModel(faster_whisper.WhisperModel): # suppress_tokens=options.suppress_tokens, # max_initial_timestamp_index=max_initial_timestamp_index, ) - + tokens_batch = [x.sequences_ids[0] for x in result] def decode_batch(tokens: List[List[int]]) -> str: @@ -127,7 +140,7 @@ class WhisperModel(faster_whisper.WhisperModel): text = decode_batch(tokens_batch) return text - + def encode(self, features: np.ndarray) -> ctranslate2.StorageView: # When the model is running on multiple GPUs, the encoder output should be moved # to the CPU since we don't know which GPU will handle the next job. @@ -136,9 +149,9 @@ class WhisperModel(faster_whisper.WhisperModel): if len(features.shape) == 2: features = np.expand_dims(features, 0) features = faster_whisper.transcribe.get_ctranslate2_storage(features) - + return self.model.encode(features, to_cpu=to_cpu) - + class FasterWhisperPipeline(Pipeline): """ Huggingface Pipeline wrapper for FasterWhisperModel. @@ -176,7 +189,7 @@ class FasterWhisperPipeline(Pipeline): self.device = torch.device(f"cuda:{device}") else: self.device = device - + super(Pipeline, self).__init__() self.vad_model = vad @@ -194,7 +207,7 @@ class FasterWhisperPipeline(Pipeline): def _forward(self, model_inputs): outputs = self.model.generate_segment_batched(model_inputs['inputs'], self.tokenizer, self.options) return {'text': outputs} - + def postprocess(self, model_outputs): return model_outputs @@ -214,11 +227,11 @@ class FasterWhisperPipeline(Pipeline): return final_iterator def transcribe( - self, audio: Union[str, np.ndarray], batch_size=None, num_workers=0 - ): + self, audio: Union[str, np.ndarray], batch_size=None, num_workers=0, language=None, task=None + ) -> TranscriptionResult: if isinstance(audio, str): audio = load_audio(audio) - + def data(audio, segments): for seg in segments: f1 = int(seg['start'] * SAMPLE_RATE) @@ -228,16 +241,21 @@ class FasterWhisperPipeline(Pipeline): vad_segments = self.vad_model({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": SAMPLE_RATE}) vad_segments = merge_chunks(vad_segments, 30) - - del_tokenizer = False if self.tokenizer is None: - language = self.detect_language(audio) - self.tokenizer = faster_whisper.tokenizer.Tokenizer(self.model.hf_tokenizer, self.model.model.is_multilingual, task="transcribe", language=language) - del_tokenizer = True + language = language or self.detect_language(audio) + task = task or "transcribe" + self.tokenizer = faster_whisper.tokenizer.Tokenizer(self.model.hf_tokenizer, + self.model.model.is_multilingual, task=task, + language=language) else: - language = self.tokenizer.language_code + language = language or self.tokenizer.language_code + task = task or self.tokenizer.task + if task != self.tokenizer.task or language != self.tokenizer.language_code: + self.tokenizer = faster_whisper.tokenizer.Tokenizer(self.model.hf_tokenizer, + self.model.model.is_multilingual, task=task, + language=language) - segments = [] + segments: List[SingleSegment] = [] batch_size = batch_size or self._batch_size for idx, out in enumerate(self.__call__(data(audio, vad_segments), batch_size=batch_size, num_workers=num_workers)): text = out['text'] @@ -245,14 +263,11 @@ class FasterWhisperPipeline(Pipeline): text = text[0] segments.append( { - "text": out['text'], + "text": text, "start": round(vad_segments[idx]['start'], 3), "end": round(vad_segments[idx]['end'], 3) } ) - - if del_tokenizer: - self.tokenizer = None return {"segments": segments, "language": language} diff --git a/whisperx/transcribe.py b/whisperx/transcribe.py index 3edc746..691e3f9 100644 --- a/whisperx/transcribe.py +++ b/whisperx/transcribe.py @@ -21,6 +21,7 @@ def cli(): parser.add_argument("--model", default="small", help="name of the Whisper model to use") parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default") parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference") + parser.add_argument("--device_index", default=0, type=int, help="device index to use for FasterWhisper inference") parser.add_argument("--batch_size", default=8, type=int, help="device to use for PyTorch inference") parser.add_argument("--compute_type", default="float16", type=str, choices=["float16", "float32", "int8"], help="compute type for computation") @@ -78,6 +79,7 @@ def cli(): output_dir: str = args.pop("output_dir") output_format: str = args.pop("output_format") device: str = args.pop("device") + device_index: int = args.pop("device_index") compute_type: str = args.pop("compute_type") # model_flush: bool = args.pop("model_flush") @@ -144,7 +146,7 @@ def cli(): results = [] tmp_results = [] # model = load_model(model_name, device=device, download_root=model_dir) - model = load_model(model_name, device=device, compute_type=compute_type, language=args['language'], asr_options=asr_options, vad_options={"vad_onset": vad_onset, "vad_offset": vad_offset}, task=task) + model = load_model(model_name, device=device, device_index=device_index, compute_type=compute_type, language=args['language'], asr_options=asr_options, vad_options={"vad_onset": vad_onset, "vad_offset": vad_offset}, task=task) for audio_path in args.pop("audio"): audio = load_audio(audio_path) diff --git a/whisperx/types.py b/whisperx/types.py new file mode 100644 index 0000000..75d4485 --- /dev/null +++ b/whisperx/types.py @@ -0,0 +1,58 @@ +from typing import TypedDict, Optional + + +class SingleWordSegment(TypedDict): + """ + A single word of a speech. + """ + word: str + start: float + end: float + score: float + +class SingleCharSegment(TypedDict): + """ + A single char of a speech. + """ + char: str + start: float + end: float + score: float + + +class SingleSegment(TypedDict): + """ + A single segment (up to multiple sentences) of a speech. + """ + + start: float + end: float + text: str + + +class SingleAlignedSegment(TypedDict): + """ + A single segment (up to multiple sentences) of a speech with word alignment. + """ + + start: float + end: float + text: str + words: list[SingleWordSegment] + chars: Optional[list[SingleCharSegment]] + + +class TranscriptionResult(TypedDict): + """ + A list of segments and word segments of a speech. + """ + segments: list[SingleSegment] + language: str + + +class AlignedTranscriptionResult(TypedDict): + """ + A list of segments and word segments of a speech. + """ + segments: list[SingleAlignedSegment] + word_segments: list[SingleWordSegment] diff --git a/whisperx/vad.py b/whisperx/vad.py index a7a2451..15a9e5e 100644 --- a/whisperx/vad.py +++ b/whisperx/vad.py @@ -157,7 +157,7 @@ class Binarize: curr_scores = curr_scores[min_score_div_idx+1:] curr_timestamps = curr_timestamps[min_score_div_idx+1:] # switching from active to inactive - elif y <= self.offset: + elif y < self.offset: region = Segment(start - self.pad_onset, t + self.pad_offset) active[region, k] = label start = t @@ -169,7 +169,7 @@ class Binarize: # currently inactive else: # switching from inactive to active - if y >= self.onset: + if y > self.onset: start = t is_active = True