diff --git a/figures/pipeline.png b/figures/pipeline.png new file mode 100644 index 0000000..232ea78 Binary files /dev/null and b/figures/pipeline.png differ diff --git a/whisperx/asr.py b/whisperx/asr.py index 27213ef..ac16459 100644 --- a/whisperx/asr.py +++ b/whisperx/asr.py @@ -395,8 +395,10 @@ def transcribe_with_vad( # merge segments to approx 30s inputs to make whisper most appropraite vad_segments = merge_chunks(vad_segments, chunk_size=CHUNK_LENGTH) + if len(vad_segments) == 0: + return output - print("Performing transcription...") + print(">>Performing transcription...") for sdx, seg_t in enumerate(vad_segments): if verbose: print(f"~~ Transcribing VAD chunk: ({format_timestamp(seg_t['start'])} --> {format_timestamp(seg_t['end'])}) ~~") diff --git a/whisperx/transcribe.py b/whisperx/transcribe.py index 04eb201..0548bfc 100644 --- a/whisperx/transcribe.py +++ b/whisperx/transcribe.py @@ -16,7 +16,7 @@ from whisper.utils import ( from .alignment import load_align_model, align from .asr import transcribe, transcribe_with_vad -from .diarize import DiarizationPipeline +from .diarize import DiarizationPipeline, assign_word_speakers from .utils import get_writer from .vad import load_vad_model @@ -44,7 +44,7 @@ def cli(): parser.add_argument("--no_align", action='store_true', help="Do not perform phoneme alignment") # vad params - parser.add_argument("--vad_filter", action="store_true", help="Whether to pre-segment audio with VAD, highly recommended! Produces more accurate alignment + timestamp see WhisperX paper https://arxiv.org/abs/2303.00747") + parser.add_argument("--vad_filter", default=True, help="Whether to pre-segment audio with VAD, highly recommended! Produces more accurate alignment + timestamp see WhisperX paper https://arxiv.org/abs/2303.00747") parser.add_argument("--vad_onset", type=float, default=0.500, help="Onset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected") parser.add_argument("--vad_offset", type=float, default=0.363, help="Offset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected.") @@ -61,7 +61,7 @@ def cli(): parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations") parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.") - parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop") + parser.add_argument("--condition_on_previous_text", type=str2bool, default=False, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop") parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default") parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below") @@ -74,7 +74,7 @@ def cli(): parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS") parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face Access Token to access PyAnnote gated models") - parser.add_argument("--model_flush", action="store_true", help="Flush memory from each model after use, reduces GPU requirement but slower processing >1 audio file.") + # parser.add_argument("--model_flush", action="store_true", help="Flush memory from each model after use, reduces GPU requirement but slower processing >1 audio file.") parser.add_argument("--tmp_dir", default=None, help="Temporary directory to write audio file if input if not .wav format (only for VAD).") # fmt: on @@ -84,9 +84,13 @@ def cli(): output_dir: str = args.pop("output_dir") output_format: str = args.pop("output_format") device: str = args.pop("device") - model_flush: bool = args.pop("model_flush") + # model_flush: bool = args.pop("model_flush") os.makedirs(output_dir, exist_ok=True) + tmp_dir: str = args.pop("tmp_dir") + if tmp_dir is not None: + os.makedirs(tmp_dir, exist_ok=True) + align_model: str = args.pop("align_model") align_extend: float = args.pop("align_extend") align_from_prev: bool = args.pop("align_from_prev") @@ -124,6 +128,11 @@ def cli(): align_language = args["language"] if args["language"] is not None else "en" # default to loading english if not specified align_model, align_metadata = load_align_model(align_language, device, model_name=align_model) + # if model_flush: + # print(">>Model flushing activated... Only loading model after ASR stage") + # del align_model + # align_model = "" + if model_name.endswith(".en") and args["language"] not in {"en", "English"}: if args["language"] is not None: @@ -148,34 +157,36 @@ def cli(): writer = get_writer(output_format, output_dir) for audio_path in args.pop("audio"): input_audio_path = audio_path + tfile = None if vad_model is not None: if not audio_path.endswith(".wav"): - print("VAD requires .wav format, converting to wav as a tempfile...") + print(">>VAD requires .wav format, converting to wav as a tempfile...") tfile = tempfile.NamedTemporaryFile(delete=True, suffix=".wav") ffmpeg.input(audio_path, threads=0).output(tfile.name, ac=1, ar=SAMPLE_RATE).run(cmd=["ffmpeg"]) input_audio_path = tfile.name - print("Performing VAD...") + print(">>Performing VAD...") result = transcribe_with_vad(model, input_audio_path, vad_model, temperature=temperature, **args) - - if tfile is not None: - tfile.close() else: - print("Performing transcription...") + print(">>Performing transcription...") result = transcribe(model, input_audio_path, temperature=temperature, **args) - if align_model is not None: - if result["language"] != align_metadata["language"]: + if align_model is not None and len(result["segments"]) > 0: + if result.get("language", "en") != align_metadata["language"]: # load new language print(f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language...") align_model, align_metadata = load_align_model(result["language"], device) - + print(">>Performing alignment...") result = align(result["segments"], align_model, align_metadata, input_audio_path, device, extend_duration=align_extend, start_from_previous=align_from_prev, interpolate_method=interpolate_method) - # if diarize_model is not None: - # diarize_segments = diarize_model(input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers) - # results_segments, word_segments = assign_word_speakers(diarize_segments, ) + + if diarize_model is not None: + diarize_segments = diarize_model(input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers) + results_segments, word_segments = assign_word_speakers(diarize_segments) + + if tfile is not None: + tfile.close() writer(result, audio_path) diff --git a/whisperx/utils.py b/whisperx/utils.py index 84f03a9..992960b 100644 --- a/whisperx/utils.py +++ b/whisperx/utils.py @@ -270,6 +270,8 @@ class WriteSRTWord(ResultWriter): yield segment_start, segment_end, segment_text def write_result(self, result: dict, file: TextIO): + if "word_segments" not in result: + return for i, (start, end, text) in enumerate(self.iterate_result(result), start=1): print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True) @@ -286,11 +288,16 @@ def get_writer(output_format: str, output_dir: str) -> Callable[[dict, TextIO], "vtt": WriteVTT, "srt": WriteSRT, "tsv": WriteTSV, - # "json": WriteJSON, "ass": WriteASS, + "srt-word": WriteSRTWord, # "ass-char": WriteASSchar, # "pickle": WritePickle, - "srt-word": WriteSRTWord, + # "json": WriteJSON, + } + + writers_other = { + "pkl": WritePickle, + "ass-char": WriteASSchar } if output_format == "all": @@ -302,4 +309,9 @@ def get_writer(output_format: str, output_dir: str) -> Callable[[dict, TextIO], return write_all - return writers[output_format](output_dir) \ No newline at end of file + if output_format in writers: + return writers[output_format](output_dir) + elif output_format in writers_other: + return writers_other[output_format](output_dir) + else: + raise ValueError(f"Output format '{output_format}' not supported, choose from {writers.keys()} and {writers_other.keys()}") diff --git a/whisperx/vad.py b/whisperx/vad.py index fc291dd..aa86060 100644 --- a/whisperx/vad.py +++ b/whisperx/vad.py @@ -275,7 +275,10 @@ def merge_chunks(segments, chunk_size): for speech_turn in segments.get_timeline(): segments_list.append(SegmentX(speech_turn.start, speech_turn.end, "UNKNOWN")) - assert segments_list, "segments_list is empty." + if len(segments_list) == 0: + print("No active speech found in audio") + return [] + # assert segments_list, "segments_list is empty." # Make sur the starting point is the start of the segment. curr_start = segments_list[0].start