Skip to content

Commit

Permalink
chore: Check configs and data sources in all offline store methods (#…
Browse files Browse the repository at this point in the history
…3107)

Check configs and data sources in all offline store methods

Signed-off-by: Felix Wang <[email protected]>

Signed-off-by: Felix Wang <[email protected]>
  • Loading branch information
felixwang9817 authored Aug 19, 2022
1 parent 2b493e0 commit 29f2895
Show file tree
Hide file tree
Showing 9 changed files with 53 additions and 78 deletions.
18 changes: 6 additions & 12 deletions sdk/python/feast/infra/offline_stores/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def pull_latest_from_table_or_query(
start_date: datetime,
end_date: datetime,
) -> RetrievalJob:
assert isinstance(config.offline_store, BigQueryOfflineStoreConfig)
assert isinstance(data_source, BigQuerySource)
from_expression = data_source.get_table_query_string()

Expand Down Expand Up @@ -156,6 +157,7 @@ def pull_all_from_table_or_query(
start_date: datetime,
end_date: datetime,
) -> RetrievalJob:
assert isinstance(config.offline_store, BigQueryOfflineStoreConfig)
assert isinstance(data_source, BigQuerySource)
from_expression = data_source.get_table_query_string()

Expand Down Expand Up @@ -191,6 +193,8 @@ def get_historical_features(
) -> RetrievalJob:
# TODO: Add entity_df validation in order to fail before interacting with BigQuery
assert isinstance(config.offline_store, BigQueryOfflineStoreConfig)
for fv in feature_views:
assert isinstance(fv.batch_source, BigQuerySource)

