mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
cut up vad segments when too long to prevent OOM
This commit is contained in:
@ -12,6 +12,7 @@ from .decoding import DecodingOptions, DecodingResult
|
|||||||
from .diarize import assign_word_speakers, Segment
|
from .diarize import assign_word_speakers, Segment
|
||||||
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
|
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
|
||||||
from .utils import exact_div, format_timestamp, optional_int, optional_float, str2bool, interpolate_nans, write_txt, write_vtt, write_srt, write_ass, write_tsv
|
from .utils import exact_div, format_timestamp, optional_int, optional_float, str2bool, interpolate_nans, write_txt, write_vtt, write_srt, write_ass, write_tsv
|
||||||
|
from .vad import Binarize
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -266,7 +267,16 @@ def merge_chunks(segments, chunk_size=CHUNK_LENGTH):
|
|||||||
merged_segments = []
|
merged_segments = []
|
||||||
seg_idxs = []
|
seg_idxs = []
|
||||||
speaker_idxs = []
|
speaker_idxs = []
|
||||||
for sdx, seg in enumerate(segments):
|
|
||||||
|
max_duration = chunk_size // 2
|
||||||
|
assert max_duration > 0
|
||||||
|
binarize = Binarize(max_duration=chunk_size//2)
|
||||||
|
segments = binarize(segments)
|
||||||
|
segments_list = []
|
||||||
|
for speech_turn in segments.get_timeline():
|
||||||
|
segments_list.append(Segment(speech_turn.start, speech_turn.end, "UNKNOWN"))
|
||||||
|
|
||||||
|
for sdx, seg in enumerate(segments_list):
|
||||||
if seg.end - curr_start > chunk_size and curr_end-curr_start > 0:
|
if seg.end - curr_start > chunk_size and curr_end-curr_start > 0:
|
||||||
merged_segments.append({
|
merged_segments.append({
|
||||||
"start": curr_start,
|
"start": curr_start,
|
||||||
@ -306,12 +316,9 @@ def transcribe_with_vad(
|
|||||||
prev = 0
|
prev = 0
|
||||||
output = {"segments": []}
|
output = {"segments": []}
|
||||||
|
|
||||||
vad_segments_list = []
|
|
||||||
vad_segments = vad_pipeline(audio)
|
vad_segments = vad_pipeline(audio)
|
||||||
for speech_turn in vad_segments.get_timeline().support():
|
|
||||||
vad_segments_list.append(Segment(speech_turn.start, speech_turn.end, "UNKNOWN"))
|
|
||||||
# merge segments to approx 30s inputs to make whisper most appropraite
|
# merge segments to approx 30s inputs to make whisper most appropraite
|
||||||
vad_segments = merge_chunks(vad_segments_list)
|
vad_segments = merge_chunks(vad_segments)
|
||||||
|
|
||||||
for sdx, seg_t in enumerate(vad_segments):
|
for sdx, seg_t in enumerate(vad_segments):
|
||||||
if verbose:
|
if verbose:
|
||||||
@ -411,9 +418,10 @@ def cli():
|
|||||||
if vad_input is not None:
|
if vad_input is not None:
|
||||||
vad_input = pd.read_csv(vad_input, header=None, sep= " ")
|
vad_input = pd.read_csv(vad_input, header=None, sep= " ")
|
||||||
elif vad_filter:
|
elif vad_filter:
|
||||||
from pyannote.audio import Pipeline
|
from pyannote.audio import Inference
|
||||||
vad_pipeline = Pipeline.from_pretrained("pyannote/voice-activity-detection",
|
vad_pipeline = Inference("pyannote/segmentation",
|
||||||
use_auth_token=hf_token)
|
pre_aggregation_hook=lambda segmentation: segmentation,
|
||||||
|
use_auth_token=hf_token)
|
||||||
|
|
||||||
diarize_pipeline = None
|
diarize_pipeline = None
|
||||||
if diarize:
|
if diarize:
|
||||||
|
Reference in New Issue
Block a user