Skip to content

Commit

Permalink
generalize provider selection
Browse files Browse the repository at this point in the history
Signed-off-by: pyalex <[email protected]>
  • Loading branch information
pyalex committed Nov 5, 2021
1 parent 2b7efd5 commit bfa4b92
Showing 1 changed file with 16 additions and 20 deletions.
36 changes: 16 additions & 20 deletions sdk/python/feast/infra/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@
from feast.repo_config import RepoConfig
from feast.type_map import python_value_to_proto_value

PROVIDERS_CLASS_FOR_TYPE = {
"gcp": "feast.infra.gcp.GcpProvider",
"aws": "feast.infra.aws.AwsProvider",
"local": "feast.infra.local.LocalProvider",
}


class Provider(abc.ABC):
@abc.abstractmethod
Expand Down Expand Up @@ -158,30 +164,20 @@ def get_feature_server_endpoint(self) -> Optional[str]:

def get_provider(config: RepoConfig, repo_path: Path) -> Provider:
if "." not in config.provider:
if config.provider in {"gcp", "aws", "local"}:
if config.provider == "aws":
from feast.infra.aws import AwsProvider

return AwsProvider(config)

if config.provider == "gcp":
from feast.infra.gcp import GcpProvider

return GcpProvider(config)

from feast.infra.local import LocalProvider

return LocalProvider(config)
else:
if config.provider not in PROVIDERS_CLASS_FOR_TYPE:
raise errors.FeastProviderNotImplementedError(config.provider)

provider = PROVIDERS_CLASS_FOR_TYPE[config.provider]
else:
# Split provider into module and class names by finding the right-most dot.
# For example, provider 'foo.bar.MyProvider' will be parsed into 'foo.bar' and 'MyProvider'
module_name, class_name = config.provider.rsplit(".", 1)
provider = config.provider

# Split provider into module and class names by finding the right-most dot.
# For example, provider 'foo.bar.MyProvider' will be parsed into 'foo.bar' and 'MyProvider'
module_name, class_name = provider.rsplit(".", 1)

cls = importer.get_class_from_type(module_name, class_name, "Provider")
cls = importer.get_class_from_type(module_name, class_name, "Provider")

return cls(config)
return cls(config)


def _get_requested_feature_views_to_features_dict(
Expand Down

0 comments on commit bfa4b92

Please sign in to comment.