cut up vad segments when too long to prevent OOM

This commit is contained in:
Max Bain
2023-01-28 00:01:39 +00:00
parent 69673eb39b
commit c6dbac76c8

View File

@ -12,6 +12,7 @@ from .decoding import DecodingOptions, DecodingResult
from .diarize import assign_word_speakers, Segment from .diarize import assign_word_speakers, Segment
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
from .utils import exact_div, format_timestamp, optional_int, optional_float, str2bool, interpolate_nans, write_txt, write_vtt, write_srt, write_ass, write_tsv from .utils import exact_div, format_timestamp, optional_int, optional_float, str2bool, interpolate_nans, write_txt, write_vtt, write_srt, write_ass, write_tsv
from .vad import Binarize
import pandas as pd import pandas as pd
if TYPE_CHECKING: if TYPE_CHECKING:
@ -266,7 +267,16 @@ def merge_chunks(segments, chunk_size=CHUNK_LENGTH):
merged_segments = [] merged_segments = []
seg_idxs = [] seg_idxs = []
speaker_idxs = [] speaker_idxs = []
for sdx, seg in enumerate(segments):
max_duration = chunk_size // 2
assert max_duration > 0
binarize = Binarize(max_duration=chunk_size//2)
segments = binarize(segments)
segments_list = []
for speech_turn in segments.get_timeline():
segments_list.append(Segment(speech_turn.start, speech_turn.end, "UNKNOWN"))
for sdx, seg in enumerate(segments_list):
if seg.end - curr_start > chunk_size and curr_end-curr_start > 0: if seg.end - curr_start > chunk_size and curr_end-curr_start > 0:
merged_segments.append({ merged_segments.append({
"start": curr_start, "start": curr_start,
@ -306,12 +316,9 @@ def transcribe_with_vad(
prev = 0 prev = 0
output = {"segments": []} output = {"segments": []}
vad_segments_list = []
vad_segments = vad_pipeline(audio) vad_segments = vad_pipeline(audio)
for speech_turn in vad_segments.get_timeline().support():
vad_segments_list.append(Segment(speech_turn.start, speech_turn.end, "UNKNOWN"))
# merge segments to approx 30s inputs to make whisper most appropraite # merge segments to approx 30s inputs to make whisper most appropraite
vad_segments = merge_chunks(vad_segments_list) vad_segments = merge_chunks(vad_segments)
for sdx, seg_t in enumerate(vad_segments): for sdx, seg_t in enumerate(vad_segments):
if verbose: if verbose:
@ -411,9 +418,10 @@ def cli():
if vad_input is not None: if vad_input is not None:
vad_input = pd.read_csv(vad_input, header=None, sep= " ") vad_input = pd.read_csv(vad_input, header=None, sep= " ")
elif vad_filter: elif vad_filter:
from pyannote.audio import Pipeline from pyannote.audio import Inference
vad_pipeline = Pipeline.from_pretrained("pyannote/voice-activity-detection", vad_pipeline = Inference("pyannote/segmentation",
use_auth_token=hf_token) pre_aggregation_hook=lambda segmentation: segmentation,
use_auth_token=hf_token)
diarize_pipeline = None diarize_pipeline = None
if diarize: if diarize: