From c6dbac76c87b0bbc0e507b8759732931db89d9f1 Mon Sep 17 00:00:00 2001 From: Max Bain Date: Sat, 28 Jan 2023 00:01:39 +0000 Subject: [PATCH] cut up vad segments when too long to prevent OOM --- whisperx/transcribe.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/whisperx/transcribe.py b/whisperx/transcribe.py index 4acc9c5..b3cf16b 100644 --- a/whisperx/transcribe.py +++ b/whisperx/transcribe.py @@ -12,6 +12,7 @@ from .decoding import DecodingOptions, DecodingResult from .diarize import assign_word_speakers, Segment from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer from .utils import exact_div, format_timestamp, optional_int, optional_float, str2bool, interpolate_nans, write_txt, write_vtt, write_srt, write_ass, write_tsv +from .vad import Binarize import pandas as pd if TYPE_CHECKING: @@ -266,7 +267,16 @@ def merge_chunks(segments, chunk_size=CHUNK_LENGTH): merged_segments = [] seg_idxs = [] speaker_idxs = [] - for sdx, seg in enumerate(segments): + + max_duration = chunk_size // 2 + assert max_duration > 0 + binarize = Binarize(max_duration=chunk_size//2) + segments = binarize(segments) + segments_list = [] + for speech_turn in segments.get_timeline(): + segments_list.append(Segment(speech_turn.start, speech_turn.end, "UNKNOWN")) + + for sdx, seg in enumerate(segments_list): if seg.end - curr_start > chunk_size and curr_end-curr_start > 0: merged_segments.append({ "start": curr_start, @@ -306,12 +316,9 @@ def transcribe_with_vad( prev = 0 output = {"segments": []} - vad_segments_list = [] vad_segments = vad_pipeline(audio) - for speech_turn in vad_segments.get_timeline().support(): - vad_segments_list.append(Segment(speech_turn.start, speech_turn.end, "UNKNOWN")) # merge segments to approx 30s inputs to make whisper most appropraite - vad_segments = merge_chunks(vad_segments_list) + vad_segments = merge_chunks(vad_segments) for sdx, seg_t in enumerate(vad_segments): if verbose: @@ -411,9 +418,10 @@ def cli(): if vad_input is not None: vad_input = pd.read_csv(vad_input, header=None, sep= " ") elif vad_filter: - from pyannote.audio import Pipeline - vad_pipeline = Pipeline.from_pretrained("pyannote/voice-activity-detection", - use_auth_token=hf_token) + from pyannote.audio import Inference + vad_pipeline = Inference("pyannote/segmentation", + pre_aggregation_hook=lambda segmentation: segmentation, + use_auth_token=hf_token) diarize_pipeline = None if diarize: