diff --git a/whisperx/transcribe.py b/whisperx/transcribe.py index 3bb1a36..1855178 100644 --- a/whisperx/transcribe.py +++ b/whisperx/transcribe.py @@ -26,7 +26,7 @@ def cli(): parser.add_argument("--compute_type", default="float16", type=str, choices=["float16", "float32", "int8"], help="compute type for computation") parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs") - parser.add_argument("--output_format", "-f", type=str, default="all", choices=["all", "srt", "vtt", "txt", "tsv", "json"], help="format of the output file; if not specified, all available formats will be produced") + parser.add_argument("--output_format", "-f", type=str, default="all", choices=["all", "srt", "vtt", "txt", "tsv", "json", "aud"], help="format of the output file; if not specified, all available formats will be produced") parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages") parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')") @@ -210,4 +210,4 @@ def cli(): writer(result, audio_path, writer_args) if __name__ == "__main__": - cli() \ No newline at end of file + cli() diff --git a/whisperx/utils.py b/whisperx/utils.py index d042bb7..36c7543 100644 --- a/whisperx/utils.py +++ b/whisperx/utils.py @@ -365,6 +365,28 @@ class WriteTSV(ResultWriter): print(round(1000 * segment["end"]), file=file, end="\t") print(segment["text"].strip().replace("\t", " "), file=file, flush=True) +class WriteAudacity(ResultWriter): + """ + Write a transcript to a text file that audacity can import as labels. + The extension used is "aud" to distinguish it from the txt file produced by WriteTXT. + Yet this is not an audacity project but only a label file! + + Please note : Audacity uses seconds in timestamps not ms! + Also there is no header expected. + + If speaker is provided it is prepended to the text between double square brackets [[]]. + """ + + extension: str = "aud" + + def write_result(self, result: dict, file: TextIO, options: dict): + ARROW = " " + for segment in result["segments"]: + print(segment["start"], file=file, end=ARROW) + print(segment["end"], file=file, end=ARROW) + print( ( ("[[" + segment["speaker"] + "]]") if "speaker" in segment else "") + segment["text"].strip().replace("\t", " "), file=file, flush=True) + + class WriteJSON(ResultWriter): extension: str = "json" @@ -383,6 +405,9 @@ def get_writer( "tsv": WriteTSV, "json": WriteJSON, } + optional_writers = { + "aud": WriteAudacity, + } if output_format == "all": all_writers = [writer(output_dir) for writer in writers.values()] @@ -393,10 +418,12 @@ def get_writer( return write_all + if output_format in optional_writers: + return optional_writers[output_format](output_dir) return writers[output_format](output_dir) def interpolate_nans(x, method='nearest'): if x.notnull().sum() > 1: return x.interpolate(method=method).ffill().bfill() else: - return x.ffill().bfill() \ No newline at end of file + return x.ffill().bfill()