Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

generalize helper for loading module from source #2862

Merged
merged 3 commits into from
Mar 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 4 additions & 12 deletions api/core/extension/extensible.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import enum
import importlib.util
import json
import logging
import os
from typing import Any, Optional

from pydantic import BaseModel

from core.utils.module_import_helper import load_single_subclass_from_source
from core.utils.position_helper import sort_to_dict_by_position_map


Expand Down Expand Up @@ -73,17 +73,9 @@ def scan_extensions(cls):

# Dynamic loading {subdir_name}.py file and find the subclass of Extensible
py_path = os.path.join(subdir_path, extension_name + '.py')
spec = importlib.util.spec_from_file_location(extension_name, py_path)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)

extension_class = None
for name, obj in vars(mod).items():
if isinstance(obj, type) and issubclass(obj, cls) and obj != cls:
extension_class = obj
break

if not extension_class:
try:
extension_class = load_single_subclass_from_source(extension_name, py_path, cls)
except Exception:
logging.warning(f"Missing subclass of {cls.__name__} in {py_path}, Skip.")
continue

Expand Down
17 changes: 5 additions & 12 deletions api/core/model_runtime/model_providers/__base/model_provider.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import importlib
import os
from abc import ABC, abstractmethod

Expand All @@ -7,6 +6,7 @@
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
from core.model_runtime.entities.provider_entities import ProviderEntity
from core.model_runtime.model_providers.__base.ai_model import AIModel
from core.utils.module_import_helper import get_subclasses_from_module, import_module_from_source


class ModelProvider(ABC):
Expand Down Expand Up @@ -104,17 +104,10 @@ def get_model_instance(self, model_type: ModelType) -> AIModel:

# Dynamic loading {model_type_name}.py file and find the subclass of AIModel
parent_module = '.'.join(self.__class__.__module__.split('.')[:-1])
spec = importlib.util.spec_from_file_location(f"{parent_module}.{model_type_name}.{model_type_name}", model_type_py_path)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)

model_class = None
for name, obj in vars(mod).items():
if (isinstance(obj, type) and issubclass(obj, AIModel) and not obj.__abstractmethods__
and obj != AIModel and obj.__module__ == mod.__name__):
model_class = obj
break

mod = import_module_from_source(
f'{parent_module}.{model_type_name}.{model_type_name}', model_type_py_path)
model_class = next(filter(lambda x: x.__module__ == mod.__name__ and not x.__abstractmethods__,
get_subclasses_from_module(mod, AIModel)), None)
if not model_class:
raise Exception(f'Missing AIModel Class for model type {model_type} in {model_type_py_path}')

Expand Down
15 changes: 5 additions & 10 deletions api/core/model_runtime/model_providers/model_provider_factory.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import importlib
import logging
import os
from typing import Optional
Expand All @@ -10,6 +9,7 @@
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
from core.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator
from core.model_runtime.schema_validators.provider_credential_schema_validator import ProviderCredentialSchemaValidator
from core.utils.module_import_helper import load_single_subclass_from_source
from core.utils.position_helper import get_position_map, sort_to_dict_by_position_map

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -229,15 +229,10 @@ def _get_model_provider_map(self) -> dict[str, ModelProviderExtension]:

# Dynamic loading {model_provider_name}.py file and find the subclass of ModelProvider
py_path = os.path.join(model_provider_dir_path, model_provider_name + '.py')
spec = importlib.util.spec_from_file_location(f'core.model_runtime.model_providers.{model_provider_name}.{model_provider_name}', py_path)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)

model_provider_class = None
for name, obj in vars(mod).items():
if isinstance(obj, type) and issubclass(obj, ModelProvider) and obj != ModelProvider:
model_provider_class = obj
break
model_provider_class = load_single_subclass_from_source(
module_name=f'core.model_runtime.model_providers.{model_provider_name}.{model_provider_name}',
script_path=py_path,
parent_type=ModelProvider)

