Skip to content

Commit

Permalink
Add support for third party providers (#1501)
Browse files Browse the repository at this point in the history
* Add support for third party providers

Signed-off-by: Tsotne Tabidze <[email protected]>

* Add unit tests & assume providers without dots in name refers to builtin providers

Signed-off-by: Tsotne Tabidze <[email protected]>
  • Loading branch information
Tsotne Tabidze authored and woop committed Apr 27, 2021
1 parent f22ce2b commit 75c9f3f
Show file tree
Hide file tree
Showing 5 changed files with 197 additions and 8 deletions.
17 changes: 17 additions & 0 deletions sdk/python/feast/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,20 @@ def __init__(self, name, project=None):

class FeastProviderLoginError(Exception):
"""Error class that indicates a user has not authenticated with their provider."""


class FeastProviderNotImplementedError(Exception):
def __init__(self, provider_name):
super().__init__(f"Provider '{provider_name}' is not implemented")


class FeastProviderModuleImportError(Exception):
def __init__(self, module_name):
super().__init__(f"Could not import provider module '{module_name}'")


class FeastProviderClassImportError(Exception):
def __init__(self, module_name, class_name):
super().__init__(
f"Could not import provider '{class_name}' from module '{module_name}'"
)
42 changes: 35 additions & 7 deletions sdk/python/feast/infra/provider.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import abc
import importlib
from datetime import datetime
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

import pandas
import pyarrow

from feast import errors
from feast.entity import Entity
from feast.feature_table import FeatureTable
from feast.feature_view import FeatureView
Expand Down Expand Up @@ -135,16 +137,42 @@ def online_read(


def get_provider(config: RepoConfig, repo_path: Path) -> Provider:
if config.provider == "gcp":
from feast.infra.gcp import GcpProvider
if "." not in config.provider:
if config.provider == "gcp":
from feast.infra.gcp import GcpProvider

return GcpProvider(config)
elif config.provider == "local":
from feast.infra.local import LocalProvider
return GcpProvider(config)
elif config.provider == "local":
from feast.infra.local import LocalProvider

return LocalProvider(config, repo_path)
return LocalProvider(config, repo_path)
else:
raise errors.FeastProviderNotImplementedError(config.provider)
else:
raise ValueError(config)
# Split provider into module and class names by finding the right-most dot.
# For example, provider 'foo.bar.MyProvider' will be parsed into 'foo.bar' and 'MyProvider'
module_name, class_name = config.provider.rsplit(".", 1)

# Try importing the module that contains the custom provider
try:
module = importlib.import_module(module_name)
except Exception as e:
# The original exception can be anything - either module not found,
# or any other kind of error happening during the module import time.
# So we should include the original error as well in the stack trace.
raise errors.FeastProviderModuleImportError(module_name) from e

# Try getting the provider class definition
try:
ProviderCls = getattr(module, class_name)
except AttributeError:
# This can only be one type of error, when class_name attribute does not exist in the module
# So we don't have to include the original exception here
raise errors.FeastProviderClassImportError(
module_name, class_name
) from None

return ProviderCls(config, repo_path)


def _get_requested_feature_views_to_features_dict(
Expand Down
15 changes: 14 additions & 1 deletion sdk/python/tests/cli_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from contextlib import contextmanager
from pathlib import Path
from textwrap import dedent
from typing import List
from typing import List, Tuple

from feast import cli
from feast.feature_store import FeatureStore
Expand All @@ -26,6 +26,19 @@ class CliRunner:
def run(self, args: List[str], cwd: Path) -> subprocess.CompletedProcess:
return subprocess.run([sys.executable, cli.__file__] + args, cwd=cwd)

def run_with_output(self, args: List[str], cwd: Path) -> Tuple[int, bytes]:
try:
return (
0,
subprocess.check_output(
[sys.executable, cli.__file__] + args,
cwd=cwd,
stderr=subprocess.STDOUT,
),
)
except subprocess.CalledProcessError as e:
return e.returncode, e.output

@contextmanager
def local_repo(self, example_repo_py: str):
"""
Expand Down
75 changes: 75 additions & 0 deletions sdk/python/tests/foo_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

import pandas

from feast import Entity, FeatureTable, FeatureView, RepoConfig
from feast.infra.offline_stores.offline_store import RetrievalJob
from feast.infra.provider import Provider
from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto
from feast.protos.feast.types.Value_pb2 import Value as ValueProto
from feast.registry import Registry


class FooProvider(Provider):
def update_infra(
self,
project: str,
tables_to_delete: Sequence[Union[FeatureTable, FeatureView]],
tables_to_keep: Sequence[Union[FeatureTable, FeatureView]],
entities_to_delete: Sequence[Entity],
entities_to_keep: Sequence[Entity],
partial: bool,
):
pass

def teardown_infra(
self,
project: str,
tables: Sequence[Union[FeatureTable, FeatureView]],
entities: Sequence[Entity],
):
pass

def online_write_batch(
self,
project: str,
table: Union[FeatureTable, FeatureView],
data: List[
Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]]
],
progress: Optional[Callable[[int], Any]],
) -> None:
pass

def materialize_single_feature_view(
self,
feature_view: FeatureView,
start_date: datetime,
end_date: datetime,
registry: Registry,
project: str,
) -> None:
pass

@staticmethod
def get_historical_features(
config: RepoConfig,
feature_views: List[FeatureView],
feature_refs: List[str],
entity_df: Union[pandas.DataFrame, str],
registry: Registry,
project: str,
) -> RetrievalJob:
pass

def online_read(
self,
project: str,
table: Union[FeatureTable, FeatureView],
entity_keys: List[EntityKeyProto],
) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]:
pass

def __init__(self, config, repo_path):
pass
56 changes: 56 additions & 0 deletions sdk/python/tests/test_cli_local.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import tempfile
from contextlib import contextmanager
from pathlib import Path
from textwrap import dedent

Expand Down Expand Up @@ -110,3 +111,58 @@ def test_non_local_feature_repo() -> None:

result = runner.run(["teardown"], cwd=repo_path)
assert result.returncode == 0


@contextmanager
def setup_third_party_provider_repo(provider_name: str):
with tempfile.TemporaryDirectory() as repo_dir_name:

# Construct an example repo in a temporary dir
repo_path = Path(repo_dir_name)

repo_config = repo_path / "feature_store.yaml"

repo_config.write_text(
dedent(
f"""
project: foo
registry: data/registry.db
provider: {provider_name}
online_store:
path: data/online_store.db
type: sqlite
"""
)
)

(repo_path / "foo").mkdir()
repo_example = repo_path / "foo/provider.py"
repo_example.write_text((Path(__file__).parent / "foo_provider.py").read_text())

yield repo_path


def test_3rd_party_providers() -> None:
"""
Test running apply on third party providers
"""
runner = CliRunner()
# Check with incorrect built-in provider name (no dots)
with setup_third_party_provider_repo("feast123") as repo_path:
return_code, output = runner.run_with_output(["apply"], cwd=repo_path)
assert return_code == 1
assert b"Provider 'feast123' is not implemented" in output
# Check with incorrect third-party provider name (with dots)
with setup_third_party_provider_repo("feast_foo.provider") as repo_path:
return_code, output = runner.run_with_output(["apply"], cwd=repo_path)
assert return_code == 1
assert b"Could not import provider module 'feast_foo'" in output
# Check with incorrect third-party provider name (with dots)
with setup_third_party_provider_repo("foo.provider") as repo_path:
return_code, output = runner.run_with_output(["apply"], cwd=repo_path)
assert return_code == 1
assert b"Could not import provider 'provider' from module 'foo'" in output
# Check with correct third-party provider name
with setup_third_party_provider_repo("foo.provider.FooProvider") as repo_path:
return_code, output = runner.run_with_output(["apply"], cwd=repo_path)
assert return_code == 0

0 comments on commit 75c9f3f

Please sign in to comment.