Skip to content

Commit

Permalink
Optimize Python FeatureServer
Browse files Browse the repository at this point in the history
Signed-off-by: Judah Rand <[email protected]>
  • Loading branch information
judahrand committed Jan 7, 2022
1 parent f969e53 commit 179f272
Show file tree
Hide file tree
Showing 2 changed files with 147 additions and 108 deletions.
16 changes: 5 additions & 11 deletions sdk/python/feast/feature_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import feast
from feast import proto_json
from feast.protos.feast.serving.ServingService_pb2 import GetOnlineFeaturesRequest
from feast.type_map import feast_value_type_to_python_type


def get_app(store: "feast.FeatureStore"):
Expand Down Expand Up @@ -41,16 +40,11 @@ def get_online_features(body=Depends(get_body)):
if any(batch_size != num_entities for batch_size in batch_sizes):
raise HTTPException(status_code=500, detail="Uneven number of columns")

entity_rows = [
{
k: feast_value_type_to_python_type(v.val[idx])
for k, v in request_proto.entities.items()
}
for idx in range(num_entities)
]

response_proto = store.get_online_features(
features, entity_rows, full_feature_names=full_feature_names
response_proto = store._get_online_features(
features,
request_proto.entities,
full_feature_names=full_feature_names,
native_entity_values=False,
).proto

# Convert the Protobuf object to JSON and return it
Expand Down
239 changes: 142 additions & 97 deletions sdk/python/feast/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,14 +264,18 @@ def _list_feature_views(
return feature_views

@log_exceptions_and_usage
def list_on_demand_feature_views(self) -> List[OnDemandFeatureView]:
def list_on_demand_feature_views(
self, allow_cache: bool = False
) -> List[OnDemandFeatureView]:
"""
Retrieves the list of on demand feature views from the registry.
Returns:
A list of on demand feature views.
"""
return self._registry.list_on_demand_feature_views(self.project)
return self._registry.list_on_demand_feature_views(
self.project, allow_cache=allow_cache
)

@log_exceptions_and_usage
def get_entity(self, name: str) -> Entity:
Expand Down Expand Up @@ -1049,6 +1053,25 @@ def get_online_features(
... )
>>> online_response_dict = online_response.to_dict()
"""
columnar = defaultdict(list)
for entity_row in entity_rows:
for key, value in entity_row.items():
columnar[key].append(value)

return self._get_online_features(
features=features,
entity_values=columnar,
full_feature_names=full_feature_names,
native_entity_values=True,
)

def _get_online_features(
self,
features: Union[List[str], FeatureService],
entity_values: Dict[str, List[Any]],
full_feature_names: bool = False,
native_entity_values: bool = True,
):
_feature_refs = self._get_features(features, allow_cache=True)
(
requested_feature_views,
Expand All @@ -1058,6 +1081,22 @@ def get_online_features(
features=features, allow_cache=True, hide_dummy_entity=False
)

entity_name_to_join_key_map, entity_type_map = self._get_entity_maps(
requested_feature_views
)

if native_entity_values:
# Convert values to Protobuf once.
entity_proto_values: Dict[str, List[Value]] = {
k: python_values_to_proto_values(
v, entity_type_map.get(k, ValueType.UNKNOWN)
)
for k, v in entity_values.items()
}
else:
entity_proto_values = entity_values

num_rows = _validate_entity_values(entity_proto_values)
_validate_feature_refs(_feature_refs, full_feature_names)
(
grouped_refs,
Expand All @@ -1083,101 +1122,65 @@ def get_online_features(
}

feature_views = list(view for view, _ in grouped_refs)
entityless_case = DUMMY_ENTITY_NAME in [
entity_name
for feature_view in feature_views
for entity_name in feature_view.entities
]

provider = self._get_provider()
entities = self._list_entities(allow_cache=True, hide_dummy_entity=False)
entity_name_to_join_key_map: Dict[str, str] = {}
join_key_to_entity_type_map: Dict[str, ValueType] = {}
for entity in entities:
entity_name_to_join_key_map[entity.name] = entity.join_key
join_key_to_entity_type_map[entity.join_key] = entity.value_type
for feature_view in requested_feature_views:
for entity_name in feature_view.entities:
entity = self._registry.get_entity(
entity_name, self.project, allow_cache=True
)
# User directly uses join_key as the entity reference in the entity_rows for the
# entity mapping case.
entity_name = feature_view.projection.join_key_map.get(
entity.join_key, entity.name
)
join_key = feature_view.projection.join_key_map.get(
entity.join_key, entity.join_key
)
entity_name_to_join_key_map[entity_name] = join_key
join_key_to_entity_type_map[join_key] = entity.value_type

needed_request_data, needed_request_fv_features = self.get_needed_request_data(
grouped_odfv_refs, grouped_request_fv_refs
)

join_key_rows = []
request_data_features: Dict[str, List[Any]] = defaultdict(list)
join_key_values: Dict[str, List[Value]] = {}
request_data_features: Dict[str, List[Value]] = {}
# Entity rows may be either entities or request data.
for row in entity_rows:
join_key_row = {}
for entity_name, entity_value in row.items():
# Found request data
if (
entity_name in needed_request_data
or entity_name in needed_request_fv_features
):
if entity_name in needed_request_fv_features:
# If the data was requested as a feature then
# make sure it appears in the result.
requested_result_row_names.add(entity_name)
request_data_features[entity_name].append(entity_value)
else:
try:
join_key = entity_name_to_join_key_map[entity_name]
except KeyError:
raise EntityNotFoundException(entity_name, self.project)
# All join keys should be returned in the result.
requested_result_row_names.add(join_key)
join_key_row[join_key] = entity_value
if entityless_case:
join_key_row[DUMMY_ENTITY_ID] = DUMMY_ENTITY_VAL
if len(join_key_row) > 0:
# May be empty if this entity row was request data
join_key_rows.append(join_key_row)
for entity_name, values in entity_proto_values.items():
# Found request data
if (
entity_name in needed_request_data
or entity_name in needed_request_fv_features
):
if entity_name in needed_request_fv_features:
# If the data was requested as a feature then
# make sure it appears in the result.
requested_result_row_names.add(entity_name)
request_data_features[entity_name] = values
else:
try:
join_key = entity_name_to_join_key_map[entity_name]
except KeyError:
raise EntityNotFoundException(entity_name, self.project)
# All join keys should be returned in the result.
requested_result_row_names.add(join_key)
join_key_values[join_key] = values

self.ensure_request_data_values_exist(
needed_request_data, needed_request_fv_features, request_data_features
)

# Convert join_key_rows from rowise to columnar.
join_key_python_values: Dict[str, List[Value]] = defaultdict(list)
for join_key_row in join_key_rows:
for join_key, value in join_key_row.items():
join_key_python_values[join_key].append(value)

# Convert all join key values to Protobuf Values
join_key_proto_values = {
k: python_values_to_proto_values(v, join_key_to_entity_type_map[k])
for k, v in join_key_python_values.items()
}
# Populate result rows with join keys and request data features
result_rows = [GetOnlineFeaturesResponse.FieldValues() for _ in range(num_rows)]
self._populate_result_rows_from_columnar(
data=dict(**join_key_values, **request_data_features),
result_rows=result_rows,
)

# Populate result rows with join keys
result_rows = [
GetOnlineFeaturesResponse.FieldValues() for _ in range(len(entity_rows))
# Add the Entityless case after populating result rows to avoid having to remove
# it later.
entityless_case = DUMMY_ENTITY_NAME in [
entity_name
for feature_view in feature_views
for entity_name in feature_view.entities
]
for key, values in join_key_proto_values.items():
for row_idx, result_row in enumerate(result_rows):
result_row.fields[key].CopyFrom(values[row_idx])
result_row.statuses[key] = FieldStatus.PRESENT
if entityless_case:
join_key_values[DUMMY_ENTITY_ID] = python_values_to_proto_values(
[DUMMY_ENTITY_VAL] * num_rows, DUMMY_ENTITY.value_type
)

# Initialize the set of EntityKeyProtos once and reuse them for each FeatureView
# to avoid initialization overhead.
entity_keys = [EntityKeyProto() for _ in range(len(join_key_rows))]
entity_keys = [EntityKeyProto() for _ in range(num_rows)]
provider = self._get_provider()
for table, requested_features in grouped_refs:
# Get the correct set of entity values with the correct join keys.
entity_values = self._get_table_entity_values(
table, entity_name_to_join_key_map, join_key_proto_values,
table, entity_name_to_join_key_map, join_key_values,
)

# Set the EntityKeyProtos inplace.
Expand All @@ -1195,10 +1198,6 @@ def get_online_features(
table,
)

self._populate_request_data_features(
request_data_features, result_rows,
)

if grouped_odfv_refs:
self._augment_response_with_on_demand_transforms(
_feature_refs,
Expand All @@ -1207,11 +1206,55 @@ def get_online_features(
result_rows,
)

self._drop_unneeded_columns(
requested_result_row_names, result_rows,
)
self._drop_unneeded_columns(
requested_result_row_names, result_rows,
)
return OnlineResponse(GetOnlineFeaturesResponse(field_values=result_rows))

@staticmethod
def _get_columnar_entity_values(
rowise: Optional[List[Dict[str, Any]]], columnar: Optional[Dict[str, List[Any]]]
) -> Dict[str, List[Any]]:
if (rowise is None and columnar is None) or (
rowise is not None and columnar is not None
):
raise ValueError(
"Exactly one of `columnar_entity_values` and `rowise_entity_values` must be set."
)

if rowise is not None:
# Convert entity_rows from rowise to columnar.
res = defaultdict(list)
for entity_row in rowise:
for key, value in entity_row.items():
res[key].append(value)
return res
return cast(Dict[str, List[Any]], columnar)

def _get_entity_maps(self, feature_views):
entities = self._list_entities(allow_cache=True, hide_dummy_entity=False)
entity_name_to_join_key_map: Dict[str, str] = {}
entity_type_map: Dict[str, ValueType] = {}
for entity in entities:
entity_name_to_join_key_map[entity.name] = entity.join_key
entity_type_map[entity.name] = entity.value_type
for feature_view in feature_views:
for entity_name in feature_view.entities:
entity = self._registry.get_entity(
entity_name, self.project, allow_cache=True
)
# User directly uses join_key as the entity reference in the entity_rows for the
# entity mapping case.
entity_name = feature_view.projection.join_key_map.get(
entity.join_key, entity.name
)
join_key = feature_view.projection.join_key_map.get(
entity.join_key, entity.join_key
)
entity_name_to_join_key_map[entity_name] = join_key
entity_type_map[join_key] = entity.value_type
return entity_name_to_join_key_map, entity_type_map

@staticmethod
def _get_table_entity_values(
table: FeatureView,
Expand Down Expand Up @@ -1252,20 +1295,15 @@ def _set_table_entity_keys(
entity_key.entity_values.extend(next(rowise_values))

@staticmethod
def _populate_request_data_features(
request_data_features: Dict[str, List[Any]],
def _populate_result_rows_from_columnar(
data: Dict[str, List[Value]],
result_rows: List[GetOnlineFeaturesResponse.FieldValues],
):
# Add more feature values to the existing result rows for the request data features
for feature_name, feature_values in request_data_features.items():
proto_values = python_values_to_proto_values(
feature_values, ValueType.UNKNOWN
)

for row_idx, proto_value in enumerate(proto_values):
result_row = result_rows[row_idx]
result_row.fields[feature_name].CopyFrom(proto_value)
result_row.statuses[feature_name] = FieldStatus.PRESENT
# Add more values to the existing result rows
for name, values in data.items():
for row_idx, result_row in enumerate(result_rows):
result_row.fields[name].CopyFrom(values[row_idx])
result_row.statuses[name] = FieldStatus.PRESENT

@staticmethod
def get_needed_request_data(
Expand Down Expand Up @@ -1528,6 +1566,13 @@ def serve_transformations(self, port: int) -> None:
transformation_server.start_server(self, port)


def _validate_entity_values(join_key_values: Dict[str, List[Value]]):
set_of_row_lengths = {len(v) for v in join_key_values.values()}
if len(set_of_row_lengths) > 1:
raise ValueError("All entity rows must have the same columns.")
return set_of_row_lengths.pop()


def _validate_feature_refs(feature_refs: List[str], full_feature_names: bool = False):
collided_feature_refs = []

Expand Down

0 comments on commit 179f272

Please sign in to comment.