diff --git a/sdk/python/feast/feature_server.py b/sdk/python/feast/feature_server.py index 9e3ec42177..f80d03dbcd 100644 --- a/sdk/python/feast/feature_server.py +++ b/sdk/python/feast/feature_server.py @@ -52,8 +52,8 @@ def get_online_features(body=Depends(get_body)): raise HTTPException(status_code=500, detail="Uneven number of columns") response_proto = store._get_online_features( - features, - request_proto.entities, + features=features, + entity_values=request_proto.entities, full_feature_names=full_feature_names, native_entity_values=False, ).proto diff --git a/sdk/python/feast/feature_store.py b/sdk/python/feast/feature_store.py index 63db151412..1aa4cef602 100644 --- a/sdk/python/feast/feature_store.py +++ b/sdk/python/feast/feature_store.py @@ -82,7 +82,10 @@ from feast.repo_contents import RepoContents from feast.request_feature_view import RequestFeatureView from feast.saved_dataset import SavedDataset, SavedDatasetStorage -from feast.type_map import python_values_to_proto_values +from feast.type_map import ( + feast_value_type_to_python_type, + python_values_to_proto_values, +) from feast.usage import log_exceptions, log_exceptions_and_usage, set_usage_attribute from feast.value_type import ValueType from feast.version import get_version @@ -135,6 +138,7 @@ def __init__( self._registry = Registry(registry_config, repo_path=self.repo_path) self._registry._initialize_registry() self._provider = get_provider(self.config, self.repo_path) + self._go_server = None @log_exceptions def version(self) -> str: @@ -1284,7 +1288,29 @@ def get_online_features( except KeyError as e: raise ValueError("All entity_rows must have the same keys.") from e - # If Go feature server is enabled, send request to it instead of going through a regular Python logic + return self._get_online_features( + features=features, + entity_values=columnar, + full_feature_names=full_feature_names, + native_entity_values=True, + ) + + def _get_online_features( + self, + features: Union[List[str], FeatureService], + entity_values: Mapping[ + str, Union[Sequence[Any], Sequence[Value], RepeatedValue] + ], + full_feature_names: bool = False, + native_entity_values: bool = True, + ): + # Extract Sequence from RepeatedValue Protobuf. + entity_value_lists: Dict[str, Union[List[Any], List[Value]]] = { + k: list(v) if isinstance(v, Sequence) else list(v.val) + for k, v in entity_values.items() + } + + # If Go feature server is enabled, send request to it instead of going through regular Python logic if self.config.go_feature_server: from feast.embedded_go.online_features_service import ( EmbeddedOnlineFeatureServer, @@ -1296,32 +1322,31 @@ def get_online_features( str(self.repo_path.absolute()), self.config, self ) + entity_native_values: Dict[str, List[Any]] + if not native_entity_values: + # Convert proto types to native types since Go feature server currently + # only handles native types. + # TODO(felixwang9817): Remove this logic once native types are supported. + entity_native_values = { + k: [ + feast_value_type_to_python_type(proto_value) + for proto_value in v + ] + for k, v in entity_value_lists.items() + } + else: + entity_native_values = entity_value_lists + return self._go_server.get_online_features( features_refs=features if isinstance(features, list) else [], feature_service=features if isinstance(features, FeatureService) else None, - entities=columnar, + entities=entity_native_values, request_data={}, # TODO: add request data parameter to public API full_feature_names=full_feature_names, ) - return self._get_online_features( - features=features, - entity_values=columnar, - full_feature_names=full_feature_names, - native_entity_values=True, - ) - - def _get_online_features( - self, - features: Union[List[str], FeatureService], - entity_values: Mapping[ - str, Union[Sequence[Any], Sequence[Value], RepeatedValue] - ], - full_feature_names: bool = False, - native_entity_values: bool = True, - ): _feature_refs = self._get_features(features, allow_cache=True) ( requested_feature_views, @@ -1344,12 +1369,6 @@ def _get_online_features( join_keys_set, ) = self._get_entity_maps(requested_feature_views) - # Extract Sequence from RepeatedValue Protobuf. - entity_value_lists: Dict[str, Union[List[Any], List[Value]]] = { - k: list(v) if isinstance(v, Sequence) else list(v.val) - for k, v in entity_values.items() - } - entity_proto_values: Dict[str, List[Value]] if native_entity_values: # Convert values to Protobuf once. diff --git a/sdk/python/tests/integration/feature_repos/repo_configuration.py b/sdk/python/tests/integration/feature_repos/repo_configuration.py index a99ba95953..4da02bcacb 100644 --- a/sdk/python/tests/integration/feature_repos/repo_configuration.py +++ b/sdk/python/tests/integration/feature_repos/repo_configuration.py @@ -118,6 +118,11 @@ IntegrationTestRepoConfig( online_store=REDIS_CONFIG, go_feature_server=True, ), + IntegrationTestRepoConfig( + online_store=REDIS_CONFIG, + python_feature_server=True, + go_feature_server=True, + ), ] ) full_repo_configs_module = os.environ.get(FULL_REPO_CONFIGS_MODULE_ENV_NAME)