mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
pandas fix
This commit is contained in:
@ -182,7 +182,7 @@ def align(
|
|||||||
|
|
||||||
# if no characters are in the dictionary, then we skip this segment...
|
# if no characters are in the dictionary, then we skip this segment...
|
||||||
if len(clean_char) == 0:
|
if len(clean_char) == 0:
|
||||||
print("Failed to align segment: no characters in this segment found in model dictionary, resorting to original...")
|
print(f'Failed to align segment ("{segment["text"]}"): no characters in this segment found in model dictionary, resorting to original...')
|
||||||
break
|
break
|
||||||
|
|
||||||
transcription_cleaned = "".join(clean_char)
|
transcription_cleaned = "".join(clean_char)
|
||||||
@ -225,7 +225,7 @@ def align(
|
|||||||
trellis = get_trellis(emission, tokens)
|
trellis = get_trellis(emission, tokens)
|
||||||
path = backtrack(trellis, emission, tokens)
|
path = backtrack(trellis, emission, tokens)
|
||||||
if path is None:
|
if path is None:
|
||||||
print("Failed to align segment: backtrack failed, resorting to original...")
|
print(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original...')
|
||||||
break
|
break
|
||||||
char_segments = merge_repeats(path, transcription_cleaned)
|
char_segments = merge_repeats(path, transcription_cleaned)
|
||||||
# word_segments = merge_words(char_segments)
|
# word_segments = merge_words(char_segments)
|
||||||
@ -295,24 +295,23 @@ def align(
|
|||||||
per_word_grp = char_segments_arr[not_space].groupby(["segment-idx", "subsegment-idx", "word-idx"])
|
per_word_grp = char_segments_arr[not_space].groupby(["segment-idx", "subsegment-idx", "word-idx"])
|
||||||
per_subseg_grp = char_segments_arr[not_space].groupby(["segment-idx", "subsegment-idx"])
|
per_subseg_grp = char_segments_arr[not_space].groupby(["segment-idx", "subsegment-idx"])
|
||||||
per_seg_grp = char_segments_arr[not_space].groupby(["segment-idx"])
|
per_seg_grp = char_segments_arr[not_space].groupby(["segment-idx"])
|
||||||
|
char_segments_arr["local-char-idx"] = char_segments_arr.groupby(["segment-idx", "subsegment-idx"]).cumcount()
|
||||||
|
per_word_grp = char_segments_arr[not_space].groupby(["segment-idx", "subsegment-idx", "word-idx"]) # regroup
|
||||||
|
|
||||||
word_segments_arr = {}
|
word_segments_arr = {}
|
||||||
|
|
||||||
# start of word is first char with a timestamp
|
# start of word is first char with a timestamp
|
||||||
word_segments_arr["start"] = per_word_grp["start"].min().reset_index()["start"]
|
word_segments_arr["start"] = per_word_grp["start"].min().values
|
||||||
# end of word is last char with a timestamp
|
# end of word is last char with a timestamp
|
||||||
word_segments_arr["end"] = per_word_grp["end"].max().reset_index()["end"]
|
word_segments_arr["end"] = per_word_grp["end"].max().values
|
||||||
# score of word is mean (excluding nan)
|
# score of word is mean (excluding nan)
|
||||||
word_segments_arr["score"] = per_word_grp["score"].mean().reset_index()["score"]
|
word_segments_arr["score"] = per_word_grp["score"].mean().values
|
||||||
|
|
||||||
|
|
||||||
word_segments_arr["segment-text-start"] = per_word_grp["level_1"].min().reset_index()["level_1"]
|
|
||||||
word_segments_arr["segment-text-end"] = per_word_grp["level_1"].max().reset_index()["level_1"] + 1
|
|
||||||
word_segments_arr["segment-idx"] = per_word_grp["level_1"].min().reset_index()["segment-idx"]
|
|
||||||
|
|
||||||
|
word_segments_arr["segment-text-start"] = per_word_grp["local-char-idx"].min().astype(int).values
|
||||||
|
word_segments_arr["segment-text-end"] = per_word_grp["local-char-idx"].max().astype(int).values+1
|
||||||
word_segments_arr = pd.DataFrame(word_segments_arr)
|
word_segments_arr = pd.DataFrame(word_segments_arr)
|
||||||
word_segments_arr[["segment-idx", "subsegment-idx", "word-idx"]] = per_word_grp["level_1"].min().reset_index()[["segment-idx", "subsegment-idx", "word-idx"]]
|
|
||||||
|
|
||||||
|
word_segments_arr[["segment-idx", "subsegment-idx", "word-idx"]] = per_word_grp["local-char-idx"].min().reset_index()[["segment-idx", "subsegment-idx", "word-idx"]].astype(int)
|
||||||
segments_arr = {}
|
segments_arr = {}
|
||||||
segments_arr["start"] = per_subseg_grp["start"].min().reset_index()["start"]
|
segments_arr["start"] = per_subseg_grp["start"].min().reset_index()["start"]
|
||||||
segments_arr["end"] = per_subseg_grp["end"].min().reset_index()["end"]
|
segments_arr["end"] = per_subseg_grp["end"].min().reset_index()["end"]
|
||||||
@ -322,8 +321,8 @@ def align(
|
|||||||
|
|
||||||
# interpolate missing words / sub-segments
|
# interpolate missing words / sub-segments
|
||||||
if interpolate_method != "ignore":
|
if interpolate_method != "ignore":
|
||||||
wrd_subseg_grp = word_segments_arr.groupby(["segment-idx", "subsegment-idx"])
|
wrd_subseg_grp = word_segments_arr.groupby(["segment-idx", "subsegment-idx"], group_keys=False)
|
||||||
wrd_seg_grp = word_segments_arr.groupby(["segment-idx"])
|
wrd_seg_grp = word_segments_arr.groupby(["segment-idx"], group_keys=False)
|
||||||
# we still know which word timestamps are interpolated because their score == nan
|
# we still know which word timestamps are interpolated because their score == nan
|
||||||
word_segments_arr["start"] = wrd_subseg_grp['start'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
|
word_segments_arr["start"] = wrd_subseg_grp['start'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
|
||||||
word_segments_arr["end"] = wrd_subseg_grp['end'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
|
word_segments_arr["end"] = wrd_subseg_grp['end'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
|
||||||
@ -331,11 +330,19 @@ def align(
|
|||||||
word_segments_arr["start"] = wrd_seg_grp['start'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
|
word_segments_arr["start"] = wrd_seg_grp['start'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
|
||||||
word_segments_arr["end"] = wrd_seg_grp['end'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
|
word_segments_arr["end"] = wrd_seg_grp['end'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
|
||||||
|
|
||||||
sub_seg_grp = segments_arr.groupby(["segment-idx"])
|
sub_seg_grp = segments_arr.groupby(["segment-idx"], group_keys=False)
|
||||||
segments_arr['start'] = sub_seg_grp['start'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
|
segments_arr['start'] = sub_seg_grp['start'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
|
||||||
segments_arr['end'] = sub_seg_grp['end'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
|
segments_arr['end'] = sub_seg_grp['end'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
|
||||||
# merge subsegments which are missing times
|
|
||||||
# group by sub seg and time.
|
# merge words & subsegments which are missing times
|
||||||
|
word_grp = word_segments_arr.groupby(["segment-idx", "subsegment-idx", "end"])
|
||||||
|
word_seg_grp = word_segments_arr.groupby(["segment-idx", "end"])
|
||||||
|
|
||||||
|
word_segments_arr["segment-text-start"] = word_grp["segment-text-start"].transform(min)
|
||||||
|
word_segments_arr["segment-text-end"] = word_grp["segment-text-end"].transform(max)
|
||||||
|
word_segments_arr.drop_duplicates(subset=["segment-idx", "subsegment-idx", "end"], inplace=True)
|
||||||
|
|
||||||
|
|
||||||
seg_grp_dup = segments_arr.groupby(["segment-idx", "start", "end"])
|
seg_grp_dup = segments_arr.groupby(["segment-idx", "start", "end"])
|
||||||
segments_arr["subsegment-idx-start"] = seg_grp_dup["subsegment-idx-start"].transform(min)
|
segments_arr["subsegment-idx-start"] = seg_grp_dup["subsegment-idx-start"].transform(min)
|
||||||
segments_arr["subsegment-idx-end"] = seg_grp_dup["subsegment-idx-end"].transform(max)
|
segments_arr["subsegment-idx-end"] = seg_grp_dup["subsegment-idx-end"].transform(max)
|
||||||
|
@ -207,6 +207,8 @@ def write_ass(transcript: Iterator[dict],
|
|||||||
ass_arr = []
|
ass_arr = []
|
||||||
|
|
||||||
for segment in transcript:
|
for segment in transcript:
|
||||||
|
# if "12" in segment['text']:
|
||||||
|
# import pdb; pdb.set_trace()
|
||||||
if resolution_key in segment:
|
if resolution_key in segment:
|
||||||
res_segs = pd.DataFrame(segment[resolution_key])
|
res_segs = pd.DataFrame(segment[resolution_key])
|
||||||
prev = segment['start']
|
prev = segment['start']
|
||||||
|
Reference in New Issue
Block a user