mirror of
https://github.com/m-bain/whisperX.git
synced 2025-07-01 18:17:27 -04:00
BIN
models/pytorch_model.bin
Normal file
BIN
models/pytorch_model.bin
Normal file
Binary file not shown.
@ -15,33 +15,29 @@ from tqdm import tqdm
|
|||||||
|
|
||||||
from .diarize import Segment as SegmentX
|
from .diarize import Segment as SegmentX
|
||||||
|
|
||||||
|
# deprecated
|
||||||
VAD_SEGMENTATION_URL = "https://whisperx.s3.eu-west-2.amazonaws.com/model_weights/segmentation/0b5b3216d60a2d32fc086b47ea8c67589aaeb26b7e07fcbe620d6d0b83e209ea/pytorch_model.bin"
|
VAD_SEGMENTATION_URL = "https://whisperx.s3.eu-west-2.amazonaws.com/model_weights/segmentation/0b5b3216d60a2d32fc086b47ea8c67589aaeb26b7e07fcbe620d6d0b83e209ea/pytorch_model.bin"
|
||||||
|
|
||||||
def load_vad_model(device, vad_onset=0.500, vad_offset=0.363, use_auth_token=None, model_fp=None):
|
def load_vad_model(device, vad_onset=0.500, vad_offset=0.363, use_auth_token=None, model_fp=None):
|
||||||
model_dir = torch.hub._get_torch_home()
|
model_dir = torch.hub._get_torch_home()
|
||||||
|
|
||||||
|
vad_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
os.makedirs(model_dir, exist_ok = True)
|
os.makedirs(model_dir, exist_ok = True)
|
||||||
if model_fp is None:
|
if model_fp is None:
|
||||||
model_fp = os.path.join(model_dir, "whisperx-vad-segmentation.bin")
|
# Dynamically resolve the path to the model file
|
||||||
|
model_fp = os.path.join(vad_dir, "..", "models", "pytorch_model.bin")
|
||||||
|
model_fp = os.path.abspath(model_fp) # Ensure the path is absolute
|
||||||
|
else:
|
||||||
|
model_fp = os.path.abspath(model_fp) # Ensure any provided path is absolute
|
||||||
|
|
||||||
|
# Check if the resolved model file exists
|
||||||
|
if not os.path.exists(model_fp):
|
||||||
|
raise FileNotFoundError(f"Model file not found at {model_fp}")
|
||||||
|
|
||||||
if os.path.exists(model_fp) and not os.path.isfile(model_fp):
|
if os.path.exists(model_fp) and not os.path.isfile(model_fp):
|
||||||
raise RuntimeError(f"{model_fp} exists and is not a regular file")
|
raise RuntimeError(f"{model_fp} exists and is not a regular file")
|
||||||
|
|
||||||
if not os.path.isfile(model_fp):
|
|
||||||
with urllib.request.urlopen(VAD_SEGMENTATION_URL) as source, open(model_fp, "wb") as output:
|
|
||||||
with tqdm(
|
|
||||||
total=int(source.info().get("Content-Length")),
|
|
||||||
ncols=80,
|
|
||||||
unit="iB",
|
|
||||||
unit_scale=True,
|
|
||||||
unit_divisor=1024,
|
|
||||||
) as loop:
|
|
||||||
while True:
|
|
||||||
buffer = source.read(8192)
|
|
||||||
if not buffer:
|
|
||||||
break
|
|
||||||
|
|
||||||
output.write(buffer)
|
|
||||||
loop.update(len(buffer))
|
|
||||||
|
|
||||||
model_bytes = open(model_fp, "rb").read()
|
model_bytes = open(model_fp, "rb").read()
|
||||||
if hashlib.sha256(model_bytes).hexdigest() != VAD_SEGMENTATION_URL.split('/')[-2]:
|
if hashlib.sha256(model_bytes).hexdigest() != VAD_SEGMENTATION_URL.split('/')[-2]:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
|
Reference in New Issue
Block a user