support batch processing

This commit is contained in:
Tengda Han
2023-02-01 19:41:20 +00:00
parent fd2a093754
commit 039af89a86
3 changed files with 234 additions and 3 deletions

View File

@ -346,6 +346,230 @@ def transcribe_with_vad(
return output
def transcribe_with_vad_parallel(
model: "Whisper",
audio: Union[str, np.ndarray, torch.Tensor],
vad_pipeline,
mel = None,
verbose: Optional[bool] = None,
batch_size = -1,
**kwargs
):
"""
Transcribe per VAD segment
"""
if mel is None:
mel = log_mel_spectrogram(audio)
output = {"segments": []}
vad_segments = vad_pipeline(audio)
# merge segments to approx 30s inputs to make whisper most appropraite
vad_segments = merge_chunks(vad_segments)
################################
### START of parallelization ###
################################
# pad mel to a same length
start_seconds = [i['start'] for i in vad_segments]
end_seconds = [i['end'] for i in vad_segments]
duration_list = np.array(end_seconds) - np.array(start_seconds)
max_length = round(30 / (HOP_LENGTH / SAMPLE_RATE))
offset_list = np.array(start_seconds)
chunks = []
for start_ts, end_ts in zip(start_seconds, end_seconds):
start_ts = round(start_ts / (HOP_LENGTH / SAMPLE_RATE))
end_ts = round(end_ts / (HOP_LENGTH / SAMPLE_RATE))
chunk = mel[:, start_ts:end_ts]
chunk = torch.nn.functional.pad(chunk, (0, max_length-chunk.shape[-1]))
chunks.append(chunk)
mel_chunk = torch.stack(chunks, dim=0).to(model.device)
# using 'decode_options1': only support single temperature decoding (no fallbacks)
# result_list2 = model.decode(mel_chunk, decode_options1)
# prepare DecodingOptions
temperatures = kwargs.pop("temperature", None)
compression_ratio_threshold = kwargs.pop("compression_ratio_threshold", None)
logprob_threshold = kwargs.pop("logprob_threshold", None)
no_speech_threshold = kwargs.pop("no_speech_threshold", None)
condition_on_previous_text = kwargs.pop("condition_on_previous_text", None)
initial_prompt = kwargs.pop("initial_prompt", None)
t = 0 # TODO: does not upport temperature sweeping
if t > 0:
# disable beam_size and patience when t > 0
kwargs.pop("beam_size", None)
kwargs.pop("patience", None)
else:
# disable best_of when t == 0
kwargs.pop("best_of", None)
options = DecodingOptions(**kwargs, temperature=t)
mel_chunk_batches = torch.split(mel_chunk, split_size_or_sections=batch_size)
decode_result = []
for mel_chunk_batch in mel_chunk_batches:
decode_result.extend(model.decode(mel_chunk_batch, options))
##############################
### END of parallelization ###
##############################
# post processing: get segments rfom batch-decoded results
input_stride = exact_div(
N_FRAMES, model.dims.n_audio_ctx
) # mel frames per output token: 2
language = kwargs["language"]
task = kwargs["task"]
tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task)
result_segments = post_process_results(
decode_result,
duration_list,
offset_list,
input_stride,
language,
tokenizer,
no_speech_threshold=no_speech_threshold,
logprob_threshold=logprob_threshold,
verbose=verbose)
# post processing: collect outputs
assert len(result_segments) == len(vad_segments)
for sdx, (seg_t, result) in enumerate(zip(vad_segments, result_segments)):
seg_t["text"] = result["text"]
output["segments"].append(
{
"start": seg_t["start"],
"end": seg_t["end"],
"language": result["language"],
"text": result["text"],
"seg-text": [x["text"] for x in result["segments"]],
"seg-start": [x["start"] for x in result["segments"]],
"seg-end": [x["end"] for x in result["segments"]],
}
)
output["language"] = output["segments"][0]["language"]
return output
def post_process_results(
result_list,
duration_list,
offset_list,
input_stride,
language,
tokenizer,
no_speech_threshold = None,
logprob_threshold = None,
verbose: Optional[bool] = None,
):
seek = 0
time_precision = (
input_stride * HOP_LENGTH / SAMPLE_RATE
) # time per output token: 0.02 (seconds)
all_tokens = []
all_segments = []
outputs = []
def add_segment(
*, start: float, end: float, text_tokens: torch.Tensor, result: DecodingResult
):
text = tokenizer.decode([token for token in text_tokens if token < tokenizer.eot])
if len(text.strip()) == 0: # skip empty text output
return
all_segments.append(
{
"id": len(all_segments),
"seek": seek,
"start": start,
"end": end,
"text": text,
"tokens": text_tokens.tolist(),
"temperature": result.temperature,
"avg_logprob": result.avg_logprob,
"compression_ratio": result.compression_ratio,
"no_speech_prob": result.no_speech_prob,
}
)
if verbose:
print(f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}")
# process the output
for result, segment_duration, timestamp_offset in zip(result_list, duration_list, offset_list):
all_tokens = []
all_segments = []
# segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE
segment_shape = int(segment_duration / (HOP_LENGTH / SAMPLE_RATE))
tokens = torch.tensor(result.tokens)
if no_speech_threshold is not None:
# no voice activity check
should_skip = result.no_speech_prob > no_speech_threshold
if logprob_threshold is not None and result.avg_logprob > logprob_threshold:
# don't skip if the logprob is high enough, despite the no_speech_prob
should_skip = False
if should_skip:
seek += segment_shape # fast-forward to the next segment boundary
continue
timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0].add_(1)
if len(consecutive) > 0: # if the output contains two consecutive timestamp tokens
last_slice = 0
for current_slice in consecutive:
sliced_tokens = tokens[last_slice:current_slice]
start_timestamp_position = (
sliced_tokens[0].item() - tokenizer.timestamp_begin
)
end_timestamp_position = (
sliced_tokens[-1].item() - tokenizer.timestamp_begin
)
add_segment(
start=timestamp_offset + start_timestamp_position * time_precision,
end=timestamp_offset + end_timestamp_position * time_precision,
text_tokens=sliced_tokens[1:-1],
result=result,
)
last_slice = current_slice
last_timestamp_position = (
tokens[last_slice - 1].item() - tokenizer.timestamp_begin
)
seek += last_timestamp_position * input_stride
all_tokens.extend(tokens[: last_slice + 1].tolist())
else:
duration = segment_duration
timestamps = tokens[timestamp_tokens.nonzero().flatten()]
if len(timestamps) > 0 and timestamps[-1].item() != tokenizer.timestamp_begin:
# no consecutive timestamps but it has a timestamp; use the last one.
# single timestamp at the end means no speech after the last timestamp.
last_timestamp_position = timestamps[-1].item() - tokenizer.timestamp_begin
duration = last_timestamp_position * time_precision
add_segment(
start=timestamp_offset,
end=timestamp_offset + duration,
text_tokens=tokens,
result=result,
)
seek += segment_shape
all_tokens.extend(tokens.tolist())
outputs.append(dict(text=tokenizer.decode(all_tokens), segments=all_segments, language=language))
return outputs
def cli():
from . import available_models
@ -362,6 +586,7 @@ def cli():
# vad params
parser.add_argument("--vad_filter", action="store_true", help="Whether to first perform VAD filtering to target only transcribe within VAD. Produces more accurate alignment + timestamp, requires more GPU memory & compute.")
parser.add_argument("--vad_input", default=None, type=str)
parser.add_argument("--parallel_bs", default=-1, type=int, help="Enable parallel transcribing if > 1")
# diarization params
parser.add_argument("--diarize", action='store_true')
parser.add_argument("--min_speakers", default=None, type=int)
@ -408,6 +633,7 @@ def cli():
hf_token: str = args.pop("hf_token")
vad_filter: bool = args.pop("vad_filter")
vad_input: bool = args.pop("vad_input")
parallel_bs: int = args.pop("parallel_bs")
diarize: bool = args.pop("diarize")
min_speakers: int = args.pop("min_speakers")
@ -454,8 +680,12 @@ def cli():
for audio_path in args.pop("audio"):
if vad_filter:
print("Performing VAD...")
result = transcribe_with_vad(model, audio_path, vad_pipeline, temperature=temperature, **args)
if parallel_bs > 1:
print("Performing VAD and parallel transcribing ...")
result = transcribe_with_vad_parallel(model, audio_path, vad_pipeline, temperature=temperature, batch_size=parallel_bs, **args)
else:
print("Performing VAD...")
result = transcribe_with_vad(model, audio_path, vad_pipeline, temperature=temperature, **args)
else:
print("Performing transcription...")
result = transcribe(model, audio_path, temperature=temperature, **args)