diff --git a/whisperx/alignment.py b/whisperx/alignment.py index ae91828..e5d92cb 100644 --- a/whisperx/alignment.py +++ b/whisperx/alignment.py @@ -130,6 +130,8 @@ def align( # 1. Preprocess to keep only characters in dictionary total_segments = len(transcript) + # Store temporary processing values + segment_data = {} for sdx, segment in enumerate(transcript): # strip spaces at beginning / end, but keep track of the amount. if print_progress: @@ -174,11 +176,13 @@ def align( sentence_splitter = PunktSentenceTokenizer(punkt_param) sentence_spans = list(sentence_splitter.span_tokenize(text)) - segment["clean_char"] = clean_char - segment["clean_cdx"] = clean_cdx - segment["clean_wdx"] = clean_wdx - segment["sentence_spans"] = sentence_spans - + segment_data[sdx] = { + "clean_char": clean_char, + "clean_cdx": clean_cdx, + "clean_wdx": clean_wdx, + "sentence_spans": sentence_spans + } + aligned_segments: List[SingleAlignedSegment] = [] # 2. Get prediction matrix from alignment model & align @@ -200,7 +204,7 @@ def align( aligned_seg["chars"] = [] # check we can align - if len(segment["clean_char"]) == 0: + if len(segment_data[sdx]["clean_char"]) == 0: print(f'Failed to align segment ("{segment["text"]}"): no characters in this segment found in model dictionary, resorting to original...') aligned_segments.append(aligned_seg) continue @@ -210,7 +214,7 @@ def align( aligned_segments.append(aligned_seg) continue - text_clean = "".join(segment["clean_char"]) + text_clean = "".join(segment_data[sdx]["clean_char"]) tokens = [model_dictionary[c] for c in text_clean] f1 = int(t1 * SAMPLE_RATE) @@ -261,8 +265,8 @@ def align( word_idx = 0 for cdx, char in enumerate(text): start, end, score = None, None, None - if cdx in segment["clean_cdx"]: - char_seg = char_segments[segment["clean_cdx"].index(cdx)] + if cdx in segment_data[sdx]["clean_cdx"]: + char_seg = char_segments[segment_data[sdx]["clean_cdx"].index(cdx)] start = round(char_seg.start * ratio + t1, 3) end = round(char_seg.end * ratio + t1, 3) score = round(char_seg.score, 3) @@ -288,10 +292,10 @@ def align( aligned_subsegments = [] # assign sentence_idx to each character index char_segments_arr["sentence-idx"] = None - for sdx, (sstart, send) in enumerate(segment["sentence_spans"]): + for sdx2, (sstart, send) in enumerate(segment_data[sdx]["sentence_spans"]): curr_chars = char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send)] - char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send), "sentence-idx"] = sdx - + char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send), "sentence-idx"] = sdx2 + sentence_text = text[sstart:send] sentence_start = curr_chars["start"].min() end_chars = curr_chars[curr_chars["char"] != ' ']