diff --git a/litgpt/__init__.py b/litgpt/__init__.py index 6c8dc3d820..af29843106 100644 --- a/litgpt/__init__.py +++ b/litgpt/__init__.py @@ -1,13 +1,59 @@ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. +import importlib import logging import re +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from litgpt.api import LLM + from litgpt.config import Config + from litgpt.model import GPT + from litgpt.prompts import PromptStyle + from litgpt.tokenizer import Tokenizer + + +_LAZY_IMPORTS = { + "LLM": "litgpt.api", + "Config": "litgpt.config", + "GPT": "litgpt.model", + "PromptStyle": "litgpt.prompts", + "Tokenizer": "litgpt.tokenizer", + "api": "litgpt.api", + "chat": "litgpt.chat", + "config": "litgpt.config", + "generate": "litgpt.generate", + "lora": "litgpt.lora", + "model": "litgpt.model", + "prompts": "litgpt.prompts", + "scripts": "litgpt.scripts", + "tokenizer": "litgpt.tokenizer", + "utils": "litgpt.utils", +} + + +def __getattr__(name): + """ + Allow for lazy imports of all litgpt submodules, + as well as some selected top-level attributes. + """ + if name in _LAZY_IMPORTS: + module_name = _LAZY_IMPORTS[name] + module = importlib.import_module(module_name) + if not module_name.endswith(name): + return getattr(module, name) + return module + + # If the attribute is not found, raise an AttributeError + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +def __dir__(): + """ + Return a list of all attributes in the litgpt module. + """ + return sorted(list(_LAZY_IMPORTS.keys()) + list(globals().keys())) -from litgpt.api import LLM -from litgpt.config import Config -from litgpt.model import GPT # needs to be imported before config -from litgpt.prompts import PromptStyle -from litgpt.tokenizer import Tokenizer # Suppress excessive warnings, see https://github.com/pytorch/pytorch/issues/111632 pattern = re.compile(".*Profiler function .* will be ignored")