diff --git a/sdk/python/feast/feature_store.py b/sdk/python/feast/feature_store.py index 2b5d9ec4c7..d26021829a 100644 --- a/sdk/python/feast/feature_store.py +++ b/sdk/python/feast/feature_store.py @@ -110,16 +110,19 @@ def refresh_registry(self): ) self._registry.refresh() - def list_entities(self) -> List[Entity]: + def list_entities(self, allow_cache: bool = False) -> List[Entity]: """ Retrieve a list of entities from the registry + Args: + allow_cache (bool): Whether to allow returning entities from a cached registry + Returns: List of entities """ self._tele.log("list_entities") - return self._registry.list_entities(self.project) + return self._registry.list_entities(self.project, allow_cache=allow_cache) def list_feature_views(self) -> List[FeatureView]: """ @@ -418,9 +421,11 @@ def get_online_features( """ self._tele.log("get_online_features") - entity_name_to_join_key_map = self._registry.get_entity_name_to_join_key_map( - self.project, allow_cache=True - ) + provider = self._get_provider() + entities = self.list_entities(allow_cache=True) + entity_name_to_join_key_map = {} + for entity in entities: + entity_name_to_join_key_map[entity.name] = entity.join_key join_key_rows = [] for row in entity_rows: @@ -435,31 +440,14 @@ def get_online_features( join_key_row[join_key] = entity_value join_key_rows.append(join_key_row) - response = self._get_online_features( - feature_refs=feature_refs, - entity_rows=_infer_online_entity_rows(join_key_rows), - project=self.config.project, - join_key_map=entity_name_to_join_key_map, - ) - - return OnlineResponse(response) - - def _get_online_features( - self, - entity_rows: List[GetOnlineFeaturesRequestV2.EntityRow], - feature_refs: List[str], - project: str, - join_key_map: Dict[str, str], - ) -> GetOnlineFeaturesResponse: - - provider = self._get_provider() + entity_row_proto_list = _infer_online_entity_rows(join_key_rows) union_of_entity_keys = [] result_rows: List[GetOnlineFeaturesResponse.FieldValues] = [] - for row in entity_rows: - union_of_entity_keys.append(_entity_row_to_key(row)) - result_rows.append(_entity_row_to_field_values(row)) + for entity_row_proto in entity_row_proto_list: + union_of_entity_keys.append(_entity_row_to_key(entity_row_proto)) + result_rows.append(_entity_row_to_field_values(entity_row_proto)) all_feature_views = self._registry.list_feature_views( project=self.config.project, allow_cache=True @@ -468,10 +456,10 @@ def _get_online_features( grouped_refs = _group_refs(feature_refs, all_feature_views) for table, requested_features in grouped_refs: entity_keys = _get_table_entity_keys( - table, union_of_entity_keys, join_key_map + table, union_of_entity_keys, entity_name_to_join_key_map ) read_rows = provider.online_read( - project=project, table=table, entity_keys=entity_keys, + project=self.project, table=table, entity_keys=entity_keys, ) for row_idx, read_row in enumerate(read_rows): row_ts, feature_data = read_row @@ -494,7 +482,7 @@ def _get_online_features( feature_ref ] = GetOnlineFeaturesResponse.FieldStatus.PRESENT - return GetOnlineFeaturesResponse(field_values=result_rows) + return OnlineResponse(GetOnlineFeaturesResponse(field_values=result_rows)) def _entity_row_to_key(row: GetOnlineFeaturesRequestV2.EntityRow) -> EntityKeyProto: diff --git a/sdk/python/feast/registry.py b/sdk/python/feast/registry.py index 29647718e5..03a214edff 100644 --- a/sdk/python/feast/registry.py +++ b/sdk/python/feast/registry.py @@ -17,7 +17,7 @@ from datetime import datetime, timedelta from pathlib import Path from tempfile import TemporaryFile -from typing import Callable, Dict, List, Optional +from typing import Callable, List, Optional from urllib.parse import urlparse from google.auth.exceptions import DefaultCredentialsError @@ -92,6 +92,7 @@ def list_entities(self, project: str, allow_cache: bool = False) -> List[Entity] Retrieve a list of entities from the registry Args: + allow_cache: Whether to allow returning entities from a cached registry project: Filter entities based on project name Returns: @@ -307,15 +308,6 @@ def updater(registry_proto: RegistryProto): self._registry_store.update_registry_proto(updater) - def get_entity_name_to_join_key_map( - self, project: str, allow_cache: bool = False - ) -> Dict[str, str]: - entities = self.list_entities(project, allow_cache=allow_cache) - entity_name_to_join_key_map = {} - for entity in entities: - entity_name_to_join_key_map[entity.name] = entity.join_key - return entity_name_to_join_key_map - def refresh(self): """Refreshes the state of the registry cache by fetching the registry state from the remote registry store.""" self._get_registry_proto(allow_cache=False)