handle non-alignable whole segments

This commit is contained in:
Max Bain
2023-01-28 13:53:03 +00:00
parent 8081ef2dcd
commit c19cf407d8
2 changed files with 23 additions and 6 deletions

View File

@ -369,7 +369,7 @@ def cli():
parser.add_argument("--max_speakers", default=None, type=int)
# output save params
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
parser.add_argument("--output_type", default="all", choices=["all", "srt", "srt-word", "vtt", "txt", "tsv", "ass", "ass-char", "pickle"], help="File type for desired output save")
parser.add_argument("--output_type", default="all", choices=["all", "srt", "srt-word", "vtt", "txt", "tsv", "ass", "ass-char", "pickle", "vad"], help="File type for desired output save")
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
@ -524,6 +524,10 @@ def cli():
exp_fp = os.path.join(output_dir, audio_basename + ".pkl")
pd.DataFrame(result_aligned["segments"]).to_pickle(exp_fp)
# save word tsv
if output_type in ["vad"]:
exp_fp = os.path.join(output_dir, audio_basename + ".sad")
wrd_segs = pd.concat([x["word-segments"] for x in result_aligned["segments"]])
wrd_segs.to_csv(exp_fp, sep='\t', header=None, index=False)
if __name__ == "__main__":
cli()