add vad model external dl

This commit is contained in:
Max Bain
2023-03-30 18:57:55 +01:00
parent 18b63d46e2
commit ae4a9de307
3 changed files with 53 additions and 12 deletions

View File

@ -1,19 +1,51 @@
import os
import urllib
import pandas as pd
import numpy as np
import torch
import hashlib
from tqdm import tqdm
from typing import Optional, Callable, Union, Text
from pyannote.audio.core.io import AudioFile
from pyannote.core import Annotation, Segment, SlidingWindowFeature
from pyannote.audio.pipelines.utils import PipelineModel
from pyannote.audio import Model, Pipeline
from pyannote.audio import Model
from pyannote.audio.pipelines import VoiceActivityDetection
from .diarize import Segment as SegmentX
from typing import List, Tuple, Optional
VAD_SEGMENTATION_URL = "https://whisperx.s3.eu-west-2.amazonaws.com/model_weights/segmentation/0b5b3216d60a2d32fc086b47ea8c67589aaeb26b7e07fcbe620d6d0b83e209ea/pytorch_model.bin"
def load_vad_model(device, vad_onset, vad_offset, use_auth_token=None):
vad_model = Model.from_pretrained("pyannote/segmentation", use_auth_token=use_auth_token)
model_dir = torch.hub._get_torch_home()
model_fp = os.path.join(model_dir, "whisperx-vad-segmentation.bin")
if os.path.exists(model_fp) and not os.path.isfile(model_fp):
raise RuntimeError(f"{model_fp} exists and is not a regular file")
if not os.path.isfile(model_fp):
with urllib.request.urlopen(VAD_SEGMENTATION_URL) as source, open(model_fp, "wb") as output:
with tqdm(
total=int(source.info().get("Content-Length")),
ncols=80,
unit="iB",
unit_scale=True,
unit_divisor=1024,
) as loop:
while True:
buffer = source.read(8192)
if not buffer:
break
output.write(buffer)
loop.update(len(buffer))
model_bytes = open(model_fp, "rb").read()
if hashlib.sha256(model_bytes).hexdigest() != VAD_SEGMENTATION_URL.split('/')[-2]:
raise RuntimeError(
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
)
vad_model = Model.from_pretrained(model_fp, use_auth_token=use_auth_token)
hyperparameters = {"onset": vad_onset,
"offset": vad_offset,
"min_duration_on": 0.1,