From 27fe5023442cace6f7adb36b3edf85575598ba30 Mon Sep 17 00:00:00 2001 From: smly Date: Wed, 22 Feb 2023 02:45:13 +0900 Subject: [PATCH] Fix assertion error in batch processing --- whisperx/transcribe.py | 47 +++++++++++++++++++----------------------- 1 file changed, 21 insertions(+), 26 deletions(-) diff --git a/whisperx/transcribe.py b/whisperx/transcribe.py index 3e06dc1..6a31405 100644 --- a/whisperx/transcribe.py +++ b/whisperx/transcribe.py @@ -365,8 +365,6 @@ def transcribe_with_vad_parallel( if mel is None: mel = log_mel_spectrogram(audio) - output = {"segments": []} - vad_segments = vad_pipeline(audio) # merge segments to approx 30s inputs to make whisper most appropraite vad_segments = merge_chunks(vad_segments) @@ -428,7 +426,9 @@ def transcribe_with_vad_parallel( language = kwargs["language"] task = kwargs["task"] tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task) - result_segments = post_process_results( + + output = post_process_results( + vad_segments, decode_result, duration_list, offset_list, @@ -438,29 +438,11 @@ def transcribe_with_vad_parallel( no_speech_threshold=no_speech_threshold, logprob_threshold=logprob_threshold, verbose=verbose) - - # post processing: collect outputs - assert len(result_segments) == len(vad_segments) - for sdx, (seg_t, result) in enumerate(zip(vad_segments, result_segments)): - seg_t["text"] = result["text"] - output["segments"].append( - { - "start": seg_t["start"], - "end": seg_t["end"], - "language": result["language"], - "text": result["text"], - "seg-text": [x["text"] for x in result["segments"]], - "seg-start": [x["start"] for x in result["segments"]], - "seg-end": [x["end"] for x in result["segments"]], - } - ) - - output["language"] = output["segments"][0]["language"] - return output def post_process_results( + vad_segments, result_list, duration_list, offset_list, @@ -478,7 +460,7 @@ def post_process_results( ) # time per output token: 0.02 (seconds) all_tokens = [] all_segments = [] - outputs = [] + output = {"segments": []} def add_segment( *, start: float, end: float, text_tokens: torch.Tensor, result: DecodingResult @@ -505,7 +487,7 @@ def post_process_results( print(f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}") # process the output - for result, segment_duration, timestamp_offset in zip(result_list, duration_list, offset_list): + for seg_t, result, segment_duration, timestamp_offset in zip(vad_segments, result_list, duration_list, offset_list): all_tokens = [] all_segments = [] @@ -568,9 +550,22 @@ def post_process_results( seek += segment_shape all_tokens.extend(tokens.tolist()) - outputs.append(dict(text=tokenizer.decode(all_tokens), segments=all_segments, language=language)) + result = dict(text=tokenizer.decode(all_tokens), segments=all_segments, language=language) + output["segments"].append( + { + "start": seg_t["start"], + "end": seg_t["end"], + "language": result["language"], + "text": result["text"], + "seg-text": [x["text"] for x in result["segments"]], + "seg-start": [x["start"] for x in result["segments"]], + "seg-end": [x["end"] for x in result["segments"]], + } + ) - return outputs + output["language"] = output["segments"][0]["language"] + + return output def cli():