diff --git a/whisperx/alignment.py b/whisperx/alignment.py index d9b1c3e..7b06711 100644 --- a/whisperx/alignment.py +++ b/whisperx/alignment.py @@ -182,7 +182,7 @@ def align( # if no characters are in the dictionary, then we skip this segment... if len(clean_char) == 0: - print("Failed to align segment: no characters in this segment found in model dictionary, resorting to original...") + print(f'Failed to align segment ("{segment["text"]}"): no characters in this segment found in model dictionary, resorting to original...') break transcription_cleaned = "".join(clean_char) @@ -225,7 +225,7 @@ def align( trellis = get_trellis(emission, tokens) path = backtrack(trellis, emission, tokens) if path is None: - print("Failed to align segment: backtrack failed, resorting to original...") + print(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original...') break char_segments = merge_repeats(path, transcription_cleaned) # word_segments = merge_words(char_segments) @@ -295,24 +295,23 @@ def align( per_word_grp = char_segments_arr[not_space].groupby(["segment-idx", "subsegment-idx", "word-idx"]) per_subseg_grp = char_segments_arr[not_space].groupby(["segment-idx", "subsegment-idx"]) per_seg_grp = char_segments_arr[not_space].groupby(["segment-idx"]) + char_segments_arr["local-char-idx"] = char_segments_arr.groupby(["segment-idx", "subsegment-idx"]).cumcount() + per_word_grp = char_segments_arr[not_space].groupby(["segment-idx", "subsegment-idx", "word-idx"]) # regroup word_segments_arr = {} # start of word is first char with a timestamp - word_segments_arr["start"] = per_word_grp["start"].min().reset_index()["start"] + word_segments_arr["start"] = per_word_grp["start"].min().values # end of word is last char with a timestamp - word_segments_arr["end"] = per_word_grp["end"].max().reset_index()["end"] + word_segments_arr["end"] = per_word_grp["end"].max().values # score of word is mean (excluding nan) - word_segments_arr["score"] = per_word_grp["score"].mean().reset_index()["score"] - - - word_segments_arr["segment-text-start"] = per_word_grp["level_1"].min().reset_index()["level_1"] - word_segments_arr["segment-text-end"] = per_word_grp["level_1"].max().reset_index()["level_1"] + 1 - word_segments_arr["segment-idx"] = per_word_grp["level_1"].min().reset_index()["segment-idx"] + word_segments_arr["score"] = per_word_grp["score"].mean().values + word_segments_arr["segment-text-start"] = per_word_grp["local-char-idx"].min().astype(int).values + word_segments_arr["segment-text-end"] = per_word_grp["local-char-idx"].max().astype(int).values+1 word_segments_arr = pd.DataFrame(word_segments_arr) - word_segments_arr[["segment-idx", "subsegment-idx", "word-idx"]] = per_word_grp["level_1"].min().reset_index()[["segment-idx", "subsegment-idx", "word-idx"]] + word_segments_arr[["segment-idx", "subsegment-idx", "word-idx"]] = per_word_grp["local-char-idx"].min().reset_index()[["segment-idx", "subsegment-idx", "word-idx"]].astype(int) segments_arr = {} segments_arr["start"] = per_subseg_grp["start"].min().reset_index()["start"] segments_arr["end"] = per_subseg_grp["end"].min().reset_index()["end"] @@ -322,8 +321,8 @@ def align( # interpolate missing words / sub-segments if interpolate_method != "ignore": - wrd_subseg_grp = word_segments_arr.groupby(["segment-idx", "subsegment-idx"]) - wrd_seg_grp = word_segments_arr.groupby(["segment-idx"]) + wrd_subseg_grp = word_segments_arr.groupby(["segment-idx", "subsegment-idx"], group_keys=False) + wrd_seg_grp = word_segments_arr.groupby(["segment-idx"], group_keys=False) # we still know which word timestamps are interpolated because their score == nan word_segments_arr["start"] = wrd_subseg_grp['start'].apply(lambda group: interpolate_nans(group, method=interpolate_method)) word_segments_arr["end"] = wrd_subseg_grp['end'].apply(lambda group: interpolate_nans(group, method=interpolate_method)) @@ -331,11 +330,19 @@ def align( word_segments_arr["start"] = wrd_seg_grp['start'].apply(lambda group: interpolate_nans(group, method=interpolate_method)) word_segments_arr["end"] = wrd_seg_grp['end'].apply(lambda group: interpolate_nans(group, method=interpolate_method)) - sub_seg_grp = segments_arr.groupby(["segment-idx"]) + sub_seg_grp = segments_arr.groupby(["segment-idx"], group_keys=False) segments_arr['start'] = sub_seg_grp['start'].apply(lambda group: interpolate_nans(group, method=interpolate_method)) segments_arr['end'] = sub_seg_grp['end'].apply(lambda group: interpolate_nans(group, method=interpolate_method)) - # merge subsegments which are missing times - # group by sub seg and time. + + # merge words & subsegments which are missing times + word_grp = word_segments_arr.groupby(["segment-idx", "subsegment-idx", "end"]) + word_seg_grp = word_segments_arr.groupby(["segment-idx", "end"]) + + word_segments_arr["segment-text-start"] = word_grp["segment-text-start"].transform(min) + word_segments_arr["segment-text-end"] = word_grp["segment-text-end"].transform(max) + word_segments_arr.drop_duplicates(subset=["segment-idx", "subsegment-idx", "end"], inplace=True) + + seg_grp_dup = segments_arr.groupby(["segment-idx", "start", "end"]) segments_arr["subsegment-idx-start"] = seg_grp_dup["subsegment-idx-start"].transform(min) segments_arr["subsegment-idx-end"] = seg_grp_dup["subsegment-idx-end"].transform(max) diff --git a/whisperx/utils.py b/whisperx/utils.py index 6f46514..86d4063 100644 --- a/whisperx/utils.py +++ b/whisperx/utils.py @@ -207,6 +207,8 @@ def write_ass(transcript: Iterator[dict], ass_arr = [] for segment in transcript: + # if "12" in segment['text']: + # import pdb; pdb.set_trace() if resolution_key in segment: res_segs = pd.DataFrame(segment[resolution_key]) prev = segment['start']