diff --git a/whisperx/alignment.py b/whisperx/alignment.py index 7b06711..6d2ad6f 100644 --- a/whisperx/alignment.py +++ b/whisperx/alignment.py @@ -336,13 +336,11 @@ def align( # merge words & subsegments which are missing times word_grp = word_segments_arr.groupby(["segment-idx", "subsegment-idx", "end"]) - word_seg_grp = word_segments_arr.groupby(["segment-idx", "end"]) word_segments_arr["segment-text-start"] = word_grp["segment-text-start"].transform(min) word_segments_arr["segment-text-end"] = word_grp["segment-text-end"].transform(max) word_segments_arr.drop_duplicates(subset=["segment-idx", "subsegment-idx", "end"], inplace=True) - seg_grp_dup = segments_arr.groupby(["segment-idx", "start", "end"]) segments_arr["subsegment-idx-start"] = seg_grp_dup["subsegment-idx-start"].transform(min) segments_arr["subsegment-idx-end"] = seg_grp_dup["subsegment-idx-end"].transform(max) @@ -351,6 +349,13 @@ def align( word_segments_arr.dropna(inplace=True) segments_arr.dropna(inplace=True) + # if some segments still have missing timestamps (usually because all numerals / symbols), then use original timestamps... + segments_arr['start'].fillna(pd.Series([x['start'] for x in transcript]), inplace=True) + segments_arr['end'].fillna(pd.Series([x['end'] for x in transcript]), inplace=True) + segments_arr['subsegment-idx-start'].fillna(0, inplace=True) + segments_arr['subsegment-idx-end'].fillna(1, inplace=True) + + aligned_segments = [] aligned_segments_word = [] @@ -360,13 +365,21 @@ def align( for sdx, srow in segments_arr.iterrows(): seg_idx = int(srow["segment-idx"]) - sub_start = int(srow["subsegment-idx-start"]) + try: + sub_start = int(srow["subsegment-idx-start"]) + except: + import pdb; pdb.set_trace() sub_end = int(srow["subsegment-idx-end"]) seg = transcript[seg_idx] text = "".join(seg["seg-text"][sub_start:sub_end]) - wseg = word_segments_arr.loc[seg_idx].loc[sub_start:sub_end-1] + wseg = word_segments_arr.loc[seg_idx].loc[sub_start:sub_end-1] + wseg["start"].fillna(srow["start"], inplace=True) + wseg["end"].fillna(srow["end"], inplace=True) + wseg["segment-text-start"].fillna(0, inplace=True) + wseg["segment-text-end"].fillna(len(text)-1, inplace=True) + cseg = char_segments_arr.loc[seg_idx].loc[sub_start:sub_end-1] cseg['segment-text-start'] = cseg['level_1'] cseg['segment-text-end'] = cseg['level_1'] + 1 diff --git a/whisperx/transcribe.py b/whisperx/transcribe.py index b3cf16b..ff92a38 100644 --- a/whisperx/transcribe.py +++ b/whisperx/transcribe.py @@ -369,7 +369,7 @@ def cli(): parser.add_argument("--max_speakers", default=None, type=int) # output save params parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs") - parser.add_argument("--output_type", default="all", choices=["all", "srt", "srt-word", "vtt", "txt", "tsv", "ass", "ass-char", "pickle"], help="File type for desired output save") + parser.add_argument("--output_type", default="all", choices=["all", "srt", "srt-word", "vtt", "txt", "tsv", "ass", "ass-char", "pickle", "vad"], help="File type for desired output save") parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages") @@ -524,6 +524,10 @@ def cli(): exp_fp = os.path.join(output_dir, audio_basename + ".pkl") pd.DataFrame(result_aligned["segments"]).to_pickle(exp_fp) - + # save word tsv + if output_type in ["vad"]: + exp_fp = os.path.join(output_dir, audio_basename + ".sad") + wrd_segs = pd.concat([x["word-segments"] for x in result_aligned["segments"]]) + wrd_segs.to_csv(exp_fp, sep='\t', header=None, index=False) if __name__ == "__main__": cli()