pandas fix

This commit is contained in:
Max Bain
2023-01-27 15:05:08 +00:00
parent 7f2159a953
commit 5b8c8a7bd3
2 changed files with 25 additions and 16 deletions

View File

@ -182,7 +182,7 @@ def align(
# if no characters are in the dictionary, then we skip this segment... # if no characters are in the dictionary, then we skip this segment...
if len(clean_char) == 0: if len(clean_char) == 0:
print("Failed to align segment: no characters in this segment found in model dictionary, resorting to original...") print(f'Failed to align segment ("{segment["text"]}"): no characters in this segment found in model dictionary, resorting to original...')
break break
transcription_cleaned = "".join(clean_char) transcription_cleaned = "".join(clean_char)
@ -225,7 +225,7 @@ def align(
trellis = get_trellis(emission, tokens) trellis = get_trellis(emission, tokens)
path = backtrack(trellis, emission, tokens) path = backtrack(trellis, emission, tokens)
if path is None: if path is None:
print("Failed to align segment: backtrack failed, resorting to original...") print(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original...')
break break
char_segments = merge_repeats(path, transcription_cleaned) char_segments = merge_repeats(path, transcription_cleaned)
# word_segments = merge_words(char_segments) # word_segments = merge_words(char_segments)
@ -295,24 +295,23 @@ def align(
per_word_grp = char_segments_arr[not_space].groupby(["segment-idx", "subsegment-idx", "word-idx"]) per_word_grp = char_segments_arr[not_space].groupby(["segment-idx", "subsegment-idx", "word-idx"])
per_subseg_grp = char_segments_arr[not_space].groupby(["segment-idx", "subsegment-idx"]) per_subseg_grp = char_segments_arr[not_space].groupby(["segment-idx", "subsegment-idx"])
per_seg_grp = char_segments_arr[not_space].groupby(["segment-idx"]) per_seg_grp = char_segments_arr[not_space].groupby(["segment-idx"])
char_segments_arr["local-char-idx"] = char_segments_arr.groupby(["segment-idx", "subsegment-idx"]).cumcount()
per_word_grp = char_segments_arr[not_space].groupby(["segment-idx", "subsegment-idx", "word-idx"]) # regroup
word_segments_arr = {} word_segments_arr = {}
# start of word is first char with a timestamp # start of word is first char with a timestamp
word_segments_arr["start"] = per_word_grp["start"].min().reset_index()["start"] word_segments_arr["start"] = per_word_grp["start"].min().values
# end of word is last char with a timestamp # end of word is last char with a timestamp
word_segments_arr["end"] = per_word_grp["end"].max().reset_index()["end"] word_segments_arr["end"] = per_word_grp["end"].max().values
# score of word is mean (excluding nan) # score of word is mean (excluding nan)
word_segments_arr["score"] = per_word_grp["score"].mean().reset_index()["score"] word_segments_arr["score"] = per_word_grp["score"].mean().values
word_segments_arr["segment-text-start"] = per_word_grp["level_1"].min().reset_index()["level_1"]
word_segments_arr["segment-text-end"] = per_word_grp["level_1"].max().reset_index()["level_1"] + 1
word_segments_arr["segment-idx"] = per_word_grp["level_1"].min().reset_index()["segment-idx"]
word_segments_arr["segment-text-start"] = per_word_grp["local-char-idx"].min().astype(int).values
word_segments_arr["segment-text-end"] = per_word_grp["local-char-idx"].max().astype(int).values+1
word_segments_arr = pd.DataFrame(word_segments_arr) word_segments_arr = pd.DataFrame(word_segments_arr)
word_segments_arr[["segment-idx", "subsegment-idx", "word-idx"]] = per_word_grp["level_1"].min().reset_index()[["segment-idx", "subsegment-idx", "word-idx"]]
word_segments_arr[["segment-idx", "subsegment-idx", "word-idx"]] = per_word_grp["local-char-idx"].min().reset_index()[["segment-idx", "subsegment-idx", "word-idx"]].astype(int)
segments_arr = {} segments_arr = {}
segments_arr["start"] = per_subseg_grp["start"].min().reset_index()["start"] segments_arr["start"] = per_subseg_grp["start"].min().reset_index()["start"]
segments_arr["end"] = per_subseg_grp["end"].min().reset_index()["end"] segments_arr["end"] = per_subseg_grp["end"].min().reset_index()["end"]
@ -322,8 +321,8 @@ def align(
# interpolate missing words / sub-segments # interpolate missing words / sub-segments
if interpolate_method != "ignore": if interpolate_method != "ignore":
wrd_subseg_grp = word_segments_arr.groupby(["segment-idx", "subsegment-idx"]) wrd_subseg_grp = word_segments_arr.groupby(["segment-idx", "subsegment-idx"], group_keys=False)
wrd_seg_grp = word_segments_arr.groupby(["segment-idx"]) wrd_seg_grp = word_segments_arr.groupby(["segment-idx"], group_keys=False)
# we still know which word timestamps are interpolated because their score == nan # we still know which word timestamps are interpolated because their score == nan
word_segments_arr["start"] = wrd_subseg_grp['start'].apply(lambda group: interpolate_nans(group, method=interpolate_method)) word_segments_arr["start"] = wrd_subseg_grp['start'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
word_segments_arr["end"] = wrd_subseg_grp['end'].apply(lambda group: interpolate_nans(group, method=interpolate_method)) word_segments_arr["end"] = wrd_subseg_grp['end'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
@ -331,11 +330,19 @@ def align(
word_segments_arr["start"] = wrd_seg_grp['start'].apply(lambda group: interpolate_nans(group, method=interpolate_method)) word_segments_arr["start"] = wrd_seg_grp['start'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
word_segments_arr["end"] = wrd_seg_grp['end'].apply(lambda group: interpolate_nans(group, method=interpolate_method)) word_segments_arr["end"] = wrd_seg_grp['end'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
sub_seg_grp = segments_arr.groupby(["segment-idx"]) sub_seg_grp = segments_arr.groupby(["segment-idx"], group_keys=False)
segments_arr['start'] = sub_seg_grp['start'].apply(lambda group: interpolate_nans(group, method=interpolate_method)) segments_arr['start'] = sub_seg_grp['start'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
segments_arr['end'] = sub_seg_grp['end'].apply(lambda group: interpolate_nans(group, method=interpolate_method)) segments_arr['end'] = sub_seg_grp['end'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
# merge subsegments which are missing times
# group by sub seg and time. # merge words & subsegments which are missing times
word_grp = word_segments_arr.groupby(["segment-idx", "subsegment-idx", "end"])
word_seg_grp = word_segments_arr.groupby(["segment-idx", "end"])
word_segments_arr["segment-text-start"] = word_grp["segment-text-start"].transform(min)
word_segments_arr["segment-text-end"] = word_grp["segment-text-end"].transform(max)
word_segments_arr.drop_duplicates(subset=["segment-idx", "subsegment-idx", "end"], inplace=True)
seg_grp_dup = segments_arr.groupby(["segment-idx", "start", "end"]) seg_grp_dup = segments_arr.groupby(["segment-idx", "start", "end"])
segments_arr["subsegment-idx-start"] = seg_grp_dup["subsegment-idx-start"].transform(min) segments_arr["subsegment-idx-start"] = seg_grp_dup["subsegment-idx-start"].transform(min)
segments_arr["subsegment-idx-end"] = seg_grp_dup["subsegment-idx-end"].transform(max) segments_arr["subsegment-idx-end"] = seg_grp_dup["subsegment-idx-end"].transform(max)

View File

@ -207,6 +207,8 @@ def write_ass(transcript: Iterator[dict],
ass_arr = [] ass_arr = []
for segment in transcript: for segment in transcript:
# if "12" in segment['text']:
# import pdb; pdb.set_trace()
if resolution_key in segment: if resolution_key in segment:
res_segs = pd.DataFrame(segment[resolution_key]) res_segs = pd.DataFrame(segment[resolution_key])
prev = segment['start'] prev = segment['start']