diff --git a/sdk/python/feast/infra/offline_stores/bigquery.py b/sdk/python/feast/infra/offline_stores/bigquery.py index 00b4c06a24..acd828bf8f 100644 --- a/sdk/python/feast/infra/offline_stores/bigquery.py +++ b/sdk/python/feast/infra/offline_stores/bigquery.py @@ -7,6 +7,7 @@ import pandas import pyarrow from jinja2 import BaseLoader, Environment +from pandas import Timestamp from pydantic import StrictStr from pydantic.typing import Literal from tenacity import retry, stop_after_delay, wait_fixed @@ -129,12 +130,16 @@ def get_historical_features( full_feature_names=full_feature_names, ) - # TODO: Infer min_timestamp and max_timestamp from entity_df + # Infer min and max timestamps from entity_df to limit data read in BigQuery SQL query + min_timestamp, max_timestamp = _get_entity_df_timestamp_bounds( + client, str(table.reference), entity_df_event_timestamp_col + ) + # Generate the BigQuery SQL query from the query context query = build_point_in_time_query( query_context, - min_timestamp=datetime.now() - timedelta(days=365), - max_timestamp=datetime.now() + timedelta(days=1), + min_timestamp=min_timestamp, + max_timestamp=max_timestamp, left_table_query_string=str(table.reference), entity_df_event_timestamp_col=entity_df_event_timestamp_col, full_feature_names=full_feature_names, @@ -374,6 +379,28 @@ def _upload_entity_df_into_bigquery( return table +def _get_entity_df_timestamp_bounds( + client: Client, entity_df_bq_table: str, event_timestamp_col: str, +): + + boundary_df = ( + client.query( + f""" + SELECT + MIN({event_timestamp_col}) AS min_timestamp, + MAX({event_timestamp_col}) AS max_timestamp + FROM {entity_df_bq_table} + """ + ) + .result() + .to_dataframe() + ) + + min_timestamp = boundary_df.loc[0, "min_timestamp"] + max_timestamp = boundary_df.loc[0, "max_timestamp"] + return min_timestamp, max_timestamp + + def get_feature_view_query_context( feature_refs: List[str], feature_views: List[FeatureView], @@ -435,8 +462,8 @@ def get_feature_view_query_context( def build_point_in_time_query( feature_view_query_contexts: List[FeatureViewQueryContext], - min_timestamp: datetime, - max_timestamp: datetime, + min_timestamp: Timestamp, + max_timestamp: Timestamp, left_table_query_string: str, entity_df_event_timestamp_col: str, full_feature_names: bool = False, @@ -533,6 +560,10 @@ def _get_bigquery_client(project: Optional[str] = None): {{ feature }} as {% if full_feature_names %}{{ featureview.name }}__{{feature}}{% else %}{{ feature }}{% endif %}{% if loop.last %}{% else %}, {% endif %} {% endfor %} FROM {{ featureview.table_subquery }} + WHERE {{ featureview.event_timestamp_column }} <= '{{max_timestamp}}' + {% if featureview.ttl == 0 %}{% else %} + AND {{ featureview.event_timestamp_column }} >= Timestamp_sub('{{min_timestamp}}', interval {{ featureview.ttl }} second) + {% endif %} ), {{ featureview.name }}__base AS ( diff --git a/sdk/python/tests/test_historical_retrieval.py b/sdk/python/tests/test_historical_retrieval.py index 3a708c7503..24c9177698 100644 --- a/sdk/python/tests/test_historical_retrieval.py +++ b/sdk/python/tests/test_historical_retrieval.py @@ -21,7 +21,10 @@ from feast.feature import Feature from feast.feature_store import FeatureStore, _validate_feature_refs from feast.feature_view import FeatureView -from feast.infra.offline_stores.bigquery import BigQueryOfflineStoreConfig +from feast.infra.offline_stores.bigquery import ( + BigQueryOfflineStoreConfig, + _get_entity_df_timestamp_bounds, +) from feast.infra.online_stores.sqlite import SqliteOnlineStoreConfig from feast.infra.provider import DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL from feast.value_type import ValueType @@ -595,6 +598,31 @@ def test_historical_features_from_bigquery_sources( ) +@pytest.mark.integration +def test_timestamp_bound_inference_from_entity_df_using_bigquery(): + start_date = datetime.now().replace(microsecond=0, second=0, minute=0) + (_, _, _, entity_df, start_date) = generate_entities( + start_date, infer_event_timestamp_col=True + ) + + table_id = "foo.table_id" + stage_orders_bigquery(entity_df, table_id) + + client = bigquery.Client() + table = client.get_table(table=table_id) + + # Ensure that the table expires after some time + table.expires = datetime.utcnow() + timedelta(minutes=30) + client.update_table(table, ["expires"]) + + min_timestamp, max_timestamp = _get_entity_df_timestamp_bounds( + client, str(table.reference), "e_ts" + ) + + assert min_timestamp.astimezone("UTC") == min(entity_df["e_ts"]).astimezone("UTC") + assert max_timestamp.astimezone("UTC") == max(entity_df["e_ts"]).astimezone("UTC") + + def test_feature_name_collision_on_historical_retrieval(): # _validate_feature_refs is the function that checks for colliding feature names