Skip to content

Commit

Permalink
Support on demand feature views in feature services
Browse files Browse the repository at this point in the history
Signed-off-by: Achal Shah <[email protected]>
  • Loading branch information
achals committed Sep 10, 2021
1 parent 2d52ce7 commit f2e6a90
Show file tree
Hide file tree
Showing 5 changed files with 156 additions and 5 deletions.
11 changes: 9 additions & 2 deletions sdk/python/feast/feature_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from feast.feature_table import FeatureTable
from feast.feature_view import FeatureView
from feast.feature_view_projection import FeatureViewProjection
from feast.on_demand_feature_view import OnDemandFeatureView
from feast.protos.feast.core.FeatureService_pb2 import (
FeatureService as FeatureServiceProto,
)
Expand Down Expand Up @@ -38,7 +39,9 @@ class FeatureService:
def __init__(
self,
name: str,
features: List[Union[FeatureTable, FeatureView, FeatureViewProjection]],
features: List[
Union[FeatureTable, FeatureView, OnDemandFeatureView, FeatureViewProjection]
],
tags: Optional[Dict[str, str]] = None,
description: Optional[str] = None,
):
Expand All @@ -51,7 +54,11 @@ def __init__(
self.name = name
self.features = []
for feature in features:
if isinstance(feature, FeatureTable) or isinstance(feature, FeatureView):
if (
isinstance(feature, FeatureTable)
or isinstance(feature, FeatureView)
or isinstance(feature, OnDemandFeatureView)
):
self.features.append(FeatureViewProjection.from_definition(feature))
elif isinstance(feature, FeatureViewProjection):
self.features.append(feature)
Expand Down
11 changes: 11 additions & 0 deletions sdk/python/feast/on_demand_feature_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from feast.errors import RegistryInferenceFailure
from feast.feature import Feature
from feast.feature_view import FeatureView
from feast.feature_view_projection import FeatureViewProjection
from feast.protos.feast.core.OnDemandFeatureView_pb2 import (
OnDemandFeatureView as OnDemandFeatureViewProto,
)
Expand Down Expand Up @@ -132,6 +133,16 @@ def get_transformed_features_df(
df_with_features.drop(columns=columns_to_cleanup, inplace=True)
return df_with_transformed_features

def __getitem__(self, item) -> FeatureViewProjection:
assert isinstance(item, list)

referenced_features = []
for feature in self.features:
if feature.name in item:
referenced_features.append(feature)

return FeatureViewProjection(self.name, referenced_features)

def infer_features_from_batch_source(self, config: RepoConfig):
"""
Infers the set of features associated to this feature view from the input source.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def conv_rate_plus_100_feature_view(
)


def create_driver_hourly_stats_feature_view(source, infer_features: bool = True):
def create_driver_hourly_stats_feature_view(source, infer_features: bool = False):
driver_stats_feature_view = FeatureView(
name="driver_stats",
entities=["driver"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pytz import utc

from feast import utils
from feast.feature_service import FeatureService
from feast.feature_view import FeatureView
from feast.infra.offline_stores.offline_utils import (
DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL,
Expand Down Expand Up @@ -183,9 +184,22 @@ def test_historical_features(environment, universal_data_sources, full_feature_n
feature_views["global"],
)

feature_service = FeatureService(
"convrate_plus100",
features=[feature_views["driver"][["conv_rate"]], feature_views["driver_odfv"]],
)

feast_objects = []
feast_objects.extend(
[customer_fv, driver_fv, driver_odfv, global_fv, driver(), customer()]
[
customer_fv,
driver_fv,
driver_odfv,
global_fv,
driver(),
customer(),
feature_service,
]
)
store.apply(feast_objects)

Expand Down Expand Up @@ -312,6 +326,14 @@ def test_historical_features(environment, universal_data_sources, full_feature_n
assert_frame_equal(
expected_df, actual_df_from_df_entities, check_dtype=False,
)
assert_feature_service_correctness(
store,
feature_service,
full_feature_names,
orders_df,
expected_df,
event_timestamp,
)

# on demand features is only plumbed through to to_df for now.
table_from_df_entities: pd.DataFrame = job_from_df.to_arrow().to_pandas()
Expand All @@ -330,3 +352,62 @@ def test_historical_features(environment, universal_data_sources, full_feature_n
.reset_index(drop=True)
)
assert_frame_equal(actual_df_from_df_entities_for_table, table_from_df_entities)


def response_feature_name(feature: str, full_feature_names: bool) -> str:
if (
feature in {"current_balance", "avg_passenger_count", "lifetime_trip_count"}
and full_feature_names
):
return f"customer_profile__{feature}"

if feature in {"conv_rate", "avg_daily_trips"} and full_feature_names:
return f"driver_stats__{feature}"

if feature in {"conv_rate_plus_100"} and full_feature_names:
return f"conv_rate_plus_100__{feature}"

if feature in {"num_rides", "avg_ride_length"} and full_feature_names:
return f"global_stats__{feature}"
return feature


def assert_feature_service_correctness(
store, feature_service, full_feature_names, orders_df, expected_df, event_timestamp
):

job_from_df = store.get_historical_features(
entity_df=orders_df,
features=feature_service,
full_feature_names=full_feature_names,
)

actual_df_from_df_entities = job_from_df.to_df()

expected_df: pd.DataFrame = (
expected_df.sort_values(
by=[event_timestamp, "order_id", "driver_id", "customer_id"]
)
.drop_duplicates()
.reset_index(drop=True)
)
expected_df = expected_df[
[
event_timestamp,
"order_id",
"driver_id",
"customer_id",
response_feature_name("conv_rate", full_feature_names),
"conv_rate_plus_100",
]
]
actual_df_from_df_entities = (
actual_df_from_df_entities[expected_df.columns]
.sort_values(by=[event_timestamp, "order_id", "driver_id", "customer_id"])
.drop_duplicates()
.reset_index(drop=True)
)

assert_frame_equal(
expected_df, actual_df_from_df_entities, check_dtype=False,
)
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pandas as pd
import pytest

from feast import FeatureService
from tests.integration.feature_repos.repo_configuration import (
construct_universal_feature_views,
)
Expand All @@ -17,9 +18,15 @@ def test_online_retrieval(environment, universal_data_sources, full_feature_name
fs = environment.feature_store
entities, datasets, data_sources = universal_data_sources
feature_views = construct_universal_feature_views(data_sources)

feature_service = FeatureService(
"convrate_plus100",
features=[feature_views["driver"][["conv_rate"]], feature_views["driver_odfv"]],
)

feast_objects = []
feast_objects.extend(feature_views.values())
feast_objects.extend([driver(), customer()])
feast_objects.extend([driver(), customer(), feature_service])
fs.apply(feast_objects)
fs.materialize(environment.start_date, environment.end_date)

Expand Down Expand Up @@ -114,6 +121,16 @@ def test_online_retrieval(environment, universal_data_sources, full_feature_name
][0]
)

assert_feature_service_correctness(
fs,
feature_service,
entity_rows,
full_feature_names,
drivers_df,
customers_df,
global_df,
)


def response_feature_name(feature: str, full_feature_names: bool) -> str:
if (
Expand Down Expand Up @@ -147,3 +164,38 @@ def get_latest_feature_values_from_dataframes(
latest_global_row = global_df.loc[global_df["event_timestamp"].idxmax()].to_dict()

return {**latest_customer_row, **latest_driver_row, **latest_global_row}


def assert_feature_service_correctness(
fs,
feature_service,
entity_rows,
full_feature_names,
drivers_df,
customers_df,
global_df,
):
feature_service_response = fs.get_online_features(
features=feature_service,
entity_rows=entity_rows,
full_feature_names=full_feature_names,
)
assert feature_service_response is not None

feature_service_online_features_dict = feature_service_response.to_dict()
feature_service_keys = feature_service_online_features_dict.keys()

assert (
len(feature_service_keys) == len(feature_service.features) + 2
) # Add two for the driver id and the customer id entity keys.

for i, entity_row in enumerate(entity_rows):
df_features = get_latest_feature_values_from_dataframes(
drivers_df, customers_df, global_df, entity_row
)
assert (
feature_service_online_features_dict[
response_feature_name("conv_rate_plus_100", full_feature_names)
][i]
== df_features["conv_rate"] + 100
)

0 comments on commit f2e6a90

Please sign in to comment.