From b8e39ea4cd2d990f2422c60bf39d8d940ecc9522 Mon Sep 17 00:00:00 2001 From: Achal Shah Date: Tue, 16 Aug 2022 16:28:49 -0700 Subject: [PATCH] fix: Register BatchFeatureView in feature repos correctly (#3092) * fix: Registry BatchFeatureView in feature repos correctly Signed-off-by: Achal Shah * tests Signed-off-by: Achal Shah Signed-off-by: Achal Shah --- sdk/python/feast/repo_operations.py | 9 ++++ .../example_feature_repo_with_bfvs.py | 52 +++++++++++++++++++ .../unit/local_feast_tests/test_e2e_local.py | 10 ++++ sdk/python/tests/utils/cli_repo_creator.py | 8 ++- 4 files changed, 77 insertions(+), 2 deletions(-) create mode 100644 sdk/python/tests/example_repos/example_feature_repo_with_bfvs.py diff --git a/sdk/python/feast/repo_operations.py b/sdk/python/feast/repo_operations.py index 007d87bf75..bbdde5bf3f 100644 --- a/sdk/python/feast/repo_operations.py +++ b/sdk/python/feast/repo_operations.py @@ -172,6 +172,15 @@ def parse_repo(repo_root: Path) -> RepoContents: assert stream_source if not any((stream_source is ds) for ds in res.data_sources): res.data_sources.append(stream_source) + elif isinstance(obj, BatchFeatureView) and not any( + (obj is bfv) for bfv in res.feature_views + ): + res.feature_views.append(obj) + + # Handle batch sources defined with feature views. + batch_source = obj.batch_source + if not any((batch_source is ds) for ds in res.data_sources): + res.data_sources.append(batch_source) elif isinstance(obj, Entity) and not any( (obj is entity) for entity in res.entities ): diff --git a/sdk/python/tests/example_repos/example_feature_repo_with_bfvs.py b/sdk/python/tests/example_repos/example_feature_repo_with_bfvs.py new file mode 100644 index 0000000000..e0f75c0c6f --- /dev/null +++ b/sdk/python/tests/example_repos/example_feature_repo_with_bfvs.py @@ -0,0 +1,52 @@ +from datetime import timedelta + +from feast import BatchFeatureView, Entity, Field, FileSource +from feast.types import Float32, Int32, Int64 + +driver_hourly_stats = FileSource( + path="%PARQUET_PATH%", # placeholder to be replaced by the test + timestamp_field="event_timestamp", + created_timestamp_column="created", +) + +driver = Entity( + name="driver_id", + description="driver id", +) + + +driver_hourly_stats_view = BatchFeatureView( + name="driver_hourly_stats", + entities=[driver], + ttl=timedelta(days=1), + schema=[ + Field(name="conv_rate", dtype=Float32), + Field(name="acc_rate", dtype=Float32), + Field(name="avg_daily_trips", dtype=Int64), + Field(name="driver_id", dtype=Int32), + ], + online=True, + source=driver_hourly_stats, + tags={}, +) + + +global_daily_stats = FileSource( + path="%PARQUET_PATH_GLOBAL%", # placeholder to be replaced by the test + timestamp_field="event_timestamp", + created_timestamp_column="created", +) + + +global_stats_feature_view = BatchFeatureView( + name="global_daily_stats", + entities=None, + ttl=timedelta(days=1), + schema=[ + Field(name="num_rides", dtype=Int32), + Field(name="avg_ride_length", dtype=Float32), + ], + online=True, + source=global_daily_stats, + tags={}, +) diff --git a/sdk/python/tests/unit/local_feast_tests/test_e2e_local.py b/sdk/python/tests/unit/local_feast_tests/test_e2e_local.py index fe6c187835..1ead69f52a 100644 --- a/sdk/python/tests/unit/local_feast_tests/test_e2e_local.py +++ b/sdk/python/tests/unit/local_feast_tests/test_e2e_local.py @@ -51,6 +51,16 @@ def test_e2e_local() -> None: runner, store, start_date, end_date, driver_df ) + with runner.local_repo( + get_example_repo("example_feature_repo_with_bfvs.py") + .replace("%PARQUET_PATH%", driver_stats_path) + .replace("%PARQUET_PATH_GLOBAL%", global_stats_path), + "file", + ) as store: + _test_materialize_and_online_retrieval( + runner, store, start_date, end_date, driver_df + ) + with runner.local_repo( get_example_repo("example_feature_repo_with_ttl_0.py") .replace("%PARQUET_PATH%", driver_stats_path) diff --git a/sdk/python/tests/utils/cli_repo_creator.py b/sdk/python/tests/utils/cli_repo_creator.py index 66f67384f9..92b6dd992a 100644 --- a/sdk/python/tests/utils/cli_repo_creator.py +++ b/sdk/python/tests/utils/cli_repo_creator.py @@ -88,7 +88,9 @@ def local_repo(self, example_repo_py: str, offline_store: str): stderr = result.stderr.decode("utf-8") print(f"Apply stdout:\n{stdout}") print(f"Apply stderr:\n{stderr}") - assert result.returncode == 0 + assert ( + result.returncode == 0 + ), f"stdout: {result.stdout}\nstderr: {result.stderr}" yield FeatureStore(repo_path=str(repo_path), config=None) @@ -97,4 +99,6 @@ def local_repo(self, example_repo_py: str, offline_store: str): stderr = result.stderr.decode("utf-8") print(f"Apply stdout:\n{stdout}") print(f"Apply stderr:\n{stderr}") - assert result.returncode == 0 + assert ( + result.returncode == 0 + ), f"stdout: {result.stdout}\nstderr: {result.stderr}"