From da458863d7ee608385d1ac4de843d7755a6865a5 Mon Sep 17 00:00:00 2001 From: Max Bain Date: Fri, 14 Apr 2023 21:40:36 +0100 Subject: [PATCH] allow custom model_dir for torchaudio models --- whisperx/alignment.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/whisperx/alignment.py b/whisperx/alignment.py index 783a540..c15310b 100644 --- a/whisperx/alignment.py +++ b/whisperx/alignment.py @@ -40,7 +40,7 @@ DEFAULT_ALIGN_MODELS_HF = { } -def load_align_model(language_code, device, model_name=None): +def load_align_model(language_code, device, model_name=None, model_dir=None): if model_name is None: # use default model if language_code in DEFAULT_ALIGN_MODELS_TORCH: @@ -55,7 +55,7 @@ def load_align_model(language_code, device, model_name=None): if model_name in torchaudio.pipelines.__all__: pipeline_type = "torchaudio" bundle = torchaudio.pipelines.__dict__[model_name] - align_model = bundle.get_model().to(device) + align_model = bundle.get_model(dl_kwargs={"model_dir": model_dir}).to(device) labels = bundle.get_labels() align_dictionary = {c.lower(): i for i, c in enumerate(labels)} else: