refactor: consolidate segment data handling in alignment function

This commit is contained in:
Barabazs
2025-01-13 09:13:30 +01:00
parent f286e7f3de
commit 024bc8481b

View File

@ -130,6 +130,8 @@ def align(
# 1. Preprocess to keep only characters in dictionary # 1. Preprocess to keep only characters in dictionary
total_segments = len(transcript) total_segments = len(transcript)
# Store temporary processing values
segment_data = {}
for sdx, segment in enumerate(transcript): for sdx, segment in enumerate(transcript):
# strip spaces at beginning / end, but keep track of the amount. # strip spaces at beginning / end, but keep track of the amount.
if print_progress: if print_progress:
@ -174,10 +176,12 @@ def align(
sentence_splitter = PunktSentenceTokenizer(punkt_param) sentence_splitter = PunktSentenceTokenizer(punkt_param)
sentence_spans = list(sentence_splitter.span_tokenize(text)) sentence_spans = list(sentence_splitter.span_tokenize(text))
segment["clean_char"] = clean_char segment_data[sdx] = {
segment["clean_cdx"] = clean_cdx "clean_char": clean_char,
segment["clean_wdx"] = clean_wdx "clean_cdx": clean_cdx,
segment["sentence_spans"] = sentence_spans "clean_wdx": clean_wdx,
"sentence_spans": sentence_spans
}
aligned_segments: List[SingleAlignedSegment] = [] aligned_segments: List[SingleAlignedSegment] = []
@ -200,7 +204,7 @@ def align(
aligned_seg["chars"] = [] aligned_seg["chars"] = []
# check we can align # check we can align
if len(segment["clean_char"]) == 0: if len(segment_data[sdx]["clean_char"]) == 0:
print(f'Failed to align segment ("{segment["text"]}"): no characters in this segment found in model dictionary, resorting to original...') print(f'Failed to align segment ("{segment["text"]}"): no characters in this segment found in model dictionary, resorting to original...')
aligned_segments.append(aligned_seg) aligned_segments.append(aligned_seg)
continue continue
@ -210,7 +214,7 @@ def align(
aligned_segments.append(aligned_seg) aligned_segments.append(aligned_seg)
continue continue
text_clean = "".join(segment["clean_char"]) text_clean = "".join(segment_data[sdx]["clean_char"])
tokens = [model_dictionary[c] for c in text_clean] tokens = [model_dictionary[c] for c in text_clean]
f1 = int(t1 * SAMPLE_RATE) f1 = int(t1 * SAMPLE_RATE)
@ -261,8 +265,8 @@ def align(
word_idx = 0 word_idx = 0
for cdx, char in enumerate(text): for cdx, char in enumerate(text):
start, end, score = None, None, None start, end, score = None, None, None
if cdx in segment["clean_cdx"]: if cdx in segment_data[sdx]["clean_cdx"]:
char_seg = char_segments[segment["clean_cdx"].index(cdx)] char_seg = char_segments[segment_data[sdx]["clean_cdx"].index(cdx)]
start = round(char_seg.start * ratio + t1, 3) start = round(char_seg.start * ratio + t1, 3)
end = round(char_seg.end * ratio + t1, 3) end = round(char_seg.end * ratio + t1, 3)
score = round(char_seg.score, 3) score = round(char_seg.score, 3)
@ -288,9 +292,9 @@ def align(
aligned_subsegments = [] aligned_subsegments = []
# assign sentence_idx to each character index # assign sentence_idx to each character index
char_segments_arr["sentence-idx"] = None char_segments_arr["sentence-idx"] = None
for sdx, (sstart, send) in enumerate(segment["sentence_spans"]): for sdx2, (sstart, send) in enumerate(segment_data[sdx]["sentence_spans"]):
curr_chars = char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send)] curr_chars = char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send)]
char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send), "sentence-idx"] = sdx char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send), "sentence-idx"] = sdx2
sentence_text = text[sstart:send] sentence_text = text[sstart:send]
sentence_start = curr_chars["start"].min() sentence_start = curr_chars["start"].min()