From c8404d9805d50d89f96a6c07580e44157c21553a Mon Sep 17 00:00:00 2001 From: Marcus Brandt Date: Sat, 4 Mar 2023 13:20:40 +0100 Subject: [PATCH 01/17] added a danish alignment model --- whisperx/alignment.py | 1 + 1 file changed, 1 insertion(+) diff --git a/whisperx/alignment.py b/whisperx/alignment.py index 4188da5..2088ecd 100644 --- a/whisperx/alignment.py +++ b/whisperx/alignment.py @@ -37,6 +37,7 @@ DEFAULT_ALIGN_MODELS_HF = { "fa": "jonatasgrosman/wav2vec2-large-xlsr-53-persian", "el": "jonatasgrosman/wav2vec2-large-xlsr-53-greek", "tr": "mpoyraz/wav2vec2-xls-r-300m-cv7-turkish", + "da": "saattrupdan/wav2vec2-xls-r-300m-ftspeech" } From 0b839f3f019d6b71773c07fe2f1d671091f4420f Mon Sep 17 00:00:00 2001 From: Max Bain <36994049+m-bain@users.noreply.github.com> Date: Sun, 7 May 2023 20:36:08 +0100 Subject: [PATCH 02/17] Update README.md --- README.md | 7 ------- 1 file changed, 7 deletions(-) diff --git a/README.md b/README.md index ae3d5bd..a6d400c 100644 --- a/README.md +++ b/README.md @@ -52,13 +52,6 @@ This repository provides fast automatic speaker recognition (70x realtime with l **Speaker Diarization** is the process of partitioning an audio stream containing human speech into homogeneous segments according to the identity of each speaker. -- v3 pre-release [this branch](https://github.com/m-bain/whisperX/tree/v3) *70x speed-up open-sourced. Using batched whisper with faster-whisper backend*! -- v2 released, code cleanup, imports whisper library. VAD filtering is now turned on by default, as in the paper. -- Paper drop🎓👨‍🏫! Please see our [ArxiV preprint](https://arxiv.org/abs/2303.00747) for benchmarking and details of WhisperX. We also introduce more efficient batch inference resulting in large-v2 with *60-70x REAL TIME speed (not provided in this repo). -- VAD filtering: Voice Activity Detection (VAD) from [Pyannote.audio](https://huggingface.co/pyannote/voice-activity-detection) is used as a preprocessing step to remove reliance on whisper timestamps and only transcribe audio segments containing speech. add `--vad_filter True` flag, increases timestamp accuracy and robustness (requires more GPU mem due to 30s inputs in wav2vec2) -- Character level timestamps (see `*.char.ass` file output) -- Diarization (still in beta, add `--diarize`) -

New🚨

- v3 transcript segment-per-sentence: using nltk sent_tokenize for better subtitlting & better diarization From 2efa136114a3ad677a76c8e7b9e75008b3bb4f60 Mon Sep 17 00:00:00 2001 From: Max Bain <36994049+m-bain@users.noreply.github.com> Date: Mon, 8 May 2023 17:20:38 +0100 Subject: [PATCH 03/17] update python usage example --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index a6d400c..9d61d9b 100644 --- a/README.md +++ b/README.md @@ -176,10 +176,10 @@ print(result["segments"]) # after alignment diarize_model = whisperx.DiarizationPipeline(use_auth_token=YOUR_HF_TOKEN, device=device) # add min/max number of speakers if known -diarize_segments = diarize_model(input_audio_path) -# diarize_model(input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers) +diarize_segments = diarize_model(audio_file) +# diarize_model(audio_file, min_speakers=min_speakers, max_speakers=max_speakers) -result = assign_word_speakers(diarize_segments, result) +result = whisperx.assign_word_speakers(diarize_segments, result) print(diarize_segments) print(result["segments"]) # segments are now assigned speaker IDs ``` From b50aafb17b286a162e5bf08bd71eb820d0df396d Mon Sep 17 00:00:00 2001 From: Simon Date: Mon, 8 May 2023 20:03:42 +0200 Subject: [PATCH 04/17] Fix tuple unpacking --- whisperx/asr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/whisperx/asr.py b/whisperx/asr.py index f2c54f6..21357ec 100644 --- a/whisperx/asr.py +++ b/whisperx/asr.py @@ -245,7 +245,7 @@ class FasterWhisperPipeline(Pipeline): text = text[0] segments.append( { - "text": out['text'], + "text": text, "start": round(vad_segments[idx]['start'], 3), "end": round(vad_segments[idx]['end'], 3) } From eabf35dff0d80ff3cabc946b65d2faf42797e671 Mon Sep 17 00:00:00 2001 From: Simon Date: Mon, 8 May 2023 20:45:34 +0200 Subject: [PATCH 05/17] Custom result types --- whisperx/alignment.py | 13 +++++----- whisperx/asr.py | 6 ++--- whisperx/types.py | 58 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 68 insertions(+), 9 deletions(-) create mode 100644 whisperx/types.py diff --git a/whisperx/alignment.py b/whisperx/alignment.py index b873475..eb8d4b6 100644 --- a/whisperx/alignment.py +++ b/whisperx/alignment.py @@ -3,7 +3,7 @@ Forced Alignment with Whisper C. Max Bain """ from dataclasses import dataclass -from typing import Iterator, Union +from typing import Iterator, Union, List import numpy as np import pandas as pd @@ -13,6 +13,7 @@ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor from .audio import SAMPLE_RATE, load_audio from .utils import interpolate_nans +from .types import AlignedTranscriptionResult, SingleSegment, SingleAlignedSegment, SingleWordSegment import nltk LANGUAGES_WITHOUT_SPACES = ["ja", "zh"] @@ -80,14 +81,14 @@ def load_align_model(language_code, device, model_name=None, model_dir=None): def align( - transcript: Iterator[dict], + transcript: Iterator[SingleSegment], model: torch.nn.Module, align_model_metadata: dict, audio: Union[str, np.ndarray, torch.Tensor], device: str, interpolate_method: str = "nearest", return_char_alignments: bool = False, -): +) -> AlignedTranscriptionResult: """ Align phoneme recognition predictions to known transcription. """ @@ -146,7 +147,7 @@ def align( segment["clean_wdx"] = clean_wdx segment["sentence_spans"] = sentence_spans - aligned_segments = [] + aligned_segments: List[SingleAlignedSegment] = [] # 2. Get prediction matrix from alignment model & align for sdx, segment in enumerate(transcript): @@ -154,7 +155,7 @@ def align( t2 = segment["end"] text = segment["text"] - aligned_seg = { + aligned_seg: SingleAlignedSegment = { "start": t1, "end": t2, "text": text, @@ -301,7 +302,7 @@ def align( aligned_segments += aligned_subsegments # create word_segments list - word_segments = [] + word_segments: List[SingleWordSegment] = [] for segment in aligned_segments: word_segments += segment["words"] diff --git a/whisperx/asr.py b/whisperx/asr.py index 21357ec..e131ae1 100644 --- a/whisperx/asr.py +++ b/whisperx/asr.py @@ -11,7 +11,7 @@ from transformers.pipelines.pt_utils import PipelineIterator from .audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram from .vad import load_vad_model, merge_chunks - +from .types import TranscriptionResult, SingleSegment def load_model(whisper_arch, device, compute_type="float16", asr_options=None, language=None, vad_options=None, model=None): @@ -215,7 +215,7 @@ class FasterWhisperPipeline(Pipeline): def transcribe( self, audio: Union[str, np.ndarray], batch_size=None, num_workers=0 - ): + ) -> TranscriptionResult: if isinstance(audio, str): audio = load_audio(audio) @@ -237,7 +237,7 @@ class FasterWhisperPipeline(Pipeline): else: language = self.tokenizer.language_code - segments = [] + segments: List[SingleSegment] = [] batch_size = batch_size or self._batch_size for idx, out in enumerate(self.__call__(data(audio, vad_segments), batch_size=batch_size, num_workers=num_workers)): text = out['text'] diff --git a/whisperx/types.py b/whisperx/types.py new file mode 100644 index 0000000..75d4485 --- /dev/null +++ b/whisperx/types.py @@ -0,0 +1,58 @@ +from typing import TypedDict, Optional + + +class SingleWordSegment(TypedDict): + """ + A single word of a speech. + """ + word: str + start: float + end: float + score: float + +class SingleCharSegment(TypedDict): + """ + A single char of a speech. + """ + char: str + start: float + end: float + score: float + + +class SingleSegment(TypedDict): + """ + A single segment (up to multiple sentences) of a speech. + """ + + start: float + end: float + text: str + + +class SingleAlignedSegment(TypedDict): + """ + A single segment (up to multiple sentences) of a speech with word alignment. + """ + + start: float + end: float + text: str + words: list[SingleWordSegment] + chars: Optional[list[SingleCharSegment]] + + +class TranscriptionResult(TypedDict): + """ + A list of segments and word segments of a speech. + """ + segments: list[SingleSegment] + language: str + + +class AlignedTranscriptionResult(TypedDict): + """ + A list of segments and word segments of a speech. + """ + segments: list[SingleAlignedSegment] + word_segments: list[SingleWordSegment] From 5421f1d7ca5aef03fcb9a0a7c5bc415feffec6f8 Mon Sep 17 00:00:00 2001 From: Max Bain <36994049+m-bain@users.noreply.github.com> Date: Tue, 9 May 2023 13:42:50 +0100 Subject: [PATCH 06/17] remove v3 tag on pip install --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 9d61d9b..943c482 100644 --- a/README.md +++ b/README.md @@ -80,11 +80,11 @@ See other methods [here.](https://pytorch.org/get-started/locally/) ### 3. Install this repo -`pip install git+https://github.com/m-bain/whisperx.git@v3` +`pip install git+https://github.com/m-bain/whisperx.git` If already installed, update package to most recent commit -`pip install git+https://github.com/m-bain/whisperx.git@v3 --upgrade` +`pip install git+https://github.com/m-bain/whisperx.git --upgrade` If wishing to modify this package, clone and install in editable mode: ``` From 53396adb210d1db07f4400bb29e8aa8c0ae88af5 Mon Sep 17 00:00:00 2001 From: Simon Date: Sat, 20 May 2023 13:02:46 +0200 Subject: [PATCH 07/17] add device_index --- whisperx/asr.py | 4 ++-- whisperx/transcribe.py | 4 +++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/whisperx/asr.py b/whisperx/asr.py index 88d5bf6..470e701 100644 --- a/whisperx/asr.py +++ b/whisperx/asr.py @@ -13,7 +13,7 @@ from .audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram from .vad import load_vad_model, merge_chunks from .types import TranscriptionResult, SingleSegment -def load_model(whisper_arch, device, compute_type="float16", asr_options=None, language=None, +def load_model(whisper_arch, device, device_index=0, compute_type="float16", asr_options=None, language=None, vad_options=None, model=None, task="transcribe"): '''Load a Whisper model for inference. Args: @@ -29,7 +29,7 @@ def load_model(whisper_arch, device, compute_type="float16", asr_options=None, l if whisper_arch.endswith(".en"): language = "en" - model = WhisperModel(whisper_arch, device=device, compute_type=compute_type) + model = WhisperModel(whisper_arch, device=device, device_index=device_index, compute_type=compute_type) if language is not None: tokenizer = faster_whisper.tokenizer.Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task=task, language=language) else: diff --git a/whisperx/transcribe.py b/whisperx/transcribe.py index 3edc746..4432abe 100644 --- a/whisperx/transcribe.py +++ b/whisperx/transcribe.py @@ -21,6 +21,7 @@ def cli(): parser.add_argument("--model", default="small", help="name of the Whisper model to use") parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default") parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference") + parser.add_argument("--device_index", default=None, type=int, help="device index to use for FasterWhisper inference") parser.add_argument("--batch_size", default=8, type=int, help="device to use for PyTorch inference") parser.add_argument("--compute_type", default="float16", type=str, choices=["float16", "float32", "int8"], help="compute type for computation") @@ -78,6 +79,7 @@ def cli(): output_dir: str = args.pop("output_dir") output_format: str = args.pop("output_format") device: str = args.pop("device") + device_index: int = args.pop("device_index") compute_type: str = args.pop("compute_type") # model_flush: bool = args.pop("model_flush") @@ -144,7 +146,7 @@ def cli(): results = [] tmp_results = [] # model = load_model(model_name, device=device, download_root=model_dir) - model = load_model(model_name, device=device, compute_type=compute_type, language=args['language'], asr_options=asr_options, vad_options={"vad_onset": vad_onset, "vad_offset": vad_offset}, task=task) + model = load_model(model_name, device=device, device_index=device_index, compute_type=compute_type, language=args['language'], asr_options=asr_options, vad_options={"vad_onset": vad_onset, "vad_offset": vad_offset}, task=task) for audio_path in args.pop("audio"): audio = load_audio(audio_path) From 74b98ebfaab771f4078c7ffe973117257667dda2 Mon Sep 17 00:00:00 2001 From: Simon Date: Sat, 20 May 2023 13:11:30 +0200 Subject: [PATCH 08/17] ensure device_index not None --- whisperx/transcribe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/whisperx/transcribe.py b/whisperx/transcribe.py index 4432abe..691e3f9 100644 --- a/whisperx/transcribe.py +++ b/whisperx/transcribe.py @@ -21,7 +21,7 @@ def cli(): parser.add_argument("--model", default="small", help="name of the Whisper model to use") parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default") parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference") - parser.add_argument("--device_index", default=None, type=int, help="device index to use for FasterWhisper inference") + parser.add_argument("--device_index", default=0, type=int, help="device index to use for FasterWhisper inference") parser.add_argument("--batch_size", default=8, type=int, help="device to use for PyTorch inference") parser.add_argument("--compute_type", default="float16", type=str, choices=["float16", "float32", "int8"], help="compute type for computation") From 1fc965bc1a78b5ae9a9bb155ab02a426dce6ea58 Mon Sep 17 00:00:00 2001 From: Simon Date: Sat, 20 May 2023 15:30:25 +0200 Subject: [PATCH 09/17] add task, language keyword to transcribe --- whisperx/asr.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/whisperx/asr.py b/whisperx/asr.py index 88d5bf6..2fab8bc 100644 --- a/whisperx/asr.py +++ b/whisperx/asr.py @@ -214,7 +214,7 @@ class FasterWhisperPipeline(Pipeline): return final_iterator def transcribe( - self, audio: Union[str, np.ndarray], batch_size=None, num_workers=0 + self, audio: Union[str, np.ndarray], batch_size=None, num_workers=0, language=None, task=None ) -> TranscriptionResult: if isinstance(audio, str): audio = load_audio(audio) @@ -229,13 +229,12 @@ class FasterWhisperPipeline(Pipeline): vad_segments = self.vad_model({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": SAMPLE_RATE}) vad_segments = merge_chunks(vad_segments, 30) - del_tokenizer = False - if self.tokenizer is None: - language = self.detect_language(audio) - self.tokenizer = faster_whisper.tokenizer.Tokenizer(self.model.hf_tokenizer, self.model.model.is_multilingual, task="transcribe", language=language) - del_tokenizer = True - else: - language = self.tokenizer.language_code + language = language or self.tokenizer.language_code + task = task or self.tokenizer.task + if task != self.tokenizer.task or language != self.tokenizer.language_code: + self.tokenizer = faster_whisper.tokenizer.Tokenizer(self.model.hf_tokenizer, + self.model.model.is_multilingual, task=task, + language=language) segments: List[SingleSegment] = [] batch_size = batch_size or self._batch_size @@ -250,9 +249,6 @@ class FasterWhisperPipeline(Pipeline): "end": round(vad_segments[idx]['end'], 3) } ) - - if del_tokenizer: - self.tokenizer = None return {"segments": segments, "language": language} From 715435db4284c1e73caf284662c131542a938bb9 Mon Sep 17 00:00:00 2001 From: Simon Date: Sat, 20 May 2023 15:42:21 +0200 Subject: [PATCH 10/17] add tokenizer is None case --- whisperx/asr.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/whisperx/asr.py b/whisperx/asr.py index 2fab8bc..b4035e5 100644 --- a/whisperx/asr.py +++ b/whisperx/asr.py @@ -228,9 +228,12 @@ class FasterWhisperPipeline(Pipeline): vad_segments = self.vad_model({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": SAMPLE_RATE}) vad_segments = merge_chunks(vad_segments, 30) - - language = language or self.tokenizer.language_code - task = task or self.tokenizer.task + if self.tokenizer is None: + language = language or self.detect_language(audio) + task = task or "transcribe" + else: + language = language or self.tokenizer.language_code + task = task or self.tokenizer.task if task != self.tokenizer.task or language != self.tokenizer.language_code: self.tokenizer = faster_whisper.tokenizer.Tokenizer(self.model.hf_tokenizer, self.model.model.is_multilingual, task=task, From a1c705b3a75a0582733109136b6013e652e14464 Mon Sep 17 00:00:00 2001 From: Simon Date: Sat, 20 May 2023 15:52:45 +0200 Subject: [PATCH 11/17] fix tokenizer is None --- whisperx/asr.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/whisperx/asr.py b/whisperx/asr.py index b4035e5..9b1e450 100644 --- a/whisperx/asr.py +++ b/whisperx/asr.py @@ -231,13 +231,16 @@ class FasterWhisperPipeline(Pipeline): if self.tokenizer is None: language = language or self.detect_language(audio) task = task or "transcribe" - else: - language = language or self.tokenizer.language_code - task = task or self.tokenizer.task - if task != self.tokenizer.task or language != self.tokenizer.language_code: self.tokenizer = faster_whisper.tokenizer.Tokenizer(self.model.hf_tokenizer, self.model.model.is_multilingual, task=task, language=language) + else: + language = language or self.tokenizer.language_code + task = task or self.tokenizer.task + if task != self.tokenizer.task or language != self.tokenizer.language_code: + self.tokenizer = faster_whisper.tokenizer.Tokenizer(self.model.hf_tokenizer, + self.model.model.is_multilingual, task=task, + language=language) segments: List[SingleSegment] = [] batch_size = batch_size or self._batch_size From 9c042c2d28eab5e048b047790e0d6c5c2f547c86 Mon Sep 17 00:00:00 2001 From: iambestfeeddddd Date: Fri, 26 May 2023 16:46:55 +0700 Subject: [PATCH 12/17] Add war2vec model for Vietnamese --- whisperx/alignment.py | 1 + 1 file changed, 1 insertion(+) diff --git a/whisperx/alignment.py b/whisperx/alignment.py index 13dfddc..efd75de 100644 --- a/whisperx/alignment.py +++ b/whisperx/alignment.py @@ -42,6 +42,7 @@ DEFAULT_ALIGN_MODELS_HF = { "tr": "mpoyraz/wav2vec2-xls-r-300m-cv7-turkish", "da": "saattrupdan/wav2vec2-xls-r-300m-ftspeech", "he": "imvladikon/wav2vec2-xls-r-300m-hebrew", + "vi": 'nguyenvulebinh/wav2vec2-base-vi' } From 1d9d630fb9c4db809c29449066310dc5d1c282e0 Mon Sep 17 00:00:00 2001 From: Youssef Boulaoaune <43298428+Boulaouaney@users.noreply.github.com> Date: Fri, 26 May 2023 20:33:16 +0900 Subject: [PATCH 13/17] added Korean wav2vec2 model --- whisperx/alignment.py | 1 + 1 file changed, 1 insertion(+) diff --git a/whisperx/alignment.py b/whisperx/alignment.py index 13dfddc..8f84ee5 100644 --- a/whisperx/alignment.py +++ b/whisperx/alignment.py @@ -42,6 +42,7 @@ DEFAULT_ALIGN_MODELS_HF = { "tr": "mpoyraz/wav2vec2-xls-r-300m-cv7-turkish", "da": "saattrupdan/wav2vec2-xls-r-300m-ftspeech", "he": "imvladikon/wav2vec2-xls-r-300m-hebrew", + "ko": "kresnik/wav2vec2-large-xlsr-korean", } From bb15d6b68edd9493d0c221d87ca7c72570b890e7 Mon Sep 17 00:00:00 2001 From: Thebys Date: Fri, 26 May 2023 21:17:01 +0200 Subject: [PATCH 14/17] Add Czech alignment model This PR adds the following Czech alignment model: https://huggingface.co/comodoro/wav2vec2-xls-r-300m-cs-250. I have successfully tested this with several Czech audio recordings with length of up to 3 hours, and the results are satisfactory. However, I have received the following warnings and I am not sure how relevant it is: ``` Lightning automatically upgraded your loaded checkpoint from v1.5.4 to v2.0.2. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint --file C:\Users\Thebys\.cache\torch\whisperx-vad-segmentation.bin` Model was trained with pyannote.audio 0.0.1, yours is 2.1.1. Bad things might happen unless you revert pyannote.audio to 0.x. Model was trained with torch 1.10.0+cu102, yours is 2.0.0. Bad things might happen unless you revert torch to 1.x. ``` --- whisperx/alignment.py | 1 + 1 file changed, 1 insertion(+) diff --git a/whisperx/alignment.py b/whisperx/alignment.py index 8f84ee5..34153ec 100644 --- a/whisperx/alignment.py +++ b/whisperx/alignment.py @@ -33,6 +33,7 @@ DEFAULT_ALIGN_MODELS_HF = { "uk": "Yehor/wav2vec2-xls-r-300m-uk-with-small-lm", "pt": "jonatasgrosman/wav2vec2-large-xlsr-53-portuguese", "ar": "jonatasgrosman/wav2vec2-large-xlsr-53-arabic", + "cs": "comodoro/wav2vec2-xls-r-300m-cs-250", "ru": "jonatasgrosman/wav2vec2-large-xlsr-53-russian", "pl": "jonatasgrosman/wav2vec2-large-xlsr-53-polish", "hu": "jonatasgrosman/wav2vec2-large-xlsr-53-hungarian", From f1032bb40a4d39fa7a92cc6c2a1b071b282d8aa4 Mon Sep 17 00:00:00 2001 From: Max Bain <36994049+m-bain@users.noreply.github.com> Date: Fri, 26 May 2023 20:39:19 +0100 Subject: [PATCH 15/17] VAD unequal stack size, remove debug change --- whisperx/vad.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/whisperx/vad.py b/whisperx/vad.py index a7a2451..15a9e5e 100644 --- a/whisperx/vad.py +++ b/whisperx/vad.py @@ -157,7 +157,7 @@ class Binarize: curr_scores = curr_scores[min_score_div_idx+1:] curr_timestamps = curr_timestamps[min_score_div_idx+1:] # switching from active to inactive - elif y <= self.offset: + elif y < self.offset: region = Segment(start - self.pad_onset, t + self.pad_offset) active[region, k] = label start = t @@ -169,7 +169,7 @@ class Binarize: # currently inactive else: # switching from inactive to active - if y >= self.onset: + if y > self.onset: start = t is_active = True From 5a47f458ac56f1a0d8549d850371ec380a7ec5dd Mon Sep 17 00:00:00 2001 From: prameshbajra Date: Sat, 27 May 2023 11:38:54 +0200 Subject: [PATCH 16/17] Added download path parameter. --- whisperx/asr.py | 35 ++++++++++++++++++++++++----------- 1 file changed, 24 insertions(+), 11 deletions(-) diff --git a/whisperx/asr.py b/whisperx/asr.py index 713531c..d0e6962 100644 --- a/whisperx/asr.py +++ b/whisperx/asr.py @@ -13,8 +13,16 @@ from .audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram from .vad import load_vad_model, merge_chunks from .types import TranscriptionResult, SingleSegment -def load_model(whisper_arch, device, device_index=0, compute_type="float16", asr_options=None, language=None, - vad_options=None, model=None, task="transcribe"): +def load_model(whisper_arch, + device, + device_index=0, + compute_type="float16", + asr_options=None, + language=None, + vad_options=None, + model=None, + task="transcribe", + download_root=None): '''Load a Whisper model for inference. Args: whisper_arch: str - The name of the Whisper model to load. @@ -22,14 +30,19 @@ def load_model(whisper_arch, device, device_index=0, compute_type="float16", asr compute_type: str - The compute type to use for the model. options: dict - A dictionary of options to use for the model. language: str - The language of the model. (use English for now) + download_root: Optional[str] - The root directory to download the model to. Returns: A Whisper pipeline. - ''' + ''' if whisper_arch.endswith(".en"): language = "en" - model = WhisperModel(whisper_arch, device=device, device_index=device_index, compute_type=compute_type) + model = WhisperModel(whisper_arch, + device=device, + device_index=device_index, + compute_type=compute_type, + download_root=download_root) if language is not None: tokenizer = faster_whisper.tokenizer.Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task=task, language=language) else: @@ -114,7 +127,7 @@ class WhisperModel(faster_whisper.WhisperModel): # suppress_tokens=options.suppress_tokens, # max_initial_timestamp_index=max_initial_timestamp_index, ) - + tokens_batch = [x.sequences_ids[0] for x in result] def decode_batch(tokens: List[List[int]]) -> str: @@ -127,7 +140,7 @@ class WhisperModel(faster_whisper.WhisperModel): text = decode_batch(tokens_batch) return text - + def encode(self, features: np.ndarray) -> ctranslate2.StorageView: # When the model is running on multiple GPUs, the encoder output should be moved # to the CPU since we don't know which GPU will handle the next job. @@ -136,9 +149,9 @@ class WhisperModel(faster_whisper.WhisperModel): if len(features.shape) == 2: features = np.expand_dims(features, 0) features = faster_whisper.transcribe.get_ctranslate2_storage(features) - + return self.model.encode(features, to_cpu=to_cpu) - + class FasterWhisperPipeline(Pipeline): """ Huggingface Pipeline wrapper for FasterWhisperModel. @@ -176,7 +189,7 @@ class FasterWhisperPipeline(Pipeline): self.device = torch.device(f"cuda:{device}") else: self.device = device - + super(Pipeline, self).__init__() self.vad_model = vad @@ -194,7 +207,7 @@ class FasterWhisperPipeline(Pipeline): def _forward(self, model_inputs): outputs = self.model.generate_segment_batched(model_inputs['inputs'], self.tokenizer, self.options) return {'text': outputs} - + def postprocess(self, model_outputs): return model_outputs @@ -218,7 +231,7 @@ class FasterWhisperPipeline(Pipeline): ) -> TranscriptionResult: if isinstance(audio, str): audio = load_audio(audio) - + def data(audio, segments): for seg in segments: f1 = int(seg['start'] * SAMPLE_RATE) From 4cbd3030cc0011fa8e20b93d03078dc353ef6fa7 Mon Sep 17 00:00:00 2001 From: Max Bain <36994049+m-bain@users.noreply.github.com> Date: Mon, 29 May 2023 12:48:14 +0100 Subject: [PATCH 17/17] no sentence split on mr. mrs. dr... --- whisperx/alignment.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/whisperx/alignment.py b/whisperx/alignment.py index 7ac3a04..17e96f4 100644 --- a/whisperx/alignment.py +++ b/whisperx/alignment.py @@ -15,6 +15,9 @@ from .audio import SAMPLE_RATE, load_audio from .utils import interpolate_nans from .types import AlignedTranscriptionResult, SingleSegment, SingleAlignedSegment, SingleWordSegment import nltk +from nltk.tokenize.punkt import PunktSentenceTokenizer, PunktParameters + +PUNKT_ABBREVIATIONS = ['dr', 'vs', 'mr', 'mrs', 'prof'] LANGUAGES_WITHOUT_SPACES = ["ja", "zh"] @@ -143,7 +146,11 @@ def align( if any([c in model_dictionary.keys() for c in wrd]): clean_wdx.append(wdx) - sentence_spans = list(nltk.tokenize.punkt.PunktSentenceTokenizer().span_tokenize(text)) + + punkt_param = PunktParameters() + punkt_param.abbrev_types = set(PUNKT_ABBREVIATIONS) + sentence_splitter = PunktSentenceTokenizer(punkt_param) + sentence_spans = list(sentence_splitter.span_tokenize(text)) segment["clean_char"] = clean_char segment["clean_cdx"] = clean_cdx