add vad model external dl

This commit is contained in:
Max Bain
2023-03-30 18:57:55 +01:00
parent 18b63d46e2
commit ae4a9de307
3 changed files with 53 additions and 12 deletions

View File

@ -1,10 +1,9 @@
import warnings
from typing import TYPE_CHECKING, Optional, Tuple, Union
import tempfile
import numpy as np
import torch
import tqdm
import ffmpeg
from whisper.audio import (
FRAMES_PER_SECOND,
HOP_LENGTH,

View File

@ -2,22 +2,22 @@ import argparse
import os
import warnings
from typing import TYPE_CHECKING, Optional, Tuple, Union
import numpy as np
import torch
import tempfile
import ffmpeg
from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE
from whisper.audio import SAMPLE_RATE
from whisper.utils import (
optional_float,
optional_int,
str2bool,
)
from .utils import get_writer
from .asr import transcribe, transcribe_with_vad
from .alignment import load_align_model, align
from .asr import transcribe, transcribe_with_vad
from .diarize import DiarizationPipeline
from .utils import get_writer
from .vad import load_vad_model
def cli():
@ -74,7 +74,7 @@ def cli():
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face Access Token to access PyAnnote gated models")
parser.add_argument("--model_flush", action="store_true", help="Flush memory of each stage after use, more GPU memory efficient, but slower when there are multiple audio files")
parser.add_argument("--model_flush", action="store_true", help="Flush memory from each model after use, reduces GPU requirement but slower processing >1 audio file.")
# fmt: on
args = parser.parse_args().__dict__
@ -148,8 +148,18 @@ def cli():
for audio_path in args.pop("audio"):
if vad_model is not None:
if not audio_path.endswith(".wav"):
print("VAD requires .wav format, converting to wav as a tempfile...")
tfile = tempfile.NamedTemporaryFile(delete=True, suffix=".wav")
ffmpeg.input(audio_path, threads=0).output(tfile.name, ac=1, ar=SAMPLE_RATE).run(cmd=["ffmpeg"])
vad_audio_path = tfile.name
else:
vad_audio_path = audio_path
print("Performing VAD...")
result = transcribe_with_vad(model, audio_path, vad_model, temperature=temperature, **args)
result = transcribe_with_vad(model, vad_audio_path, vad_model, temperature=temperature, **args)
if tfile is not None:
tfile.close()
else:
print("Performing transcription...")
result = transcribe(model, audio_path, temperature=temperature, **args)

View File

@ -1,19 +1,51 @@
import os
import urllib
import pandas as pd
import numpy as np
import torch
import hashlib
from tqdm import tqdm
from typing import Optional, Callable, Union, Text
from pyannote.audio.core.io import AudioFile
from pyannote.core import Annotation, Segment, SlidingWindowFeature
from pyannote.audio.pipelines.utils import PipelineModel
from pyannote.audio import Model, Pipeline
from pyannote.audio import Model
from pyannote.audio.pipelines import VoiceActivityDetection
from .diarize import Segment as SegmentX
from typing import List, Tuple, Optional
VAD_SEGMENTATION_URL = "https://whisperx.s3.eu-west-2.amazonaws.com/model_weights/segmentation/0b5b3216d60a2d32fc086b47ea8c67589aaeb26b7e07fcbe620d6d0b83e209ea/pytorch_model.bin"
def load_vad_model(device, vad_onset, vad_offset, use_auth_token=None):
vad_model = Model.from_pretrained("pyannote/segmentation", use_auth_token=use_auth_token)
model_dir = torch.hub._get_torch_home()
model_fp = os.path.join(model_dir, "whisperx-vad-segmentation.bin")
if os.path.exists(model_fp) and not os.path.isfile(model_fp):
raise RuntimeError(f"{model_fp} exists and is not a regular file")
if not os.path.isfile(model_fp):
with urllib.request.urlopen(VAD_SEGMENTATION_URL) as source, open(model_fp, "wb") as output:
with tqdm(
total=int(source.info().get("Content-Length")),
ncols=80,
unit="iB",
unit_scale=True,
unit_divisor=1024,
) as loop:
while True:
buffer = source.read(8192)
if not buffer:
break
output.write(buffer)
loop.update(len(buffer))
model_bytes = open(model_fp, "rb").read()
if hashlib.sha256(model_bytes).hexdigest() != VAD_SEGMENTATION_URL.split('/')[-2]:
raise RuntimeError(
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
)
vad_model = Model.from_pretrained(model_fp, use_auth_token=use_auth_token)
hyperparameters = {"onset": vad_onset,
"offset": vad_offset,
"min_duration_on": 0.1,