Skip to content

Commit

Permalink
Add to_table() to RetrievalJob object (#1663)
Browse files Browse the repository at this point in the history
* Add notion of OfflineJob

Signed-off-by: Matt Delacour <[email protected]>

* Use RetrievalJob instead of creating a new OfflineJob object

Signed-off-by: Matt Delacour <[email protected]>

* Add to_table() in integration tests

Signed-off-by: Matt Delacour <[email protected]>

Co-authored-by: Tsotne Tabidze <[email protected]>
  • Loading branch information
MattDelac and Tsotne Tabidze authored Jun 29, 2021
1 parent 51fe128 commit 5a49470
Show file tree
Hide file tree
Showing 10 changed files with 101 additions and 71 deletions.
26 changes: 14 additions & 12 deletions sdk/python/feast/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,12 +386,13 @@ def tqdm_builder(length):
end_date = utils.make_tzaware(end_date)

provider.materialize_single_feature_view(
feature_view,
start_date,
end_date,
self._registry,
self.project,
tqdm_builder,
config=self.config,
feature_view=feature_view,
start_date=start_date,
end_date=end_date,
registry=self._registry,
project=self.project,
tqdm_builder=tqdm_builder,
)

self._registry.apply_materialization(
Expand Down Expand Up @@ -464,12 +465,13 @@ def tqdm_builder(length):
end_date = utils.make_tzaware(end_date)

provider.materialize_single_feature_view(
feature_view,
start_date,
end_date,
self._registry,
self.project,
tqdm_builder,
config=self.config,
feature_view=feature_view,
start_date=start_date,
end_date=end_date,
registry=self._registry,
project=self.project,
tqdm_builder=tqdm_builder,
)

self._registry.apply_materialization(
Expand Down
5 changes: 4 additions & 1 deletion sdk/python/feast/infra/gcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def online_read(

def materialize_single_feature_view(
self,
config: RepoConfig,
feature_view: FeatureView,
start_date: datetime,
end_date: datetime,
Expand All @@ -99,7 +100,8 @@ def materialize_single_feature_view(
created_timestamp_column,
) = _get_column_names(feature_view, entities)

table = self.offline_store.pull_latest_from_table_or_query(
offline_job = self.offline_store.pull_latest_from_table_or_query(
config=config,
data_source=feature_view.input,
join_key_columns=join_key_columns,
feature_name_columns=feature_name_columns,
Expand All @@ -108,6 +110,7 @@ def materialize_single_feature_view(
start_date=start_date,
end_date=end_date,
)
table = offline_job.to_table()

if feature_view.input.field_mapping is not None:
table = _run_field_mapping(table, feature_view.input.field_mapping)
Expand Down
5 changes: 4 additions & 1 deletion sdk/python/feast/infra/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def online_read(

def materialize_single_feature_view(
self,
config: RepoConfig,
feature_view: FeatureView,
start_date: datetime,
end_date: datetime,
Expand All @@ -98,15 +99,17 @@ def materialize_single_feature_view(
created_timestamp_column,
) = _get_column_names(feature_view, entities)

table = self.offline_store.pull_latest_from_table_or_query(
offline_job = self.offline_store.pull_latest_from_table_or_query(
data_source=feature_view.input,
join_key_columns=join_key_columns,
feature_name_columns=feature_name_columns,
event_timestamp_column=event_timestamp_column,
created_timestamp_column=created_timestamp_column,
start_date=start_date,
end_date=end_date,
config=config,
)
table = offline_job.to_table()

if feature_view.input.field_mapping is not None:
table = _run_field_mapping(table, feature_view.input.field_mapping)
Expand Down
35 changes: 14 additions & 21 deletions sdk/python/feast/infra/offline_stores/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,9 @@
from feast.data_source import BigQuerySource, DataSource
from feast.errors import FeastProviderLoginError
from feast.feature_view import FeatureView
from feast.infra.offline_stores.offline_store import OfflineStore
from feast.infra.offline_stores.offline_store import OfflineStore, RetrievalJob
from feast.infra.provider import (
DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL,
RetrievalJob,
_get_requested_feature_views_to_features_dict,
)
from feast.registry import Registry
Expand Down Expand Up @@ -52,14 +51,15 @@ class BigQueryOfflineStoreConfig(FeastConfigBaseModel):
class BigQueryOfflineStore(OfflineStore):
@staticmethod
def pull_latest_from_table_or_query(
config: RepoConfig,
data_source: DataSource,
join_key_columns: List[str],
feature_name_columns: List[str],
event_timestamp_column: str,
created_timestamp_column: Optional[str],
start_date: datetime,
end_date: datetime,
) -> pyarrow.Table:
) -> RetrievalJob:
assert isinstance(data_source, BigQuerySource)
from_expression = data_source.get_table_query_string()

Expand All @@ -74,6 +74,7 @@ def pull_latest_from_table_or_query(
timestamp_desc_string = " DESC, ".join(timestamps) + " DESC"
field_string = ", ".join(join_key_columns + feature_name_columns + timestamps)

client = _get_bigquery_client(project=config.offline_store.project_id)
query = f"""
SELECT {field_string}
FROM (
Expand All @@ -84,14 +85,7 @@ def pull_latest_from_table_or_query(
)
WHERE _feast_row = 1
"""

return BigQueryOfflineStore._pull_query(query)

@staticmethod
def _pull_query(query: str) -> pyarrow.Table:
client = _get_bigquery_client()
query_job = client.query(query)
return query_job.to_arrow()
return BigQueryRetrievalJob(query=query, client=client, config=config)

@staticmethod
def get_historical_features(
Expand All @@ -103,19 +97,18 @@ def get_historical_features(
project: str,
) -> RetrievalJob:
# TODO: Add entity_df validation in order to fail before interacting with BigQuery
assert isinstance(config.offline_store, BigQueryOfflineStoreConfig)

client = _get_bigquery_client()

client = _get_bigquery_client(project=config.offline_store.project_id)
expected_join_keys = _get_join_keys(project, feature_views, registry)

assert isinstance(config.offline_store, BigQueryOfflineStoreConfig)
dataset_project = config.offline_store.project_id or client.project

table = _upload_entity_df_into_bigquery(
client=client,
project=config.project,
dataset_name=config.offline_store.dataset,
dataset_project=dataset_project,
dataset_project=client.project,
entity_df=entity_df,
)

Expand Down Expand Up @@ -263,10 +256,7 @@ def _block_until_done():
if not job_config:
today = date.today().strftime("%Y%m%d")
rand_id = str(uuid.uuid4())[:7]
dataset_project = (
self.config.offline_store.project_id or self.client.project
)
path = f"{dataset_project}.{self.config.offline_store.dataset}.historical_{today}_{rand_id}"
path = f"{self.client.project}.{self.config.offline_store.dataset}.historical_{today}_{rand_id}"
job_config = bigquery.QueryJobConfig(destination=path)

bq_job = self.client.query(self.query, job_config=job_config)
Expand All @@ -285,6 +275,9 @@ def _block_until_done():
print(f"Done writing to '{job_config.destination}'.")
return str(job_config.destination)

def to_table(self) -> pyarrow.Table:
return self.client.query(self.query).to_arrow()


@dataclass(frozen=True)
class FeatureViewQueryContext:
Expand Down Expand Up @@ -446,9 +439,9 @@ def build_point_in_time_query(
return query


def _get_bigquery_client():
def _get_bigquery_client(project: Optional[str] = None):
try:
client = bigquery.Client()
client = bigquery.Client(project=project)
except DefaultCredentialsError as e:
raise FeastProviderLoginError(
str(e)
Expand Down
76 changes: 44 additions & 32 deletions sdk/python/feast/infra/offline_stores/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ def to_df(self):
df = self.evaluation_function()
return df

def to_table(self):
# Only execute the evaluation function to build the final historical retrieval dataframe at the last moment.
df = self.evaluation_function()
return pyarrow.Table.from_pandas(df)


class FileOfflineStore(OfflineStore):
@staticmethod
Expand All @@ -48,7 +53,7 @@ def get_historical_features(
entity_df: Union[pd.DataFrame, str],
registry: Registry,
project: str,
) -> FileRetrievalJob:
) -> RetrievalJob:
if not isinstance(entity_df, pd.DataFrame):
raise ValueError(
f"Please provide an entity_df of type {type(pd.DataFrame)} instead of type {type(entity_df)}"
Expand Down Expand Up @@ -205,49 +210,56 @@ def evaluate_historical_retrieval():

@staticmethod
def pull_latest_from_table_or_query(
config: RepoConfig,
data_source: DataSource,
join_key_columns: List[str],
feature_name_columns: List[str],
event_timestamp_column: str,
created_timestamp_column: Optional[str],
start_date: datetime,
end_date: datetime,
) -> pyarrow.Table:
) -> RetrievalJob:
assert isinstance(data_source, FileSource)

source_df = pd.read_parquet(data_source.path)
# Make sure all timestamp fields are tz-aware. We default tz-naive fields to UTC
source_df[event_timestamp_column] = source_df[event_timestamp_column].apply(
lambda x: x if x.tzinfo is not None else x.replace(tzinfo=pytz.utc)
)
if created_timestamp_column:
source_df[created_timestamp_column] = source_df[
created_timestamp_column
].apply(lambda x: x if x.tzinfo is not None else x.replace(tzinfo=pytz.utc))

source_columns = set(source_df.columns)
if not set(join_key_columns).issubset(source_columns):
raise FeastJoinKeysDuringMaterialization(
data_source.path, set(join_key_columns), source_columns
# Create lazy function that is only called from the RetrievalJob object
def evaluate_offline_job():
source_df = pd.read_parquet(data_source.path)
# Make sure all timestamp fields are tz-aware. We default tz-naive fields to UTC
source_df[event_timestamp_column] = source_df[event_timestamp_column].apply(
lambda x: x if x.tzinfo is not None else x.replace(tzinfo=pytz.utc)
)
if created_timestamp_column:
source_df[created_timestamp_column] = source_df[
created_timestamp_column
].apply(
lambda x: x if x.tzinfo is not None else x.replace(tzinfo=pytz.utc)
)

ts_columns = (
[event_timestamp_column, created_timestamp_column]
if created_timestamp_column
else [event_timestamp_column]
)
source_columns = set(source_df.columns)
if not set(join_key_columns).issubset(source_columns):
raise FeastJoinKeysDuringMaterialization(
data_source.path, set(join_key_columns), source_columns
)

source_df.sort_values(by=ts_columns, inplace=True)
ts_columns = (
[event_timestamp_column, created_timestamp_column]
if created_timestamp_column
else [event_timestamp_column]
)

filtered_df = source_df[
(source_df[event_timestamp_column] >= start_date)
& (source_df[event_timestamp_column] < end_date)
]
last_values_df = filtered_df.drop_duplicates(
join_key_columns, keep="last", ignore_index=True
)
source_df.sort_values(by=ts_columns, inplace=True)

columns_to_extract = set(join_key_columns + feature_name_columns + ts_columns)
table = pyarrow.Table.from_pandas(last_values_df[columns_to_extract])
filtered_df = source_df[
(source_df[event_timestamp_column] >= start_date)
& (source_df[event_timestamp_column] < end_date)
]
last_values_df = filtered_df.drop_duplicates(
join_key_columns, keep="last", ignore_index=True
)

columns_to_extract = set(
join_key_columns + feature_name_columns + ts_columns
)
return last_values_df[columns_to_extract]

return table
return FileRetrievalJob(evaluation_function=evaluate_offline_job)
10 changes: 8 additions & 2 deletions sdk/python/feast/infra/offline_stores/offline_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,15 @@ class RetrievalJob(ABC):
"""RetrievalJob is used to manage the execution of a historical feature retrieval"""

@abstractmethod
def to_df(self):
def to_df(self) -> pd.DataFrame:
"""Return dataset as Pandas DataFrame synchronously"""
pass

@abstractmethod
def to_table(self) -> pyarrow.Table:
"""Return dataset as pyarrow Table synchronously"""
pass


class OfflineStore(ABC):
"""
Expand All @@ -42,14 +47,15 @@ class OfflineStore(ABC):
@staticmethod
@abstractmethod
def pull_latest_from_table_or_query(
config: RepoConfig,
data_source: DataSource,
join_key_columns: List[str],
feature_name_columns: List[str],
event_timestamp_column: str,
created_timestamp_column: Optional[str],
start_date: datetime,
end_date: datetime,
) -> pyarrow.Table:
) -> RetrievalJob:
"""
Note that join_key_columns, feature_name_columns, event_timestamp_column, and created_timestamp_column
have all already been mapped to column names of the source table and those column names are the values passed
Expand Down
4 changes: 2 additions & 2 deletions sdk/python/feast/infra/offline_stores/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import List, Optional, Union

import pandas as pd
import pyarrow
from pydantic import StrictStr
from pydantic.typing import Literal

Expand Down Expand Up @@ -38,14 +37,15 @@ class RedshiftOfflineStoreConfig(FeastConfigBaseModel):
class RedshiftOfflineStore(OfflineStore):
@staticmethod
def pull_latest_from_table_or_query(
config: RepoConfig,
data_source: DataSource,
join_key_columns: List[str],
feature_name_columns: List[str],
event_timestamp_column: str,
created_timestamp_column: Optional[str],
start_date: datetime,
end_date: datetime,
) -> pyarrow.Table:
) -> RetrievalJob:
pass

@staticmethod
Expand Down
1 change: 1 addition & 0 deletions sdk/python/feast/infra/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def online_write_batch(
@abc.abstractmethod
def materialize_single_feature_view(
self,
config: RepoConfig,
feature_view: FeatureView,
start_date: datetime,
end_date: datetime,
Expand Down
1 change: 1 addition & 0 deletions sdk/python/tests/foo_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def online_write_batch(

def materialize_single_feature_view(
self,
config: RepoConfig,
feature_view: FeatureView,
start_date: datetime,
end_date: datetime,
Expand Down
Loading

0 comments on commit 5a49470

Please sign in to comment.