Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Python FeatureServer optimization #2202

Merged
merged 5 commits into from
Jan 18, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 5 additions & 11 deletions sdk/python/feast/feature_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import feast
from feast import proto_json
from feast.protos.feast.serving.ServingService_pb2 import GetOnlineFeaturesRequest
from feast.type_map import feast_value_type_to_python_type


def get_app(store: "feast.FeatureStore"):
Expand Down Expand Up @@ -41,16 +40,11 @@ def get_online_features(body=Depends(get_body)):
if any(batch_size != num_entities for batch_size in batch_sizes):
raise HTTPException(status_code=500, detail="Uneven number of columns")

entity_rows = [
{
k: feast_value_type_to_python_type(v.val[idx])
for k, v in request_proto.entities.items()
}
for idx in range(num_entities)
]

response_proto = store.get_online_features(
features, entity_rows, full_feature_names=full_feature_names
response_proto = store._get_online_features(
features,
request_proto.entities,
full_feature_names=full_feature_names,
native_entity_values=False,
).proto

# Convert the Protobuf object to JSON and return it
Expand Down
251 changes: 152 additions & 99 deletions sdk/python/feast/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
List,
NamedTuple,
Optional,
Sequence,
Set,
Tuple,
Union,
Expand Down Expand Up @@ -72,7 +73,7 @@
GetOnlineFeaturesResponse,
)
from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto
from feast.protos.feast.types.Value_pb2 import Value
from feast.protos.feast.types.Value_pb2 import RepeatedValue, Value
from feast.registry import Registry
from feast.repo_config import RepoConfig, load_repo_config
from feast.request_feature_view import RequestFeatureView
Expand Down Expand Up @@ -267,14 +268,18 @@ def _list_feature_views(
return feature_views

@log_exceptions_and_usage
def list_on_demand_feature_views(self) -> List[OnDemandFeatureView]:
def list_on_demand_feature_views(
self, allow_cache: bool = False
) -> List[OnDemandFeatureView]:
"""
Retrieves the list of on demand feature views from the registry.

Returns:
A list of on demand feature views.
"""
return self._registry.list_on_demand_feature_views(self.project)
return self._registry.list_on_demand_feature_views(
self.project, allow_cache=allow_cache
)

@log_exceptions_and_usage
def get_entity(self, name: str) -> Entity:
Expand Down Expand Up @@ -1067,6 +1072,25 @@ def get_online_features(
... )
>>> online_response_dict = online_response.to_dict()
"""
columnar = defaultdict(list)
for entity_row in entity_rows:
for key, value in entity_row.items():
columnar[key].append(value)

return self._get_online_features(
features=features,
entity_values=dict(columnar),
judahrand marked this conversation as resolved.
Show resolved Hide resolved
full_feature_names=full_feature_names,
native_entity_values=True,
)

def _get_online_features(
self,
features: Union[List[str], FeatureService],
entity_values: Dict[str, Union[Sequence[Any], Sequence[Value], RepeatedValue]],
adchia marked this conversation as resolved.
Show resolved Hide resolved
full_feature_names: bool = False,
native_entity_values: bool = True,
):
_feature_refs = self._get_features(features, allow_cache=True)
(
requested_feature_views,
Expand All @@ -1076,6 +1100,29 @@ def get_online_features(
features=features, allow_cache=True, hide_dummy_entity=False
)

entity_name_to_join_key_map, entity_type_map = self._get_entity_maps(
requested_feature_views
)

# Extract Sequence from RepeatedValue Protobuf.
entity_value_lists: Dict[str, Union[List[Any], List[Value]]] = {
k: list(v) if isinstance(v, Sequence) else list(v.val)
for k, v in entity_values.items()
}

entity_proto_values: Dict[str, List[Value]]
if native_entity_values:
# Convert values to Protobuf once.
entity_proto_values = {
k: python_values_to_proto_values(
v, entity_type_map.get(k, ValueType.UNKNOWN)
)
for k, v in entity_value_lists.items()
}
else:
entity_proto_values = entity_value_lists

num_rows = _validate_entity_values(entity_proto_values)
adchia marked this conversation as resolved.
Show resolved Hide resolved
_validate_feature_refs(_feature_refs, full_feature_names)
(
grouped_refs,
Expand All @@ -1101,111 +1148,72 @@ def get_online_features(
}

feature_views = list(view for view, _ in grouped_refs)
entityless_case = DUMMY_ENTITY_NAME in [
entity_name
for feature_view in feature_views
for entity_name in feature_view.entities
]

provider = self._get_provider()
entities = self._list_entities(allow_cache=True, hide_dummy_entity=False)
entity_name_to_join_key_map: Dict[str, str] = {}
join_key_to_entity_type_map: Dict[str, ValueType] = {}
for entity in entities:
entity_name_to_join_key_map[entity.name] = entity.join_key
join_key_to_entity_type_map[entity.join_key] = entity.value_type
for feature_view in requested_feature_views:
for entity_name in feature_view.entities:
entity = self._registry.get_entity(
entity_name, self.project, allow_cache=True
)
# User directly uses join_key as the entity reference in the entity_rows for the
# entity mapping case.
entity_name = feature_view.projection.join_key_map.get(
entity.join_key, entity.name
)
join_key = feature_view.projection.join_key_map.get(
entity.join_key, entity.join_key
)
entity_name_to_join_key_map[entity_name] = join_key
join_key_to_entity_type_map[join_key] = entity.value_type

needed_request_data, needed_request_fv_features = self.get_needed_request_data(
grouped_odfv_refs, grouped_request_fv_refs
)

join_key_rows = []
request_data_features: Dict[str, List[Any]] = defaultdict(list)
join_key_values: Dict[str, List[Value]] = {}
request_data_features: Dict[str, List[Value]] = {}
# Entity rows may be either entities or request data.
for row in entity_rows:
join_key_row = {}
for entity_name, entity_value in row.items():
# Found request data
if (
entity_name in needed_request_data
or entity_name in needed_request_fv_features
):
if entity_name in needed_request_fv_features:
# If the data was requested as a feature then
# make sure it appears in the result.
requested_result_row_names.add(entity_name)
request_data_features[entity_name].append(entity_value)
else:
try:
join_key = entity_name_to_join_key_map[entity_name]
except KeyError:
raise EntityNotFoundException(entity_name, self.project)
# All join keys should be returned in the result.
requested_result_row_names.add(join_key)
join_key_row[join_key] = entity_value
if entityless_case:
join_key_row[DUMMY_ENTITY_ID] = DUMMY_ENTITY_VAL
if len(join_key_row) > 0:
# May be empty if this entity row was request data
join_key_rows.append(join_key_row)
for entity_name, values in entity_proto_values.items():
# Found request data
if (
entity_name in needed_request_data
or entity_name in needed_request_fv_features
):
if entity_name in needed_request_fv_features:
# If the data was requested as a feature then
# make sure it appears in the result.
requested_result_row_names.add(entity_name)
request_data_features[entity_name] = values
else:
try:
join_key = entity_name_to_join_key_map[entity_name]
except KeyError:
raise EntityNotFoundException(entity_name, self.project)
# All join keys should be returned in the result.
requested_result_row_names.add(join_key)
join_key_values[join_key] = values

self.ensure_request_data_values_exist(
needed_request_data, needed_request_fv_features, request_data_features
)

# Convert join_key_rows from rowise to columnar.
join_key_python_values: Dict[str, List[Value]] = defaultdict(list)
for join_key_row in join_key_rows:
for join_key, value in join_key_row.items():
join_key_python_values[join_key].append(value)

# Convert all join key values to Protobuf Values
join_key_proto_values = {
k: python_values_to_proto_values(v, join_key_to_entity_type_map[k])
for k, v in join_key_python_values.items()
}

# Populate online features response proto with join keys
# Populate online features response proto with join keys and request data features
online_features_response = GetOnlineFeaturesResponse(
results=[
GetOnlineFeaturesResponse.FeatureVector()
for _ in range(len(entity_rows))
]
results=[GetOnlineFeaturesResponse.FeatureVector() for _ in range(num_rows)]
)
for key, values in join_key_proto_values.items():
online_features_response.metadata.feature_names.val.append(key)
for row_idx, result_row in enumerate(online_features_response.results):
result_row.values.append(values[row_idx])
result_row.statuses.append(FieldStatus.PRESENT)
result_row.event_timestamps.append(Timestamp())
self._populate_result_rows_from_columnar(
online_features_response=online_features_response,
data=dict(**join_key_values, **request_data_features),
)

# Add the Entityless case after populating result rows to avoid having to remove
# it later.
entityless_case = DUMMY_ENTITY_NAME in [
entity_name
for feature_view in feature_views
for entity_name in feature_view.entities
]
if entityless_case:
join_key_values[DUMMY_ENTITY_ID] = python_values_to_proto_values(
[DUMMY_ENTITY_VAL] * num_rows, DUMMY_ENTITY.value_type
)

# Initialize the set of EntityKeyProtos once and reuse them for each FeatureView
# to avoid initialization overhead.
entity_keys = [EntityKeyProto() for _ in range(len(join_key_rows))]
entity_keys = [EntityKeyProto() for _ in range(num_rows)]
provider = self._get_provider()
for table, requested_features in grouped_refs:
# Get the correct set of entity values with the correct join keys.
entity_values = self._get_table_entity_values(
table, entity_name_to_join_key_map, join_key_proto_values,
table_entity_values = self._get_table_entity_values(
table, entity_name_to_join_key_map, join_key_values,
)

# Set the EntityKeyProtos inplace.
self._set_table_entity_keys(
entity_values, entity_keys,
table_entity_values, entity_keys,
)

# Populate the result_rows with the Features from the OnlineStore inplace.
Expand All @@ -1218,10 +1226,6 @@ def get_online_features(
table,
)

self._populate_request_data_features(
online_features_response, request_data_features
)

if grouped_odfv_refs:
self._augment_response_with_on_demand_transforms(
online_features_response,
Expand All @@ -1235,6 +1239,50 @@ def get_online_features(
)
return OnlineResponse(online_features_response)

@staticmethod
def _get_columnar_entity_values(
rowise: Optional[List[Dict[str, Any]]], columnar: Optional[Dict[str, List[Any]]]
) -> Dict[str, List[Any]]:
if (rowise is None and columnar is None) or (
rowise is not None and columnar is not None
):
raise ValueError(
"Exactly one of `columnar_entity_values` and `rowise_entity_values` must be set."
)

if rowise is not None:
# Convert entity_rows from rowise to columnar.
res = defaultdict(list)
for entity_row in rowise:
for key, value in entity_row.items():
res[key].append(value)
return res
return cast(Dict[str, List[Any]], columnar)

def _get_entity_maps(self, feature_views):
entities = self._list_entities(allow_cache=True, hide_dummy_entity=False)
entity_name_to_join_key_map: Dict[str, str] = {}
entity_type_map: Dict[str, ValueType] = {}
for entity in entities:
entity_name_to_join_key_map[entity.name] = entity.join_key
entity_type_map[entity.name] = entity.value_type
for feature_view in feature_views:
for entity_name in feature_view.entities:
entity = self._registry.get_entity(
entity_name, self.project, allow_cache=True
)
# User directly uses join_key as the entity reference in the entity_rows for the
# entity mapping case.
entity_name = feature_view.projection.join_key_map.get(
entity.join_key, entity.name
)
join_key = feature_view.projection.join_key_map.get(
entity.join_key, entity.join_key
)
entity_name_to_join_key_map[entity_name] = join_key
entity_type_map[join_key] = entity.value_type
return entity_name_to_join_key_map, entity_type_map

@staticmethod
def _get_table_entity_values(
table: FeatureView,
Expand Down Expand Up @@ -1275,23 +1323,21 @@ def _set_table_entity_keys(
entity_key.entity_values.extend(next(rowise_values))

@staticmethod
def _populate_request_data_features(
def _populate_result_rows_from_columnar(
online_features_response: GetOnlineFeaturesResponse,
request_data_features: Dict[str, List[Any]],
data: Dict[str, List[Value]],
):
# Add more feature values to the existing result rows for the request data features
for feature_name, feature_values in request_data_features.items():
proto_values = python_values_to_proto_values(
feature_values, ValueType.UNKNOWN
)
timestamp = Timestamp() # Only initialize this timestamp once.
# Add more values to the existing result rows
for feature_name, feature_values in data.items():

online_features_response.metadata.feature_names.val.append(feature_name)

for row_idx, proto_value in enumerate(proto_values):
for row_idx, proto_value in enumerate(feature_values):
result_row = online_features_response.results[row_idx]
result_row.values.append(proto_value)
result_row.statuses.append(FieldStatus.PRESENT)
result_row.event_timestamps.append(Timestamp())
result_row.event_timestamps.append(timestamp)

@staticmethod
def get_needed_request_data(
Expand Down Expand Up @@ -1567,6 +1613,13 @@ def serve_transformations(self, port: int) -> None:
transformation_server.start_server(self, port)


def _validate_entity_values(join_key_values: Dict[str, List[Value]]):
set_of_row_lengths = {len(v) for v in join_key_values.values()}
if len(set_of_row_lengths) > 1:
raise ValueError("All entity rows must have the same columns.")
return set_of_row_lengths.pop()


def _validate_feature_refs(feature_refs: List[str], full_feature_names: bool = False):
collided_feature_refs = []

Expand Down