diff --git a/models/pytorch_model.bin b/models/pytorch_model.bin new file mode 100644 index 0000000..75c92f0 Binary files /dev/null and b/models/pytorch_model.bin differ diff --git a/whisperx/vad.py b/whisperx/vad.py index ab2c7bb..7f9aae3 100644 --- a/whisperx/vad.py +++ b/whisperx/vad.py @@ -15,33 +15,29 @@ from tqdm import tqdm from .diarize import Segment as SegmentX +# deprecated VAD_SEGMENTATION_URL = "https://whisperx.s3.eu-west-2.amazonaws.com/model_weights/segmentation/0b5b3216d60a2d32fc086b47ea8c67589aaeb26b7e07fcbe620d6d0b83e209ea/pytorch_model.bin" def load_vad_model(device, vad_onset=0.500, vad_offset=0.363, use_auth_token=None, model_fp=None): model_dir = torch.hub._get_torch_home() + + vad_dir = os.path.dirname(os.path.abspath(__file__)) + os.makedirs(model_dir, exist_ok = True) if model_fp is None: - model_fp = os.path.join(model_dir, "whisperx-vad-segmentation.bin") + # Dynamically resolve the path to the model file + model_fp = os.path.join(vad_dir, "..", "models", "pytorch_model.bin") + model_fp = os.path.abspath(model_fp) # Ensure the path is absolute + else: + model_fp = os.path.abspath(model_fp) # Ensure any provided path is absolute + + # Check if the resolved model file exists + if not os.path.exists(model_fp): + raise FileNotFoundError(f"Model file not found at {model_fp}") + if os.path.exists(model_fp) and not os.path.isfile(model_fp): raise RuntimeError(f"{model_fp} exists and is not a regular file") - if not os.path.isfile(model_fp): - with urllib.request.urlopen(VAD_SEGMENTATION_URL) as source, open(model_fp, "wb") as output: - with tqdm( - total=int(source.info().get("Content-Length")), - ncols=80, - unit="iB", - unit_scale=True, - unit_divisor=1024, - ) as loop: - while True: - buffer = source.read(8192) - if not buffer: - break - - output.write(buffer) - loop.update(len(buffer)) - model_bytes = open(model_fp, "rb").read() if hashlib.sha256(model_bytes).hexdigest() != VAD_SEGMENTATION_URL.split('/')[-2]: raise RuntimeError(