mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
support batch processing
This commit is contained in:
@ -346,6 +346,230 @@ def transcribe_with_vad(
|
||||
return output
|
||||
|
||||
|
||||
def transcribe_with_vad_parallel(
|
||||
model: "Whisper",
|
||||
audio: Union[str, np.ndarray, torch.Tensor],
|
||||
vad_pipeline,
|
||||
mel = None,
|
||||
verbose: Optional[bool] = None,
|
||||
batch_size = -1,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
Transcribe per VAD segment
|
||||
"""
|
||||
|
||||
if mel is None:
|
||||
mel = log_mel_spectrogram(audio)
|
||||
|
||||
output = {"segments": []}
|
||||
|
||||
vad_segments = vad_pipeline(audio)
|
||||
# merge segments to approx 30s inputs to make whisper most appropraite
|
||||
vad_segments = merge_chunks(vad_segments)
|
||||
|
||||
################################
|
||||
### START of parallelization ###
|
||||
################################
|
||||
|
||||
# pad mel to a same length
|
||||
start_seconds = [i['start'] for i in vad_segments]
|
||||
end_seconds = [i['end'] for i in vad_segments]
|
||||
duration_list = np.array(end_seconds) - np.array(start_seconds)
|
||||
max_length = round(30 / (HOP_LENGTH / SAMPLE_RATE))
|
||||
offset_list = np.array(start_seconds)
|
||||
chunks = []
|
||||
|
||||
for start_ts, end_ts in zip(start_seconds, end_seconds):
|
||||
start_ts = round(start_ts / (HOP_LENGTH / SAMPLE_RATE))
|
||||
end_ts = round(end_ts / (HOP_LENGTH / SAMPLE_RATE))
|
||||
chunk = mel[:, start_ts:end_ts]
|
||||
chunk = torch.nn.functional.pad(chunk, (0, max_length-chunk.shape[-1]))
|
||||
chunks.append(chunk)
|
||||
|
||||
mel_chunk = torch.stack(chunks, dim=0).to(model.device)
|
||||
# using 'decode_options1': only support single temperature decoding (no fallbacks)
|
||||
# result_list2 = model.decode(mel_chunk, decode_options1)
|
||||
|
||||
# prepare DecodingOptions
|
||||
temperatures = kwargs.pop("temperature", None)
|
||||
compression_ratio_threshold = kwargs.pop("compression_ratio_threshold", None)
|
||||
logprob_threshold = kwargs.pop("logprob_threshold", None)
|
||||
no_speech_threshold = kwargs.pop("no_speech_threshold", None)
|
||||
condition_on_previous_text = kwargs.pop("condition_on_previous_text", None)
|
||||
initial_prompt = kwargs.pop("initial_prompt", None)
|
||||
|
||||
t = 0 # TODO: does not upport temperature sweeping
|
||||
if t > 0:
|
||||
# disable beam_size and patience when t > 0
|
||||
kwargs.pop("beam_size", None)
|
||||
kwargs.pop("patience", None)
|
||||
else:
|
||||
# disable best_of when t == 0
|
||||
kwargs.pop("best_of", None)
|
||||
|
||||
options = DecodingOptions(**kwargs, temperature=t)
|
||||
mel_chunk_batches = torch.split(mel_chunk, split_size_or_sections=batch_size)
|
||||
decode_result = []
|
||||
for mel_chunk_batch in mel_chunk_batches:
|
||||
decode_result.extend(model.decode(mel_chunk_batch, options))
|
||||
|
||||
##############################
|
||||
### END of parallelization ###
|
||||
##############################
|
||||
|
||||
# post processing: get segments rfom batch-decoded results
|
||||
input_stride = exact_div(
|
||||
N_FRAMES, model.dims.n_audio_ctx
|
||||
) # mel frames per output token: 2
|
||||
language = kwargs["language"]
|
||||
task = kwargs["task"]
|
||||
tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task)
|
||||
result_segments = post_process_results(
|
||||
decode_result,
|
||||
duration_list,
|
||||
offset_list,
|
||||
input_stride,
|
||||
language,
|
||||
tokenizer,
|
||||
no_speech_threshold=no_speech_threshold,
|
||||
logprob_threshold=logprob_threshold,
|
||||
verbose=verbose)
|
||||
|
||||
# post processing: collect outputs
|
||||
assert len(result_segments) == len(vad_segments)
|
||||
for sdx, (seg_t, result) in enumerate(zip(vad_segments, result_segments)):
|
||||
seg_t["text"] = result["text"]
|
||||
output["segments"].append(
|
||||
{
|
||||
"start": seg_t["start"],
|
||||
"end": seg_t["end"],
|
||||
"language": result["language"],
|
||||
"text": result["text"],
|
||||
"seg-text": [x["text"] for x in result["segments"]],
|
||||
"seg-start": [x["start"] for x in result["segments"]],
|
||||
"seg-end": [x["end"] for x in result["segments"]],
|
||||
}
|
||||
)
|
||||
|
||||
output["language"] = output["segments"][0]["language"]
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def post_process_results(
|
||||
result_list,
|
||||
duration_list,
|
||||
offset_list,
|
||||
input_stride,
|
||||
language,
|
||||
tokenizer,
|
||||
no_speech_threshold = None,
|
||||
logprob_threshold = None,
|
||||
verbose: Optional[bool] = None,
|
||||
):
|
||||
|
||||
seek = 0
|
||||
time_precision = (
|
||||
input_stride * HOP_LENGTH / SAMPLE_RATE
|
||||
) # time per output token: 0.02 (seconds)
|
||||
all_tokens = []
|
||||
all_segments = []
|
||||
outputs = []
|
||||
|
||||
def add_segment(
|
||||
*, start: float, end: float, text_tokens: torch.Tensor, result: DecodingResult
|
||||
):
|
||||
text = tokenizer.decode([token for token in text_tokens if token < tokenizer.eot])
|
||||
if len(text.strip()) == 0: # skip empty text output
|
||||
return
|
||||
|
||||
all_segments.append(
|
||||
{
|
||||
"id": len(all_segments),
|
||||
"seek": seek,
|
||||
"start": start,
|
||||
"end": end,
|
||||
"text": text,
|
||||
"tokens": text_tokens.tolist(),
|
||||
"temperature": result.temperature,
|
||||
"avg_logprob": result.avg_logprob,
|
||||
"compression_ratio": result.compression_ratio,
|
||||
"no_speech_prob": result.no_speech_prob,
|
||||
}
|
||||
)
|
||||
if verbose:
|
||||
print(f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}")
|
||||
|
||||
# process the output
|
||||
for result, segment_duration, timestamp_offset in zip(result_list, duration_list, offset_list):
|
||||
all_tokens = []
|
||||
all_segments = []
|
||||
|
||||
# segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE
|
||||
segment_shape = int(segment_duration / (HOP_LENGTH / SAMPLE_RATE))
|
||||
tokens = torch.tensor(result.tokens)
|
||||
|
||||
if no_speech_threshold is not None:
|
||||
# no voice activity check
|
||||
should_skip = result.no_speech_prob > no_speech_threshold
|
||||
if logprob_threshold is not None and result.avg_logprob > logprob_threshold:
|
||||
# don't skip if the logprob is high enough, despite the no_speech_prob
|
||||
should_skip = False
|
||||
|
||||
if should_skip:
|
||||
seek += segment_shape # fast-forward to the next segment boundary
|
||||
continue
|
||||
|
||||
timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
|
||||
consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0].add_(1)
|
||||
|
||||
if len(consecutive) > 0: # if the output contains two consecutive timestamp tokens
|
||||
last_slice = 0
|
||||
for current_slice in consecutive:
|
||||
sliced_tokens = tokens[last_slice:current_slice]
|
||||
start_timestamp_position = (
|
||||
sliced_tokens[0].item() - tokenizer.timestamp_begin
|
||||
)
|
||||
end_timestamp_position = (
|
||||
sliced_tokens[-1].item() - tokenizer.timestamp_begin
|
||||
)
|
||||
add_segment(
|
||||
start=timestamp_offset + start_timestamp_position * time_precision,
|
||||
end=timestamp_offset + end_timestamp_position * time_precision,
|
||||
text_tokens=sliced_tokens[1:-1],
|
||||
result=result,
|
||||
)
|
||||
last_slice = current_slice
|
||||
last_timestamp_position = (
|
||||
tokens[last_slice - 1].item() - tokenizer.timestamp_begin
|
||||
)
|
||||
seek += last_timestamp_position * input_stride
|
||||
all_tokens.extend(tokens[: last_slice + 1].tolist())
|
||||
else:
|
||||
duration = segment_duration
|
||||
timestamps = tokens[timestamp_tokens.nonzero().flatten()]
|
||||
if len(timestamps) > 0 and timestamps[-1].item() != tokenizer.timestamp_begin:
|
||||
# no consecutive timestamps but it has a timestamp; use the last one.
|
||||
# single timestamp at the end means no speech after the last timestamp.
|
||||
last_timestamp_position = timestamps[-1].item() - tokenizer.timestamp_begin
|
||||
duration = last_timestamp_position * time_precision
|
||||
|
||||
add_segment(
|
||||
start=timestamp_offset,
|
||||
end=timestamp_offset + duration,
|
||||
text_tokens=tokens,
|
||||
result=result,
|
||||
)
|
||||
|
||||
seek += segment_shape
|
||||
all_tokens.extend(tokens.tolist())
|
||||
|
||||
outputs.append(dict(text=tokenizer.decode(all_tokens), segments=all_segments, language=language))
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
def cli():
|
||||
from . import available_models
|
||||
|
||||
@ -362,6 +586,7 @@ def cli():
|
||||
# vad params
|
||||
parser.add_argument("--vad_filter", action="store_true", help="Whether to first perform VAD filtering to target only transcribe within VAD. Produces more accurate alignment + timestamp, requires more GPU memory & compute.")
|
||||
parser.add_argument("--vad_input", default=None, type=str)
|
||||
parser.add_argument("--parallel_bs", default=-1, type=int, help="Enable parallel transcribing if > 1")
|
||||
# diarization params
|
||||
parser.add_argument("--diarize", action='store_true')
|
||||
parser.add_argument("--min_speakers", default=None, type=int)
|
||||
@ -408,6 +633,7 @@ def cli():
|
||||
hf_token: str = args.pop("hf_token")
|
||||
vad_filter: bool = args.pop("vad_filter")
|
||||
vad_input: bool = args.pop("vad_input")
|
||||
parallel_bs: int = args.pop("parallel_bs")
|
||||
|
||||
diarize: bool = args.pop("diarize")
|
||||
min_speakers: int = args.pop("min_speakers")
|
||||
@ -454,8 +680,12 @@ def cli():
|
||||
|
||||
for audio_path in args.pop("audio"):
|
||||
if vad_filter:
|
||||
print("Performing VAD...")
|
||||
result = transcribe_with_vad(model, audio_path, vad_pipeline, temperature=temperature, **args)
|
||||
if parallel_bs > 1:
|
||||
print("Performing VAD and parallel transcribing ...")
|
||||
result = transcribe_with_vad_parallel(model, audio_path, vad_pipeline, temperature=temperature, batch_size=parallel_bs, **args)
|
||||
else:
|
||||
print("Performing VAD...")
|
||||
result = transcribe_with_vad(model, audio_path, vad_pipeline, temperature=temperature, **args)
|
||||
else:
|
||||
print("Performing transcription...")
|
||||
result = transcribe(model, audio_path, temperature=temperature, **args)
|
||||
|
Reference in New Issue
Block a user