From 73e644559d9bb6b933fe8773e07147dfdbcbb461 Mon Sep 17 00:00:00 2001 From: Barabazs <31799121+Barabazs@users.noreply.github.com> Date: Mon, 13 Jan 2025 08:26:49 +0100 Subject: [PATCH 1/4] refactor: remove namespace for consistency --- whisperx/asr.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/whisperx/asr.py b/whisperx/asr.py index 43976f2..52e7972 100644 --- a/whisperx/asr.py +++ b/whisperx/asr.py @@ -13,8 +13,8 @@ from transformers import Pipeline from transformers.pipelines.pt_utils import PipelineIterator from .audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram -import whisperx.vads from .types import SingleSegment, TranscriptionResult +from .vads import Vad, Silero, Pyannote def find_numeral_symbol_tokens(tokenizer): numeral_symbol_tokens = [] @@ -209,12 +209,12 @@ class FasterWhisperPipeline(Pipeline): # Pre-process audio and merge chunks as defined by the respective VAD child class # In case vad_model is manually assigned (see 'load_model') follow the functionality of pyannote toolkit - if issubclass(type(self.vad_model), whisperx.vads.Vad): + if issubclass(type(self.vad_model), Vad): waveform = self.vad_model.preprocess_audio(audio) merge_chunks = self.vad_model.merge_chunks else: - waveform = whisperx.vads.Pyannote.preprocess_audio(audio) - merge_chunks = whisperx.vads.Pyannote.merge_chunks + waveform = Pyannote.preprocess_audio(audio) + merge_chunks = Pyannote.merge_chunks vad_segments = self.vad_model({"waveform": waveform, "sample_rate": SAMPLE_RATE}) vad_segments = merge_chunks( @@ -398,9 +398,9 @@ def load_model( vad_model = vad_model else: if vad_method == "silero": - vad_model = whisperx.vads.Silero(**default_vad_options) + vad_model = Silero(**default_vad_options) elif vad_method == "pyannote": - vad_model = whisperx.vads.Pyannote(torch.device(device), use_auth_token=None, **default_vad_options) + vad_model = Pyannote(torch.device(device), use_auth_token=None, **default_vad_options) else: raise ValueError(f"Invalid vad_method: {vad_method}") From f286e7f3de9b08838e7f83e971d841717c44aaa7 Mon Sep 17 00:00:00 2001 From: Barabazs <31799121+Barabazs@users.noreply.github.com> Date: Mon, 13 Jan 2025 08:28:27 +0100 Subject: [PATCH 2/4] refactor: improve type hints and clean up imports --- whisperx/alignment.py | 4 ++-- whisperx/asr.py | 9 ++++----- whisperx/diarize.py | 2 +- whisperx/utils.py | 2 +- whisperx/vads/vad.py | 4 ++-- 5 files changed, 10 insertions(+), 11 deletions(-) diff --git a/whisperx/alignment.py b/whisperx/alignment.py index d6241bb..ae91828 100644 --- a/whisperx/alignment.py +++ b/whisperx/alignment.py @@ -1,4 +1,4 @@ -"""" +""" Forced Alignment with Whisper C. Max Bain """ @@ -14,7 +14,6 @@ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor from .audio import SAMPLE_RATE, load_audio from .utils import interpolate_nans from .types import AlignedTranscriptionResult, SingleSegment, SingleAlignedSegment, SingleWordSegment -import nltk from nltk.tokenize.punkt import PunktSentenceTokenizer, PunktParameters PUNKT_ABBREVIATIONS = ['dr', 'vs', 'mr', 'mrs', 'prof'] @@ -194,6 +193,7 @@ def align( "end": t2, "text": text, "words": [], + "chars": None, } if return_char_alignments: diff --git a/whisperx/asr.py b/whisperx/asr.py index 52e7972..c71c90a 100644 --- a/whisperx/asr.py +++ b/whisperx/asr.py @@ -1,6 +1,5 @@ import os -import warnings -from typing import List, NamedTuple, Optional, Union +from typing import List, Optional, Union from dataclasses import replace import ctranslate2 @@ -304,8 +303,8 @@ def load_model( compute_type="float16", asr_options: Optional[dict] = None, language: Optional[str] = None, - vad_model = None, - vad_method = None, + vad_model: Optional[Vad]= None, + vad_method: Optional[str] = None, vad_options: Optional[dict] = None, model: Optional[WhisperModel] = None, task="transcribe", @@ -318,7 +317,7 @@ def load_model( whisper_arch - The name of the Whisper model to load. device - The device to load the model on. compute_type - The compute type to use for the model. - vad_method: str - The vad method to use. vad_model has higher priority if is not None. + vad_method - The vad method to use. vad_model has higher priority if is not None. options - A dictionary of options to use for the model. language - The language of the model. (use English for now) model - The WhisperModel instance to use. diff --git a/whisperx/diarize.py b/whisperx/diarize.py index 2a636e6..5b685b3 100644 --- a/whisperx/diarize.py +++ b/whisperx/diarize.py @@ -79,7 +79,7 @@ def assign_word_speakers( class Segment: - def __init__(self, start, end, speaker=None): + def __init__(self, start:int, end:int, speaker:Optional[str]=None): self.start = start self.end = end self.speaker = speaker diff --git a/whisperx/utils.py b/whisperx/utils.py index 0b440b7..dfe3cf2 100644 --- a/whisperx/utils.py +++ b/whisperx/utils.py @@ -241,7 +241,7 @@ class SubtitlesWriter(ResultWriter): line_count = 1 # the next subtitle to yield (a list of word timings with whitespace) subtitle: list[dict] = [] - times = [] + times: list[tuple] = [] last = result["segments"][0]["start"] for segment in result["segments"]: for i, original_timing in enumerate(segment["words"]): diff --git a/whisperx/vads/vad.py b/whisperx/vads/vad.py index d96184c..d3ffbb1 100644 --- a/whisperx/vads/vad.py +++ b/whisperx/vads/vad.py @@ -26,8 +26,8 @@ class Vad: """ curr_end = 0 merged_segments = [] - seg_idxs = [] - speaker_idxs = [] + seg_idxs: list[tuple]= [] + speaker_idxs: list[Optional[str]] = [] curr_start = segments[0].start for seg in segments: From 024bc8481b566d19aef08b7de8564aca31c260b8 Mon Sep 17 00:00:00 2001 From: Barabazs <31799121+Barabazs@users.noreply.github.com> Date: Mon, 13 Jan 2025 09:13:30 +0100 Subject: [PATCH 3/4] refactor: consolidate segment data handling in alignment function --- whisperx/alignment.py | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/whisperx/alignment.py b/whisperx/alignment.py index ae91828..e5d92cb 100644 --- a/whisperx/alignment.py +++ b/whisperx/alignment.py @@ -130,6 +130,8 @@ def align( # 1. Preprocess to keep only characters in dictionary total_segments = len(transcript) + # Store temporary processing values + segment_data = {} for sdx, segment in enumerate(transcript): # strip spaces at beginning / end, but keep track of the amount. if print_progress: @@ -174,11 +176,13 @@ def align( sentence_splitter = PunktSentenceTokenizer(punkt_param) sentence_spans = list(sentence_splitter.span_tokenize(text)) - segment["clean_char"] = clean_char - segment["clean_cdx"] = clean_cdx - segment["clean_wdx"] = clean_wdx - segment["sentence_spans"] = sentence_spans - + segment_data[sdx] = { + "clean_char": clean_char, + "clean_cdx": clean_cdx, + "clean_wdx": clean_wdx, + "sentence_spans": sentence_spans + } + aligned_segments: List[SingleAlignedSegment] = [] # 2. Get prediction matrix from alignment model & align @@ -200,7 +204,7 @@ def align( aligned_seg["chars"] = [] # check we can align - if len(segment["clean_char"]) == 0: + if len(segment_data[sdx]["clean_char"]) == 0: print(f'Failed to align segment ("{segment["text"]}"): no characters in this segment found in model dictionary, resorting to original...') aligned_segments.append(aligned_seg) continue @@ -210,7 +214,7 @@ def align( aligned_segments.append(aligned_seg) continue - text_clean = "".join(segment["clean_char"]) + text_clean = "".join(segment_data[sdx]["clean_char"]) tokens = [model_dictionary[c] for c in text_clean] f1 = int(t1 * SAMPLE_RATE) @@ -261,8 +265,8 @@ def align( word_idx = 0 for cdx, char in enumerate(text): start, end, score = None, None, None - if cdx in segment["clean_cdx"]: - char_seg = char_segments[segment["clean_cdx"].index(cdx)] + if cdx in segment_data[sdx]["clean_cdx"]: + char_seg = char_segments[segment_data[sdx]["clean_cdx"].index(cdx)] start = round(char_seg.start * ratio + t1, 3) end = round(char_seg.end * ratio + t1, 3) score = round(char_seg.score, 3) @@ -288,10 +292,10 @@ def align( aligned_subsegments = [] # assign sentence_idx to each character index char_segments_arr["sentence-idx"] = None - for sdx, (sstart, send) in enumerate(segment["sentence_spans"]): + for sdx2, (sstart, send) in enumerate(segment_data[sdx]["sentence_spans"]): curr_chars = char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send)] - char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send), "sentence-idx"] = sdx - + char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send), "sentence-idx"] = sdx2 + sentence_text = text[sstart:send] sentence_start = curr_chars["start"].min() end_chars = curr_chars[curr_chars["char"] != ' '] From 2f93e029c772f15a26f4b953cbdcc1c39f41ee6a Mon Sep 17 00:00:00 2001 From: Barabazs <31799121+Barabazs@users.noreply.github.com> Date: Mon, 13 Jan 2025 09:27:33 +0100 Subject: [PATCH 4/4] feat: add SegmentData type for temporary processing during alignment --- whisperx/alignment.py | 11 +++++++++-- whisperx/types.py | 13 ++++++++++++- 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/whisperx/alignment.py b/whisperx/alignment.py index e5d92cb..c4750ca 100644 --- a/whisperx/alignment.py +++ b/whisperx/alignment.py @@ -2,6 +2,7 @@ Forced Alignment with Whisper C. Max Bain """ + from dataclasses import dataclass from typing import Iterable, Optional, Union, List @@ -13,7 +14,13 @@ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor from .audio import SAMPLE_RATE, load_audio from .utils import interpolate_nans -from .types import AlignedTranscriptionResult, SingleSegment, SingleAlignedSegment, SingleWordSegment +from .types import ( + AlignedTranscriptionResult, + SingleSegment, + SingleAlignedSegment, + SingleWordSegment, + SegmentData, +) from nltk.tokenize.punkt import PunktSentenceTokenizer, PunktParameters PUNKT_ABBREVIATIONS = ['dr', 'vs', 'mr', 'mrs', 'prof'] @@ -131,7 +138,7 @@ def align( # 1. Preprocess to keep only characters in dictionary total_segments = len(transcript) # Store temporary processing values - segment_data = {} + segment_data: dict[int, SegmentData] = {} for sdx, segment in enumerate(transcript): # strip spaces at beginning / end, but keep track of the amount. if print_progress: diff --git a/whisperx/types.py b/whisperx/types.py index 68f2d78..70b10a7 100644 --- a/whisperx/types.py +++ b/whisperx/types.py @@ -1,4 +1,4 @@ -from typing import TypedDict, Optional, List +from typing import TypedDict, Optional, List, Tuple class SingleWordSegment(TypedDict): @@ -30,6 +30,17 @@ class SingleSegment(TypedDict): text: str +class SegmentData(TypedDict): + """ + Temporary processing data used during alignment. + Contains cleaned and preprocessed data for each segment. + """ + clean_char: List[str] # Cleaned characters that exist in model dictionary + clean_cdx: List[int] # Original indices of cleaned characters + clean_wdx: List[int] # Indices of words containing valid characters + sentence_spans: List[Tuple[int, int]] # Start and end indices of sentences + + class SingleAlignedSegment(TypedDict): """ A single segment (up to multiple sentences) of a speech with word alignment.