Skip to content

Commit

Permalink
Simplify online retrieval code
Browse files Browse the repository at this point in the history
Signed-off-by: Willem Pienaar <[email protected]>
  • Loading branch information
woop committed Apr 10, 2021
1 parent 2bc9acb commit bd85028
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 39 deletions.
46 changes: 17 additions & 29 deletions sdk/python/feast/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,16 +110,19 @@ def refresh_registry(self):
)
self._registry.refresh()

def list_entities(self) -> List[Entity]:
def list_entities(self, allow_cache: bool = False) -> List[Entity]:
"""
Retrieve a list of entities from the registry
Args:
allow_cache (bool): Whether to allow returning entities from a cached registry
Returns:
List of entities
"""
self._tele.log("list_entities")

return self._registry.list_entities(self.project)
return self._registry.list_entities(self.project, allow_cache=allow_cache)

def list_feature_views(self) -> List[FeatureView]:
"""
Expand Down Expand Up @@ -418,9 +421,11 @@ def get_online_features(
"""
self._tele.log("get_online_features")

entity_name_to_join_key_map = self._registry.get_entity_name_to_join_key_map(
self.project, allow_cache=True
)
provider = self._get_provider()
entities = self.list_entities(allow_cache=True)
entity_name_to_join_key_map = {}
for entity in entities:
entity_name_to_join_key_map[entity.name] = entity.join_key

join_key_rows = []
for row in entity_rows:
Expand All @@ -435,31 +440,14 @@ def get_online_features(
join_key_row[join_key] = entity_value
join_key_rows.append(join_key_row)

response = self._get_online_features(
feature_refs=feature_refs,
entity_rows=_infer_online_entity_rows(join_key_rows),
project=self.config.project,
join_key_map=entity_name_to_join_key_map,
)

return OnlineResponse(response)

def _get_online_features(
self,
entity_rows: List[GetOnlineFeaturesRequestV2.EntityRow],
feature_refs: List[str],
project: str,
join_key_map: Dict[str, str],
) -> GetOnlineFeaturesResponse:

provider = self._get_provider()
entity_row_proto_list = _infer_online_entity_rows(join_key_rows)

union_of_entity_keys = []
result_rows: List[GetOnlineFeaturesResponse.FieldValues] = []

for row in entity_rows:
union_of_entity_keys.append(_entity_row_to_key(row))
result_rows.append(_entity_row_to_field_values(row))
for entity_row_proto in entity_row_proto_list:
union_of_entity_keys.append(_entity_row_to_key(entity_row_proto))
result_rows.append(_entity_row_to_field_values(entity_row_proto))

all_feature_views = self._registry.list_feature_views(
project=self.config.project, allow_cache=True
Expand All @@ -468,10 +456,10 @@ def _get_online_features(
grouped_refs = _group_refs(feature_refs, all_feature_views)
for table, requested_features in grouped_refs:
entity_keys = _get_table_entity_keys(
table, union_of_entity_keys, join_key_map
table, union_of_entity_keys, entity_name_to_join_key_map
)
read_rows = provider.online_read(
project=project, table=table, entity_keys=entity_keys,
project=self.project, table=table, entity_keys=entity_keys,
)
for row_idx, read_row in enumerate(read_rows):
row_ts, feature_data = read_row
Expand All @@ -494,7 +482,7 @@ def _get_online_features(
feature_ref
] = GetOnlineFeaturesResponse.FieldStatus.PRESENT

return GetOnlineFeaturesResponse(field_values=result_rows)
return OnlineResponse(GetOnlineFeaturesResponse(field_values=result_rows))


def _entity_row_to_key(row: GetOnlineFeaturesRequestV2.EntityRow) -> EntityKeyProto:
Expand Down
12 changes: 2 additions & 10 deletions sdk/python/feast/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from datetime import datetime, timedelta
from pathlib import Path
from tempfile import TemporaryFile
from typing import Callable, Dict, List, Optional
from typing import Callable, List, Optional
from urllib.parse import urlparse

from google.auth.exceptions import DefaultCredentialsError
Expand Down Expand Up @@ -92,6 +92,7 @@ def list_entities(self, project: str, allow_cache: bool = False) -> List[Entity]
Retrieve a list of entities from the registry
Args:
allow_cache: Whether to allow returning entities from a cached registry
project: Filter entities based on project name
Returns:
Expand Down Expand Up @@ -307,15 +308,6 @@ def updater(registry_proto: RegistryProto):

self._registry_store.update_registry_proto(updater)

def get_entity_name_to_join_key_map(
self, project: str, allow_cache: bool = False
) -> Dict[str, str]:
entities = self.list_entities(project, allow_cache=allow_cache)
entity_name_to_join_key_map = {}
for entity in entities:
entity_name_to_join_key_map[entity.name] = entity.join_key
return entity_name_to_join_key_map

def refresh(self):
"""Refreshes the state of the registry cache by fetching the registry state from the remote registry store."""
self._get_registry_proto(allow_cache=False)
Expand Down

0 comments on commit bd85028

Please sign in to comment.