allow custom path for vad model

This commit is contained in:
Max Bain
2023-04-14 15:02:58 +01:00
parent 6a72b61564
commit cf252a8592

View File

@ -16,10 +16,11 @@ from typing import List, Tuple, Optional
VAD_SEGMENTATION_URL = "https://whisperx.s3.eu-west-2.amazonaws.com/model_weights/segmentation/0b5b3216d60a2d32fc086b47ea8c67589aaeb26b7e07fcbe620d6d0b83e209ea/pytorch_model.bin"
def load_vad_model(device, vad_onset, vad_offset, use_auth_token=None):
def load_vad_model(device, vad_onset, vad_offset, use_auth_token=None, model_fp=None):
model_dir = torch.hub._get_torch_home()
os.makedirs(model_dir, exist_ok = True)
model_fp = os.path.join(model_dir, "whisperx-vad-segmentation.bin")
if model_fp is None:
model_fp = os.path.join(model_dir, "whisperx-vad-segmentation.bin")
if os.path.exists(model_fp) and not os.path.isfile(model_fp):
raise RuntimeError(f"{model_fp} exists and is not a regular file")