diff --git a/sdk/python/feast/infra/online_stores/snowflake.py b/sdk/python/feast/infra/online_stores/snowflake.py new file mode 100644 index 0000000000..80074cf509 --- /dev/null +++ b/sdk/python/feast/infra/online_stores/snowflake.py @@ -0,0 +1,232 @@ +import itertools +import os +from binascii import hexlify +from datetime import datetime +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple + +import pandas as pd +import pytz +from pydantic import Field +from pydantic.schema import Literal + +from feast import Entity, FeatureView +from feast.infra.key_encoding_utils import serialize_entity_key +from feast.infra.online_stores.online_store import OnlineStore +from feast.infra.utils.snowflake_utils import get_snowflake_conn, write_pandas_binary +from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto +from feast.protos.feast.types.Value_pb2 import Value as ValueProto +from feast.repo_config import FeastConfigBaseModel, RepoConfig +from feast.usage import log_exceptions_and_usage + + +class SnowflakeOnlineStoreConfig(FeastConfigBaseModel): + """ Online store config for Snowflake """ + + type: Literal["snowflake.online"] = "snowflake.online" + """ Online store type selector""" + + config_path: Optional[str] = ( + Path(os.environ["HOME"]) / ".snowsql/config" + ).__str__() + """ Snowflake config path -- absolute path required (Can't use ~)""" + + account: Optional[str] = None + """ Snowflake deployment identifier -- drop .snowflakecomputing.com""" + + user: Optional[str] = None + """ Snowflake user name """ + + password: Optional[str] = None + """ Snowflake password """ + + role: Optional[str] = None + """ Snowflake role name""" + + warehouse: Optional[str] = None + """ Snowflake warehouse name """ + + database: Optional[str] = None + """ Snowflake database name """ + + schema_: Optional[str] = Field("PUBLIC", alias="schema") + """ Snowflake schema name """ + + class Config: + allow_population_by_field_name = True + + +class SnowflakeOnlineStore(OnlineStore): + @log_exceptions_and_usage(online_store="snowflake") + def online_write_batch( + self, + config: RepoConfig, + table: FeatureView, + data: List[ + Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]] + ], + progress: Optional[Callable[[int], Any]], + ) -> None: + assert isinstance(config.online_store, SnowflakeOnlineStoreConfig) + + dfs = [None] * len(data) + for i, (entity_key, values, timestamp, created_ts) in enumerate(data): + + df = pd.DataFrame( + columns=[ + "entity_feature_key", + "entity_key", + "feature_name", + "value", + "event_ts", + "created_ts", + ], + index=range(0, len(values)), + ) + + timestamp = _to_naive_utc(timestamp) + if created_ts is not None: + created_ts = _to_naive_utc(created_ts) + + for j, (feature_name, val) in enumerate(values.items()): + df.loc[j, "entity_feature_key"] = serialize_entity_key( + entity_key + ) + bytes(feature_name, encoding="utf-8") + df.loc[j, "entity_key"] = serialize_entity_key(entity_key) + df.loc[j, "feature_name"] = feature_name + df.loc[j, "value"] = val.SerializeToString() + df.loc[j, "event_ts"] = timestamp + df.loc[j, "created_ts"] = created_ts + + dfs[i] = df + if progress: + progress(1) + + if dfs: + agg_df = pd.concat(dfs) + + with get_snowflake_conn(config.online_store, autocommit=False) as conn: + + write_pandas_binary(conn, agg_df, f"{config.project}_{table.name}") + + query = f""" + INSERT OVERWRITE INTO "{config.online_store.database}"."{config.online_store.schema_}"."{config.project}_{table.name}" + SELECT + "entity_feature_key", + "entity_key", + "feature_name", + "value", + "event_ts", + "created_ts" + FROM + (SELECT + *, + ROW_NUMBER() OVER(PARTITION BY "entity_key","feature_name" ORDER BY "event_ts" DESC, "created_ts" DESC) AS "_feast_row" + FROM + "{config.online_store.database}"."{config.online_store.schema_}"."{config.project}_{table.name}") + WHERE + "_feast_row" = 1; + """ + + conn.cursor().execute(query) + + return None + + @log_exceptions_and_usage(online_store="snowflake") + def online_read( + self, + config: RepoConfig, + table: FeatureView, + entity_keys: List[EntityKeyProto], + requested_features: List[str], + ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]: + assert isinstance(config.online_store, SnowflakeOnlineStoreConfig) + + result: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]] = [] + + with get_snowflake_conn(config.online_store) as conn: + + df = ( + conn.cursor() + .execute( + f""" + SELECT + "entity_key", "feature_name", "value", "event_ts" + FROM + "{config.online_store.database}"."{config.online_store.schema_}"."{config.project}_{table.name}" + WHERE + "entity_feature_key" IN ({','.join([('TO_BINARY('+hexlify(serialize_entity_key(combo[0])+bytes(combo[1], encoding='utf-8')).__str__()[1:]+")") for combo in itertools.product(entity_keys,requested_features)])}) + """, + ) + .fetch_pandas_all() + ) + + for entity_key in entity_keys: + entity_key_bin = serialize_entity_key(entity_key) + res = {} + res_ts = None + for index, row in df[df["entity_key"] == entity_key_bin].iterrows(): + val = ValueProto() + val.ParseFromString(row["value"]) + res[row["feature_name"]] = val + res_ts = row["event_ts"].to_pydatetime() + + if not res: + result.append((None, None)) + else: + result.append((res_ts, res)) + return result + + @log_exceptions_and_usage(online_store="snowflake") + def update( + self, + config: RepoConfig, + tables_to_delete: Sequence[FeatureView], + tables_to_keep: Sequence[FeatureView], + entities_to_delete: Sequence[Entity], + entities_to_keep: Sequence[Entity], + partial: bool, + ): + assert isinstance(config.online_store, SnowflakeOnlineStoreConfig) + + with get_snowflake_conn(config.online_store) as conn: + + for table in tables_to_keep: + + conn.cursor().execute( + f"""CREATE TABLE IF NOT EXISTS "{config.online_store.database}"."{config.online_store.schema_}"."{config.project}_{table.name}" ( + "entity_feature_key" BINARY, + "entity_key" BINARY, + "feature_name" VARCHAR, + "value" BINARY, + "event_ts" TIMESTAMP, + "created_ts" TIMESTAMP + )""" + ) + + for table in tables_to_delete: + + conn.cursor().execute( + f'DROP TABLE IF EXISTS "{config.online_store.database}"."{config.online_store.schema_}"."{config.project}_{table.name}"' + ) + + def teardown( + self, + config: RepoConfig, + tables: Sequence[FeatureView], + entities: Sequence[Entity], + ): + assert isinstance(config.online_store, SnowflakeOnlineStoreConfig) + + with get_snowflake_conn(config.online_store) as conn: + + for table in tables: + query = f'DROP TABLE IF EXISTS "{config.online_store.database}"."{config.online_store.schema_}"."{config.project}_{table.name}"' + conn.cursor().execute(query) + + +def _to_naive_utc(ts: datetime): + if ts.tzinfo is None: + return ts + else: + return ts.astimezone(pytz.utc).replace(tzinfo=None) diff --git a/sdk/python/feast/infra/utils/snowflake_utils.py b/sdk/python/feast/infra/utils/snowflake_utils.py index 05834ae436..d9be930439 100644 --- a/sdk/python/feast/infra/utils/snowflake_utils.py +++ b/sdk/python/feast/infra/utils/snowflake_utils.py @@ -44,8 +44,12 @@ def execute_snowflake_statement(conn: SnowflakeConnection, query) -> SnowflakeCu def get_snowflake_conn(config, autocommit=True) -> SnowflakeConnection: - assert config.type == "snowflake.offline" - config_header = "connections.feast_offline_store" + assert config.type in ["snowflake.offline", "snowflake.online"] + + if config.type == "snowflake.offline": + config_header = "connections.feast_offline_store" + elif config.type == "snowflake.online": + config_header = "connections.feast_online_store" config_dict = dict(config) @@ -429,3 +433,176 @@ def parse_private_key_path(key_path: str, private_key_passphrase: str) -> bytes: ) return pkb + + +def write_pandas_binary( + conn: SnowflakeConnection, + df: pd.DataFrame, + table_name: str, + database: Optional[str] = None, + schema: Optional[str] = None, + chunk_size: Optional[int] = None, + compression: str = "gzip", + on_error: str = "abort_statement", + parallel: int = 4, + quote_identifiers: bool = True, + auto_create_table: bool = False, + create_temp_table: bool = False, +): + """Allows users to most efficiently write back a pandas DataFrame to Snowflake. + + It works by dumping the DataFrame into Parquet files, uploading them and finally copying their data into the table. + + Returns whether all files were ingested correctly, number of chunks uploaded, and number of rows ingested + with all of the COPY INTO command's output for debugging purposes. + + Example usage: + import pandas + from snowflake.connector.pandas_tools import write_pandas + + df = pandas.DataFrame([('Mark', 10), ('Luke', 20)], columns=['name', 'balance']) + success, nchunks, nrows, _ = write_pandas(cnx, df, 'customers') + + Args: + conn: Connection to be used to communicate with Snowflake. + df: Dataframe we'd like to write back. + table_name: Table name where we want to insert into. + database: Database schema and table is in, if not provided the default one will be used (Default value = None). + schema: Schema table is in, if not provided the default one will be used (Default value = None). + chunk_size: Number of elements to be inserted once, if not provided all elements will be dumped once + (Default value = None). + compression: The compression used on the Parquet files, can only be gzip, or snappy. Gzip gives supposedly a + better compression, while snappy is faster. Use whichever is more appropriate (Default value = 'gzip'). + on_error: Action to take when COPY INTO statements fail, default follows documentation at: + https://docs.snowflake.com/en/sql-reference/sql/copy-into-table.html#copy-options-copyoptions + (Default value = 'abort_statement'). + parallel: Number of threads to be used when uploading chunks, default follows documentation at: + https://docs.snowflake.com/en/sql-reference/sql/put.html#optional-parameters (Default value = 4). + quote_identifiers: By default, identifiers, specifically database, schema, table and column names + (from df.columns) will be quoted. If set to False, identifiers are passed on to Snowflake without quoting. + I.e. identifiers will be coerced to uppercase by Snowflake. (Default value = True) + auto_create_table: When true, will automatically create a table with corresponding columns for each column in + the passed in DataFrame. The table will not be created if it already exists + create_temp_table: Will make the auto-created table as a temporary table + """ + if database is not None and schema is None: + raise ProgrammingError( + "Schema has to be provided to write_pandas when a database is provided" + ) + # This dictionary maps the compression algorithm to Snowflake put copy into command type + # https://docs.snowflake.com/en/sql-reference/sql/copy-into-table.html#type-parquet + compression_map = {"gzip": "auto", "snappy": "snappy"} + if compression not in compression_map.keys(): + raise ProgrammingError( + "Invalid compression '{}', only acceptable values are: {}".format( + compression, compression_map.keys() + ) + ) + if quote_identifiers: + location = ( + (('"' + database + '".') if database else "") + + (('"' + schema + '".') if schema else "") + + ('"' + table_name + '"') + ) + else: + location = ( + (database + "." if database else "") + + (schema + "." if schema else "") + + (table_name) + ) + if chunk_size is None: + chunk_size = len(df) + cursor: SnowflakeCursor = conn.cursor() + stage_name = create_temporary_sfc_stage(cursor) + + with TemporaryDirectory() as tmp_folder: + for i, chunk in chunk_helper(df, chunk_size): + chunk_path = os.path.join(tmp_folder, "file{}.txt".format(i)) + # Dump chunk into parquet file + chunk.to_parquet( + chunk_path, + compression=compression, + use_deprecated_int96_timestamps=True, + ) + # Upload parquet file + upload_sql = ( + "PUT /* Python:snowflake.connector.pandas_tools.write_pandas() */ " + "'file://{path}' @\"{stage_name}\" PARALLEL={parallel}" + ).format( + path=chunk_path.replace("\\", "\\\\").replace("'", "\\'"), + stage_name=stage_name, + parallel=parallel, + ) + logger.debug(f"uploading files with '{upload_sql}'") + cursor.execute(upload_sql, _is_internal=True) + # Remove chunk file + os.remove(chunk_path) + if quote_identifiers: + columns = '"' + '","'.join(list(df.columns)) + '"' + else: + columns = ",".join(list(df.columns)) + + if auto_create_table: + file_format_name = create_file_format(compression, compression_map, cursor) + infer_schema_sql = f"SELECT COLUMN_NAME, TYPE FROM table(infer_schema(location=>'@\"{stage_name}\"', file_format=>'{file_format_name}'))" + logger.debug(f"inferring schema with '{infer_schema_sql}'") + result_cursor = cursor.execute(infer_schema_sql, _is_internal=True) + if result_cursor is None: + raise SnowflakeQueryUnknownError(infer_schema_sql) + result = cast(List[Tuple[str, str]], result_cursor.fetchall()) + column_type_mapping: Dict[str, str] = dict(result) + # Infer schema can return the columns out of order depending on the chunking we do when uploading + # so we have to iterate through the dataframe columns to make sure we create the table with its + # columns in order + quote = '"' if quote_identifiers else "" + create_table_columns = ", ".join( + [f"{quote}{c}{quote} {column_type_mapping[c]}" for c in df.columns] + ) + create_table_sql = ( + f"CREATE {'TEMP ' if create_temp_table else ''}TABLE IF NOT EXISTS {location} " + f"({create_table_columns})" + f" /* Python:snowflake.connector.pandas_tools.write_pandas() */ " + ) + logger.debug(f"auto creating table with '{create_table_sql}'") + cursor.execute(create_table_sql, _is_internal=True) + drop_file_format_sql = f"DROP FILE FORMAT IF EXISTS {file_format_name}" + logger.debug(f"dropping file format with '{drop_file_format_sql}'") + cursor.execute(drop_file_format_sql, _is_internal=True) + + # in Snowflake, all parquet data is stored in a single column, $1, so we must select columns explicitly + # see (https://docs.snowflake.com/en/user-guide/script-data-load-transform-parquet.html) + if quote_identifiers: + parquet_columns = ",".join( + f'TO_BINARY($1:"{c}")' + if c in ["entity_feature_key", "entity_key", "value"] + else f'$1:"{c}"' + for c in df.columns + ) + else: + parquet_columns = ",".join( + f"TO_BINARY($1:{c})" + if c in ["entity_feature_key", "entity_key", "value"] + else f"$1:{c}" + for c in df.columns + ) + + copy_into_sql = ( + "COPY INTO {location} /* Python:snowflake.connector.pandas_tools.write_pandas() */ " + "({columns}) " + 'FROM (SELECT {parquet_columns} FROM @"{stage_name}") ' + "FILE_FORMAT=(TYPE=PARQUET COMPRESSION={compression} BINARY_AS_TEXT = FALSE) " + "PURGE=TRUE ON_ERROR={on_error}" + ).format( + location=location, + columns=columns, + parquet_columns=parquet_columns, + stage_name=stage_name, + compression=compression_map[compression], + on_error=on_error, + ) + logger.debug("copying into with '{}'".format(copy_into_sql)) + # Snowflake returns the original cursor if the query execution succeeded. + result_cursor = cursor.execute(copy_into_sql, _is_internal=True) + if result_cursor is None: + raise SnowflakeQueryUnknownError(copy_into_sql) + result_cursor.close() diff --git a/sdk/python/tests/integration/feature_repos/repo_configuration.py b/sdk/python/tests/integration/feature_repos/repo_configuration.py index 6f40d3171b..67e585839d 100644 --- a/sdk/python/tests/integration/feature_repos/repo_configuration.py +++ b/sdk/python/tests/integration/feature_repos/repo_configuration.py @@ -74,11 +74,22 @@ "connection_string": "127.0.0.1:6001,127.0.0.1:6002,127.0.0.1:6003", } +SNOWFLAKE_CONFIG = { + "type": "snowflake.online", + "account": os.environ["SNOWFLAKE_CI_DEPLOYMENT"], + "user": os.environ["SNOWFLAKE_CI_USER"], + "password": os.environ["SNOWFLAKE_CI_PASSWORD"], + "role": os.environ["SNOWFLAKE_CI_ROLE"], + "warehouse": os.environ["SNOWFLAKE_CI_WAREHOUSE"], + "database": "FEAST", + "schema": "ONLINE", +} + OFFLINE_STORE_TO_PROVIDER_CONFIG: Dict[str, DataSourceCreator] = { "file": ("local", FileDataSourceCreator), "bigquery": ("gcp", BigQueryDataSourceCreator), "redshift": ("aws", RedshiftDataSourceCreator), - "snowflake": ("aws", RedshiftDataSourceCreator), + "snowflake": ("aws", SnowflakeDataSourceCreator), } AVAILABLE_OFFLINE_STORES: List[Tuple[str, Type[DataSourceCreator]]] = [ @@ -103,6 +114,7 @@ AVAILABLE_ONLINE_STORES["redis"] = (REDIS_CONFIG, None) AVAILABLE_ONLINE_STORES["dynamodb"] = (DYNAMO_CONFIG, None) AVAILABLE_ONLINE_STORES["datastore"] = ("datastore", None) + AVAILABLE_ONLINE_STORES["snowflake"] = (SNOWFLAKE_CONFIG, None) full_repo_configs_module = os.environ.get(FULL_REPO_CONFIGS_MODULE_ENV_NAME)