mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
add vad model external dl
This commit is contained in:
@ -1,19 +1,51 @@
|
||||
import os
|
||||
import urllib
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import torch
|
||||
import hashlib
|
||||
from tqdm import tqdm
|
||||
from typing import Optional, Callable, Union, Text
|
||||
from pyannote.audio.core.io import AudioFile
|
||||
from pyannote.core import Annotation, Segment, SlidingWindowFeature
|
||||
from pyannote.audio.pipelines.utils import PipelineModel
|
||||
from pyannote.audio import Model, Pipeline
|
||||
from pyannote.audio import Model
|
||||
from pyannote.audio.pipelines import VoiceActivityDetection
|
||||
from .diarize import Segment as SegmentX
|
||||
|
||||
from typing import List, Tuple, Optional
|
||||
|
||||
VAD_SEGMENTATION_URL = "https://whisperx.s3.eu-west-2.amazonaws.com/model_weights/segmentation/0b5b3216d60a2d32fc086b47ea8c67589aaeb26b7e07fcbe620d6d0b83e209ea/pytorch_model.bin"
|
||||
|
||||
def load_vad_model(device, vad_onset, vad_offset, use_auth_token=None):
|
||||
vad_model = Model.from_pretrained("pyannote/segmentation", use_auth_token=use_auth_token)
|
||||
model_dir = torch.hub._get_torch_home()
|
||||
model_fp = os.path.join(model_dir, "whisperx-vad-segmentation.bin")
|
||||
if os.path.exists(model_fp) and not os.path.isfile(model_fp):
|
||||
raise RuntimeError(f"{model_fp} exists and is not a regular file")
|
||||
|
||||
if not os.path.isfile(model_fp):
|
||||
with urllib.request.urlopen(VAD_SEGMENTATION_URL) as source, open(model_fp, "wb") as output:
|
||||
with tqdm(
|
||||
total=int(source.info().get("Content-Length")),
|
||||
ncols=80,
|
||||
unit="iB",
|
||||
unit_scale=True,
|
||||
unit_divisor=1024,
|
||||
) as loop:
|
||||
while True:
|
||||
buffer = source.read(8192)
|
||||
if not buffer:
|
||||
break
|
||||
|
||||
output.write(buffer)
|
||||
loop.update(len(buffer))
|
||||
|
||||
model_bytes = open(model_fp, "rb").read()
|
||||
if hashlib.sha256(model_bytes).hexdigest() != VAD_SEGMENTATION_URL.split('/')[-2]:
|
||||
raise RuntimeError(
|
||||
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
|
||||
)
|
||||
|
||||
vad_model = Model.from_pretrained(model_fp, use_auth_token=use_auth_token)
|
||||
hyperparameters = {"onset": vad_onset,
|
||||
"offset": vad_offset,
|
||||
"min_duration_on": 0.1,
|
||||
|
Reference in New Issue
Block a user