mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
handle non-alignable whole segments
This commit is contained in:
@ -336,13 +336,11 @@ def align(
|
||||
|
||||
# merge words & subsegments which are missing times
|
||||
word_grp = word_segments_arr.groupby(["segment-idx", "subsegment-idx", "end"])
|
||||
word_seg_grp = word_segments_arr.groupby(["segment-idx", "end"])
|
||||
|
||||
word_segments_arr["segment-text-start"] = word_grp["segment-text-start"].transform(min)
|
||||
word_segments_arr["segment-text-end"] = word_grp["segment-text-end"].transform(max)
|
||||
word_segments_arr.drop_duplicates(subset=["segment-idx", "subsegment-idx", "end"], inplace=True)
|
||||
|
||||
|
||||
seg_grp_dup = segments_arr.groupby(["segment-idx", "start", "end"])
|
||||
segments_arr["subsegment-idx-start"] = seg_grp_dup["subsegment-idx-start"].transform(min)
|
||||
segments_arr["subsegment-idx-end"] = seg_grp_dup["subsegment-idx-end"].transform(max)
|
||||
@ -351,6 +349,13 @@ def align(
|
||||
word_segments_arr.dropna(inplace=True)
|
||||
segments_arr.dropna(inplace=True)
|
||||
|
||||
# if some segments still have missing timestamps (usually because all numerals / symbols), then use original timestamps...
|
||||
segments_arr['start'].fillna(pd.Series([x['start'] for x in transcript]), inplace=True)
|
||||
segments_arr['end'].fillna(pd.Series([x['end'] for x in transcript]), inplace=True)
|
||||
segments_arr['subsegment-idx-start'].fillna(0, inplace=True)
|
||||
segments_arr['subsegment-idx-end'].fillna(1, inplace=True)
|
||||
|
||||
|
||||
aligned_segments = []
|
||||
aligned_segments_word = []
|
||||
|
||||
@ -360,13 +365,21 @@ def align(
|
||||
for sdx, srow in segments_arr.iterrows():
|
||||
|
||||
seg_idx = int(srow["segment-idx"])
|
||||
sub_start = int(srow["subsegment-idx-start"])
|
||||
try:
|
||||
sub_start = int(srow["subsegment-idx-start"])
|
||||
except:
|
||||
import pdb; pdb.set_trace()
|
||||
sub_end = int(srow["subsegment-idx-end"])
|
||||
|
||||
seg = transcript[seg_idx]
|
||||
text = "".join(seg["seg-text"][sub_start:sub_end])
|
||||
|
||||
wseg = word_segments_arr.loc[seg_idx].loc[sub_start:sub_end-1]
|
||||
wseg = word_segments_arr.loc[seg_idx].loc[sub_start:sub_end-1]
|
||||
wseg["start"].fillna(srow["start"], inplace=True)
|
||||
wseg["end"].fillna(srow["end"], inplace=True)
|
||||
wseg["segment-text-start"].fillna(0, inplace=True)
|
||||
wseg["segment-text-end"].fillna(len(text)-1, inplace=True)
|
||||
|
||||
cseg = char_segments_arr.loc[seg_idx].loc[sub_start:sub_end-1]
|
||||
cseg['segment-text-start'] = cseg['level_1']
|
||||
cseg['segment-text-end'] = cseg['level_1'] + 1
|
||||
|
@ -369,7 +369,7 @@ def cli():
|
||||
parser.add_argument("--max_speakers", default=None, type=int)
|
||||
# output save params
|
||||
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
|
||||
parser.add_argument("--output_type", default="all", choices=["all", "srt", "srt-word", "vtt", "txt", "tsv", "ass", "ass-char", "pickle"], help="File type for desired output save")
|
||||
parser.add_argument("--output_type", default="all", choices=["all", "srt", "srt-word", "vtt", "txt", "tsv", "ass", "ass-char", "pickle", "vad"], help="File type for desired output save")
|
||||
|
||||
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
|
||||
|
||||
@ -524,6 +524,10 @@ def cli():
|
||||
exp_fp = os.path.join(output_dir, audio_basename + ".pkl")
|
||||
pd.DataFrame(result_aligned["segments"]).to_pickle(exp_fp)
|
||||
|
||||
|
||||
# save word tsv
|
||||
if output_type in ["vad"]:
|
||||
exp_fp = os.path.join(output_dir, audio_basename + ".sad")
|
||||
wrd_segs = pd.concat([x["word-segments"] for x in result_aligned["segments"]])
|
||||
wrd_segs.to_csv(exp_fp, sep='\t', header=None, index=False)
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
|
Reference in New Issue
Block a user