Skip to content

Commit

Permalink
[API][Python] Lazy import tools module. (#85)
Browse files Browse the repository at this point in the history
Signed-off-by: Duyi-Wang <[email protected]>
  • Loading branch information
Duyi-Wang authored Nov 27, 2023
1 parent 08d8169 commit 475d438
Showing 1 changed file with 51 additions and 5 deletions.
56 changes: 51 additions & 5 deletions src/xfastertransformer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,59 @@
# ============================================================================
import torch
import os
import sys
from types import ModuleType
from typing import Any
from typing import TYPE_CHECKING

torch.classes.load_library(os.path.dirname(os.path.abspath(__file__)) + "/libxfastertransformer_pt.so")

from .automodel import AutoModel

from .tools import LlamaConvert
from .tools import ChatGLMConvert
from .tools import ChatGLM2Convert
from .tools import OPTConvert
from .tools import BaichuanConvert
_import_structure = {"tools": ["LlamaConvert", "ChatGLMConvert", "ChatGLM2Convert", "OPTConvert", "BaichuanConvert"]}

if TYPE_CHECKING:
from .tools import LlamaConvert
from .tools import ChatGLMConvert
from .tools import ChatGLM2Convert
from .tools import OPTConvert
from .tools import BaichuanConvert
else:
# This LazyImportModule is refer to optuna.integration._IntegrationModule
# Source code url https://github.com/optuna/optuna/blob/master/optuna/integration/__init__.py
class _LazyImportModule(ModuleType):
"""
This class applies lazy import under `xfastertransformer` excluding `AutoModel`, where submodules are imported
when they are actually accessed. Otherwise, `import xfastertransformer` will import some unnecessary dependencise.
"""

__file__ = globals()["__file__"]
__path__ = [os.path.dirname(__file__)]

_modules = set(_import_structure.keys())
_class_to_module = {}
for key, values in _import_structure.items():
for value in values:
_class_to_module[value] = key

def __getattr__(self, name: str) -> Any:
if name in self._modules:
value = self._get_module(name)
elif name in self._class_to_module.keys():
module = self._get_module(self._class_to_module[name])
value = getattr(module, name)
else:
raise AttributeError("module {} has no attribute {}".format(self.__name__, name))

setattr(self, name, value)
return value

def _get_module(self, module_name: str) -> ModuleType:
import importlib

try:
return importlib.import_module("." + module_name, self.__name__)
except ModuleNotFoundError:
raise ModuleNotFoundError(f"Fail to import module {module_name}.")

sys.modules[__name__] = _LazyImportModule(__name__)

0 comments on commit 475d438

Please sign in to comment.