diff --git a/whisperx/__init__.py b/whisperx/__init__.py index 8ab8a26..6356fdf 100644 --- a/whisperx/__init__.py +++ b/whisperx/__init__.py @@ -1,7 +1,40 @@ -from whisperx.alignment import load_align_model as load_align_model, align as align -from whisperx.asr import load_model as load_model -from whisperx.audio import load_audio as load_audio -from whisperx.diarize import ( - assign_word_speakers as assign_word_speakers, - DiarizationPipeline as DiarizationPipeline, -) +import importlib + + +def _lazy_import(name): + module = importlib.import_module(f"whisperx.{name}") + return module + + +def load_align_model(*args, **kwargs): + alignment = _lazy_import("alignment") + return alignment.load_align_model(*args, **kwargs) + + +def align(*args, **kwargs): + alignment = _lazy_import("alignment") + return alignment.align(*args, **kwargs) + + +def load_model(*args, **kwargs): + asr = _lazy_import("asr") + return asr.load_model(*args, **kwargs) + + +def load_audio(*args, **kwargs): + audio = _lazy_import("audio") + return audio.load_audio(*args, **kwargs) + + +def assign_word_speakers(*args, **kwargs): + diarize = _lazy_import("diarize") + return diarize.assign_word_speakers(*args, **kwargs) + + +class DiarizationPipeline: + def __init__(self, *args, **kwargs): + diarize = _lazy_import("diarize") + self._pipeline = diarize.DiarizationPipeline(*args, **kwargs) + + def __getattr__(self, name): + return getattr(self._pipeline, name)