fix new vad paths

This commit is contained in:
Max Bain
2025-01-12 12:50:15 +00:00
parent 7a98456321
commit 6695426a85

View File

@ -23,12 +23,12 @@ VAD_SEGMENTATION_URL = "https://whisperx.s3.eu-west-2.amazonaws.com/model_weight
def load_vad_model(device, vad_onset=0.500, vad_offset=0.363, use_auth_token=None, model_fp=None):
model_dir = torch.hub._get_torch_home()
vad_dir = os.path.dirname(os.path.abspath(__file__))
main_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
os.makedirs(model_dir, exist_ok = True)
if model_fp is None:
# Dynamically resolve the path to the model file
model_fp = os.path.join(vad_dir, "assets", "pytorch_model.bin")
model_fp = os.path.join(main_dir, "assets", "pytorch_model.bin")
model_fp = os.path.abspath(model_fp) # Ensure the path is absolute
else:
model_fp = os.path.abspath(model_fp) # Ensure any provided path is absolute
@ -243,44 +243,7 @@ class Pyannote(Vad):
def __init__(self, device, use_auth_token=None, model_fp=None, **kwargs):
print(">>Performing voice activity detection using Pyannote...")
super().__init__(kwargs['vad_onset'])
model_dir = torch.hub._get_torch_home()
os.makedirs(model_dir, exist_ok=True)
if model_fp is None:
model_fp = os.path.join(model_dir, "whisperx-vad-segmentation.bin")
if os.path.exists(model_fp) and not os.path.isfile(model_fp):
raise RuntimeError(f"{model_fp} exists and is not a regular file")
if not os.path.isfile(model_fp):
with urllib.request.urlopen(VAD_SEGMENTATION_URL) as source, open(model_fp, "wb") as output:
with tqdm(
total=int(source.info().get("Content-Length")),
ncols=80,
unit="iB",
unit_scale=True,
unit_divisor=1024,
) as loop:
while True:
buffer = source.read(8192)
if not buffer:
break
output.write(buffer)
loop.update(len(buffer))
model_bytes = open(model_fp, "rb").read()
if hashlib.sha256(model_bytes).hexdigest() != VAD_SEGMENTATION_URL.split('/')[-2]:
raise RuntimeError(
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
)
vad_model = Model.from_pretrained(model_fp, use_auth_token=use_auth_token)
hyperparameters = {"onset": kwargs['vad_onset'],
"offset": kwargs['vad_offset'],
"min_duration_on": 0.1,
"min_duration_off": 0.1}
self.vad_pipeline = VoiceActivitySegmentation(segmentation=vad_model, device=torch.device(device))
self.vad_pipeline.instantiate(hyperparameters)
self.vad_pipeline = load_vad_model(device, use_auth_token=use_auth_token, model_fp=model_fp)
def __call__(self, audio: AudioFile, **kwargs):
return self.vad_pipeline(audio)