diff --git a/sdk/python/feast/infra/provider.py b/sdk/python/feast/infra/provider.py index f8bd52a6e4..5546984f71 100644 --- a/sdk/python/feast/infra/provider.py +++ b/sdk/python/feast/infra/provider.py @@ -20,6 +20,12 @@ from feast.repo_config import RepoConfig from feast.type_map import python_value_to_proto_value +PROVIDERS_CLASS_FOR_TYPE = { + "gcp": "feast.infra.gcp.GcpProvider", + "aws": "feast.infra.aws.AwsProvider", + "local": "feast.infra.local.LocalProvider", +} + class Provider(abc.ABC): @abc.abstractmethod @@ -158,30 +164,20 @@ def get_feature_server_endpoint(self) -> Optional[str]: def get_provider(config: RepoConfig, repo_path: Path) -> Provider: if "." not in config.provider: - if config.provider in {"gcp", "aws", "local"}: - if config.provider == "aws": - from feast.infra.aws import AwsProvider - - return AwsProvider(config) - - if config.provider == "gcp": - from feast.infra.gcp import GcpProvider - - return GcpProvider(config) - - from feast.infra.local import LocalProvider - - return LocalProvider(config) - else: + if config.provider not in PROVIDERS_CLASS_FOR_TYPE: raise errors.FeastProviderNotImplementedError(config.provider) + + provider = PROVIDERS_CLASS_FOR_TYPE[config.provider] else: - # Split provider into module and class names by finding the right-most dot. - # For example, provider 'foo.bar.MyProvider' will be parsed into 'foo.bar' and 'MyProvider' - module_name, class_name = config.provider.rsplit(".", 1) + provider = config.provider + + # Split provider into module and class names by finding the right-most dot. + # For example, provider 'foo.bar.MyProvider' will be parsed into 'foo.bar' and 'MyProvider' + module_name, class_name = provider.rsplit(".", 1) - cls = importer.get_class_from_type(module_name, class_name, "Provider") + cls = importer.get_class_from_type(module_name, class_name, "Provider") - return cls(config) + return cls(config) def _get_requested_feature_views_to_features_dict(