From c5539fd9874fed3f69b0aaebc4d1d32e624bd041 Mon Sep 17 00:00:00 2001 From: Danny Chiao Date: Tue, 17 May 2022 15:53:14 -0400 Subject: [PATCH] fix: Fix on demand feature view crash from inference when it uses df.apply (#2713) * fix: Fix on demand feature view crash from inference when transformation uses df.apply Signed-off-by: Danny Chiao * Fix inference Signed-off-by: Danny Chiao * Fix test Signed-off-by: Danny Chiao --- sdk/python/feast/on_demand_feature_view.py | 18 ++++- .../on_demand_feature_view_repo.py | 48 ++++++++++++ .../integration/registration/test_cli.py | 31 ++++++++ .../integration/registration/test_registry.py | 74 +++++++++++++++++++ 4 files changed, 168 insertions(+), 3 deletions(-) create mode 100644 sdk/python/tests/example_repos/on_demand_feature_view_repo.py diff --git a/sdk/python/feast/on_demand_feature_view.py b/sdk/python/feast/on_demand_feature_view.py index 1cddc0b881..f2048e7f5e 100644 --- a/sdk/python/feast/on_demand_feature_view.py +++ b/sdk/python/feast/on_demand_feature_view.py @@ -1,8 +1,9 @@ import copy import functools import warnings +from datetime import datetime from types import MethodType -from typing import Dict, List, Optional, Type, Union +from typing import Any, Dict, List, Optional, Type, Union import dill import pandas as pd @@ -442,6 +443,15 @@ def infer_features(self): Raises: RegistryInferenceFailure: The set of features could not be inferred. """ + rand_df_value: Dict[str, Any] = { + "float": 1.0, + "int": 1, + "str": "hello world", + "bytes": str.encode("hello world"), + "bool": True, + "datetime64[ns]": datetime.utcnow(), + } + df = pd.DataFrame() for feature_view_projection in self.source_feature_view_projections.values(): for feature in feature_view_projection.features: @@ -449,11 +459,13 @@ def infer_features(self): df[f"{feature_view_projection.name}__{feature.name}"] = pd.Series( dtype=dtype ) - df[f"{feature.name}"] = pd.Series(dtype=dtype) + sample_val = rand_df_value[dtype] if dtype in rand_df_value else None + df[f"{feature.name}"] = pd.Series(data=sample_val, dtype=dtype) for request_data in self.source_request_sources.values(): for field in request_data.schema: dtype = feast_value_type_to_pandas_type(field.dtype.to_value_type()) - df[f"{field.name}"] = pd.Series(dtype=dtype) + sample_val = rand_df_value[dtype] if dtype in rand_df_value else None + df[f"{field.name}"] = pd.Series(sample_val, dtype=dtype) output_df: pd.DataFrame = self.udf.__call__(df) inferred_features = [] for f, dt in zip(output_df.columns, output_df.dtypes): diff --git a/sdk/python/tests/example_repos/on_demand_feature_view_repo.py b/sdk/python/tests/example_repos/on_demand_feature_view_repo.py new file mode 100644 index 0000000000..453158b9dc --- /dev/null +++ b/sdk/python/tests/example_repos/on_demand_feature_view_repo.py @@ -0,0 +1,48 @@ +from datetime import timedelta + +import pandas as pd + +from feast import FeatureView, Field, FileSource +from feast.on_demand_feature_view import on_demand_feature_view +from feast.types import Float32, String + +driver_stats = FileSource( + name="driver_stats_source", + path="data/driver_stats_lat_lon.parquet", + timestamp_field="event_timestamp", + created_timestamp_column="created", + description="A table describing the stats of a driver based on hourly logs", + owner="test2@gmail.com", +) + +driver_daily_features_view = FeatureView( + name="driver_daily_features", + entities=["driver"], + ttl=timedelta(seconds=8640000000), + schema=[ + Field(name="daily_miles_driven", dtype=Float32), + Field(name="lat", dtype=Float32), + Field(name="lon", dtype=Float32), + Field(name="string_feature", dtype=String), + ], + online=True, + source=driver_stats, + tags={"production": "True"}, + owner="test2@gmail.com", +) + + +@on_demand_feature_view( + sources=[driver_daily_features_view], + schema=[ + Field(name="first_char", dtype=String), + Field(name="concat_string", dtype=String), + ], +) +def location_features_from_push(inputs: pd.DataFrame) -> pd.DataFrame: + df = pd.DataFrame() + df["concat_string"] = inputs.apply( + lambda x: x.string_feature + "hello", axis=1 + ).astype("string") + df["first_char"] = inputs["string_feature"].str[:1].astype("string") + return df diff --git a/sdk/python/tests/integration/registration/test_cli.py b/sdk/python/tests/integration/registration/test_cli.py index ce23ed66a6..ecc17fc06c 100644 --- a/sdk/python/tests/integration/registration/test_cli.py +++ b/sdk/python/tests/integration/registration/test_cli.py @@ -201,6 +201,37 @@ def test_nullable_online_store(test_nullable_online_store) -> None: runner.run(["teardown"], cwd=repo_path) +@pytest.mark.integration +@pytest.mark.universal_offline_stores +def test_odfv_apply(environment) -> None: + project = f"test_odfv_apply{str(uuid.uuid4()).replace('-', '')[:8]}" + runner = CliRunner() + + with tempfile.TemporaryDirectory() as repo_dir_name: + try: + repo_path = Path(repo_dir_name) + feature_store_yaml = make_feature_store_yaml( + project, environment.test_repo_config, repo_path + ) + + repo_config = repo_path / "feature_store.yaml" + + repo_config.write_text(dedent(feature_store_yaml)) + + repo_example = repo_path / "example.py" + repo_example.write_text(get_example_repo("on_demand_feature_view_repo.py")) + result = runner.run(["apply"], cwd=repo_path) + assertpy.assert_that(result.returncode).is_equal_to(0) + + # entity & feature view list commands should succeed + result = runner.run(["entities", "list"], cwd=repo_path) + assertpy.assert_that(result.returncode).is_equal_to(0) + result = runner.run(["on-demand-feature-views", "list"], cwd=repo_path) + assertpy.assert_that(result.returncode).is_equal_to(0) + finally: + runner.run(["teardown"], cwd=repo_path) + + @contextmanager def setup_third_party_provider_repo(provider_name: str): with tempfile.TemporaryDirectory() as repo_dir_name: diff --git a/sdk/python/tests/integration/registration/test_registry.py b/sdk/python/tests/integration/registration/test_registry.py index f011d73d2d..46e9a19544 100644 --- a/sdk/python/tests/integration/registration/test_registry.py +++ b/sdk/python/tests/integration/registration/test_registry.py @@ -234,6 +234,80 @@ def test_apply_feature_view_success(test_registry): test_registry._get_registry_proto() +@pytest.mark.parametrize( + "test_registry", [lazy_fixture("local_registry")], +) +def test_apply_on_demand_feature_view_success(test_registry): + # Create Feature Views + driver_stats = FileSource( + name="driver_stats_source", + path="data/driver_stats_lat_lon.parquet", + timestamp_field="event_timestamp", + created_timestamp_column="created", + description="A table describing the stats of a driver based on hourly logs", + owner="test2@gmail.com", + ) + + driver_daily_features_view = FeatureView( + name="driver_daily_features", + entities=["driver"], + ttl=timedelta(seconds=8640000000), + schema=[ + Field(name="daily_miles_driven", dtype=Float32), + Field(name="lat", dtype=Float32), + Field(name="lon", dtype=Float32), + Field(name="string_feature", dtype=String), + ], + online=True, + source=driver_stats, + tags={"production": "True"}, + owner="test2@gmail.com", + ) + + @on_demand_feature_view( + sources=[driver_daily_features_view], + schema=[Field(name="first_char", dtype=String)], + ) + def location_features_from_push(inputs: pd.DataFrame) -> pd.DataFrame: + df = pd.DataFrame() + df["first_char"] = inputs["string_feature"].str[:1].astype("string") + return df + + project = "project" + + # Register Feature View + test_registry.apply_feature_view(location_features_from_push, project) + + feature_views = test_registry.list_on_demand_feature_views(project) + + # List Feature Views + assert ( + len(feature_views) == 1 + and feature_views[0].name == "location_features_from_push" + and feature_views[0].features[0].name == "first_char" + and feature_views[0].features[0].dtype == String + ) + + feature_view = test_registry.get_on_demand_feature_view( + "location_features_from_push", project + ) + assert ( + feature_view.name == "location_features_from_push" + and feature_view.features[0].name == "first_char" + and feature_view.features[0].dtype == String + ) + + test_registry.delete_feature_view("location_features_from_push", project) + feature_views = test_registry.list_on_demand_feature_views(project) + assert len(feature_views) == 0 + + test_registry.teardown() + + # Will try to reload registry, which will fail because the file has been deleted + with pytest.raises(FileNotFoundError): + test_registry._get_registry_proto() + + @pytest.mark.parametrize( "test_registry", [lazy_fixture("local_registry")], )