diff --git a/README.md b/README.md index fc7254c..26f8db0 100644 --- a/README.md +++ b/README.md @@ -50,6 +50,7 @@ This repository refines the timestamps of openAI's Whisper model via forced alig

New🚨

+- Batch processing: Add `--vad_filter --parallel_bs [int]` for transcribing long audio file in batches (only supported with VAD filtering). Replace `[int]` with a batch size that fits your GPU memory, e.g. `--parallel_bs 16`. - VAD filtering: Voice Activity Detection (VAD) from [Pyannote.audio](https://huggingface.co/pyannote/voice-activity-detection) is used as a preprocessing step to remove reliance on whisper timestamps and only transcribe audio segments containing speech. add `--vad_filter` flag, increases timestamp accuracy and robustness (requires more GPU mem due to 30s inputs in wav2vec2) - Character level timestamps (see `*.char.ass` file output) - Diarization (still in beta, add `--diarize`) diff --git a/whisperx/__init__.py b/whisperx/__init__.py index b897f01..1d361e4 100644 --- a/whisperx/__init__.py +++ b/whisperx/__init__.py @@ -11,7 +11,7 @@ from tqdm import tqdm from .audio import load_audio, log_mel_spectrogram, pad_or_trim from .decoding import DecodingOptions, DecodingResult, decode, detect_language from .model import Whisper, ModelDimensions -from .transcribe import transcribe, transcribe_with_vad +from .transcribe import transcribe, transcribe_with_vad, transcribe_with_vad_parallel from .alignment import load_align_model, align _MODELS = { "tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt", diff --git a/whisperx/transcribe.py b/whisperx/transcribe.py index 25fbcad..44df48a 100644 --- a/whisperx/transcribe.py +++ b/whisperx/transcribe.py @@ -346,6 +346,230 @@ def transcribe_with_vad( return output +def transcribe_with_vad_parallel( + model: "Whisper", + audio: Union[str, np.ndarray, torch.Tensor], + vad_pipeline, + mel = None, + verbose: Optional[bool] = None, + batch_size = -1, + **kwargs +): + """ + Transcribe per VAD segment + """ + + if mel is None: + mel = log_mel_spectrogram(audio) + + output = {"segments": []} + + vad_segments = vad_pipeline(audio) + # merge segments to approx 30s inputs to make whisper most appropraite + vad_segments = merge_chunks(vad_segments) + + ################################ + ### START of parallelization ### + ################################ + + # pad mel to a same length + start_seconds = [i['start'] for i in vad_segments] + end_seconds = [i['end'] for i in vad_segments] + duration_list = np.array(end_seconds) - np.array(start_seconds) + max_length = round(30 / (HOP_LENGTH / SAMPLE_RATE)) + offset_list = np.array(start_seconds) + chunks = [] + + for start_ts, end_ts in zip(start_seconds, end_seconds): + start_ts = round(start_ts / (HOP_LENGTH / SAMPLE_RATE)) + end_ts = round(end_ts / (HOP_LENGTH / SAMPLE_RATE)) + chunk = mel[:, start_ts:end_ts] + chunk = torch.nn.functional.pad(chunk, (0, max_length-chunk.shape[-1])) + chunks.append(chunk) + + mel_chunk = torch.stack(chunks, dim=0).to(model.device) + # using 'decode_options1': only support single temperature decoding (no fallbacks) + # result_list2 = model.decode(mel_chunk, decode_options1) + + # prepare DecodingOptions + temperatures = kwargs.pop("temperature", None) + compression_ratio_threshold = kwargs.pop("compression_ratio_threshold", None) + logprob_threshold = kwargs.pop("logprob_threshold", None) + no_speech_threshold = kwargs.pop("no_speech_threshold", None) + condition_on_previous_text = kwargs.pop("condition_on_previous_text", None) + initial_prompt = kwargs.pop("initial_prompt", None) + + t = 0 # TODO: does not upport temperature sweeping + if t > 0: + # disable beam_size and patience when t > 0 + kwargs.pop("beam_size", None) + kwargs.pop("patience", None) + else: + # disable best_of when t == 0 + kwargs.pop("best_of", None) + + options = DecodingOptions(**kwargs, temperature=t) + mel_chunk_batches = torch.split(mel_chunk, split_size_or_sections=batch_size) + decode_result = [] + for mel_chunk_batch in mel_chunk_batches: + decode_result.extend(model.decode(mel_chunk_batch, options)) + + ############################## + ### END of parallelization ### + ############################## + + # post processing: get segments rfom batch-decoded results + input_stride = exact_div( + N_FRAMES, model.dims.n_audio_ctx + ) # mel frames per output token: 2 + language = kwargs["language"] + task = kwargs["task"] + tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task) + result_segments = post_process_results( + decode_result, + duration_list, + offset_list, + input_stride, + language, + tokenizer, + no_speech_threshold=no_speech_threshold, + logprob_threshold=logprob_threshold, + verbose=verbose) + + # post processing: collect outputs + assert len(result_segments) == len(vad_segments) + for sdx, (seg_t, result) in enumerate(zip(vad_segments, result_segments)): + seg_t["text"] = result["text"] + output["segments"].append( + { + "start": seg_t["start"], + "end": seg_t["end"], + "language": result["language"], + "text": result["text"], + "seg-text": [x["text"] for x in result["segments"]], + "seg-start": [x["start"] for x in result["segments"]], + "seg-end": [x["end"] for x in result["segments"]], + } + ) + + output["language"] = output["segments"][0]["language"] + + return output + + +def post_process_results( + result_list, + duration_list, + offset_list, + input_stride, + language, + tokenizer, + no_speech_threshold = None, + logprob_threshold = None, + verbose: Optional[bool] = None, + ): + + seek = 0 + time_precision = ( + input_stride * HOP_LENGTH / SAMPLE_RATE + ) # time per output token: 0.02 (seconds) + all_tokens = [] + all_segments = [] + outputs = [] + + def add_segment( + *, start: float, end: float, text_tokens: torch.Tensor, result: DecodingResult + ): + text = tokenizer.decode([token for token in text_tokens if token < tokenizer.eot]) + if len(text.strip()) == 0: # skip empty text output + return + + all_segments.append( + { + "id": len(all_segments), + "seek": seek, + "start": start, + "end": end, + "text": text, + "tokens": text_tokens.tolist(), + "temperature": result.temperature, + "avg_logprob": result.avg_logprob, + "compression_ratio": result.compression_ratio, + "no_speech_prob": result.no_speech_prob, + } + ) + if verbose: + print(f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}") + + # process the output + for result, segment_duration, timestamp_offset in zip(result_list, duration_list, offset_list): + all_tokens = [] + all_segments = [] + + # segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE + segment_shape = int(segment_duration / (HOP_LENGTH / SAMPLE_RATE)) + tokens = torch.tensor(result.tokens) + + if no_speech_threshold is not None: + # no voice activity check + should_skip = result.no_speech_prob > no_speech_threshold + if logprob_threshold is not None and result.avg_logprob > logprob_threshold: + # don't skip if the logprob is high enough, despite the no_speech_prob + should_skip = False + + if should_skip: + seek += segment_shape # fast-forward to the next segment boundary + continue + + timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin) + consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0].add_(1) + + if len(consecutive) > 0: # if the output contains two consecutive timestamp tokens + last_slice = 0 + for current_slice in consecutive: + sliced_tokens = tokens[last_slice:current_slice] + start_timestamp_position = ( + sliced_tokens[0].item() - tokenizer.timestamp_begin + ) + end_timestamp_position = ( + sliced_tokens[-1].item() - tokenizer.timestamp_begin + ) + add_segment( + start=timestamp_offset + start_timestamp_position * time_precision, + end=timestamp_offset + end_timestamp_position * time_precision, + text_tokens=sliced_tokens[1:-1], + result=result, + ) + last_slice = current_slice + last_timestamp_position = ( + tokens[last_slice - 1].item() - tokenizer.timestamp_begin + ) + seek += last_timestamp_position * input_stride + all_tokens.extend(tokens[: last_slice + 1].tolist()) + else: + duration = segment_duration + timestamps = tokens[timestamp_tokens.nonzero().flatten()] + if len(timestamps) > 0 and timestamps[-1].item() != tokenizer.timestamp_begin: + # no consecutive timestamps but it has a timestamp; use the last one. + # single timestamp at the end means no speech after the last timestamp. + last_timestamp_position = timestamps[-1].item() - tokenizer.timestamp_begin + duration = last_timestamp_position * time_precision + + add_segment( + start=timestamp_offset, + end=timestamp_offset + duration, + text_tokens=tokens, + result=result, + ) + + seek += segment_shape + all_tokens.extend(tokens.tolist()) + + outputs.append(dict(text=tokenizer.decode(all_tokens), segments=all_segments, language=language)) + + return outputs + + def cli(): from . import available_models @@ -362,6 +586,7 @@ def cli(): # vad params parser.add_argument("--vad_filter", action="store_true", help="Whether to first perform VAD filtering to target only transcribe within VAD. Produces more accurate alignment + timestamp, requires more GPU memory & compute.") parser.add_argument("--vad_input", default=None, type=str) + parser.add_argument("--parallel_bs", default=-1, type=int, help="Enable parallel transcribing if > 1") # diarization params parser.add_argument("--diarize", action='store_true') parser.add_argument("--min_speakers", default=None, type=int) @@ -408,6 +633,7 @@ def cli(): hf_token: str = args.pop("hf_token") vad_filter: bool = args.pop("vad_filter") vad_input: bool = args.pop("vad_input") + parallel_bs: int = args.pop("parallel_bs") diarize: bool = args.pop("diarize") min_speakers: int = args.pop("min_speakers") @@ -454,8 +680,12 @@ def cli(): for audio_path in args.pop("audio"): if vad_filter: - print("Performing VAD...") - result = transcribe_with_vad(model, audio_path, vad_pipeline, temperature=temperature, **args) + if parallel_bs > 1: + print("Performing VAD and parallel transcribing ...") + result = transcribe_with_vad_parallel(model, audio_path, vad_pipeline, temperature=temperature, batch_size=parallel_bs, **args) + else: + print("Performing VAD...") + result = transcribe_with_vad(model, audio_path, vad_pipeline, temperature=temperature, **args) else: print("Performing transcription...") result = transcribe(model, audio_path, temperature=temperature, **args)