mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
refactor: consolidate segment data handling in alignment function
This commit is contained in:
@ -130,6 +130,8 @@ def align(
|
|||||||
|
|
||||||
# 1. Preprocess to keep only characters in dictionary
|
# 1. Preprocess to keep only characters in dictionary
|
||||||
total_segments = len(transcript)
|
total_segments = len(transcript)
|
||||||
|
# Store temporary processing values
|
||||||
|
segment_data = {}
|
||||||
for sdx, segment in enumerate(transcript):
|
for sdx, segment in enumerate(transcript):
|
||||||
# strip spaces at beginning / end, but keep track of the amount.
|
# strip spaces at beginning / end, but keep track of the amount.
|
||||||
if print_progress:
|
if print_progress:
|
||||||
@ -174,10 +176,12 @@ def align(
|
|||||||
sentence_splitter = PunktSentenceTokenizer(punkt_param)
|
sentence_splitter = PunktSentenceTokenizer(punkt_param)
|
||||||
sentence_spans = list(sentence_splitter.span_tokenize(text))
|
sentence_spans = list(sentence_splitter.span_tokenize(text))
|
||||||
|
|
||||||
segment["clean_char"] = clean_char
|
segment_data[sdx] = {
|
||||||
segment["clean_cdx"] = clean_cdx
|
"clean_char": clean_char,
|
||||||
segment["clean_wdx"] = clean_wdx
|
"clean_cdx": clean_cdx,
|
||||||
segment["sentence_spans"] = sentence_spans
|
"clean_wdx": clean_wdx,
|
||||||
|
"sentence_spans": sentence_spans
|
||||||
|
}
|
||||||
|
|
||||||
aligned_segments: List[SingleAlignedSegment] = []
|
aligned_segments: List[SingleAlignedSegment] = []
|
||||||
|
|
||||||
@ -200,7 +204,7 @@ def align(
|
|||||||
aligned_seg["chars"] = []
|
aligned_seg["chars"] = []
|
||||||
|
|
||||||
# check we can align
|
# check we can align
|
||||||
if len(segment["clean_char"]) == 0:
|
if len(segment_data[sdx]["clean_char"]) == 0:
|
||||||
print(f'Failed to align segment ("{segment["text"]}"): no characters in this segment found in model dictionary, resorting to original...')
|
print(f'Failed to align segment ("{segment["text"]}"): no characters in this segment found in model dictionary, resorting to original...')
|
||||||
aligned_segments.append(aligned_seg)
|
aligned_segments.append(aligned_seg)
|
||||||
continue
|
continue
|
||||||
@ -210,7 +214,7 @@ def align(
|
|||||||
aligned_segments.append(aligned_seg)
|
aligned_segments.append(aligned_seg)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
text_clean = "".join(segment["clean_char"])
|
text_clean = "".join(segment_data[sdx]["clean_char"])
|
||||||
tokens = [model_dictionary[c] for c in text_clean]
|
tokens = [model_dictionary[c] for c in text_clean]
|
||||||
|
|
||||||
f1 = int(t1 * SAMPLE_RATE)
|
f1 = int(t1 * SAMPLE_RATE)
|
||||||
@ -261,8 +265,8 @@ def align(
|
|||||||
word_idx = 0
|
word_idx = 0
|
||||||
for cdx, char in enumerate(text):
|
for cdx, char in enumerate(text):
|
||||||
start, end, score = None, None, None
|
start, end, score = None, None, None
|
||||||
if cdx in segment["clean_cdx"]:
|
if cdx in segment_data[sdx]["clean_cdx"]:
|
||||||
char_seg = char_segments[segment["clean_cdx"].index(cdx)]
|
char_seg = char_segments[segment_data[sdx]["clean_cdx"].index(cdx)]
|
||||||
start = round(char_seg.start * ratio + t1, 3)
|
start = round(char_seg.start * ratio + t1, 3)
|
||||||
end = round(char_seg.end * ratio + t1, 3)
|
end = round(char_seg.end * ratio + t1, 3)
|
||||||
score = round(char_seg.score, 3)
|
score = round(char_seg.score, 3)
|
||||||
@ -288,9 +292,9 @@ def align(
|
|||||||
aligned_subsegments = []
|
aligned_subsegments = []
|
||||||
# assign sentence_idx to each character index
|
# assign sentence_idx to each character index
|
||||||
char_segments_arr["sentence-idx"] = None
|
char_segments_arr["sentence-idx"] = None
|
||||||
for sdx, (sstart, send) in enumerate(segment["sentence_spans"]):
|
for sdx2, (sstart, send) in enumerate(segment_data[sdx]["sentence_spans"]):
|
||||||
curr_chars = char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send)]
|
curr_chars = char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send)]
|
||||||
char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send), "sentence-idx"] = sdx
|
char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send), "sentence-idx"] = sdx2
|
||||||
|
|
||||||
sentence_text = text[sstart:send]
|
sentence_text = text[sstart:send]
|
||||||
sentence_start = curr_chars["start"].min()
|
sentence_start = curr_chars["start"].min()
|
||||||
|
Reference in New Issue
Block a user