client = _get_bigquery_client(
project=config.offline_store.project_id,
Expand Down Expand Up @@ -333,18 +337,8 @@ def offline_write_batch(
table: pyarrow.Table,
progress: Optional[Callable[[int], Any]],
):
if not feature_view.batch_source:
raise ValueError(
"feature view does not have a batch source to persist offline data"
)
if not isinstance(config.offline_store, BigQueryOfflineStoreConfig):
raise ValueError(
f"offline store config is of type {type(config.offline_store)} when bigquery type required"
)
if not isinstance(feature_view.batch_source, BigQuerySource):
raise ValueError(
f"feature view batch source is {type(feature_view.batch_source)} not bigquery source"
)
assert isinstance(config.offline_store, BigQueryOfflineStoreConfig)
assert isinstance(feature_view.batch_source, BigQuerySource)

pa_schema, column_names = offline_utils.get_pyarrow_schema_from_batch_source(
config, feature_view.batch_source
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ def pull_latest_from_table_or_query(
start_date: datetime,
end_date: datetime,
) -> RetrievalJob:
assert isinstance(data_source, AthenaSource)
assert isinstance(config.offline_store, AthenaOfflineStoreConfig)
assert isinstance(data_source, AthenaSource)

from_expression = data_source.get_table_query_string(config)

Expand Down Expand Up @@ -136,6 +136,7 @@ def pull_all_from_table_or_query(
start_date: datetime,
end_date: datetime,
) -> RetrievalJob:
assert isinstance(config.offline_store, AthenaOfflineStoreConfig)
assert isinstance(data_source, AthenaSource)
from_expression = data_source.get_table_query_string(config)

Expand Down Expand Up @@ -175,6 +176,8 @@ def get_historical_features(
full_feature_names: bool = False,
) -> RetrievalJob:
assert isinstance(config.offline_store, AthenaOfflineStoreConfig)
for fv in feature_views:
assert isinstance(fv.batch_source, AthenaSource)

athena_client = aws_utils.get_athena_data_client(config.offline_store.region)
s3_resource = aws_utils.get_s3_resource(config.offline_store.region)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def pull_latest_from_table_or_query(
start_date: datetime,
end_date: datetime,
) -> RetrievalJob:
assert isinstance(config.offline_store, PostgreSQLOfflineStoreConfig)
assert isinstance(data_source, PostgreSQLSource)
from_expression = data_source.get_table_query_string()

Expand Down Expand Up @@ -117,6 +118,9 @@ def get_historical_features(
project: str,
full_feature_names: bool = False,
) -> RetrievalJob:
assert isinstance(config.offline_store, PostgreSQLOfflineStoreConfig)
for fv in feature_views:
assert isinstance(fv.batch_source, PostgreSQLSource)

entity_schema = _get_entity_schema(entity_df, config)

Expand Down Expand Up @@ -206,6 +210,7 @@ def pull_all_from_table_or_query(
start_date: datetime,
end_date: datetime,
) -> RetrievalJob:
assert isinstance(config.offline_store, PostgreSQLOfflineStoreConfig)
assert isinstance(data_source, PostgreSQLSource)
from_expression = data_source.get_table_query_string()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,9 @@ def get_historical_features(
full_feature_names: bool = False,
) -> RetrievalJob:
assert isinstance(config.offline_store, SparkOfflineStoreConfig)
for fv in feature_views:
assert isinstance(fv.batch_source, SparkSource)

warnings.warn(
"The spark offline store is an experimental feature in alpha development. "
"Some functionality may still be unstable so functionality can change in the future.",
Expand Down Expand Up @@ -198,18 +201,8 @@ def offline_write_batch(
table: pyarrow.Table,
progress: Optional[Callable[[int], Any]],
):
if not feature_view.batch_source:
raise ValueError(
"feature view does not have a batch source to persist offline data"
)
if not isinstance(config.offline_store, SparkOfflineStoreConfig):
raise ValueError(
f"offline store config is of type {type(config.offline_store)} when spark type required"
)
if not isinstance(feature_view.batch_source, SparkSource):
raise ValueError(
f"feature view batch source is {type(feature_view.batch_source)} not spark source"
)
assert isinstance(config.offline_store, SparkOfflineStoreConfig)
assert isinstance(feature_view.batch_source, SparkSource)

pa_schema, column_names = offline_utils.get_pyarrow_schema_from_batch_source(
config, feature_view.batch_source
Expand Down Expand Up @@ -269,6 +262,7 @@ def pull_all_from_table_or_query(
created_timestamp_column have all already been mapped to column names of the
source table and those column names are the values passed into this function.
"""
assert isinstance(config.offline_store, SparkOfflineStoreConfig)
assert isinstance(data_source, SparkSource)
warnings.warn(
"The spark offline store is an experimental feature in alpha development. "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,14 +161,8 @@ def pull_latest_from_table_or_query(
auth: Optional[Authentication] = None,
http_scheme: Optional[str] = None,
) -> TrinoRetrievalJob:
if not isinstance(data_source, TrinoSource):
raise ValueError(
f"The data_source object is not a TrinoSource but is instead '{type(data_source)}'"
)
if not isinstance(config.offline_store, TrinoOfflineStoreConfig):
raise ValueError(
f"The config.offline_store object is not a TrinoOfflineStoreConfig but is instead '{type(config.offline_store)}'"
)
assert isinstance(config.offline_store, TrinoOfflineStoreConfig)
assert isinstance(data_source, TrinoSource)

from_expression = data_source.get_table_query_string()

Expand Down Expand Up @@ -222,10 +216,9 @@ def get_historical_features(
auth: Optional[Authentication] = None,
http_scheme: Optional[str] = None,
) -> TrinoRetrievalJob:
if not isinstance(config.offline_store, TrinoOfflineStoreConfig):
raise ValueError(
f"This function should be used with a TrinoOfflineStoreConfig object. Instead we have config.offline_store being '{type(config.offline_store)}'"
)
assert isinstance(config.offline_store, TrinoOfflineStoreConfig)
for fv in feature_views:
assert isinstance(fv.batch_source, TrinoSource)

client = _get_trino_client(
config=config, user=user, auth=auth, http_scheme=http_scheme
Expand Down Expand Up @@ -314,10 +307,8 @@ def pull_all_from_table_or_query(
auth: Optional[Authentication] = None,
http_scheme: Optional[str] = None,
) -> RetrievalJob:
if not isinstance(data_source, TrinoSource):
raise ValueError(
f"The data_source object is not a TrinoSource object but is instead a {type(data_source)}"
)
assert isinstance(config.offline_store, TrinoOfflineStoreConfig)
assert isinstance(data_source, TrinoSource)
from_expression = data_source.get_table_query_string()

client = _get_trino_client(
Expand Down
23 changes: 11 additions & 12 deletions sdk/python/feast/infra/offline_stores/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,10 @@ def get_historical_features(
project: str,
full_feature_names: bool = False,
) -> RetrievalJob:
assert isinstance(config.offline_store, FileOfflineStoreConfig)
for fv in feature_views:
assert isinstance(fv.batch_source, FileSource)

if not isinstance(entity_df, pd.DataFrame) and not isinstance(
entity_df, dd.DataFrame
):
Expand Down Expand Up @@ -298,6 +302,7 @@ def pull_latest_from_table_or_query(
start_date: datetime,
end_date: datetime,
) -> RetrievalJob:
assert isinstance(config.offline_store, FileOfflineStoreConfig)
assert isinstance(data_source, FileSource)

# Create lazy function that is only called from the RetrievalJob object
Expand Down Expand Up @@ -378,6 +383,9 @@ def pull_all_from_table_or_query(
start_date: datetime,
end_date: datetime,
) -> RetrievalJob:
assert isinstance(config.offline_store, FileOfflineStoreConfig)
assert isinstance(data_source, FileSource)

return FileOfflineStore.pull_latest_from_table_or_query(
config=config,
data_source=data_source,
Expand All @@ -398,6 +406,7 @@ def write_logged_features(
logging_config: LoggingConfig,
registry: BaseRegistry,
):
assert isinstance(config.offline_store, FileOfflineStoreConfig)
destination = logging_config.destination
assert isinstance(destination, FileLoggingDestination)

Expand Down Expand Up @@ -428,18 +437,8 @@ def offline_write_batch(
table: pyarrow.Table,
progress: Optional[Callable[[int], Any]],
):
if not feature_view.batch_source:
raise ValueError(
"feature view does not have a batch source to persist offline data"
)
if not isinstance(config.offline_store, FileOfflineStoreConfig):
raise ValueError(
f"offline store config is of type {type(config.offline_store)} when file type required"
)
if not isinstance(feature_view.batch_source, FileSource):
raise ValueError(
f"feature view batch source is {type(feature_view.batch_source)} not file source"
)
assert isinstance(config.offline_store, FileOfflineStoreConfig)
assert isinstance(feature_view.batch_source, FileSource)

pa_schema, column_names = get_pyarrow_schema_from_batch_source(
config, feature_view.batch_source
Expand Down
3 changes: 3 additions & 0 deletions sdk/python/feast/infra/offline_stores/offline_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,9 @@ class OfflineStore(ABC):
"""
An offline store defines the interface that Feast uses to interact with the storage and compute system that
handles offline features.
Each offline store implementation is designed to work only with the corresponding data source. For example,
the SnowflakeOfflineStore can handle SnowflakeSources but not FileSources.
"""

@staticmethod
Expand Down
17 changes: 5 additions & 12 deletions sdk/python/feast/infra/offline_stores/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def pull_all_from_table_or_query(
start_date: datetime,
end_date: datetime,
) -> RetrievalJob:
assert isinstance(config.offline_store, RedshiftOfflineStoreConfig)
assert isinstance(data_source, RedshiftSource)
from_expression = data_source.get_table_query_string()

Expand Down Expand Up @@ -182,6 +183,8 @@ def get_historical_features(
full_feature_names: bool = False,
) -> RetrievalJob:
assert isinstance(config.offline_store, RedshiftOfflineStoreConfig)
for fv in feature_views:
assert isinstance(fv.batch_source, RedshiftSource)

redshift_client = aws_utils.get_redshift_data_client(
config.offline_store.region
Expand Down Expand Up @@ -308,18 +311,8 @@ def offline_write_batch(
table: pyarrow.Table,
progress: Optional[Callable[[int], Any]],
):
if not feature_view.batch_source:
raise ValueError(
"feature view does not have a batch source to persist offline data"
)
if not isinstance(config.offline_store, RedshiftOfflineStoreConfig):
raise ValueError(
f"offline store config is of type {type(config.offline_store)} when redshift type required"
)
if not isinstance(feature_view.batch_source, RedshiftSource):
raise ValueError(
f"feature view batch source is {type(feature_view.batch_source)} not redshift source"
)
assert isinstance(config.offline_store, RedshiftOfflineStoreConfig)
assert isinstance(feature_view.batch_source, RedshiftSource)

pa_schema, column_names = offline_utils.get_pyarrow_schema_from_batch_source(
config, feature_view.batch_source
Expand Down
19 changes: 6 additions & 13 deletions sdk/python/feast/infra/offline_stores/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,8 @@ def pull_latest_from_table_or_query(
start_date: datetime,
end_date: datetime,
) -> RetrievalJob:
assert isinstance(data_source, SnowflakeSource)
assert isinstance(config.offline_store, SnowflakeOfflineStoreConfig)
assert isinstance(data_source, SnowflakeSource)

from_expression = data_source.get_table_query_string()
if not data_source.database and data_source.table:
Expand Down Expand Up @@ -183,6 +183,7 @@ def pull_all_from_table_or_query(
start_date: datetime,
end_date: datetime,
) -> RetrievalJob:
assert isinstance(config.offline_store, SnowflakeOfflineStoreConfig)
assert isinstance(data_source, SnowflakeSource)

from_expression = data_source.get_table_query_string()
Expand Down Expand Up @@ -228,6 +229,8 @@ def get_historical_features(
full_feature_names: bool = False,
) -> RetrievalJob:
assert isinstance(config.offline_store, SnowflakeOfflineStoreConfig)
for fv in feature_views:
assert isinstance(fv.batch_source, SnowflakeSource)

snowflake_conn = get_snowflake_conn(config.offline_store)

Expand Down Expand Up @@ -332,18 +335,8 @@ def offline_write_batch(
table: pyarrow.Table,
progress: Optional[Callable[[int], Any]],
):
if not feature_view.batch_source:
raise ValueError(
"feature view does not have a batch source to persist offline data"
)
if not isinstance(config.offline_store, SnowflakeOfflineStoreConfig):
raise ValueError(
f"offline store config is of type {type(config.offline_store)} when snowflake type required"
)
if not isinstance(feature_view.batch_source, SnowflakeSource):
raise ValueError(
f"feature view batch source is {type(feature_view.batch_source)} not snowflake source"
)
assert isinstance(config.offline_store, SnowflakeOfflineStoreConfig)
assert isinstance(feature_view.batch_source, SnowflakeSource)

pa_schema, column_names = offline_utils.get_pyarrow_schema_from_batch_source(
config, feature_view.batch_source
Expand Down

0 comments on commit 29f2895

Please sign in to comment.