diff --git a/sdk/python/feast/errors.py b/sdk/python/feast/errors.py index f6a66bea5a..615069e579 100644 --- a/sdk/python/feast/errors.py +++ b/sdk/python/feast/errors.py @@ -103,14 +103,16 @@ def __init__(self, feature_server_type: str): class FeastModuleImportError(Exception): - def __init__(self, module_name: str, module_type: str): - super().__init__(f"Could not import {module_type} module '{module_name}'") + def __init__(self, module_name: str, class_name: str): + super().__init__( + f"Could not import module '{module_name}' while attempting to load class '{class_name}'" + ) class FeastClassImportError(Exception): - def __init__(self, module_name, class_name, class_type="provider"): + def __init__(self, module_name: str, class_name: str): super().__init__( - f"Could not import {class_type} '{class_name}' from module '{module_name}'" + f"Could not import class '{class_name}' from module '{module_name}'" ) @@ -168,11 +170,10 @@ def __init__(self, online_store_class_name: str): ) -class FeastClassInvalidName(Exception): +class FeastInvalidBaseClass(Exception): def __init__(self, class_name: str, class_type: str): super().__init__( - f"Config Class '{class_name}' " - f"should end with the string `{class_type}`.'" + f"Class '{class_name}' should have `{class_type}` as a base class." ) diff --git a/sdk/python/feast/importer.py b/sdk/python/feast/importer.py index 5dcd7c71c1..bbd592101a 100644 --- a/sdk/python/feast/importer.py +++ b/sdk/python/feast/importer.py @@ -1,28 +1,47 @@ import importlib -from feast import errors +from feast.errors import ( + FeastClassImportError, + FeastInvalidBaseClass, + FeastModuleImportError, +) -def get_class_from_type(module_name: str, class_name: str, class_type: str): - if not class_name.endswith(class_type): - raise errors.FeastClassInvalidName(class_name, class_type) +def import_class(module_name: str, class_name: str, class_type: str = None): + """ + Dynamically loads and returns a class from a module. - # Try importing the module that contains the custom provider + Args: + module_name: The name of the module. + class_name: The name of the class. + class_type: Optional name of a base class of the class. + + Raises: + FeastInvalidBaseClass: If the class name does not end with the specified suffix. + FeastModuleImportError: If the module cannot be imported. + FeastClassImportError: If the class cannot be imported. + """ + # Try importing the module. try: module = importlib.import_module(module_name) except Exception as e: # The original exception can be anything - either module not found, # or any other kind of error happening during the module import time. # So we should include the original error as well in the stack trace. - raise errors.FeastModuleImportError(module_name, class_type) from e + raise FeastModuleImportError(module_name, class_name) from e - # Try getting the provider class definition + # Try getting the class. try: _class = getattr(module, class_name) except AttributeError: # This can only be one type of error, when class_name attribute does not exist in the module # So we don't have to include the original exception here - raise errors.FeastClassImportError( - module_name, class_name, class_type=class_type - ) from None + raise FeastClassImportError(module_name, class_name) from None + + # Check if the class is a subclass of the base class. + if class_type and not any( + base_class.__name__ == class_type for base_class in _class.mro() + ): + raise FeastInvalidBaseClass(class_name, class_type) + return _class diff --git a/sdk/python/feast/infra/infra_object.py b/sdk/python/feast/infra/infra_object.py index f1eda19581..3cd00899fe 100644 --- a/sdk/python/feast/infra/infra_object.py +++ b/sdk/python/feast/infra/infra_object.py @@ -15,7 +15,7 @@ from dataclasses import dataclass, field from typing import Any, List -from feast.importer import get_class_from_type +from feast.importer import import_class from feast.protos.feast.core.InfraObject_pb2 import Infra as InfraProto from feast.protos.feast.core.InfraObject_pb2 import InfraObject as InfraObjectProto @@ -106,4 +106,4 @@ def from_proto(cls, infra_proto: InfraProto): def _get_infra_object_class_from_type(infra_object_class_type: str): module_name, infra_object_class_name = infra_object_class_type.rsplit(".", 1) - return get_class_from_type(module_name, infra_object_class_name, "Object") + return import_class(module_name, infra_object_class_name) diff --git a/sdk/python/feast/infra/offline_stores/offline_utils.py b/sdk/python/feast/infra/offline_stores/offline_utils.py index 6debe14ca0..0b60c3493d 100644 --- a/sdk/python/feast/infra/offline_stores/offline_utils.py +++ b/sdk/python/feast/infra/offline_stores/offline_utils.py @@ -1,4 +1,3 @@ -import importlib import uuid from dataclasses import asdict, dataclass from datetime import datetime, timedelta @@ -12,11 +11,10 @@ import feast from feast.errors import ( EntityTimestampInferenceException, - FeastClassImportError, FeastEntityDFMissingColumnsError, - FeastModuleImportError, ) from feast.feature_view import FeatureView +from feast.importer import import_class from feast.infra.offline_stores.offline_store import OfflineStore from feast.infra.provider import _get_requested_feature_views_to_features_dict from feast.registry import Registry @@ -204,27 +202,10 @@ def get_temp_entity_table_name() -> str: return "feast_entity_df_" + uuid.uuid4().hex -def get_offline_store_from_config(offline_store_config: Any,) -> OfflineStore: - """Get the offline store from offline store config""" - +def get_offline_store_from_config(offline_store_config: Any) -> OfflineStore: + """Creates an offline store corresponding to the given offline store config.""" module_name = offline_store_config.__module__ qualified_name = type(offline_store_config).__name__ - store_class_name = qualified_name.replace("Config", "") - try: - module = importlib.import_module(module_name) - except Exception as e: - # The original exception can be anything - either module not found, - # or any other kind of error happening during the module import time. - # So we should include the original error as well in the stack trace. - raise FeastModuleImportError(module_name, "OfflineStore") from e - - # Try getting the provider class definition - try: - offline_store_class = getattr(module, store_class_name) - except AttributeError: - # This can only be one type of error, when class_name attribute does not exist in the module - # So we don't have to include the original exception here - raise FeastClassImportError( - module_name, store_class_name, class_type="OfflineStore" - ) from None + class_name = qualified_name.replace("Config", "") + offline_store_class = import_class(module_name, class_name, "OfflineStore") return offline_store_class() diff --git a/sdk/python/feast/infra/online_stores/helpers.py b/sdk/python/feast/infra/online_stores/helpers.py index 5e01ddb263..b206c08b7c 100644 --- a/sdk/python/feast/infra/online_stores/helpers.py +++ b/sdk/python/feast/infra/online_stores/helpers.py @@ -1,10 +1,9 @@ -import importlib import struct from typing import Any, List import mmh3 -from feast import errors +from feast.importer import import_class from feast.infra.key_encoding_utils import ( serialize_entity_key, serialize_entity_key_prefix, @@ -13,29 +12,12 @@ from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto -def get_online_store_from_config(online_store_config: Any,) -> OnlineStore: - """Get the online store from online store config""" - +def get_online_store_from_config(online_store_config: Any) -> OnlineStore: + """Creates an online store corresponding to the given online store config.""" module_name = online_store_config.__module__ qualified_name = type(online_store_config).__name__ - store_class_name = qualified_name.replace("Config", "") - try: - module = importlib.import_module(module_name) - except Exception as e: - # The original exception can be anything - either module not found, - # or any other kind of error happening during the module import time. - # So we should include the original error as well in the stack trace. - raise errors.FeastModuleImportError(module_name, "OnlineStore") from e - - # Try getting the provider class definition - try: - online_store_class = getattr(module, store_class_name) - except AttributeError: - # This can only be one type of error, when class_name attribute does not exist in the module - # So we don't have to include the original exception here - raise errors.FeastClassImportError( - module_name, store_class_name, class_type="OnlineStore" - ) from None + class_name = qualified_name.replace("Config", "") + online_store_class = import_class(module_name, class_name, "OnlineStore") return online_store_class() diff --git a/sdk/python/feast/infra/provider.py b/sdk/python/feast/infra/provider.py index 32bca8f7d7..3c761f1195 100644 --- a/sdk/python/feast/infra/provider.py +++ b/sdk/python/feast/infra/provider.py @@ -8,9 +8,10 @@ import pyarrow from tqdm import tqdm -from feast import errors, importer +from feast import errors from feast.entity import Entity from feast.feature_view import DUMMY_ENTITY_ID, FeatureView +from feast.importer import import_class from feast.infra.offline_stores.offline_store import RetrievalJob from feast.on_demand_feature_view import OnDemandFeatureView from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto @@ -172,7 +173,7 @@ def get_provider(config: RepoConfig, repo_path: Path) -> Provider: # For example, provider 'foo.bar.MyProvider' will be parsed into 'foo.bar' and 'MyProvider' module_name, class_name = provider.rsplit(".", 1) - cls = importer.get_class_from_type(module_name, class_name, "Provider") + cls = import_class(module_name, class_name, "Provider") return cls(config) diff --git a/sdk/python/feast/registry.py b/sdk/python/feast/registry.py index d5c4c08048..0c058a0d46 100644 --- a/sdk/python/feast/registry.py +++ b/sdk/python/feast/registry.py @@ -23,7 +23,6 @@ from google.protobuf.json_format import MessageToDict from proto import Message -from feast import importer from feast.base_feature_view import BaseFeatureView from feast.diff.FcoDiff import ( FcoDiff, @@ -42,6 +41,7 @@ ) from feast.feature_service import FeatureService from feast.feature_view import FeatureView +from feast.importer import import_class from feast.infra.infra_object import Infra from feast.on_demand_feature_view import OnDemandFeatureView from feast.protos.feast.core.Registry_pb2 import Registry as RegistryProto @@ -75,9 +75,7 @@ def get_registry_store_class_from_type(registry_store_type: str): registry_store_type = REGISTRY_STORE_CLASS_FOR_TYPE[registry_store_type] module_name, registry_store_class_name = registry_store_type.rsplit(".", 1) - return importer.get_class_from_type( - module_name, registry_store_class_name, "RegistryStore" - ) + return import_class(module_name, registry_store_class_name, "RegistryStore") def get_registry_store_class_from_scheme(registry_path: str): diff --git a/sdk/python/feast/repo_config.py b/sdk/python/feast/repo_config.py index 70e64c845c..26309fe9d7 100644 --- a/sdk/python/feast/repo_config.py +++ b/sdk/python/feast/repo_config.py @@ -20,7 +20,7 @@ FeastFeatureServerTypeSetError, FeastProviderNotSetError, ) -from feast.importer import get_class_from_type +from feast.importer import import_class from feast.usage import log_exceptions # These dict exists so that: @@ -302,7 +302,7 @@ def __repr__(self) -> str: def get_data_source_class_from_type(data_source_type: str): module_name, config_class_name = data_source_type.rsplit(".", 1) - return get_class_from_type(module_name, config_class_name, "Source") + return import_class(module_name, config_class_name, "DataSource") def get_online_config_from_type(online_store_type: str): @@ -313,7 +313,7 @@ def get_online_config_from_type(online_store_type: str): module_name, online_store_class_type = online_store_type.rsplit(".", 1) config_class_name = f"{online_store_class_type}Config" - return get_class_from_type(module_name, config_class_name, config_class_name) + return import_class(module_name, config_class_name, config_class_name) def get_offline_config_from_type(offline_store_type: str): @@ -324,7 +324,7 @@ def get_offline_config_from_type(offline_store_type: str): module_name, offline_store_class_type = offline_store_type.rsplit(".", 1) config_class_name = f"{offline_store_class_type}Config" - return get_class_from_type(module_name, config_class_name, config_class_name) + return import_class(module_name, config_class_name, config_class_name) def get_feature_server_config_from_type(feature_server_type: str): @@ -334,7 +334,7 @@ def get_feature_server_config_from_type(feature_server_type: str): feature_server_type = FEATURE_SERVER_CONFIG_CLASS_FOR_TYPE[feature_server_type] module_name, config_class_name = feature_server_type.rsplit(".", 1) - return get_class_from_type(module_name, config_class_name, config_class_name) + return import_class(module_name, config_class_name, config_class_name) def load_repo_config(repo_path: Path) -> RepoConfig: diff --git a/sdk/python/tests/integration/registration/test_cli.py b/sdk/python/tests/integration/registration/test_cli.py index f05674ea5c..0fe73316ad 100644 --- a/sdk/python/tests/integration/registration/test_cli.py +++ b/sdk/python/tests/integration/registration/test_cli.py @@ -211,14 +211,14 @@ def test_3rd_party_providers() -> None: return_code, output = runner.run_with_output(["apply"], cwd=repo_path) assertpy.assert_that(return_code).is_equal_to(1) assertpy.assert_that(output).contains( - b"Could not import Provider module 'feast_foo'" + b"Could not import module 'feast_foo' while attempting to load class 'Provider'" ) # Check with incorrect third-party provider name (with dots) with setup_third_party_provider_repo("foo.FooProvider") as repo_path: return_code, output = runner.run_with_output(["apply"], cwd=repo_path) assertpy.assert_that(return_code).is_equal_to(1) assertpy.assert_that(output).contains( - b"Could not import Provider 'FooProvider' from module 'foo'" + b"Could not import class 'FooProvider' from module 'foo'" ) # Check with correct third-party provider name with setup_third_party_provider_repo("foo.provider.FooProvider") as repo_path: @@ -243,14 +243,14 @@ def test_3rd_party_registry_store() -> None: return_code, output = runner.run_with_output(["apply"], cwd=repo_path) assertpy.assert_that(return_code).is_equal_to(1) assertpy.assert_that(output).contains( - b"Could not import RegistryStore module 'feast_foo'" + b"Could not import module 'feast_foo' while attempting to load class 'RegistryStore'" ) # Check with incorrect third-party registry store name (with dots) with setup_third_party_registry_store_repo("foo.FooRegistryStore") as repo_path: return_code, output = runner.run_with_output(["apply"], cwd=repo_path) assertpy.assert_that(return_code).is_equal_to(1) assertpy.assert_that(output).contains( - b"Could not import RegistryStore 'FooRegistryStore' from module 'foo'" + b"Could not import class 'FooRegistryStore' from module 'foo'" ) # Check with correct third-party registry store name with setup_third_party_registry_store_repo(