if not model_provider_class:
logger.warning(f"Missing Model Provider Class that extends ModelProvider in {py_path}, Skip.")
Expand Down
17 changes: 6 additions & 11 deletions api/core/tools/provider/builtin_tool_provider.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import importlib
from abc import abstractmethod
from os import listdir, path
from typing import Any
Expand All @@ -16,6 +15,7 @@
from core.tools.provider.tool_provider import ToolProviderController
from core.tools.tool.builtin_tool import BuiltinTool
from core.tools.tool.tool import Tool
from core.utils.module_import_helper import load_single_subclass_from_source


class BuiltinToolProviderController(ToolProviderController):
Expand Down Expand Up @@ -63,16 +63,11 @@ def _get_builtin_tools(self) -> list[Tool]:
tool_name = tool_file.split(".")[0]
tool = load(f.read(), FullLoader)
# get tool class, import the module
py_path = path.join(path.dirname(path.realpath(__file__)), 'builtin', provider, 'tools', f'{tool_name}.py')
spec = importlib.util.spec_from_file_location(f'core.tools.provider.builtin.{provider}.tools.{tool_name}', py_path)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)

# get all the classes in the module
classes = [x for _, x in vars(mod).items()
if isinstance(x, type) and x not in [BuiltinTool, Tool] and issubclass(x, BuiltinTool)
]
assistant_tool_class = classes[0]
assistant_tool_class = load_single_subclass_from_source(
module_name=f'core.tools.provider.builtin.{provider}.tools.{tool_name}',
script_path=path.join(path.dirname(path.realpath(__file__)),
'builtin', provider, 'tools', f'{tool_name}.py'),
parent_type=BuiltinTool)
tools.append(assistant_tool_class(**tool))

self.tools = tools
Expand Down
43 changes: 11 additions & 32 deletions api/core/tools/tool_manager.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import importlib
import json
import logging
import mimetypes
Expand Down Expand Up @@ -34,6 +33,7 @@
ToolParameterConfigurationManager,
)
from core.tools.utils.encoder import serialize_base_model_dict
from core.utils.module_import_helper import load_single_subclass_from_source
from extensions.ext_database import db
from models.tools import ApiToolProvider, BuiltinToolProvider

Expand Down Expand Up @@ -72,21 +72,11 @@ def invoke(

if provider_entity is None:
# fetch the provider from .provider.builtin
py_path = path.join(path.dirname(path.realpath(__file__)), 'builtin', provider, f'{provider}.py')
spec = importlib.util.spec_from_file_location(f'core.tools.provider.builtin.{provider}.{provider}', py_path)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)

# get all the classes in the module
classes = [ x for _, x in vars(mod).items()
if isinstance(x, type) and x != ToolProviderController and issubclass(x, ToolProviderController)
]
if len(classes) == 0:
raise ToolProviderNotFoundError(f'provider {provider} not found')
if len(classes) > 1:
raise ToolProviderNotFoundError(f'multiple providers found for {provider}')

provider_entity = classes[0]()
provider_class = load_single_subclass_from_source(
module_name=f'core.tools.provider.builtin.{provider}.{provider}',
script_path=path.join(path.dirname(path.realpath(__file__)), 'builtin', provider, f'{provider}.py'),
parent_type=ToolProviderController)
provider_entity = provider_class()

return provider_entity.invoke(tool_id, tool_name, tool_parameters, credentials, prompt_messages)

Expand Down Expand Up @@ -330,23 +320,12 @@ def list_builtin_providers() -> list[BuiltinToolProviderController]:
if provider.startswith('__'):
continue

py_path = path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin', provider, f'{provider}.py')
spec = importlib.util.spec_from_file_location(f'core.tools.provider.builtin.{provider}.{provider}', py_path)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)

# load all classes
classes = [
obj for name, obj in vars(mod).items()
if isinstance(obj, type) and obj != BuiltinToolProviderController and issubclass(obj, BuiltinToolProviderController)
]
if len(classes) == 0:
raise ToolProviderNotFoundError(f'provider {provider} not found')
if len(classes) > 1:
raise ToolProviderNotFoundError(f'multiple providers found for {provider}')

# init provider
provider_class = classes[0]
provider_class = load_single_subclass_from_source(
module_name=f'core.tools.provider.builtin.{provider}.{provider}',
script_path=path.join(path.dirname(path.realpath(__file__)),
'provider', 'builtin', provider, f'{provider}.py'),
parent_type=BuiltinToolProviderController)
builtin_providers.append(provider_class())

# cache the builtin providers
Expand Down
62 changes: 62 additions & 0 deletions api/core/utils/module_import_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import importlib.util
import logging
import sys
from types import ModuleType
from typing import AnyStr


def import_module_from_source(
module_name: str,
py_file_path: AnyStr,
use_lazy_loader: bool = False
) -> ModuleType:
"""
Importing a module from the source file directly
"""
try:
existed_spec = importlib.util.find_spec(module_name)
if existed_spec:
spec = existed_spec
else:
# Refer to: https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly
spec = importlib.util.spec_from_file_location(module_name, py_file_path)
if use_lazy_loader:
# Refer to: https://docs.python.org/3/library/importlib.html#implementing-lazy-imports
spec.loader = importlib.util.LazyLoader(spec.loader)
module = importlib.util.module_from_spec(spec)
takatost marked this conversation as resolved.
Show resolved Hide resolved
if not existed_spec:
sys.modules[module_name] = module
spec.loader.exec_module(module)
return module
except Exception as e:
logging.exception(f'Failed to load module {module_name} from {py_file_path}: {str(e)}')
raise e


def get_subclasses_from_module(mod: ModuleType, parent_type: type) -> list[type]:
"""
Get all the subclasses of the parent type from the module
"""
classes = [x for _, x in vars(mod).items()
if isinstance(x, type) and x != parent_type and issubclass(x, parent_type)]
return classes


def load_single_subclass_from_source(
module_name: str,
script_path: AnyStr,
parent_type: type,
use_lazy_loader: bool = False,
) -> type:
"""
Load a single subclass from the source
"""
module = import_module_from_source(module_name, script_path, use_lazy_loader)
subclasses = get_subclasses_from_module(module, parent_type)
match len(subclasses):
case 1:
return subclasses[0]
case 0:
raise Exception(f'Missing subclass of {parent_type.__name__} in {script_path}')
case _:
raise Exception(f'Multiple subclasses of {parent_type.__name__} in {script_path}')
7 changes: 7 additions & 0 deletions api/tests/integration_tests/utils/child_class.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from tests.integration_tests.utils.parent_class import ParentClass


class ChildClass(ParentClass):
def __init__(self, name: str):
super().__init__(name)
self.name = name
7 changes: 7 additions & 0 deletions api/tests/integration_tests/utils/lazy_load_class.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from tests.integration_tests.utils.parent_class import ParentClass


class LazyLoadChildClass(ParentClass):
def __init__(self, name: str):
super().__init__(name)
self.name = name
6 changes: 6 additions & 0 deletions api/tests/integration_tests/utils/parent_class.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
class ParentClass:
def __init__(self, name):
self.name = name

def get_name(self):
return self.name
32 changes: 32 additions & 0 deletions api/tests/integration_tests/utils/test_module_import_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import os

from core.utils.module_import_helper import load_single_subclass_from_source, import_module_from_source
from tests.integration_tests.utils.parent_class import ParentClass


def test_loading_subclass_from_source():
current_path = os.getcwd()
module = load_single_subclass_from_source(
module_name='ChildClass',
script_path=os.path.join(current_path, 'child_class.py'),
parent_type=ParentClass)
assert module and module.__name__ == 'ChildClass'


def test_load_import_module_from_source():
current_path = os.getcwd()
module = import_module_from_source(
module_name='ChildClass',
py_file_path=os.path.join(current_path, 'child_class.py'))
assert module and module.__name__ == 'ChildClass'


def test_lazy_loading_subclass_from_source():
current_path = os.getcwd()
clz = load_single_subclass_from_source(
module_name='LazyLoadChildClass',
script_path=os.path.join(current_path, 'lazy_load_class.py'),
parent_type=ParentClass,
use_lazy_loader=True)
instance = clz('dify')
assert instance.get_name() == 'dify'