From 26f6b69b0e2c8a4ea37b43e3d1eaa4cdb8c085a9 Mon Sep 17 00:00:00 2001 From: Felix Wang Date: Tue, 26 Jul 2022 15:47:15 -0700 Subject: [PATCH] fix: Fix file offline store logic for feature views without ttl (#2971) * Add new test for historical retrieval with feature views with no ttl Signed-off-by: Felix Wang * Fix no ttl logic Signed-off-by: Felix Wang --- sdk/python/feast/infra/offline_stores/file.py | 8 ++ .../test_universal_historical_retrieval.py | 103 ++++++++++++++++-- 2 files changed, 104 insertions(+), 7 deletions(-) diff --git a/sdk/python/feast/infra/offline_stores/file.py b/sdk/python/feast/infra/offline_stores/file.py index 1af98c1437..829bd36c3d 100644 --- a/sdk/python/feast/infra/offline_stores/file.py +++ b/sdk/python/feast/infra/offline_stores/file.py @@ -635,6 +635,14 @@ def _filter_ttl( ) ] + df_to_join = df_to_join.persist() + else: + df_to_join = df_to_join[ + # do not drop entity rows if one of the sources returns NaNs + df_to_join[timestamp_field].isna() + | (df_to_join[timestamp_field] <= df_to_join[entity_df_event_timestamp_col]) + ] + df_to_join = df_to_join.persist() return df_to_join diff --git a/sdk/python/tests/integration/offline_store/test_universal_historical_retrieval.py b/sdk/python/tests/integration/offline_store/test_universal_historical_retrieval.py index 32e6e52d18..87bf59fe9f 100644 --- a/sdk/python/tests/integration/offline_store/test_universal_historical_retrieval.py +++ b/sdk/python/tests/integration/offline_store/test_universal_historical_retrieval.py @@ -115,13 +115,17 @@ def get_expected_training_df( entity_df.to_dict("records"), event_timestamp ) + # Set sufficiently large ttl that it effectively functions as infinite for the calculations below. + default_ttl = timedelta(weeks=52) + # Manually do point-in-time join of driver, customer, and order records against # the entity df for entity_row in entity_rows: customer_record = find_asof_record( customer_records, ts_key=customer_fv.batch_source.timestamp_field, - ts_start=entity_row[event_timestamp] - customer_fv.ttl, + ts_start=entity_row[event_timestamp] + - get_feature_view_ttl(customer_fv, default_ttl), ts_end=entity_row[event_timestamp], filter_keys=["customer_id"], filter_values=[entity_row["customer_id"]], @@ -129,7 +133,8 @@ def get_expected_training_df( driver_record = find_asof_record( driver_records, ts_key=driver_fv.batch_source.timestamp_field, - ts_start=entity_row[event_timestamp] - driver_fv.ttl, + ts_start=entity_row[event_timestamp] + - get_feature_view_ttl(driver_fv, default_ttl), ts_end=entity_row[event_timestamp], filter_keys=["driver_id"], filter_values=[entity_row["driver_id"]], @@ -137,7 +142,8 @@ def get_expected_training_df( order_record = find_asof_record( order_records, ts_key=customer_fv.batch_source.timestamp_field, - ts_start=entity_row[event_timestamp] - order_fv.ttl, + ts_start=entity_row[event_timestamp] + - get_feature_view_ttl(order_fv, default_ttl), ts_end=entity_row[event_timestamp], filter_keys=["customer_id", "driver_id"], filter_values=[entity_row["customer_id"], entity_row["driver_id"]], @@ -145,7 +151,8 @@ def get_expected_training_df( origin_record = find_asof_record( location_records, ts_key=location_fv.batch_source.timestamp_field, - ts_start=order_record[event_timestamp] - location_fv.ttl, + ts_start=order_record[event_timestamp] + - get_feature_view_ttl(location_fv, default_ttl), ts_end=order_record[event_timestamp], filter_keys=["location_id"], filter_values=[order_record["origin_id"]], @@ -153,7 +160,8 @@ def get_expected_training_df( destination_record = find_asof_record( location_records, ts_key=location_fv.batch_source.timestamp_field, - ts_start=order_record[event_timestamp] - location_fv.ttl, + ts_start=order_record[event_timestamp] + - get_feature_view_ttl(location_fv, default_ttl), ts_end=order_record[event_timestamp], filter_keys=["location_id"], filter_values=[order_record["destination_id"]], @@ -161,14 +169,16 @@ def get_expected_training_df( global_record = find_asof_record( global_records, ts_key=global_fv.batch_source.timestamp_field, - ts_start=order_record[event_timestamp] - global_fv.ttl, + ts_start=order_record[event_timestamp] + - get_feature_view_ttl(global_fv, default_ttl), ts_end=order_record[event_timestamp], ) field_mapping_record = find_asof_record( field_mapping_records, ts_key=field_mapping_fv.batch_source.timestamp_field, - ts_start=order_record[event_timestamp] - field_mapping_fv.ttl, + ts_start=order_record[event_timestamp] + - get_feature_view_ttl(field_mapping_fv, default_ttl), ts_end=order_record[event_timestamp], ) @@ -666,6 +676,78 @@ def test_historical_features_persisting( ) +@pytest.mark.integration +@pytest.mark.universal_offline_stores +@pytest.mark.parametrize("full_feature_names", [True, False], ids=lambda v: str(v)) +def test_historical_features_with_no_ttl( + environment, universal_data_sources, full_feature_names +): + store = environment.feature_store + + (entities, datasets, data_sources) = universal_data_sources + feature_views = construct_universal_feature_views(data_sources) + + # Remove ttls. + feature_views.customer.ttl = timedelta(seconds=0) + feature_views.order.ttl = timedelta(seconds=0) + feature_views.global_fv.ttl = timedelta(seconds=0) + feature_views.field_mapping.ttl = timedelta(seconds=0) + + store.apply([driver(), customer(), location(), *feature_views.values()]) + + entity_df = datasets.entity_df.drop( + columns=["order_id", "origin_id", "destination_id"] + ) + + job = store.get_historical_features( + entity_df=entity_df, + features=[ + "customer_profile:current_balance", + "customer_profile:avg_passenger_count", + "customer_profile:lifetime_trip_count", + "order:order_is_success", + "global_stats:num_rides", + "global_stats:avg_ride_length", + "field_mapping:feature_name", + ], + full_feature_names=full_feature_names, + ) + + event_timestamp = DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL + expected_df = get_expected_training_df( + datasets.customer_df, + feature_views.customer, + datasets.driver_df, + feature_views.driver, + datasets.orders_df, + feature_views.order, + datasets.location_df, + feature_views.location, + datasets.global_df, + feature_views.global_fv, + datasets.field_mapping_df, + feature_views.field_mapping, + entity_df, + event_timestamp, + full_feature_names, + ).drop( + columns=[ + response_feature_name("conv_rate_plus_100", full_feature_names), + response_feature_name("conv_rate_plus_100_rounded", full_feature_names), + response_feature_name("avg_daily_trips", full_feature_names), + response_feature_name("conv_rate", full_feature_names), + "origin__temperature", + "destination__temperature", + ] + ) + + assert_frame_equal( + expected_df, + job.to_df(), + keys=[event_timestamp, "driver_id", "customer_id"], + ) + + @pytest.mark.integration @pytest.mark.universal_offline_stores def test_historical_features_from_bigquery_sources_containing_backfills(environment): @@ -781,6 +863,13 @@ def response_feature_name(feature: str, full_feature_names: bool) -> str: return feature +def get_feature_view_ttl( + feature_view: FeatureView, default_ttl: timedelta +) -> timedelta: + """Returns the ttl of a feature view if it is non-zero. Otherwise returns the specified default.""" + return feature_view.ttl if feature_view.ttl else default_ttl + + def assert_feature_service_correctness( store, feature_service, full_feature_names, entity_df, expected_df, event_timestamp ):