diff --git a/whisperx/alignment.py b/whisperx/alignment.py index d6241bb..c4750ca 100644 --- a/whisperx/alignment.py +++ b/whisperx/alignment.py @@ -1,7 +1,8 @@ -"""" +""" Forced Alignment with Whisper C. Max Bain """ + from dataclasses import dataclass from typing import Iterable, Optional, Union, List @@ -13,8 +14,13 @@ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor from .audio import SAMPLE_RATE, load_audio from .utils import interpolate_nans -from .types import AlignedTranscriptionResult, SingleSegment, SingleAlignedSegment, SingleWordSegment -import nltk +from .types import ( + AlignedTranscriptionResult, + SingleSegment, + SingleAlignedSegment, + SingleWordSegment, + SegmentData, +) from nltk.tokenize.punkt import PunktSentenceTokenizer, PunktParameters PUNKT_ABBREVIATIONS = ['dr', 'vs', 'mr', 'mrs', 'prof'] @@ -131,6 +137,8 @@ def align( # 1. Preprocess to keep only characters in dictionary total_segments = len(transcript) + # Store temporary processing values + segment_data: dict[int, SegmentData] = {} for sdx, segment in enumerate(transcript): # strip spaces at beginning / end, but keep track of the amount. if print_progress: @@ -175,11 +183,13 @@ def align( sentence_splitter = PunktSentenceTokenizer(punkt_param) sentence_spans = list(sentence_splitter.span_tokenize(text)) - segment["clean_char"] = clean_char - segment["clean_cdx"] = clean_cdx - segment["clean_wdx"] = clean_wdx - segment["sentence_spans"] = sentence_spans - + segment_data[sdx] = { + "clean_char": clean_char, + "clean_cdx": clean_cdx, + "clean_wdx": clean_wdx, + "sentence_spans": sentence_spans + } + aligned_segments: List[SingleAlignedSegment] = [] # 2. Get prediction matrix from alignment model & align @@ -194,13 +204,14 @@ def align( "end": t2, "text": text, "words": [], + "chars": None, } if return_char_alignments: aligned_seg["chars"] = [] # check we can align - if len(segment["clean_char"]) == 0: + if len(segment_data[sdx]["clean_char"]) == 0: print(f'Failed to align segment ("{segment["text"]}"): no characters in this segment found in model dictionary, resorting to original...') aligned_segments.append(aligned_seg) continue @@ -210,7 +221,7 @@ def align( aligned_segments.append(aligned_seg) continue - text_clean = "".join(segment["clean_char"]) + text_clean = "".join(segment_data[sdx]["clean_char"]) tokens = [model_dictionary[c] for c in text_clean] f1 = int(t1 * SAMPLE_RATE) @@ -261,8 +272,8 @@ def align( word_idx = 0 for cdx, char in enumerate(text): start, end, score = None, None, None - if cdx in segment["clean_cdx"]: - char_seg = char_segments[segment["clean_cdx"].index(cdx)] + if cdx in segment_data[sdx]["clean_cdx"]: + char_seg = char_segments[segment_data[sdx]["clean_cdx"].index(cdx)] start = round(char_seg.start * ratio + t1, 3) end = round(char_seg.end * ratio + t1, 3) score = round(char_seg.score, 3) @@ -288,10 +299,10 @@ def align( aligned_subsegments = [] # assign sentence_idx to each character index char_segments_arr["sentence-idx"] = None - for sdx, (sstart, send) in enumerate(segment["sentence_spans"]): + for sdx2, (sstart, send) in enumerate(segment_data[sdx]["sentence_spans"]): curr_chars = char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send)] - char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send), "sentence-idx"] = sdx - + char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send), "sentence-idx"] = sdx2 + sentence_text = text[sstart:send] sentence_start = curr_chars["start"].min() end_chars = curr_chars[curr_chars["char"] != ' '] diff --git a/whisperx/asr.py b/whisperx/asr.py index b0aa7b4..6de9490 100644 --- a/whisperx/asr.py +++ b/whisperx/asr.py @@ -1,6 +1,5 @@ import os -import warnings -from typing import List, NamedTuple, Optional, Union +from typing import List, Optional, Union from dataclasses import replace import ctranslate2 @@ -13,8 +12,8 @@ from transformers import Pipeline from transformers.pipelines.pt_utils import PipelineIterator from .audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram -import whisperx.vads from .types import SingleSegment, TranscriptionResult +from .vads import Vad, Silero, Pyannote def find_numeral_symbol_tokens(tokenizer): numeral_symbol_tokens = [] @@ -209,12 +208,12 @@ class FasterWhisperPipeline(Pipeline): # Pre-process audio and merge chunks as defined by the respective VAD child class # In case vad_model is manually assigned (see 'load_model') follow the functionality of pyannote toolkit - if issubclass(type(self.vad_model), whisperx.vads.Vad): + if issubclass(type(self.vad_model), Vad): waveform = self.vad_model.preprocess_audio(audio) merge_chunks = self.vad_model.merge_chunks else: - waveform = whisperx.vads.Pyannote.preprocess_audio(audio) - merge_chunks = whisperx.vads.Pyannote.merge_chunks + waveform = Pyannote.preprocess_audio(audio) + merge_chunks = Pyannote.merge_chunks vad_segments = self.vad_model({"waveform": waveform, "sample_rate": SAMPLE_RATE}) vad_segments = merge_chunks( @@ -304,8 +303,8 @@ def load_model( compute_type="float16", asr_options: Optional[dict] = None, language: Optional[str] = None, - vad_model = None, - vad_method: str = "pyannote", + vad_model: Optional[Vad]= None, + vad_method: Optional[str] = "pyannote", vad_options: Optional[dict] = None, model: Optional[WhisperModel] = None, task="transcribe", @@ -318,7 +317,7 @@ def load_model( whisper_arch - The name of the Whisper model to load. device - The device to load the model on. compute_type - The compute type to use for the model. - vad_method: str - The vad method to use. vad_model has higher priority if is not None. + vad_method - The vad method to use. vad_model has higher priority if is not None. options - A dictionary of options to use for the model. language - The language of the model. (use English for now) model - The WhisperModel instance to use. @@ -398,9 +397,9 @@ def load_model( vad_model = vad_model else: if vad_method == "silero": - vad_model = whisperx.vads.Silero(**default_vad_options) + vad_model = Silero(**default_vad_options) elif vad_method == "pyannote": - vad_model = whisperx.vads.Pyannote(torch.device(device), use_auth_token=None, **default_vad_options) + vad_model = Pyannote(torch.device(device), use_auth_token=None, **default_vad_options) else: raise ValueError(f"Invalid vad_method: {vad_method}") diff --git a/whisperx/diarize.py b/whisperx/diarize.py index 2a636e6..5b685b3 100644 --- a/whisperx/diarize.py +++ b/whisperx/diarize.py @@ -79,7 +79,7 @@ def assign_word_speakers( class Segment: - def __init__(self, start, end, speaker=None): + def __init__(self, start:int, end:int, speaker:Optional[str]=None): self.start = start self.end = end self.speaker = speaker diff --git a/whisperx/types.py b/whisperx/types.py index 68f2d78..70b10a7 100644 --- a/whisperx/types.py +++ b/whisperx/types.py @@ -1,4 +1,4 @@ -from typing import TypedDict, Optional, List +from typing import TypedDict, Optional, List, Tuple class SingleWordSegment(TypedDict): @@ -30,6 +30,17 @@ class SingleSegment(TypedDict): text: str +class SegmentData(TypedDict): + """ + Temporary processing data used during alignment. + Contains cleaned and preprocessed data for each segment. + """ + clean_char: List[str] # Cleaned characters that exist in model dictionary + clean_cdx: List[int] # Original indices of cleaned characters + clean_wdx: List[int] # Indices of words containing valid characters + sentence_spans: List[Tuple[int, int]] # Start and end indices of sentences + + class SingleAlignedSegment(TypedDict): """ A single segment (up to multiple sentences) of a speech with word alignment. diff --git a/whisperx/utils.py b/whisperx/utils.py index 0b440b7..dfe3cf2 100644 --- a/whisperx/utils.py +++ b/whisperx/utils.py @@ -241,7 +241,7 @@ class SubtitlesWriter(ResultWriter): line_count = 1 # the next subtitle to yield (a list of word timings with whitespace) subtitle: list[dict] = [] - times = [] + times: list[tuple] = [] last = result["segments"][0]["start"] for segment in result["segments"]: for i, original_timing in enumerate(segment["words"]): diff --git a/whisperx/vads/vad.py b/whisperx/vads/vad.py index d96184c..d3ffbb1 100644 --- a/whisperx/vads/vad.py +++ b/whisperx/vads/vad.py @@ -26,8 +26,8 @@ class Vad: """ curr_end = 0 merged_segments = [] - seg_idxs = [] - speaker_idxs = [] + seg_idxs: list[tuple]= [] + speaker_idxs: list[Optional[str]] = [] curr_start = segments[0].start for seg in segments: