From 1f101b2bfd324f12e8b7ae18149406d782a52a33 Mon Sep 17 00:00:00 2001 From: Tsotne Tabidze Date: Fri, 23 Apr 2021 16:43:56 -0700 Subject: [PATCH 1/2] Add support for third party providers Signed-off-by: Tsotne Tabidze --- sdk/python/feast/errors.py | 21 +++++++++++++++++++++ sdk/python/feast/infra/provider.py | 27 ++++++++++++++++++++++++++- 2 files changed, 47 insertions(+), 1 deletion(-) diff --git a/sdk/python/feast/errors.py b/sdk/python/feast/errors.py index a09a6899c1..ca3fb0378d 100644 --- a/sdk/python/feast/errors.py +++ b/sdk/python/feast/errors.py @@ -30,3 +30,24 @@ def __init__(self, name, project=None): class FeastProviderLoginError(Exception): """Error class that indicates a user has not authenticated with their provider.""" + def __init__(self, project, name): + super().__init__(f"Feature table {name} does not exist in project {project}") + + +class ProviderNameParsingException(Exception): + def __init__(self, provider_name): + super().__init__( + f"Could not parse provider name '{provider_name}' into module and class names" + ) + + +class ProviderModuleImportError(Exception): + def __init__(self, module_name): + super().__init__(f"Could not import provider module '{module_name}'") + + +class ProviderClassImportError(Exception): + def __init__(self, module_name, class_name): + super().__init__( + f"Could not import provider '{class_name}' from module '{module_name}'" + ) diff --git a/sdk/python/feast/infra/provider.py b/sdk/python/feast/infra/provider.py index aacdbc90ce..1c7634fd45 100644 --- a/sdk/python/feast/infra/provider.py +++ b/sdk/python/feast/infra/provider.py @@ -1,4 +1,5 @@ import abc +import importlib from datetime import datetime from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union @@ -6,6 +7,7 @@ import pandas import pyarrow +from feast import errors from feast.entity import Entity from feast.feature_table import FeatureTable from feast.feature_view import FeatureView @@ -144,7 +146,30 @@ def get_provider(config: RepoConfig, repo_path: Path) -> Provider: return LocalProvider(config, repo_path) else: - raise ValueError(config) + if "." not in config.provider: + raise errors.ProviderNameParsingException(config.provider) + # Split provider into module and class names by finding the right-most dot. + # For example, provider 'foo.bar.MyProvider' will be parsed into 'foo.bar' and 'MyProvider' + module_name, class_name = config.provider.rsplit(".", 1) + + # Try importing the module that contains the custom provider + try: + module = importlib.import_module(module_name) + except Exception as e: + # The original exception can be anything - either module not found, + # or any other kind of error happening during the module import time. + # So we should include the original error as well in the stack trace. + raise errors.ProviderModuleImportError(module_name) from e + + # Try getting the provider class definition + try: + ProviderCls = getattr(module, class_name) + except AttributeError: + # This can only be one type of error, when class_name attribute does not exist in the module + # So we don't have to include the original exception here + raise errors.ProviderClassImportError(module_name, class_name) from None + + return ProviderCls(config) def _get_requested_feature_views_to_features_dict( From d3e145c78400fe25c3750a8ade2de3192b49f0f6 Mon Sep 17 00:00:00 2001 From: Tsotne Tabidze Date: Fri, 23 Apr 2021 17:48:11 -0700 Subject: [PATCH 2/2] Add unit tests & assume providers without dots in name refers to builtin providers Signed-off-by: Tsotne Tabidze --- sdk/python/feast/errors.py | 12 ++--- sdk/python/feast/infra/provider.py | 25 +++++----- sdk/python/tests/cli_utils.py | 15 +++++- sdk/python/tests/foo_provider.py | 75 ++++++++++++++++++++++++++++++ sdk/python/tests/test_cli_local.py | 56 ++++++++++++++++++++++ 5 files changed, 163 insertions(+), 20 deletions(-) create mode 100644 sdk/python/tests/foo_provider.py diff --git a/sdk/python/feast/errors.py b/sdk/python/feast/errors.py index ca3fb0378d..2b08868a45 100644 --- a/sdk/python/feast/errors.py +++ b/sdk/python/feast/errors.py @@ -30,23 +30,19 @@ def __init__(self, name, project=None): class FeastProviderLoginError(Exception): """Error class that indicates a user has not authenticated with their provider.""" - def __init__(self, project, name): - super().__init__(f"Feature table {name} does not exist in project {project}") -class ProviderNameParsingException(Exception): +class FeastProviderNotImplementedError(Exception): def __init__(self, provider_name): - super().__init__( - f"Could not parse provider name '{provider_name}' into module and class names" - ) + super().__init__(f"Provider '{provider_name}' is not implemented") -class ProviderModuleImportError(Exception): +class FeastProviderModuleImportError(Exception): def __init__(self, module_name): super().__init__(f"Could not import provider module '{module_name}'") -class ProviderClassImportError(Exception): +class FeastProviderClassImportError(Exception): def __init__(self, module_name, class_name): super().__init__( f"Could not import provider '{class_name}' from module '{module_name}'" diff --git a/sdk/python/feast/infra/provider.py b/sdk/python/feast/infra/provider.py index 1c7634fd45..325ea0433e 100644 --- a/sdk/python/feast/infra/provider.py +++ b/sdk/python/feast/infra/provider.py @@ -137,17 +137,18 @@ def online_read( def get_provider(config: RepoConfig, repo_path: Path) -> Provider: - if config.provider == "gcp": - from feast.infra.gcp import GcpProvider + if "." not in config.provider: + if config.provider == "gcp": + from feast.infra.gcp import GcpProvider - return GcpProvider(config) - elif config.provider == "local": - from feast.infra.local import LocalProvider + return GcpProvider(config) + elif config.provider == "local": + from feast.infra.local import LocalProvider - return LocalProvider(config, repo_path) + return LocalProvider(config, repo_path) + else: + raise errors.FeastProviderNotImplementedError(config.provider) else: - if "." not in config.provider: - raise errors.ProviderNameParsingException(config.provider) # Split provider into module and class names by finding the right-most dot. # For example, provider 'foo.bar.MyProvider' will be parsed into 'foo.bar' and 'MyProvider' module_name, class_name = config.provider.rsplit(".", 1) @@ -159,7 +160,7 @@ def get_provider(config: RepoConfig, repo_path: Path) -> Provider: # The original exception can be anything - either module not found, # or any other kind of error happening during the module import time. # So we should include the original error as well in the stack trace. - raise errors.ProviderModuleImportError(module_name) from e + raise errors.FeastProviderModuleImportError(module_name) from e # Try getting the provider class definition try: @@ -167,9 +168,11 @@ def get_provider(config: RepoConfig, repo_path: Path) -> Provider: except AttributeError: # This can only be one type of error, when class_name attribute does not exist in the module # So we don't have to include the original exception here - raise errors.ProviderClassImportError(module_name, class_name) from None + raise errors.FeastProviderClassImportError( + module_name, class_name + ) from None - return ProviderCls(config) + return ProviderCls(config, repo_path) def _get_requested_feature_views_to_features_dict( diff --git a/sdk/python/tests/cli_utils.py b/sdk/python/tests/cli_utils.py index bb90469e93..11de6ace80 100644 --- a/sdk/python/tests/cli_utils.py +++ b/sdk/python/tests/cli_utils.py @@ -6,7 +6,7 @@ from contextlib import contextmanager from pathlib import Path from textwrap import dedent -from typing import List +from typing import List, Tuple from feast import cli from feast.feature_store import FeatureStore @@ -26,6 +26,19 @@ class CliRunner: def run(self, args: List[str], cwd: Path) -> subprocess.CompletedProcess: return subprocess.run([sys.executable, cli.__file__] + args, cwd=cwd) + def run_with_output(self, args: List[str], cwd: Path) -> Tuple[int, bytes]: + try: + return ( + 0, + subprocess.check_output( + [sys.executable, cli.__file__] + args, + cwd=cwd, + stderr=subprocess.STDOUT, + ), + ) + except subprocess.CalledProcessError as e: + return e.returncode, e.output + @contextmanager def local_repo(self, example_repo_py: str): """ diff --git a/sdk/python/tests/foo_provider.py b/sdk/python/tests/foo_provider.py new file mode 100644 index 0000000000..658abbffd1 --- /dev/null +++ b/sdk/python/tests/foo_provider.py @@ -0,0 +1,75 @@ +from datetime import datetime +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union + +import pandas + +from feast import Entity, FeatureTable, FeatureView, RepoConfig +from feast.infra.offline_stores.offline_store import RetrievalJob +from feast.infra.provider import Provider +from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto +from feast.protos.feast.types.Value_pb2 import Value as ValueProto +from feast.registry import Registry + + +class FooProvider(Provider): + def update_infra( + self, + project: str, + tables_to_delete: Sequence[Union[FeatureTable, FeatureView]], + tables_to_keep: Sequence[Union[FeatureTable, FeatureView]], + entities_to_delete: Sequence[Entity], + entities_to_keep: Sequence[Entity], + partial: bool, + ): + pass + + def teardown_infra( + self, + project: str, + tables: Sequence[Union[FeatureTable, FeatureView]], + entities: Sequence[Entity], + ): + pass + + def online_write_batch( + self, + project: str, + table: Union[FeatureTable, FeatureView], + data: List[ + Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]] + ], + progress: Optional[Callable[[int], Any]], + ) -> None: + pass + + def materialize_single_feature_view( + self, + feature_view: FeatureView, + start_date: datetime, + end_date: datetime, + registry: Registry, + project: str, + ) -> None: + pass + + @staticmethod + def get_historical_features( + config: RepoConfig, + feature_views: List[FeatureView], + feature_refs: List[str], + entity_df: Union[pandas.DataFrame, str], + registry: Registry, + project: str, + ) -> RetrievalJob: + pass + + def online_read( + self, + project: str, + table: Union[FeatureTable, FeatureView], + entity_keys: List[EntityKeyProto], + ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]: + pass + + def __init__(self, config, repo_path): + pass diff --git a/sdk/python/tests/test_cli_local.py b/sdk/python/tests/test_cli_local.py index 6444492ad2..115cfaaf07 100644 --- a/sdk/python/tests/test_cli_local.py +++ b/sdk/python/tests/test_cli_local.py @@ -1,4 +1,5 @@ import tempfile +from contextlib import contextmanager from pathlib import Path from textwrap import dedent @@ -110,3 +111,58 @@ def test_non_local_feature_repo() -> None: result = runner.run(["teardown"], cwd=repo_path) assert result.returncode == 0 + + +@contextmanager +def setup_third_party_provider_repo(provider_name: str): + with tempfile.TemporaryDirectory() as repo_dir_name: + + # Construct an example repo in a temporary dir + repo_path = Path(repo_dir_name) + + repo_config = repo_path / "feature_store.yaml" + + repo_config.write_text( + dedent( + f""" + project: foo + registry: data/registry.db + provider: {provider_name} + online_store: + path: data/online_store.db + type: sqlite + """ + ) + ) + + (repo_path / "foo").mkdir() + repo_example = repo_path / "foo/provider.py" + repo_example.write_text((Path(__file__).parent / "foo_provider.py").read_text()) + + yield repo_path + + +def test_3rd_party_providers() -> None: + """ + Test running apply on third party providers + """ + runner = CliRunner() + # Check with incorrect built-in provider name (no dots) + with setup_third_party_provider_repo("feast123") as repo_path: + return_code, output = runner.run_with_output(["apply"], cwd=repo_path) + assert return_code == 1 + assert b"Provider 'feast123' is not implemented" in output + # Check with incorrect third-party provider name (with dots) + with setup_third_party_provider_repo("feast_foo.provider") as repo_path: + return_code, output = runner.run_with_output(["apply"], cwd=repo_path) + assert return_code == 1 + assert b"Could not import provider module 'feast_foo'" in output + # Check with incorrect third-party provider name (with dots) + with setup_third_party_provider_repo("foo.provider") as repo_path: + return_code, output = runner.run_with_output(["apply"], cwd=repo_path) + assert return_code == 1 + assert b"Could not import provider 'provider' from module 'foo'" in output + # Check with correct third-party provider name + with setup_third_party_provider_repo("foo.provider.FooProvider") as repo_path: + return_code, output = runner.run_with_output(["apply"], cwd=repo_path) + assert return_code == 0