diff --git a/sdk/python/tests/integration/e2e/test_universal_e2e.py b/sdk/python/tests/integration/e2e/test_universal_e2e.py index a58ea841d2..5c89d9b966 100644 --- a/sdk/python/tests/integration/e2e/test_universal_e2e.py +++ b/sdk/python/tests/integration/e2e/test_universal_e2e.py @@ -6,12 +6,24 @@ from pytz import utc from feast import FeatureStore, FeatureView -from tests.integration.feature_repos.test_repo_configuration import parametrize_e2e_test +from tests.integration.feature_repos.test_repo_configuration import ( + Environment, + parametrize_e2e_test, +) +from tests.integration.feature_repos.universal.entities import driver +from tests.integration.feature_repos.universal.feature_views import driver_feature_view @parametrize_e2e_test -def test_e2e_consistency(fs: FeatureStore): - run_offline_online_store_consistency_test(fs) +def test_e2e_consistency(test_environment: Environment): + fs, fv = ( + test_environment.feature_store, + driver_feature_view(test_environment.data_source), + ) + entity = driver() + fs.apply([fv, entity]) + + run_offline_online_store_consistency_test(fs, fv) def check_offline_and_online_features( @@ -63,10 +75,11 @@ def check_offline_and_online_features( assert math.isnan(df.to_dict()["value"][0]) -def run_offline_online_store_consistency_test(fs: FeatureStore,) -> None: +def run_offline_online_store_consistency_test( + fs: FeatureStore, fv: FeatureView +) -> None: now = datetime.utcnow() - fv = fs.get_feature_view("test_correctness") full_feature_names = True check_offline_store: bool = True diff --git a/sdk/python/tests/integration/feature_repos/test_repo_configuration.py b/sdk/python/tests/integration/feature_repos/test_repo_configuration.py index fec573abf9..c7a2046dca 100644 --- a/sdk/python/tests/integration/feature_repos/test_repo_configuration.py +++ b/sdk/python/tests/integration/feature_repos/test_repo_configuration.py @@ -1,24 +1,27 @@ import tempfile import uuid from contextlib import contextmanager +from dataclasses import dataclass, replace +from datetime import datetime, timedelta from pathlib import Path -from typing import Dict, List, Union +from typing import Dict, List, Optional, Union import pytest -from attr import dataclass -from feast import FeatureStore, RepoConfig, importer +from feast import FeatureStore, FeatureView, RepoConfig, driver_test_data, importer +from feast.data_source import DataSource from tests.data.data_creator import create_dataset from tests.integration.feature_repos.universal.data_source_creator import ( DataSourceCreator, ) -from tests.integration.feature_repos.universal.entities import driver +from tests.integration.feature_repos.universal.entities import customer, driver from tests.integration.feature_repos.universal.feature_views import ( - correctness_feature_view, + create_customer_daily_profile_feature_view, + create_driver_hourly_stats_feature_view, ) -@dataclass +@dataclass(frozen=True, repr=True) class TestRepoConfig: """ This class should hold all possible parameters that may need to be varied by individual tests. @@ -30,20 +33,21 @@ class TestRepoConfig: offline_store_creator: str = "tests.integration.feature_repos.universal.data_sources.file.FileDataSourceCreator" full_feature_names: bool = True + infer_event_timestamp_col: bool = True FULL_REPO_CONFIGS: List[TestRepoConfig] = [ TestRepoConfig(), # Local - TestRepoConfig( - provider="aws", - offline_store_creator="tests.integration.feature_repos.universal.data_sources.redshift.RedshiftDataSourceCreator", - online_store={"type": "dynamodb", "region": "us-west-2"}, - ), TestRepoConfig( provider="gcp", offline_store_creator="tests.integration.feature_repos.universal.data_sources.bigquery.BigQueryDataSourceCreator", online_store="datastore", ), + TestRepoConfig( + provider="aws", + offline_store_creator="tests.integration.feature_repos.universal.data_sources.redshift.RedshiftDataSourceCreator", + online_store={"type": "dynamodb", "region": "us-west-2"}, + ), ] @@ -52,8 +56,128 @@ class TestRepoConfig: PROVIDERS: List[str] = [] +@dataclass +class Environment: + name: str + test_repo_config: TestRepoConfig + feature_store: FeatureStore + data_source: DataSource + data_source_creator: DataSourceCreator + + end_date = datetime.now().replace(microsecond=0, second=0, minute=0) + start_date = end_date - timedelta(days=7) + before_start_date = end_date - timedelta(days=365) + after_end_date = end_date + timedelta(days=365) + + customer_entities = list(range(1001, 1110)) + customer_df = driver_test_data.create_customer_daily_profile_df( + customer_entities, start_date, end_date + ) + _customer_feature_view: Optional[FeatureView] = None + + driver_entities = list(range(5001, 5110)) + driver_df = driver_test_data.create_driver_hourly_stats_df( + driver_entities, start_date, end_date + ) + _driver_stats_feature_view: Optional[FeatureView] = None + + orders_df = driver_test_data.create_orders_df( + customers=customer_entities, + drivers=driver_entities, + start_date=before_start_date, + end_date=after_end_date, + order_count=1000, + ) + _orders_table: Optional[str] = None + + def customer_feature_view(self) -> FeatureView: + if self._customer_feature_view is None: + customer_table_id = self.data_source_creator.get_prefixed_table_name( + self.name, "customer_profile" + ) + ds = self.data_source_creator.create_data_sources( + customer_table_id, + self.customer_df, + event_timestamp_column="event_timestamp", + created_timestamp_column="created", + ) + self._customer_feature_view = create_customer_daily_profile_feature_view(ds) + return self._customer_feature_view + + def driver_stats_feature_view(self) -> FeatureView: + if self._driver_stats_feature_view is None: + driver_table_id = self.data_source_creator.get_prefixed_table_name( + self.name, "driver_hourly" + ) + ds = self.data_source_creator.create_data_sources( + driver_table_id, + self.driver_df, + event_timestamp_column="event_timestamp", + created_timestamp_column="created", + ) + self._driver_stats_feature_view = create_driver_hourly_stats_feature_view( + ds + ) + return self._driver_stats_feature_view + + def orders_table(self) -> Optional[str]: + if self._orders_table is None: + orders_table_id = self.data_source_creator.get_prefixed_table_name( + self.name, "orders" + ) + ds = self.data_source_creator.create_data_sources( + orders_table_id, + self.orders_df, + event_timestamp_column="event_timestamp", + created_timestamp_column="created", + ) + if hasattr(ds, "table_ref"): + self._orders_table = ds.table_ref + elif hasattr(ds, "table"): + self._orders_table = ds.table + return self._orders_table + + +def vary_full_feature_names(configs: List[TestRepoConfig]) -> List[TestRepoConfig]: + new_configs = [] + for c in configs: + true_c = replace(c, full_feature_names=True) + false_c = replace(c, full_feature_names=False) + new_configs.extend([true_c, false_c]) + return new_configs + + +def vary_infer_event_timestamp_col( + configs: List[TestRepoConfig], +) -> List[TestRepoConfig]: + new_configs = [] + for c in configs: + true_c = replace(c, infer_event_timestamp_col=True) + false_c = replace(c, infer_event_timestamp_col=False) + new_configs.extend([true_c, false_c]) + return new_configs + + +def vary_providers_for_offline_stores( + configs: List[TestRepoConfig], +) -> List[TestRepoConfig]: + new_configs = [] + for c in configs: + if "FileDataSourceCreator" in c.offline_store_creator: + new_configs.append(c) + elif "RedshiftDataSourceCreator" in c.offline_store_creator: + for p in ["local", "aws"]: + new_configs.append(replace(c, provider=p)) + elif "BigQueryDataSourceCreator" in c.offline_store_creator: + for p in ["local", "gcp"]: + new_configs.append(replace(c, provider=p)) + return new_configs + + @contextmanager -def construct_feature_store(test_repo_config: TestRepoConfig) -> FeatureStore: +def construct_test_environment( + test_repo_config: TestRepoConfig, create_and_apply: bool = False +) -> Environment: """ This method should take in the parameters from the test repo config and created a feature repo, apply it, and return the constructed feature store object to callers. @@ -74,8 +198,10 @@ def construct_feature_store(test_repo_config: TestRepoConfig) -> FeatureStore: offline_creator: DataSourceCreator = importer.get_class_from_type( module_name, config_class_name, "DataSourceCreator" - )() - ds = offline_creator.create_data_source(project, df) + )(project) + ds = offline_creator.create_data_sources( + project, df, field_mapping={"ts_1": "ts", "id": "driver_id"} + ) offline_store = offline_creator.create_offline_store_config() online_store = test_repo_config.online_store @@ -89,21 +215,76 @@ def construct_feature_store(test_repo_config: TestRepoConfig) -> FeatureStore: repo_path=repo_dir_name, ) fs = FeatureStore(config=config) - fv = correctness_feature_view(ds) - entity = driver() - fs.apply([fv, entity]) + environment = Environment( + name=project, + test_repo_config=test_repo_config, + feature_store=fs, + data_source=ds, + data_source_creator=offline_creator, + ) - yield fs + fvs = [] + entities = [] + try: + if create_and_apply: + entities.extend([driver(), customer()]) + fvs.extend( + [ + environment.driver_stats_feature_view(), + environment.customer_feature_view(), + ] + ) + fs.apply(fvs + entities) - fs.teardown() - offline_creator.teardown(project) + yield environment + finally: + offline_creator.teardown() + fs.teardown() def parametrize_e2e_test(e2e_test): + """ + This decorator should be used for end-to-end tests. These tests are expected to be parameterized, + and receive an empty feature repo created for all supported configurations. + + The decorator also ensures that sample data needed for the test is available in the relevant offline store. + + Decorated tests should create and apply the objects needed by the tests, and perform any operations needed + (such as materialization and looking up feature values). + + The decorator takes care of tearing down the feature store, as well as the sample data. + """ + + @pytest.mark.integration + @pytest.mark.parametrize("config", FULL_REPO_CONFIGS, ids=lambda v: str(v)) + def inner_test(config): + with construct_test_environment(config) as environment: + e2e_test(environment) + + return inner_test + + +def parametrize_offline_retrieval_test(offline_retrieval_test): + """ + This decorator should be used for end-to-end tests. These tests are expected to be parameterized, + and receive an empty feature repo created for all supported configurations. + + The decorator also ensures that sample data needed for the test is available in the relevant offline store. + + Decorated tests should create and apply the objects needed by the tests, and perform any operations needed + (such as materialization and looking up feature values). + + The decorator takes care of tearing down the feature store, as well as the sample data. + """ + + configs = vary_providers_for_offline_stores(FULL_REPO_CONFIGS) + configs = vary_full_feature_names(configs) + configs = vary_infer_event_timestamp_col(configs) + @pytest.mark.integration - @pytest.mark.parametrize("config", FULL_REPO_CONFIGS, ids=lambda v: v.provider) + @pytest.mark.parametrize("config", configs, ids=lambda v: str(v)) def inner_test(config): - with construct_feature_store(config) as fs: - e2e_test(fs) + with construct_test_environment(config, create_and_apply=True) as environment: + offline_retrieval_test(environment) return inner_test diff --git a/sdk/python/tests/integration/feature_repos/universal/data_source_creator.py b/sdk/python/tests/integration/feature_repos/universal/data_source_creator.py index b85aeeeb39..fa5293c06d 100644 --- a/sdk/python/tests/integration/feature_repos/universal/data_source_creator.py +++ b/sdk/python/tests/integration/feature_repos/universal/data_source_creator.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from typing import Dict import pandas as pd @@ -8,12 +9,13 @@ class DataSourceCreator(ABC): @abstractmethod - def create_data_source( + def create_data_sources( self, - name: str, + destination: str, df: pd.DataFrame, event_timestamp_column="ts", created_timestamp_column="created_ts", + field_mapping: Dict[str, str] = None, ) -> DataSource: ... @@ -22,5 +24,9 @@ def create_offline_store_config(self) -> FeastConfigBaseModel: ... @abstractmethod - def teardown(self, name: str): + def teardown(self): + ... + + @abstractmethod + def get_prefixed_table_name(self, name: str, suffix: str) -> str: ... diff --git a/sdk/python/tests/integration/feature_repos/universal/data_sources/bigquery.py b/sdk/python/tests/integration/feature_repos/universal/data_sources/bigquery.py index 7776b31e6d..74804e7512 100644 --- a/sdk/python/tests/integration/feature_repos/universal/data_sources/bigquery.py +++ b/sdk/python/tests/integration/feature_repos/universal/data_sources/bigquery.py @@ -1,4 +1,4 @@ -import time +from typing import Dict import pandas as pd from google.cloud import bigquery @@ -12,43 +12,62 @@ class BigQueryDataSourceCreator(DataSourceCreator): - def teardown(self, name: str): - pass - - def __init__(self): + def __init__(self, project_name: str): self.client = bigquery.Client() + self.project_name = project_name + self.gcp_project = self.client.project + self.dataset_id = f"{self.gcp_project}.{project_name}" + self.dataset = bigquery.Dataset(self.dataset_id) + print(f"Creating dataset: {self.dataset_id}") + self.client.create_dataset(self.dataset, exists_ok=True) + self.dataset.default_table_expiration_ms = ( + 1000 * 60 * 60 * 24 * 14 + ) # 2 weeks in milliseconds + self.client.update_dataset(self.dataset, ["default_table_expiration_ms"]) + + self.tables = [] + + def teardown(self): + + for table in self.tables: + self.client.delete_table(table, not_found_ok=True) + + self.client.delete_dataset( + self.dataset_id, delete_contents=True, not_found_ok=True + ) + print(f"Deleted dataset '{self.dataset_id}'") def create_offline_store_config(self): return BigQueryOfflineStoreConfig() - def create_data_source( + def create_data_sources( self, - name: str, + destination: str, df: pd.DataFrame, event_timestamp_column="ts", created_timestamp_column="created_ts", + field_mapping: Dict[str, str] = None, **kwargs, ) -> DataSource: - gcp_project = self.client.project - bigquery_dataset = "test_ingestion" - dataset = bigquery.Dataset(f"{gcp_project}.{bigquery_dataset}") - self.client.create_dataset(dataset, exists_ok=True) - dataset.default_table_expiration_ms = ( - 1000 * 60 * 60 * 24 * 14 - ) # 2 weeks in milliseconds - self.client.update_dataset(dataset, ["default_table_expiration_ms"]) job_config = bigquery.LoadJobConfig() - table_ref = f"{gcp_project}.{bigquery_dataset}.{name}_{int(time.time_ns())}" + if self.gcp_project not in destination: + destination = f"{self.gcp_project}.{self.project_name}.{destination}" + job = self.client.load_table_from_dataframe( - df, table_ref, job_config=job_config + df, destination, job_config=job_config ) job.result() + self.tables.append(destination) + return BigQuerySource( - table_ref=table_ref, + table_ref=destination, event_timestamp_column=event_timestamp_column, created_timestamp_column=created_timestamp_column, date_partition_column="", - field_mapping={"ts_1": "ts", "id": "driver_id"}, + field_mapping=field_mapping or {"ts_1": "ts"}, ) + + def get_prefixed_table_name(self, name: str, suffix: str) -> str: + return f"{self.client.project}.{name}.{suffix}" diff --git a/sdk/python/tests/integration/feature_repos/universal/data_sources/file.py b/sdk/python/tests/integration/feature_repos/universal/data_sources/file.py index 49618dd698..bb1957eb4a 100644 --- a/sdk/python/tests/integration/feature_repos/universal/data_sources/file.py +++ b/sdk/python/tests/integration/feature_repos/universal/data_sources/file.py @@ -1,5 +1,5 @@ import tempfile -from typing import Any +from typing import Any, Dict import pandas as pd @@ -16,12 +16,16 @@ class FileDataSourceCreator(DataSourceCreator): f: Any - def create_data_source( + def __init__(self, _: str): + pass + + def create_data_sources( self, - name: str, + destination: str, df: pd.DataFrame, event_timestamp_column="ts", created_timestamp_column="created_ts", + field_mapping: Dict[str, str] = None, ) -> DataSource: self.f = tempfile.NamedTemporaryFile(suffix=".parquet", delete=False) df.to_parquet(self.f.name) @@ -31,11 +35,14 @@ def create_data_source( event_timestamp_column=event_timestamp_column, created_timestamp_column=created_timestamp_column, date_partition_column="", - field_mapping={"ts_1": "ts", "id": "driver_id"}, + field_mapping=field_mapping or {"ts_1": "ts", "id": "driver_id"}, ) + def get_prefixed_table_name(self, name: str, suffix: str) -> str: + return f"{name}.{suffix}" + def create_offline_store_config(self) -> FeastConfigBaseModel: return FileOfflineStoreConfig() - def teardown(self, name: str): + def teardown(self): self.f.close() diff --git a/sdk/python/tests/integration/feature_repos/universal/data_sources/redshift.py b/sdk/python/tests/integration/feature_repos/universal/data_sources/redshift.py index f34490fc8a..f731b60bb3 100644 --- a/sdk/python/tests/integration/feature_repos/universal/data_sources/redshift.py +++ b/sdk/python/tests/integration/feature_repos/universal/data_sources/redshift.py @@ -1,6 +1,4 @@ -import random -import time -from typing import Optional +from typing import Dict import pandas as pd @@ -16,11 +14,11 @@ class RedshiftDataSourceCreator(DataSourceCreator): - table_name: Optional[str] = None - redshift_source: Optional[RedshiftSource] = None + tables = [] - def __init__(self) -> None: + def __init__(self, project_name: str): super().__init__() + self.project_name = project_name self.client = aws_utils.get_redshift_data_client("us-west-2") self.s3 = aws_utils.get_s3_resource("us-west-2") @@ -33,44 +31,49 @@ def __init__(self) -> None: iam_role="arn:aws:iam::402087665549:role/redshift_s3_access_role", ) - def create_data_source( + def create_data_sources( self, - name: str, + destination: str, df: pd.DataFrame, event_timestamp_column="ts", created_timestamp_column="created_ts", + field_mapping: Dict[str, str] = None, ) -> DataSource: - self.table_name = f"{name}_{time.time_ns()}_{random.randint(1000, 9999)}" + aws_utils.upload_df_to_redshift( self.client, self.offline_store_config.cluster_id, self.offline_store_config.database, self.offline_store_config.user, self.s3, - f"{self.offline_store_config.s3_staging_location}/copy/{self.table_name}.parquet", + f"{self.offline_store_config.s3_staging_location}/copy/{destination}.parquet", self.offline_store_config.iam_role, - self.table_name, + destination, df, ) - self.redshift_source = RedshiftSource( - table=self.table_name, + self.tables.append(destination) + + return RedshiftSource( + table=destination, event_timestamp_column=event_timestamp_column, created_timestamp_column=created_timestamp_column, date_partition_column="", - field_mapping={"ts_1": "ts", "id": "driver_id"}, + field_mapping=field_mapping or {"ts_1": "ts"}, ) - return self.redshift_source def create_offline_store_config(self) -> FeastConfigBaseModel: return self.offline_store_config - def teardown(self, name: str): - if self.table_name: + def get_prefixed_table_name(self, name: str, suffix: str) -> str: + return f"{name}_{suffix}" + + def teardown(self): + for table in self.tables: aws_utils.execute_redshift_statement( self.client, self.offline_store_config.cluster_id, self.offline_store_config.database, self.offline_store_config.user, - f"DROP TABLE {self.table_name}", + f"DROP TABLE IF EXISTS {table}", ) diff --git a/sdk/python/tests/integration/feature_repos/universal/entities.py b/sdk/python/tests/integration/feature_repos/universal/entities.py index 9b4352eb83..1db362043b 100644 --- a/sdk/python/tests/integration/feature_repos/universal/entities.py +++ b/sdk/python/tests/integration/feature_repos/universal/entities.py @@ -8,3 +8,7 @@ def driver(): description="driver id", join_key="driver_id", ) + + +def customer(): + return Entity(name="customer_id", value_type=ValueType.INT64) diff --git a/sdk/python/tests/integration/feature_repos/universal/feature_views.py b/sdk/python/tests/integration/feature_repos/universal/feature_views.py index 94c80bb84c..b5a120f453 100644 --- a/sdk/python/tests/integration/feature_repos/universal/feature_views.py +++ b/sdk/python/tests/integration/feature_repos/universal/feature_views.py @@ -4,11 +4,44 @@ from feast.data_source import DataSource -def correctness_feature_view(data_source: DataSource) -> FeatureView: +def driver_feature_view( + data_source: DataSource, name="test_correctness" +) -> FeatureView: return FeatureView( - name="test_correctness", + name=name, entities=["driver"], features=[Feature("value", ValueType.FLOAT)], ttl=timedelta(days=5), input=data_source, ) + + +def create_driver_hourly_stats_feature_view(source): + driver_stats_feature_view = FeatureView( + name="driver_stats", + entities=["driver"], + features=[ + Feature(name="conv_rate", dtype=ValueType.FLOAT), + Feature(name="acc_rate", dtype=ValueType.FLOAT), + Feature(name="avg_daily_trips", dtype=ValueType.INT32), + ], + batch_source=source, + ttl=timedelta(hours=2), + ) + return driver_stats_feature_view + + +def create_customer_daily_profile_feature_view(source): + customer_profile_feature_view = FeatureView( + name="customer_profile", + entities=["customer_id"], + features=[ + Feature(name="current_balance", dtype=ValueType.FLOAT), + Feature(name="avg_passenger_count", dtype=ValueType.FLOAT), + Feature(name="lifetime_trip_count", dtype=ValueType.INT32), + Feature(name="avg_daily_trips", dtype=ValueType.INT32), + ], + batch_source=source, + ttl=timedelta(days=2), + ) + return customer_profile_feature_view diff --git a/sdk/python/tests/integration/materialization/test_offline_online_store_consistency.py b/sdk/python/tests/integration/materialization/test_offline_online_store_consistency.py index 71954611a3..e97f7d0d84 100644 --- a/sdk/python/tests/integration/materialization/test_offline_online_store_consistency.py +++ b/sdk/python/tests/integration/materialization/test_offline_online_store_consistency.py @@ -28,9 +28,7 @@ from feast.repo_config import RepoConfig from feast.value_type import ValueType from tests.data.data_creator import create_dataset -from tests.integration.feature_repos.universal.feature_views import ( - correctness_feature_view, -) +from tests.integration.feature_repos.universal.feature_views import driver_feature_view @contextlib.contextmanager @@ -64,7 +62,7 @@ def prep_bq_fs_and_fv( field_mapping={"ts_1": "ts", "id": "driver_id"}, ) - fv = correctness_feature_view(bigquery_source) + fv = driver_feature_view(bigquery_source) e = Entity( name="driver", description="id for driver", @@ -127,7 +125,7 @@ def prep_redshift_fs_and_fv( field_mapping={"ts_1": "ts", "id": "driver_id"}, ) - fv = correctness_feature_view(redshift_source) + fv = driver_feature_view(redshift_source) e = Entity( name="driver", description="id for driver", @@ -175,7 +173,7 @@ def prep_local_fs_and_fv() -> Iterator[Tuple[FeatureStore, FeatureView]]: date_partition_column="", field_mapping={"ts_1": "ts", "id": "driver_id"}, ) - fv = correctness_feature_view(file_source) + fv = driver_feature_view(file_source) e = Entity( name="driver", description="id for driver", @@ -216,7 +214,7 @@ def prep_redis_fs_and_fv() -> Iterator[Tuple[FeatureStore, FeatureView]]: date_partition_column="", field_mapping={"ts_1": "ts", "id": "driver_id"}, ) - fv = correctness_feature_view(file_source) + fv = driver_feature_view(file_source) e = Entity( name="driver", description="id for driver", @@ -258,7 +256,7 @@ def prep_dynamodb_fs_and_fv() -> Iterator[Tuple[FeatureStore, FeatureView]]: date_partition_column="", field_mapping={"ts_1": "ts", "id": "driver_id"}, ) - fv = correctness_feature_view(file_source) + fv = driver_feature_view(file_source) e = Entity( name="driver", description="id for driver", diff --git a/sdk/python/tests/integration/offline_store/test_historical_retrieval.py b/sdk/python/tests/integration/offline_store/test_historical_retrieval.py index ddaaee9282..c9ee7c210b 100644 --- a/sdk/python/tests/integration/offline_store/test_historical_retrieval.py +++ b/sdk/python/tests/integration/offline_store/test_historical_retrieval.py @@ -257,6 +257,7 @@ def __enter__(self): client = bigquery.Client() dataset = bigquery.Dataset(f"{client.project}.{self.name}") dataset.location = "US" + print(f"Creating dataset: {dataset}") dataset = client.create_dataset(dataset, exists_ok=True) return dataset diff --git a/sdk/python/tests/integration/offline_store/test_universal_historical_retrieval.py b/sdk/python/tests/integration/offline_store/test_universal_historical_retrieval.py new file mode 100644 index 0000000000..7379c27a62 --- /dev/null +++ b/sdk/python/tests/integration/offline_store/test_universal_historical_retrieval.py @@ -0,0 +1,268 @@ +from datetime import datetime +from typing import Any, Dict, List + +import numpy as np +import pandas as pd +from pandas.testing import assert_frame_equal +from pytz import utc + +from feast import utils +from feast.feature_view import FeatureView +from feast.infra.offline_stores.offline_utils import ( + DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL, +) +from tests.integration.feature_repos.test_repo_configuration import ( + Environment, + parametrize_offline_retrieval_test, +) + +np.random.seed(0) + + +def convert_timestamp_records_to_utc( + records: List[Dict[str, Any]], column: str +) -> List[Dict[str, Any]]: + for record in records: + record[column] = utils.make_tzaware(record[column]).astimezone(utc) + return records + + +# Find the latest record in the given time range and filter +def find_asof_record( + records: List[Dict[str, Any]], + ts_key: str, + ts_start: datetime, + ts_end: datetime, + filter_key: str, + filter_value: Any, +) -> Dict[str, Any]: + found_record = {} + for record in records: + if record[filter_key] == filter_value and ts_start <= record[ts_key] <= ts_end: + if not found_record or found_record[ts_key] < record[ts_key]: + found_record = record + return found_record + + +def get_expected_training_df( + customer_df: pd.DataFrame, + customer_fv: FeatureView, + driver_df: pd.DataFrame, + driver_fv: FeatureView, + orders_df: pd.DataFrame, + event_timestamp: str, + full_feature_names: bool = False, +): + # Convert all pandas dataframes into records with UTC timestamps + order_records = convert_timestamp_records_to_utc( + orders_df.to_dict("records"), event_timestamp + ) + driver_records = convert_timestamp_records_to_utc( + driver_df.to_dict("records"), driver_fv.batch_source.event_timestamp_column + ) + customer_records = convert_timestamp_records_to_utc( + customer_df.to_dict("records"), customer_fv.batch_source.event_timestamp_column + ) + + # Manually do point-in-time join of orders to drivers and customers records + for order_record in order_records: + driver_record = find_asof_record( + driver_records, + ts_key=driver_fv.batch_source.event_timestamp_column, + ts_start=order_record[event_timestamp] - driver_fv.ttl, + ts_end=order_record[event_timestamp], + filter_key="driver_id", + filter_value=order_record["driver_id"], + ) + customer_record = find_asof_record( + customer_records, + ts_key=customer_fv.batch_source.event_timestamp_column, + ts_start=order_record[event_timestamp] - customer_fv.ttl, + ts_end=order_record[event_timestamp], + filter_key="customer_id", + filter_value=order_record["customer_id"], + ) + + order_record.update( + { + (f"driver_stats__{k}" if full_feature_names else k): driver_record.get( + k, None + ) + for k in ("conv_rate", "avg_daily_trips") + } + ) + + order_record.update( + { + ( + f"customer_profile__{k}" if full_feature_names else k + ): customer_record.get(k, None) + for k in ( + "current_balance", + "avg_passenger_count", + "lifetime_trip_count", + ) + } + ) + + # Convert records back to pandas dataframe + expected_df = pd.DataFrame(order_records) + + # Move "event_timestamp" column to front + current_cols = expected_df.columns.tolist() + current_cols.remove(event_timestamp) + expected_df = expected_df[[event_timestamp] + current_cols] + + # Cast some columns to expected types, since we lose information when converting pandas DFs into Python objects. + if full_feature_names: + expected_column_types = { + "order_is_success": "int32", + "driver_stats__conv_rate": "float32", + "customer_profile__current_balance": "float32", + "customer_profile__avg_passenger_count": "float32", + } + else: + expected_column_types = { + "order_is_success": "int32", + "conv_rate": "float32", + "current_balance": "float32", + "avg_passenger_count": "float32", + } + + for col, typ in expected_column_types.items(): + expected_df[col] = expected_df[col].astype(typ) + + return expected_df + + +@parametrize_offline_retrieval_test +def test_historical_features(environment: Environment): + store = environment.feature_store + + customer_df, customer_fv = ( + environment.customer_df, + environment.customer_feature_view(), + ) + driver_df, driver_fv = ( + environment.driver_df, + environment.driver_stats_feature_view(), + ) + orders_df = environment.orders_df + full_feature_names = environment.test_repo_config.full_feature_names + + entity_df_query = None + if environment.orders_table(): + entity_df_query = f"SELECT * FROM {environment.orders_table()}" + + event_timestamp = ( + DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL + if DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL in orders_df.columns + else "e_ts" + ) + expected_df = get_expected_training_df( + customer_df, + customer_fv, + driver_df, + driver_fv, + orders_df, + event_timestamp, + full_feature_names, + ) + + if entity_df_query: + job_from_sql = store.get_historical_features( + entity_df=entity_df_query, + features=[ + "driver_stats:conv_rate", + "driver_stats:avg_daily_trips", + "customer_profile:current_balance", + "customer_profile:avg_passenger_count", + "customer_profile:lifetime_trip_count", + ], + full_feature_names=full_feature_names, + ) + + start_time = datetime.utcnow() + actual_df_from_sql_entities = job_from_sql.to_df() + end_time = datetime.utcnow() + print( + str(f"\nTime to execute job_from_sql.to_df() = '{(end_time - start_time)}'") + ) + + assert sorted(expected_df.columns) == sorted( + actual_df_from_sql_entities.columns + ) + + actual_df_from_sql_entities = ( + actual_df_from_sql_entities[expected_df.columns] + .sort_values(by=[event_timestamp, "order_id", "driver_id", "customer_id"]) + .drop_duplicates() + .reset_index(drop=True) + ) + expected_df = ( + expected_df.sort_values( + by=[event_timestamp, "order_id", "driver_id", "customer_id"] + ) + .drop_duplicates() + .reset_index(drop=True) + ) + + assert_frame_equal( + actual_df_from_sql_entities, expected_df, check_dtype=False, + ) + + table_from_sql_entities = job_from_sql.to_arrow() + df_from_sql_entities = ( + table_from_sql_entities.to_pandas()[expected_df.columns] + .sort_values(by=[event_timestamp, "order_id", "driver_id", "customer_id"]) + .drop_duplicates() + .reset_index(drop=True) + ) + assert_frame_equal(actual_df_from_sql_entities, df_from_sql_entities) + + job_from_df = store.get_historical_features( + entity_df=orders_df, + features=[ + "driver_stats:conv_rate", + "driver_stats:avg_daily_trips", + "customer_profile:current_balance", + "customer_profile:avg_passenger_count", + "customer_profile:lifetime_trip_count", + ], + full_feature_names=full_feature_names, + ) + + start_time = datetime.utcnow() + actual_df_from_df_entities = job_from_df.to_df() + + print(f"actual_df_from_df_entities shape: {actual_df_from_df_entities.shape}") + end_time = datetime.utcnow() + print(str(f"Time to execute job_from_df.to_df() = '{(end_time - start_time)}'\n")) + + assert sorted(expected_df.columns) == sorted(actual_df_from_df_entities.columns) + expected_df = ( + expected_df.sort_values( + by=[event_timestamp, "order_id", "driver_id", "customer_id"] + ) + .drop_duplicates() + .reset_index(drop=True) + ) + actual_df_from_df_entities = ( + actual_df_from_df_entities[expected_df.columns] + .sort_values(by=[event_timestamp, "order_id", "driver_id", "customer_id"]) + .drop_duplicates() + .reset_index(drop=True) + ) + + assert_frame_equal( + expected_df, actual_df_from_df_entities, check_dtype=False, + ) + + table_from_df_entities = job_from_df.to_arrow().to_pandas() + table_from_df_entities = ( + table_from_df_entities[expected_df.columns] + .sort_values(by=[event_timestamp, "order_id", "driver_id", "customer_id"]) + .drop_duplicates() + .reset_index(drop=True) + ) + assert_frame_equal(actual_df_from_df_entities, table_from_df_entities)