From ae4a9de3072593868d34de1c44d5f8adf82d0493 Mon Sep 17 00:00:00 2001 From: Max Bain Date: Thu, 30 Mar 2023 18:57:55 +0100 Subject: [PATCH] add vad model external dl --- whisperx/asr.py | 3 +-- whisperx/transcribe.py | 24 +++++++++++++++++------- whisperx/vad.py | 38 +++++++++++++++++++++++++++++++++++--- 3 files changed, 53 insertions(+), 12 deletions(-) diff --git a/whisperx/asr.py b/whisperx/asr.py index 661a28e..27213ef 100644 --- a/whisperx/asr.py +++ b/whisperx/asr.py @@ -1,10 +1,9 @@ import warnings from typing import TYPE_CHECKING, Optional, Tuple, Union -import tempfile import numpy as np import torch import tqdm - +import ffmpeg from whisper.audio import ( FRAMES_PER_SECOND, HOP_LENGTH, diff --git a/whisperx/transcribe.py b/whisperx/transcribe.py index 4461dfb..2c4da32 100644 --- a/whisperx/transcribe.py +++ b/whisperx/transcribe.py @@ -2,22 +2,22 @@ import argparse import os import warnings from typing import TYPE_CHECKING, Optional, Tuple, Union - import numpy as np import torch - +import tempfile +import ffmpeg from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE +from whisper.audio import SAMPLE_RATE from whisper.utils import ( optional_float, optional_int, str2bool, ) -from .utils import get_writer - -from .asr import transcribe, transcribe_with_vad from .alignment import load_align_model, align +from .asr import transcribe, transcribe_with_vad from .diarize import DiarizationPipeline +from .utils import get_writer from .vad import load_vad_model def cli(): @@ -74,7 +74,7 @@ def cli(): parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS") parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face Access Token to access PyAnnote gated models") - parser.add_argument("--model_flush", action="store_true", help="Flush memory of each stage after use, more GPU memory efficient, but slower when there are multiple audio files") + parser.add_argument("--model_flush", action="store_true", help="Flush memory from each model after use, reduces GPU requirement but slower processing >1 audio file.") # fmt: on args = parser.parse_args().__dict__ @@ -148,8 +148,18 @@ def cli(): for audio_path in args.pop("audio"): if vad_model is not None: + if not audio_path.endswith(".wav"): + print("VAD requires .wav format, converting to wav as a tempfile...") + tfile = tempfile.NamedTemporaryFile(delete=True, suffix=".wav") + ffmpeg.input(audio_path, threads=0).output(tfile.name, ac=1, ar=SAMPLE_RATE).run(cmd=["ffmpeg"]) + vad_audio_path = tfile.name + else: + vad_audio_path = audio_path print("Performing VAD...") - result = transcribe_with_vad(model, audio_path, vad_model, temperature=temperature, **args) + result = transcribe_with_vad(model, vad_audio_path, vad_model, temperature=temperature, **args) + + if tfile is not None: + tfile.close() else: print("Performing transcription...") result = transcribe(model, audio_path, temperature=temperature, **args) diff --git a/whisperx/vad.py b/whisperx/vad.py index 69ad8ee..fc291dd 100644 --- a/whisperx/vad.py +++ b/whisperx/vad.py @@ -1,19 +1,51 @@ import os +import urllib import pandas as pd import numpy as np import torch +import hashlib +from tqdm import tqdm from typing import Optional, Callable, Union, Text from pyannote.audio.core.io import AudioFile from pyannote.core import Annotation, Segment, SlidingWindowFeature from pyannote.audio.pipelines.utils import PipelineModel -from pyannote.audio import Model, Pipeline +from pyannote.audio import Model from pyannote.audio.pipelines import VoiceActivityDetection from .diarize import Segment as SegmentX - from typing import List, Tuple, Optional +VAD_SEGMENTATION_URL = "https://whisperx.s3.eu-west-2.amazonaws.com/model_weights/segmentation/0b5b3216d60a2d32fc086b47ea8c67589aaeb26b7e07fcbe620d6d0b83e209ea/pytorch_model.bin" + def load_vad_model(device, vad_onset, vad_offset, use_auth_token=None): - vad_model = Model.from_pretrained("pyannote/segmentation", use_auth_token=use_auth_token) + model_dir = torch.hub._get_torch_home() + model_fp = os.path.join(model_dir, "whisperx-vad-segmentation.bin") + if os.path.exists(model_fp) and not os.path.isfile(model_fp): + raise RuntimeError(f"{model_fp} exists and is not a regular file") + + if not os.path.isfile(model_fp): + with urllib.request.urlopen(VAD_SEGMENTATION_URL) as source, open(model_fp, "wb") as output: + with tqdm( + total=int(source.info().get("Content-Length")), + ncols=80, + unit="iB", + unit_scale=True, + unit_divisor=1024, + ) as loop: + while True: + buffer = source.read(8192) + if not buffer: + break + + output.write(buffer) + loop.update(len(buffer)) + + model_bytes = open(model_fp, "rb").read() + if hashlib.sha256(model_bytes).hexdigest() != VAD_SEGMENTATION_URL.split('/')[-2]: + raise RuntimeError( + "Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model." + ) + + vad_model = Model.from_pretrained(model_fp, use_auth_token=use_auth_token) hyperparameters = {"onset": vad_onset, "offset": vad_offset, "min_duration_on": 0.1,