From 6695426a85b47da1b6f9054a563002b406c87ca6 Mon Sep 17 00:00:00 2001 From: Max Bain Date: Sun, 12 Jan 2025 12:50:15 +0000 Subject: [PATCH] fix new vad paths --- whisperx/vads/pyannote.py | 43 +++------------------------------------ 1 file changed, 3 insertions(+), 40 deletions(-) diff --git a/whisperx/vads/pyannote.py b/whisperx/vads/pyannote.py index 299f648..68d6a7f 100644 --- a/whisperx/vads/pyannote.py +++ b/whisperx/vads/pyannote.py @@ -23,12 +23,12 @@ VAD_SEGMENTATION_URL = "https://whisperx.s3.eu-west-2.amazonaws.com/model_weight def load_vad_model(device, vad_onset=0.500, vad_offset=0.363, use_auth_token=None, model_fp=None): model_dir = torch.hub._get_torch_home() - vad_dir = os.path.dirname(os.path.abspath(__file__)) + main_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) os.makedirs(model_dir, exist_ok = True) if model_fp is None: # Dynamically resolve the path to the model file - model_fp = os.path.join(vad_dir, "assets", "pytorch_model.bin") + model_fp = os.path.join(main_dir, "assets", "pytorch_model.bin") model_fp = os.path.abspath(model_fp) # Ensure the path is absolute else: model_fp = os.path.abspath(model_fp) # Ensure any provided path is absolute @@ -243,44 +243,7 @@ class Pyannote(Vad): def __init__(self, device, use_auth_token=None, model_fp=None, **kwargs): print(">>Performing voice activity detection using Pyannote...") super().__init__(kwargs['vad_onset']) - - model_dir = torch.hub._get_torch_home() - os.makedirs(model_dir, exist_ok=True) - if model_fp is None: - model_fp = os.path.join(model_dir, "whisperx-vad-segmentation.bin") - if os.path.exists(model_fp) and not os.path.isfile(model_fp): - raise RuntimeError(f"{model_fp} exists and is not a regular file") - - if not os.path.isfile(model_fp): - with urllib.request.urlopen(VAD_SEGMENTATION_URL) as source, open(model_fp, "wb") as output: - with tqdm( - total=int(source.info().get("Content-Length")), - ncols=80, - unit="iB", - unit_scale=True, - unit_divisor=1024, - ) as loop: - while True: - buffer = source.read(8192) - if not buffer: - break - - output.write(buffer) - loop.update(len(buffer)) - - model_bytes = open(model_fp, "rb").read() - if hashlib.sha256(model_bytes).hexdigest() != VAD_SEGMENTATION_URL.split('/')[-2]: - raise RuntimeError( - "Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model." - ) - - vad_model = Model.from_pretrained(model_fp, use_auth_token=use_auth_token) - hyperparameters = {"onset": kwargs['vad_onset'], - "offset": kwargs['vad_offset'], - "min_duration_on": 0.1, - "min_duration_off": 0.1} - self.vad_pipeline = VoiceActivitySegmentation(segmentation=vad_model, device=torch.device(device)) - self.vad_pipeline.instantiate(hyperparameters) + self.vad_pipeline = load_vad_model(device, use_auth_token=use_auth_token, model_fp=model_fp) def __call__(self, audio: AudioFile, **kwargs): return self.vad_pipeline(audio)