handle non-alignable whole segments

This commit is contained in:
Max Bain
2023-01-28 13:53:03 +00:00
parent 8081ef2dcd
commit c19cf407d8
2 changed files with 23 additions and 6 deletions

View File

@ -336,13 +336,11 @@ def align(
# merge words & subsegments which are missing times
word_grp = word_segments_arr.groupby(["segment-idx", "subsegment-idx", "end"])
word_seg_grp = word_segments_arr.groupby(["segment-idx", "end"])
word_segments_arr["segment-text-start"] = word_grp["segment-text-start"].transform(min)
word_segments_arr["segment-text-end"] = word_grp["segment-text-end"].transform(max)
word_segments_arr.drop_duplicates(subset=["segment-idx", "subsegment-idx", "end"], inplace=True)
seg_grp_dup = segments_arr.groupby(["segment-idx", "start", "end"])
segments_arr["subsegment-idx-start"] = seg_grp_dup["subsegment-idx-start"].transform(min)
segments_arr["subsegment-idx-end"] = seg_grp_dup["subsegment-idx-end"].transform(max)
@ -351,6 +349,13 @@ def align(
word_segments_arr.dropna(inplace=True)
segments_arr.dropna(inplace=True)
# if some segments still have missing timestamps (usually because all numerals / symbols), then use original timestamps...
segments_arr['start'].fillna(pd.Series([x['start'] for x in transcript]), inplace=True)
segments_arr['end'].fillna(pd.Series([x['end'] for x in transcript]), inplace=True)
segments_arr['subsegment-idx-start'].fillna(0, inplace=True)
segments_arr['subsegment-idx-end'].fillna(1, inplace=True)
aligned_segments = []
aligned_segments_word = []
@ -360,13 +365,21 @@ def align(
for sdx, srow in segments_arr.iterrows():
seg_idx = int(srow["segment-idx"])
sub_start = int(srow["subsegment-idx-start"])
try:
sub_start = int(srow["subsegment-idx-start"])
except:
import pdb; pdb.set_trace()
sub_end = int(srow["subsegment-idx-end"])
seg = transcript[seg_idx]
text = "".join(seg["seg-text"][sub_start:sub_end])
wseg = word_segments_arr.loc[seg_idx].loc[sub_start:sub_end-1]
wseg["start"].fillna(srow["start"], inplace=True)
wseg["end"].fillna(srow["end"], inplace=True)
wseg["segment-text-start"].fillna(0, inplace=True)
wseg["segment-text-end"].fillna(len(text)-1, inplace=True)
cseg = char_segments_arr.loc[seg_idx].loc[sub_start:sub_end-1]
cseg['segment-text-start'] = cseg['level_1']
cseg['segment-text-end'] = cseg['level_1'] + 1

View File

@ -369,7 +369,7 @@ def cli():
parser.add_argument("--max_speakers", default=None, type=int)
# output save params
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
parser.add_argument("--output_type", default="all", choices=["all", "srt", "srt-word", "vtt", "txt", "tsv", "ass", "ass-char", "pickle"], help="File type for desired output save")
parser.add_argument("--output_type", default="all", choices=["all", "srt", "srt-word", "vtt", "txt", "tsv", "ass", "ass-char", "pickle", "vad"], help="File type for desired output save")
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
@ -524,6 +524,10 @@ def cli():
exp_fp = os.path.join(output_dir, audio_basename + ".pkl")
pd.DataFrame(result_aligned["segments"]).to_pickle(exp_fp)
# save word tsv
if output_type in ["vad"]:
exp_fp = os.path.join(output_dir, audio_basename + ".sad")
wrd_segs = pd.concat([x["word-segments"] for x in result_aligned["segments"]])
wrd_segs.to_csv(exp_fp, sep='\t', header=None, index=False)
if __name__ == "__main__":
cli()