cut up vad segments when too long to prevent OOM

This commit is contained in:
Max Bain
2023-01-28 00:01:39 +00:00
parent 69673eb39b
commit c6dbac76c8

View File

@ -12,6 +12,7 @@ from .decoding import DecodingOptions, DecodingResult
from .diarize import assign_word_speakers, Segment
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
from .utils import exact_div, format_timestamp, optional_int, optional_float, str2bool, interpolate_nans, write_txt, write_vtt, write_srt, write_ass, write_tsv
from .vad import Binarize
import pandas as pd
if TYPE_CHECKING:
@ -266,7 +267,16 @@ def merge_chunks(segments, chunk_size=CHUNK_LENGTH):
merged_segments = []
seg_idxs = []
speaker_idxs = []
for sdx, seg in enumerate(segments):
max_duration = chunk_size // 2
assert max_duration > 0
binarize = Binarize(max_duration=chunk_size//2)
segments = binarize(segments)
segments_list = []
for speech_turn in segments.get_timeline():
segments_list.append(Segment(speech_turn.start, speech_turn.end, "UNKNOWN"))
for sdx, seg in enumerate(segments_list):
if seg.end - curr_start > chunk_size and curr_end-curr_start > 0:
merged_segments.append({
"start": curr_start,
@ -306,12 +316,9 @@ def transcribe_with_vad(
prev = 0
output = {"segments": []}
vad_segments_list = []
vad_segments = vad_pipeline(audio)
for speech_turn in vad_segments.get_timeline().support():
vad_segments_list.append(Segment(speech_turn.start, speech_turn.end, "UNKNOWN"))
# merge segments to approx 30s inputs to make whisper most appropraite
vad_segments = merge_chunks(vad_segments_list)
vad_segments = merge_chunks(vad_segments)
for sdx, seg_t in enumerate(vad_segments):
if verbose:
@ -411,9 +418,10 @@ def cli():
if vad_input is not None:
vad_input = pd.read_csv(vad_input, header=None, sep= " ")
elif vad_filter:
from pyannote.audio import Pipeline
vad_pipeline = Pipeline.from_pretrained("pyannote/voice-activity-detection",
use_auth_token=hf_token)
from pyannote.audio import Inference
vad_pipeline = Inference("pyannote/segmentation",
pre_aggregation_hook=lambda segmentation: segmentation,
use_auth_token=hf_token)
diarize_pipeline = None
if diarize: