mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
Merge pull request #96 from smly/fix-batch-processing
FIX: Assertion error in batch processing
This commit is contained in:
@ -365,8 +365,6 @@ def transcribe_with_vad_parallel(
|
|||||||
if mel is None:
|
if mel is None:
|
||||||
mel = log_mel_spectrogram(audio)
|
mel = log_mel_spectrogram(audio)
|
||||||
|
|
||||||
output = {"segments": []}
|
|
||||||
|
|
||||||
vad_segments = vad_pipeline(audio)
|
vad_segments = vad_pipeline(audio)
|
||||||
# merge segments to approx 30s inputs to make whisper most appropraite
|
# merge segments to approx 30s inputs to make whisper most appropraite
|
||||||
vad_segments = merge_chunks(vad_segments)
|
vad_segments = merge_chunks(vad_segments)
|
||||||
@ -428,7 +426,9 @@ def transcribe_with_vad_parallel(
|
|||||||
language = kwargs["language"]
|
language = kwargs["language"]
|
||||||
task = kwargs["task"]
|
task = kwargs["task"]
|
||||||
tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task)
|
tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task)
|
||||||
result_segments = post_process_results(
|
|
||||||
|
output = post_process_results(
|
||||||
|
vad_segments,
|
||||||
decode_result,
|
decode_result,
|
||||||
duration_list,
|
duration_list,
|
||||||
offset_list,
|
offset_list,
|
||||||
@ -438,29 +438,11 @@ def transcribe_with_vad_parallel(
|
|||||||
no_speech_threshold=no_speech_threshold,
|
no_speech_threshold=no_speech_threshold,
|
||||||
logprob_threshold=logprob_threshold,
|
logprob_threshold=logprob_threshold,
|
||||||
verbose=verbose)
|
verbose=verbose)
|
||||||
|
|
||||||
# post processing: collect outputs
|
|
||||||
assert len(result_segments) == len(vad_segments)
|
|
||||||
for sdx, (seg_t, result) in enumerate(zip(vad_segments, result_segments)):
|
|
||||||
seg_t["text"] = result["text"]
|
|
||||||
output["segments"].append(
|
|
||||||
{
|
|
||||||
"start": seg_t["start"],
|
|
||||||
"end": seg_t["end"],
|
|
||||||
"language": result["language"],
|
|
||||||
"text": result["text"],
|
|
||||||
"seg-text": [x["text"] for x in result["segments"]],
|
|
||||||
"seg-start": [x["start"] for x in result["segments"]],
|
|
||||||
"seg-end": [x["end"] for x in result["segments"]],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
output["language"] = output["segments"][0]["language"]
|
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
def post_process_results(
|
def post_process_results(
|
||||||
|
vad_segments,
|
||||||
result_list,
|
result_list,
|
||||||
duration_list,
|
duration_list,
|
||||||
offset_list,
|
offset_list,
|
||||||
@ -478,7 +460,7 @@ def post_process_results(
|
|||||||
) # time per output token: 0.02 (seconds)
|
) # time per output token: 0.02 (seconds)
|
||||||
all_tokens = []
|
all_tokens = []
|
||||||
all_segments = []
|
all_segments = []
|
||||||
outputs = []
|
output = {"segments": []}
|
||||||
|
|
||||||
def add_segment(
|
def add_segment(
|
||||||
*, start: float, end: float, text_tokens: torch.Tensor, result: DecodingResult
|
*, start: float, end: float, text_tokens: torch.Tensor, result: DecodingResult
|
||||||
@ -505,7 +487,7 @@ def post_process_results(
|
|||||||
print(f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}")
|
print(f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}")
|
||||||
|
|
||||||
# process the output
|
# process the output
|
||||||
for result, segment_duration, timestamp_offset in zip(result_list, duration_list, offset_list):
|
for seg_t, result, segment_duration, timestamp_offset in zip(vad_segments, result_list, duration_list, offset_list):
|
||||||
all_tokens = []
|
all_tokens = []
|
||||||
all_segments = []
|
all_segments = []
|
||||||
|
|
||||||
@ -568,9 +550,22 @@ def post_process_results(
|
|||||||
seek += segment_shape
|
seek += segment_shape
|
||||||
all_tokens.extend(tokens.tolist())
|
all_tokens.extend(tokens.tolist())
|
||||||
|
|
||||||
outputs.append(dict(text=tokenizer.decode(all_tokens), segments=all_segments, language=language))
|
result = dict(text=tokenizer.decode(all_tokens), segments=all_segments, language=language)
|
||||||
|
output["segments"].append(
|
||||||
|
{
|
||||||
|
"start": seg_t["start"],
|
||||||
|
"end": seg_t["end"],
|
||||||
|
"language": result["language"],
|
||||||
|
"text": result["text"],
|
||||||
|
"seg-text": [x["text"] for x in result["segments"]],
|
||||||
|
"seg-start": [x["start"] for x in result["segments"]],
|
||||||
|
"seg-end": [x["end"] for x in result["segments"]],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return outputs
|
output["language"] = output["segments"][0]["language"]
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
def cli():
|
def cli():
|
||||||
|
Reference in New Issue
Block a user