From 5a49470a63865e20e8af939b263466c3eea4e10c Mon Sep 17 00:00:00 2001 From: Matt Delacour Date: Mon, 28 Jun 2021 22:16:16 -0400 Subject: [PATCH] Add to_table() to RetrievalJob object (#1663) * Add notion of OfflineJob Signed-off-by: Matt Delacour * Use RetrievalJob instead of creating a new OfflineJob object Signed-off-by: Matt Delacour * Add to_table() in integration tests Signed-off-by: Matt Delacour Co-authored-by: Tsotne Tabidze --- sdk/python/feast/feature_store.py | 26 ++++--- sdk/python/feast/infra/gcp.py | 5 +- sdk/python/feast/infra/local.py | 5 +- .../feast/infra/offline_stores/bigquery.py | 35 ++++----- sdk/python/feast/infra/offline_stores/file.py | 76 +++++++++++-------- .../infra/offline_stores/offline_store.py | 10 ++- .../feast/infra/offline_stores/redshift.py | 4 +- sdk/python/feast/infra/provider.py | 1 + sdk/python/tests/foo_provider.py | 1 + sdk/python/tests/test_historical_retrieval.py | 9 +++ 10 files changed, 101 insertions(+), 71 deletions(-) diff --git a/sdk/python/feast/feature_store.py b/sdk/python/feast/feature_store.py index 8440ce9b50..3d5dc9fcc7 100644 --- a/sdk/python/feast/feature_store.py +++ b/sdk/python/feast/feature_store.py @@ -386,12 +386,13 @@ def tqdm_builder(length): end_date = utils.make_tzaware(end_date) provider.materialize_single_feature_view( - feature_view, - start_date, - end_date, - self._registry, - self.project, - tqdm_builder, + config=self.config, + feature_view=feature_view, + start_date=start_date, + end_date=end_date, + registry=self._registry, + project=self.project, + tqdm_builder=tqdm_builder, ) self._registry.apply_materialization( @@ -464,12 +465,13 @@ def tqdm_builder(length): end_date = utils.make_tzaware(end_date) provider.materialize_single_feature_view( - feature_view, - start_date, - end_date, - self._registry, - self.project, - tqdm_builder, + config=self.config, + feature_view=feature_view, + start_date=start_date, + end_date=end_date, + registry=self._registry, + project=self.project, + tqdm_builder=tqdm_builder, ) self._registry.apply_materialization( diff --git a/sdk/python/feast/infra/gcp.py b/sdk/python/feast/infra/gcp.py index f33b501d62..a75aeb5ddb 100644 --- a/sdk/python/feast/infra/gcp.py +++ b/sdk/python/feast/infra/gcp.py @@ -81,6 +81,7 @@ def online_read( def materialize_single_feature_view( self, + config: RepoConfig, feature_view: FeatureView, start_date: datetime, end_date: datetime, @@ -99,7 +100,8 @@ def materialize_single_feature_view( created_timestamp_column, ) = _get_column_names(feature_view, entities) - table = self.offline_store.pull_latest_from_table_or_query( + offline_job = self.offline_store.pull_latest_from_table_or_query( + config=config, data_source=feature_view.input, join_key_columns=join_key_columns, feature_name_columns=feature_name_columns, @@ -108,6 +110,7 @@ def materialize_single_feature_view( start_date=start_date, end_date=end_date, ) + table = offline_job.to_table() if feature_view.input.field_mapping is not None: table = _run_field_mapping(table, feature_view.input.field_mapping) diff --git a/sdk/python/feast/infra/local.py b/sdk/python/feast/infra/local.py index a76f49b2c4..d04fe8d740 100644 --- a/sdk/python/feast/infra/local.py +++ b/sdk/python/feast/infra/local.py @@ -80,6 +80,7 @@ def online_read( def materialize_single_feature_view( self, + config: RepoConfig, feature_view: FeatureView, start_date: datetime, end_date: datetime, @@ -98,7 +99,7 @@ def materialize_single_feature_view( created_timestamp_column, ) = _get_column_names(feature_view, entities) - table = self.offline_store.pull_latest_from_table_or_query( + offline_job = self.offline_store.pull_latest_from_table_or_query( data_source=feature_view.input, join_key_columns=join_key_columns, feature_name_columns=feature_name_columns, @@ -106,7 +107,9 @@ def materialize_single_feature_view( created_timestamp_column=created_timestamp_column, start_date=start_date, end_date=end_date, + config=config, ) + table = offline_job.to_table() if feature_view.input.field_mapping is not None: table = _run_field_mapping(table, feature_view.input.field_mapping) diff --git a/sdk/python/feast/infra/offline_stores/bigquery.py b/sdk/python/feast/infra/offline_stores/bigquery.py index 11da1dff4b..0c961a5048 100644 --- a/sdk/python/feast/infra/offline_stores/bigquery.py +++ b/sdk/python/feast/infra/offline_stores/bigquery.py @@ -15,10 +15,9 @@ from feast.data_source import BigQuerySource, DataSource from feast.errors import FeastProviderLoginError from feast.feature_view import FeatureView -from feast.infra.offline_stores.offline_store import OfflineStore +from feast.infra.offline_stores.offline_store import OfflineStore, RetrievalJob from feast.infra.provider import ( DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL, - RetrievalJob, _get_requested_feature_views_to_features_dict, ) from feast.registry import Registry @@ -52,6 +51,7 @@ class BigQueryOfflineStoreConfig(FeastConfigBaseModel): class BigQueryOfflineStore(OfflineStore): @staticmethod def pull_latest_from_table_or_query( + config: RepoConfig, data_source: DataSource, join_key_columns: List[str], feature_name_columns: List[str], @@ -59,7 +59,7 @@ def pull_latest_from_table_or_query( created_timestamp_column: Optional[str], start_date: datetime, end_date: datetime, - ) -> pyarrow.Table: + ) -> RetrievalJob: assert isinstance(data_source, BigQuerySource) from_expression = data_source.get_table_query_string() @@ -74,6 +74,7 @@ def pull_latest_from_table_or_query( timestamp_desc_string = " DESC, ".join(timestamps) + " DESC" field_string = ", ".join(join_key_columns + feature_name_columns + timestamps) + client = _get_bigquery_client(project=config.offline_store.project_id) query = f""" SELECT {field_string} FROM ( @@ -84,14 +85,7 @@ def pull_latest_from_table_or_query( ) WHERE _feast_row = 1 """ - - return BigQueryOfflineStore._pull_query(query) - - @staticmethod - def _pull_query(query: str) -> pyarrow.Table: - client = _get_bigquery_client() - query_job = client.query(query) - return query_job.to_arrow() + return BigQueryRetrievalJob(query=query, client=client, config=config) @staticmethod def get_historical_features( @@ -103,19 +97,18 @@ def get_historical_features( project: str, ) -> RetrievalJob: # TODO: Add entity_df validation in order to fail before interacting with BigQuery + assert isinstance(config.offline_store, BigQueryOfflineStoreConfig) - client = _get_bigquery_client() - + client = _get_bigquery_client(project=config.offline_store.project_id) expected_join_keys = _get_join_keys(project, feature_views, registry) assert isinstance(config.offline_store, BigQueryOfflineStoreConfig) - dataset_project = config.offline_store.project_id or client.project table = _upload_entity_df_into_bigquery( client=client, project=config.project, dataset_name=config.offline_store.dataset, - dataset_project=dataset_project, + dataset_project=client.project, entity_df=entity_df, ) @@ -263,10 +256,7 @@ def _block_until_done(): if not job_config: today = date.today().strftime("%Y%m%d") rand_id = str(uuid.uuid4())[:7] - dataset_project = ( - self.config.offline_store.project_id or self.client.project - ) - path = f"{dataset_project}.{self.config.offline_store.dataset}.historical_{today}_{rand_id}" + path = f"{self.client.project}.{self.config.offline_store.dataset}.historical_{today}_{rand_id}" job_config = bigquery.QueryJobConfig(destination=path) bq_job = self.client.query(self.query, job_config=job_config) @@ -285,6 +275,9 @@ def _block_until_done(): print(f"Done writing to '{job_config.destination}'.") return str(job_config.destination) + def to_table(self) -> pyarrow.Table: + return self.client.query(self.query).to_arrow() + @dataclass(frozen=True) class FeatureViewQueryContext: @@ -446,9 +439,9 @@ def build_point_in_time_query( return query -def _get_bigquery_client(): +def _get_bigquery_client(project: Optional[str] = None): try: - client = bigquery.Client() + client = bigquery.Client(project=project) except DefaultCredentialsError as e: raise FeastProviderLoginError( str(e) diff --git a/sdk/python/feast/infra/offline_stores/file.py b/sdk/python/feast/infra/offline_stores/file.py index 70b3372665..c61162f81f 100644 --- a/sdk/python/feast/infra/offline_stores/file.py +++ b/sdk/python/feast/infra/offline_stores/file.py @@ -38,6 +38,11 @@ def to_df(self): df = self.evaluation_function() return df + def to_table(self): + # Only execute the evaluation function to build the final historical retrieval dataframe at the last moment. + df = self.evaluation_function() + return pyarrow.Table.from_pandas(df) + class FileOfflineStore(OfflineStore): @staticmethod @@ -48,7 +53,7 @@ def get_historical_features( entity_df: Union[pd.DataFrame, str], registry: Registry, project: str, - ) -> FileRetrievalJob: + ) -> RetrievalJob: if not isinstance(entity_df, pd.DataFrame): raise ValueError( f"Please provide an entity_df of type {type(pd.DataFrame)} instead of type {type(entity_df)}" @@ -205,6 +210,7 @@ def evaluate_historical_retrieval(): @staticmethod def pull_latest_from_table_or_query( + config: RepoConfig, data_source: DataSource, join_key_columns: List[str], feature_name_columns: List[str], @@ -212,42 +218,48 @@ def pull_latest_from_table_or_query( created_timestamp_column: Optional[str], start_date: datetime, end_date: datetime, - ) -> pyarrow.Table: + ) -> RetrievalJob: assert isinstance(data_source, FileSource) - source_df = pd.read_parquet(data_source.path) - # Make sure all timestamp fields are tz-aware. We default tz-naive fields to UTC - source_df[event_timestamp_column] = source_df[event_timestamp_column].apply( - lambda x: x if x.tzinfo is not None else x.replace(tzinfo=pytz.utc) - ) - if created_timestamp_column: - source_df[created_timestamp_column] = source_df[ - created_timestamp_column - ].apply(lambda x: x if x.tzinfo is not None else x.replace(tzinfo=pytz.utc)) - - source_columns = set(source_df.columns) - if not set(join_key_columns).issubset(source_columns): - raise FeastJoinKeysDuringMaterialization( - data_source.path, set(join_key_columns), source_columns + # Create lazy function that is only called from the RetrievalJob object + def evaluate_offline_job(): + source_df = pd.read_parquet(data_source.path) + # Make sure all timestamp fields are tz-aware. We default tz-naive fields to UTC + source_df[event_timestamp_column] = source_df[event_timestamp_column].apply( + lambda x: x if x.tzinfo is not None else x.replace(tzinfo=pytz.utc) ) + if created_timestamp_column: + source_df[created_timestamp_column] = source_df[ + created_timestamp_column + ].apply( + lambda x: x if x.tzinfo is not None else x.replace(tzinfo=pytz.utc) + ) - ts_columns = ( - [event_timestamp_column, created_timestamp_column] - if created_timestamp_column - else [event_timestamp_column] - ) + source_columns = set(source_df.columns) + if not set(join_key_columns).issubset(source_columns): + raise FeastJoinKeysDuringMaterialization( + data_source.path, set(join_key_columns), source_columns + ) - source_df.sort_values(by=ts_columns, inplace=True) + ts_columns = ( + [event_timestamp_column, created_timestamp_column] + if created_timestamp_column + else [event_timestamp_column] + ) - filtered_df = source_df[ - (source_df[event_timestamp_column] >= start_date) - & (source_df[event_timestamp_column] < end_date) - ] - last_values_df = filtered_df.drop_duplicates( - join_key_columns, keep="last", ignore_index=True - ) + source_df.sort_values(by=ts_columns, inplace=True) - columns_to_extract = set(join_key_columns + feature_name_columns + ts_columns) - table = pyarrow.Table.from_pandas(last_values_df[columns_to_extract]) + filtered_df = source_df[ + (source_df[event_timestamp_column] >= start_date) + & (source_df[event_timestamp_column] < end_date) + ] + last_values_df = filtered_df.drop_duplicates( + join_key_columns, keep="last", ignore_index=True + ) + + columns_to_extract = set( + join_key_columns + feature_name_columns + ts_columns + ) + return last_values_df[columns_to_extract] - return table + return FileRetrievalJob(evaluation_function=evaluate_offline_job) diff --git a/sdk/python/feast/infra/offline_stores/offline_store.py b/sdk/python/feast/infra/offline_stores/offline_store.py index d31d11aae2..6e2394b44b 100644 --- a/sdk/python/feast/infra/offline_stores/offline_store.py +++ b/sdk/python/feast/infra/offline_stores/offline_store.py @@ -28,10 +28,15 @@ class RetrievalJob(ABC): """RetrievalJob is used to manage the execution of a historical feature retrieval""" @abstractmethod - def to_df(self): + def to_df(self) -> pd.DataFrame: """Return dataset as Pandas DataFrame synchronously""" pass + @abstractmethod + def to_table(self) -> pyarrow.Table: + """Return dataset as pyarrow Table synchronously""" + pass + class OfflineStore(ABC): """ @@ -42,6 +47,7 @@ class OfflineStore(ABC): @staticmethod @abstractmethod def pull_latest_from_table_or_query( + config: RepoConfig, data_source: DataSource, join_key_columns: List[str], feature_name_columns: List[str], @@ -49,7 +55,7 @@ def pull_latest_from_table_or_query( created_timestamp_column: Optional[str], start_date: datetime, end_date: datetime, - ) -> pyarrow.Table: + ) -> RetrievalJob: """ Note that join_key_columns, feature_name_columns, event_timestamp_column, and created_timestamp_column have all already been mapped to column names of the source table and those column names are the values passed diff --git a/sdk/python/feast/infra/offline_stores/redshift.py b/sdk/python/feast/infra/offline_stores/redshift.py index 06a437564a..f15b9af451 100644 --- a/sdk/python/feast/infra/offline_stores/redshift.py +++ b/sdk/python/feast/infra/offline_stores/redshift.py @@ -2,7 +2,6 @@ from typing import List, Optional, Union import pandas as pd -import pyarrow from pydantic import StrictStr from pydantic.typing import Literal @@ -38,6 +37,7 @@ class RedshiftOfflineStoreConfig(FeastConfigBaseModel): class RedshiftOfflineStore(OfflineStore): @staticmethod def pull_latest_from_table_or_query( + config: RepoConfig, data_source: DataSource, join_key_columns: List[str], feature_name_columns: List[str], @@ -45,7 +45,7 @@ def pull_latest_from_table_or_query( created_timestamp_column: Optional[str], start_date: datetime, end_date: datetime, - ) -> pyarrow.Table: + ) -> RetrievalJob: pass @staticmethod diff --git a/sdk/python/feast/infra/provider.py b/sdk/python/feast/infra/provider.py index 905d0fd1dc..8b92374d23 100644 --- a/sdk/python/feast/infra/provider.py +++ b/sdk/python/feast/infra/provider.py @@ -97,6 +97,7 @@ def online_write_batch( @abc.abstractmethod def materialize_single_feature_view( self, + config: RepoConfig, feature_view: FeatureView, start_date: datetime, end_date: datetime, diff --git a/sdk/python/tests/foo_provider.py b/sdk/python/tests/foo_provider.py index 38367c3179..22ae294603 100644 --- a/sdk/python/tests/foo_provider.py +++ b/sdk/python/tests/foo_provider.py @@ -45,6 +45,7 @@ def online_write_batch( def materialize_single_feature_view( self, + config: RepoConfig, feature_view: FeatureView, start_date: datetime, end_date: datetime, diff --git a/sdk/python/tests/test_historical_retrieval.py b/sdk/python/tests/test_historical_retrieval.py index 4acad10858..72cf67b178 100644 --- a/sdk/python/tests/test_historical_retrieval.py +++ b/sdk/python/tests/test_historical_retrieval.py @@ -461,6 +461,11 @@ def test_historical_features_from_bigquery_sources( check_dtype=False, ) + table_from_sql_entities = job_from_sql.to_table() + assert_frame_equal( + actual_df_from_sql_entities, table_from_sql_entities.to_pandas() + ) + timestamp_column = ( "e_ts" if infer_event_timestamp_col @@ -541,3 +546,7 @@ def test_historical_features_from_bigquery_sources( .reset_index(drop=True), check_dtype=False, ) + table_from_df_entities = job_from_df.to_table() + assert_frame_equal( + actual_df_from_df_entities, table_from_df_entities.to_pandas() + )