diff --git a/README.md b/README.md
index fc7254c..26f8db0 100644
--- a/README.md
+++ b/README.md
@@ -50,6 +50,7 @@ This repository refines the timestamps of openAI's Whisper model via forced alig
New🚨
+- Batch processing: Add `--vad_filter --parallel_bs [int]` for transcribing long audio file in batches (only supported with VAD filtering). Replace `[int]` with a batch size that fits your GPU memory, e.g. `--parallel_bs 16`.
- VAD filtering: Voice Activity Detection (VAD) from [Pyannote.audio](https://huggingface.co/pyannote/voice-activity-detection) is used as a preprocessing step to remove reliance on whisper timestamps and only transcribe audio segments containing speech. add `--vad_filter` flag, increases timestamp accuracy and robustness (requires more GPU mem due to 30s inputs in wav2vec2)
- Character level timestamps (see `*.char.ass` file output)
- Diarization (still in beta, add `--diarize`)
diff --git a/whisperx/__init__.py b/whisperx/__init__.py
index b897f01..1d361e4 100644
--- a/whisperx/__init__.py
+++ b/whisperx/__init__.py
@@ -11,7 +11,7 @@ from tqdm import tqdm
from .audio import load_audio, log_mel_spectrogram, pad_or_trim
from .decoding import DecodingOptions, DecodingResult, decode, detect_language
from .model import Whisper, ModelDimensions
-from .transcribe import transcribe, transcribe_with_vad
+from .transcribe import transcribe, transcribe_with_vad, transcribe_with_vad_parallel
from .alignment import load_align_model, align
_MODELS = {
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
diff --git a/whisperx/transcribe.py b/whisperx/transcribe.py
index 25fbcad..44df48a 100644
--- a/whisperx/transcribe.py
+++ b/whisperx/transcribe.py
@@ -346,6 +346,230 @@ def transcribe_with_vad(
return output
+def transcribe_with_vad_parallel(
+ model: "Whisper",
+ audio: Union[str, np.ndarray, torch.Tensor],
+ vad_pipeline,
+ mel = None,
+ verbose: Optional[bool] = None,
+ batch_size = -1,
+ **kwargs
+):
+ """
+ Transcribe per VAD segment
+ """
+
+ if mel is None:
+ mel = log_mel_spectrogram(audio)
+
+ output = {"segments": []}
+
+ vad_segments = vad_pipeline(audio)
+ # merge segments to approx 30s inputs to make whisper most appropraite
+ vad_segments = merge_chunks(vad_segments)
+
+ ################################
+ ### START of parallelization ###
+ ################################
+
+ # pad mel to a same length
+ start_seconds = [i['start'] for i in vad_segments]
+ end_seconds = [i['end'] for i in vad_segments]
+ duration_list = np.array(end_seconds) - np.array(start_seconds)
+ max_length = round(30 / (HOP_LENGTH / SAMPLE_RATE))
+ offset_list = np.array(start_seconds)
+ chunks = []
+
+ for start_ts, end_ts in zip(start_seconds, end_seconds):
+ start_ts = round(start_ts / (HOP_LENGTH / SAMPLE_RATE))
+ end_ts = round(end_ts / (HOP_LENGTH / SAMPLE_RATE))
+ chunk = mel[:, start_ts:end_ts]
+ chunk = torch.nn.functional.pad(chunk, (0, max_length-chunk.shape[-1]))
+ chunks.append(chunk)
+
+ mel_chunk = torch.stack(chunks, dim=0).to(model.device)
+ # using 'decode_options1': only support single temperature decoding (no fallbacks)
+ # result_list2 = model.decode(mel_chunk, decode_options1)
+
+ # prepare DecodingOptions
+ temperatures = kwargs.pop("temperature", None)
+ compression_ratio_threshold = kwargs.pop("compression_ratio_threshold", None)
+ logprob_threshold = kwargs.pop("logprob_threshold", None)
+ no_speech_threshold = kwargs.pop("no_speech_threshold", None)
+ condition_on_previous_text = kwargs.pop("condition_on_previous_text", None)
+ initial_prompt = kwargs.pop("initial_prompt", None)
+
+ t = 0 # TODO: does not upport temperature sweeping
+ if t > 0:
+ # disable beam_size and patience when t > 0
+ kwargs.pop("beam_size", None)
+ kwargs.pop("patience", None)
+ else:
+ # disable best_of when t == 0
+ kwargs.pop("best_of", None)
+
+ options = DecodingOptions(**kwargs, temperature=t)
+ mel_chunk_batches = torch.split(mel_chunk, split_size_or_sections=batch_size)
+ decode_result = []
+ for mel_chunk_batch in mel_chunk_batches:
+ decode_result.extend(model.decode(mel_chunk_batch, options))
+
+ ##############################
+ ### END of parallelization ###
+ ##############################
+
+ # post processing: get segments rfom batch-decoded results
+ input_stride = exact_div(
+ N_FRAMES, model.dims.n_audio_ctx
+ ) # mel frames per output token: 2
+ language = kwargs["language"]
+ task = kwargs["task"]
+ tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task)
+ result_segments = post_process_results(
+ decode_result,
+ duration_list,
+ offset_list,
+ input_stride,
+ language,
+ tokenizer,
+ no_speech_threshold=no_speech_threshold,
+ logprob_threshold=logprob_threshold,
+ verbose=verbose)
+
+ # post processing: collect outputs
+ assert len(result_segments) == len(vad_segments)
+ for sdx, (seg_t, result) in enumerate(zip(vad_segments, result_segments)):
+ seg_t["text"] = result["text"]
+ output["segments"].append(
+ {
+ "start": seg_t["start"],
+ "end": seg_t["end"],
+ "language": result["language"],
+ "text": result["text"],
+ "seg-text": [x["text"] for x in result["segments"]],
+ "seg-start": [x["start"] for x in result["segments"]],
+ "seg-end": [x["end"] for x in result["segments"]],
+ }
+ )
+
+ output["language"] = output["segments"][0]["language"]
+
+ return output
+
+
+def post_process_results(
+ result_list,
+ duration_list,
+ offset_list,
+ input_stride,
+ language,
+ tokenizer,
+ no_speech_threshold = None,
+ logprob_threshold = None,
+ verbose: Optional[bool] = None,
+ ):
+
+ seek = 0
+ time_precision = (
+ input_stride * HOP_LENGTH / SAMPLE_RATE
+ ) # time per output token: 0.02 (seconds)
+ all_tokens = []
+ all_segments = []
+ outputs = []
+
+ def add_segment(
+ *, start: float, end: float, text_tokens: torch.Tensor, result: DecodingResult
+ ):
+ text = tokenizer.decode([token for token in text_tokens if token < tokenizer.eot])
+ if len(text.strip()) == 0: # skip empty text output
+ return
+
+ all_segments.append(
+ {
+ "id": len(all_segments),
+ "seek": seek,
+ "start": start,
+ "end": end,
+ "text": text,
+ "tokens": text_tokens.tolist(),
+ "temperature": result.temperature,
+ "avg_logprob": result.avg_logprob,
+ "compression_ratio": result.compression_ratio,
+ "no_speech_prob": result.no_speech_prob,
+ }
+ )
+ if verbose:
+ print(f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}")
+
+ # process the output
+ for result, segment_duration, timestamp_offset in zip(result_list, duration_list, offset_list):
+ all_tokens = []
+ all_segments = []
+
+ # segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE
+ segment_shape = int(segment_duration / (HOP_LENGTH / SAMPLE_RATE))
+ tokens = torch.tensor(result.tokens)
+
+ if no_speech_threshold is not None:
+ # no voice activity check
+ should_skip = result.no_speech_prob > no_speech_threshold
+ if logprob_threshold is not None and result.avg_logprob > logprob_threshold:
+ # don't skip if the logprob is high enough, despite the no_speech_prob
+ should_skip = False
+
+ if should_skip:
+ seek += segment_shape # fast-forward to the next segment boundary
+ continue
+
+ timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
+ consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0].add_(1)
+
+ if len(consecutive) > 0: # if the output contains two consecutive timestamp tokens
+ last_slice = 0
+ for current_slice in consecutive:
+ sliced_tokens = tokens[last_slice:current_slice]
+ start_timestamp_position = (
+ sliced_tokens[0].item() - tokenizer.timestamp_begin
+ )
+ end_timestamp_position = (
+ sliced_tokens[-1].item() - tokenizer.timestamp_begin
+ )
+ add_segment(
+ start=timestamp_offset + start_timestamp_position * time_precision,
+ end=timestamp_offset + end_timestamp_position * time_precision,
+ text_tokens=sliced_tokens[1:-1],
+ result=result,
+ )
+ last_slice = current_slice
+ last_timestamp_position = (
+ tokens[last_slice - 1].item() - tokenizer.timestamp_begin
+ )
+ seek += last_timestamp_position * input_stride
+ all_tokens.extend(tokens[: last_slice + 1].tolist())
+ else:
+ duration = segment_duration
+ timestamps = tokens[timestamp_tokens.nonzero().flatten()]
+ if len(timestamps) > 0 and timestamps[-1].item() != tokenizer.timestamp_begin:
+ # no consecutive timestamps but it has a timestamp; use the last one.
+ # single timestamp at the end means no speech after the last timestamp.
+ last_timestamp_position = timestamps[-1].item() - tokenizer.timestamp_begin
+ duration = last_timestamp_position * time_precision
+
+ add_segment(
+ start=timestamp_offset,
+ end=timestamp_offset + duration,
+ text_tokens=tokens,
+ result=result,
+ )
+
+ seek += segment_shape
+ all_tokens.extend(tokens.tolist())
+
+ outputs.append(dict(text=tokenizer.decode(all_tokens), segments=all_segments, language=language))
+
+ return outputs
+
+
def cli():
from . import available_models
@@ -362,6 +586,7 @@ def cli():
# vad params
parser.add_argument("--vad_filter", action="store_true", help="Whether to first perform VAD filtering to target only transcribe within VAD. Produces more accurate alignment + timestamp, requires more GPU memory & compute.")
parser.add_argument("--vad_input", default=None, type=str)
+ parser.add_argument("--parallel_bs", default=-1, type=int, help="Enable parallel transcribing if > 1")
# diarization params
parser.add_argument("--diarize", action='store_true')
parser.add_argument("--min_speakers", default=None, type=int)
@@ -408,6 +633,7 @@ def cli():
hf_token: str = args.pop("hf_token")
vad_filter: bool = args.pop("vad_filter")
vad_input: bool = args.pop("vad_input")
+ parallel_bs: int = args.pop("parallel_bs")
diarize: bool = args.pop("diarize")
min_speakers: int = args.pop("min_speakers")
@@ -454,8 +680,12 @@ def cli():
for audio_path in args.pop("audio"):
if vad_filter:
- print("Performing VAD...")
- result = transcribe_with_vad(model, audio_path, vad_pipeline, temperature=temperature, **args)
+ if parallel_bs > 1:
+ print("Performing VAD and parallel transcribing ...")
+ result = transcribe_with_vad_parallel(model, audio_path, vad_pipeline, temperature=temperature, batch_size=parallel_bs, **args)
+ else:
+ print("Performing VAD...")
+ result = transcribe_with_vad(model, audio_path, vad_pipeline, temperature=temperature, **args)
else:
print("Performing transcription...")
result = transcribe(model, audio_path, temperature=temperature, **args)