mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
Merge pull request #96 from smly/fix-batch-processing
FIX: Assertion error in batch processing
This commit is contained in:
@ -365,8 +365,6 @@ def transcribe_with_vad_parallel(
|
||||
if mel is None:
|
||||
mel = log_mel_spectrogram(audio)
|
||||
|
||||
output = {"segments": []}
|
||||
|
||||
vad_segments = vad_pipeline(audio)
|
||||
# merge segments to approx 30s inputs to make whisper most appropraite
|
||||
vad_segments = merge_chunks(vad_segments)
|
||||
@ -428,7 +426,9 @@ def transcribe_with_vad_parallel(
|
||||
language = kwargs["language"]
|
||||
task = kwargs["task"]
|
||||
tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task)
|
||||
result_segments = post_process_results(
|
||||
|
||||
output = post_process_results(
|
||||
vad_segments,
|
||||
decode_result,
|
||||
duration_list,
|
||||
offset_list,
|
||||
@ -438,29 +438,11 @@ def transcribe_with_vad_parallel(
|
||||
no_speech_threshold=no_speech_threshold,
|
||||
logprob_threshold=logprob_threshold,
|
||||
verbose=verbose)
|
||||
|
||||
# post processing: collect outputs
|
||||
assert len(result_segments) == len(vad_segments)
|
||||
for sdx, (seg_t, result) in enumerate(zip(vad_segments, result_segments)):
|
||||
seg_t["text"] = result["text"]
|
||||
output["segments"].append(
|
||||
{
|
||||
"start": seg_t["start"],
|
||||
"end": seg_t["end"],
|
||||
"language": result["language"],
|
||||
"text": result["text"],
|
||||
"seg-text": [x["text"] for x in result["segments"]],
|
||||
"seg-start": [x["start"] for x in result["segments"]],
|
||||
"seg-end": [x["end"] for x in result["segments"]],
|
||||
}
|
||||
)
|
||||
|
||||
output["language"] = output["segments"][0]["language"]
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def post_process_results(
|
||||
vad_segments,
|
||||
result_list,
|
||||
duration_list,
|
||||
offset_list,
|
||||
@ -478,7 +460,7 @@ def post_process_results(
|
||||
) # time per output token: 0.02 (seconds)
|
||||
all_tokens = []
|
||||
all_segments = []
|
||||
outputs = []
|
||||
output = {"segments": []}
|
||||
|
||||
def add_segment(
|
||||
*, start: float, end: float, text_tokens: torch.Tensor, result: DecodingResult
|
||||
@ -505,7 +487,7 @@ def post_process_results(
|
||||
print(f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}")
|
||||
|
||||
# process the output
|
||||
for result, segment_duration, timestamp_offset in zip(result_list, duration_list, offset_list):
|
||||
for seg_t, result, segment_duration, timestamp_offset in zip(vad_segments, result_list, duration_list, offset_list):
|
||||
all_tokens = []
|
||||
all_segments = []
|
||||
|
||||
@ -568,9 +550,22 @@ def post_process_results(
|
||||
seek += segment_shape
|
||||
all_tokens.extend(tokens.tolist())
|
||||
|
||||
outputs.append(dict(text=tokenizer.decode(all_tokens), segments=all_segments, language=language))
|
||||
result = dict(text=tokenizer.decode(all_tokens), segments=all_segments, language=language)
|
||||
output["segments"].append(
|
||||
{
|
||||
"start": seg_t["start"],
|
||||
"end": seg_t["end"],
|
||||
"language": result["language"],
|
||||
"text": result["text"],
|
||||
"seg-text": [x["text"] for x in result["segments"]],
|
||||
"seg-start": [x["start"] for x in result["segments"]],
|
||||
"seg-end": [x["end"] for x in result["segments"]],
|
||||
}
|
||||
)
|
||||
|
||||
return outputs
|
||||
output["language"] = output["segments"][0]["language"]
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def cli():
|
||||
|
Reference in New Issue
Block a user