diff --git a/sdk/python/feast/infra/online_stores/snowflake.py b/sdk/python/feast/infra/online_stores/snowflake.py deleted file mode 100644 index 80074cf509..0000000000 --- a/sdk/python/feast/infra/online_stores/snowflake.py +++ /dev/null @@ -1,232 +0,0 @@ -import itertools -import os -from binascii import hexlify -from datetime import datetime -from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple - -import pandas as pd -import pytz -from pydantic import Field -from pydantic.schema import Literal - -from feast import Entity, FeatureView -from feast.infra.key_encoding_utils import serialize_entity_key -from feast.infra.online_stores.online_store import OnlineStore -from feast.infra.utils.snowflake_utils import get_snowflake_conn, write_pandas_binary -from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto -from feast.protos.feast.types.Value_pb2 import Value as ValueProto -from feast.repo_config import FeastConfigBaseModel, RepoConfig -from feast.usage import log_exceptions_and_usage - - -class SnowflakeOnlineStoreConfig(FeastConfigBaseModel): - """ Online store config for Snowflake """ - - type: Literal["snowflake.online"] = "snowflake.online" - """ Online store type selector""" - - config_path: Optional[str] = ( - Path(os.environ["HOME"]) / ".snowsql/config" - ).__str__() - """ Snowflake config path -- absolute path required (Can't use ~)""" - - account: Optional[str] = None - """ Snowflake deployment identifier -- drop .snowflakecomputing.com""" - - user: Optional[str] = None - """ Snowflake user name """ - - password: Optional[str] = None - """ Snowflake password """ - - role: Optional[str] = None - """ Snowflake role name""" - - warehouse: Optional[str] = None - """ Snowflake warehouse name """ - - database: Optional[str] = None - """ Snowflake database name """ - - schema_: Optional[str] = Field("PUBLIC", alias="schema") - """ Snowflake schema name """ - - class Config: - allow_population_by_field_name = True - - -class SnowflakeOnlineStore(OnlineStore): - @log_exceptions_and_usage(online_store="snowflake") - def online_write_batch( - self, - config: RepoConfig, - table: FeatureView, - data: List[ - Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]] - ], - progress: Optional[Callable[[int], Any]], - ) -> None: - assert isinstance(config.online_store, SnowflakeOnlineStoreConfig) - - dfs = [None] * len(data) - for i, (entity_key, values, timestamp, created_ts) in enumerate(data): - - df = pd.DataFrame( - columns=[ - "entity_feature_key", - "entity_key", - "feature_name", - "value", - "event_ts", - "created_ts", - ], - index=range(0, len(values)), - ) - - timestamp = _to_naive_utc(timestamp) - if created_ts is not None: - created_ts = _to_naive_utc(created_ts) - - for j, (feature_name, val) in enumerate(values.items()): - df.loc[j, "entity_feature_key"] = serialize_entity_key( - entity_key - ) + bytes(feature_name, encoding="utf-8") - df.loc[j, "entity_key"] = serialize_entity_key(entity_key) - df.loc[j, "feature_name"] = feature_name - df.loc[j, "value"] = val.SerializeToString() - df.loc[j, "event_ts"] = timestamp - df.loc[j, "created_ts"] = created_ts - - dfs[i] = df - if progress: - progress(1) - - if dfs: - agg_df = pd.concat(dfs) - - with get_snowflake_conn(config.online_store, autocommit=False) as conn: - - write_pandas_binary(conn, agg_df, f"{config.project}_{table.name}") - - query = f""" - INSERT OVERWRITE INTO "{config.online_store.database}"."{config.online_store.schema_}"."{config.project}_{table.name}" - SELECT - "entity_feature_key", - "entity_key", - "feature_name", - "value", - "event_ts", - "created_ts" - FROM - (SELECT - *, - ROW_NUMBER() OVER(PARTITION BY "entity_key","feature_name" ORDER BY "event_ts" DESC, "created_ts" DESC) AS "_feast_row" - FROM - "{config.online_store.database}"."{config.online_store.schema_}"."{config.project}_{table.name}") - WHERE - "_feast_row" = 1; - """ - - conn.cursor().execute(query) - - return None - - @log_exceptions_and_usage(online_store="snowflake") - def online_read( - self, - config: RepoConfig, - table: FeatureView, - entity_keys: List[EntityKeyProto], - requested_features: List[str], - ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]: - assert isinstance(config.online_store, SnowflakeOnlineStoreConfig) - - result: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]] = [] - - with get_snowflake_conn(config.online_store) as conn: - - df = ( - conn.cursor() - .execute( - f""" - SELECT - "entity_key", "feature_name", "value", "event_ts" - FROM - "{config.online_store.database}"."{config.online_store.schema_}"."{config.project}_{table.name}" - WHERE - "entity_feature_key" IN ({','.join([('TO_BINARY('+hexlify(serialize_entity_key(combo[0])+bytes(combo[1], encoding='utf-8')).__str__()[1:]+")") for combo in itertools.product(entity_keys,requested_features)])}) - """, - ) - .fetch_pandas_all() - ) - - for entity_key in entity_keys: - entity_key_bin = serialize_entity_key(entity_key) - res = {} - res_ts = None - for index, row in df[df["entity_key"] == entity_key_bin].iterrows(): - val = ValueProto() - val.ParseFromString(row["value"]) - res[row["feature_name"]] = val - res_ts = row["event_ts"].to_pydatetime() - - if not res: - result.append((None, None)) - else: - result.append((res_ts, res)) - return result - - @log_exceptions_and_usage(online_store="snowflake") - def update( - self, - config: RepoConfig, - tables_to_delete: Sequence[FeatureView], - tables_to_keep: Sequence[FeatureView], - entities_to_delete: Sequence[Entity], - entities_to_keep: Sequence[Entity], - partial: bool, - ): - assert isinstance(config.online_store, SnowflakeOnlineStoreConfig) - - with get_snowflake_conn(config.online_store) as conn: - - for table in tables_to_keep: - - conn.cursor().execute( - f"""CREATE TABLE IF NOT EXISTS "{config.online_store.database}"."{config.online_store.schema_}"."{config.project}_{table.name}" ( - "entity_feature_key" BINARY, - "entity_key" BINARY, - "feature_name" VARCHAR, - "value" BINARY, - "event_ts" TIMESTAMP, - "created_ts" TIMESTAMP - )""" - ) - - for table in tables_to_delete: - - conn.cursor().execute( - f'DROP TABLE IF EXISTS "{config.online_store.database}"."{config.online_store.schema_}"."{config.project}_{table.name}"' - ) - - def teardown( - self, - config: RepoConfig, - tables: Sequence[FeatureView], - entities: Sequence[Entity], - ): - assert isinstance(config.online_store, SnowflakeOnlineStoreConfig) - - with get_snowflake_conn(config.online_store) as conn: - - for table in tables: - query = f'DROP TABLE IF EXISTS "{config.online_store.database}"."{config.online_store.schema_}"."{config.project}_{table.name}"' - conn.cursor().execute(query) - - -def _to_naive_utc(ts: datetime): - if ts.tzinfo is None: - return ts - else: - return ts.astimezone(pytz.utc).replace(tzinfo=None) diff --git a/sdk/python/feast/infra/utils/snowflake_utils.py b/sdk/python/feast/infra/utils/snowflake_utils.py index d9be930439..05834ae436 100644 --- a/sdk/python/feast/infra/utils/snowflake_utils.py +++ b/sdk/python/feast/infra/utils/snowflake_utils.py @@ -44,12 +44,8 @@ def execute_snowflake_statement(conn: SnowflakeConnection, query) -> SnowflakeCu def get_snowflake_conn(config, autocommit=True) -> SnowflakeConnection: - assert config.type in ["snowflake.offline", "snowflake.online"] - - if config.type == "snowflake.offline": - config_header = "connections.feast_offline_store" - elif config.type == "snowflake.online": - config_header = "connections.feast_online_store" + assert config.type == "snowflake.offline" + config_header = "connections.feast_offline_store" config_dict = dict(config) @@ -433,176 +429,3 @@ def parse_private_key_path(key_path: str, private_key_passphrase: str) -> bytes: ) return pkb - - -def write_pandas_binary( - conn: SnowflakeConnection, - df: pd.DataFrame, - table_name: str, - database: Optional[str] = None, - schema: Optional[str] = None, - chunk_size: Optional[int] = None, - compression: str = "gzip", - on_error: str = "abort_statement", - parallel: int = 4, - quote_identifiers: bool = True, - auto_create_table: bool = False, - create_temp_table: bool = False, -): - """Allows users to most efficiently write back a pandas DataFrame to Snowflake. - - It works by dumping the DataFrame into Parquet files, uploading them and finally copying their data into the table. - - Returns whether all files were ingested correctly, number of chunks uploaded, and number of rows ingested - with all of the COPY INTO command's output for debugging purposes. - - Example usage: - import pandas - from snowflake.connector.pandas_tools import write_pandas - - df = pandas.DataFrame([('Mark', 10), ('Luke', 20)], columns=['name', 'balance']) - success, nchunks, nrows, _ = write_pandas(cnx, df, 'customers') - - Args: - conn: Connection to be used to communicate with Snowflake. - df: Dataframe we'd like to write back. - table_name: Table name where we want to insert into. - database: Database schema and table is in, if not provided the default one will be used (Default value = None). - schema: Schema table is in, if not provided the default one will be used (Default value = None). - chunk_size: Number of elements to be inserted once, if not provided all elements will be dumped once - (Default value = None). - compression: The compression used on the Parquet files, can only be gzip, or snappy. Gzip gives supposedly a - better compression, while snappy is faster. Use whichever is more appropriate (Default value = 'gzip'). - on_error: Action to take when COPY INTO statements fail, default follows documentation at: - https://docs.snowflake.com/en/sql-reference/sql/copy-into-table.html#copy-options-copyoptions - (Default value = 'abort_statement'). - parallel: Number of threads to be used when uploading chunks, default follows documentation at: - https://docs.snowflake.com/en/sql-reference/sql/put.html#optional-parameters (Default value = 4). - quote_identifiers: By default, identifiers, specifically database, schema, table and column names - (from df.columns) will be quoted. If set to False, identifiers are passed on to Snowflake without quoting. - I.e. identifiers will be coerced to uppercase by Snowflake. (Default value = True) - auto_create_table: When true, will automatically create a table with corresponding columns for each column in - the passed in DataFrame. The table will not be created if it already exists - create_temp_table: Will make the auto-created table as a temporary table - """ - if database is not None and schema is None: - raise ProgrammingError( - "Schema has to be provided to write_pandas when a database is provided" - ) - # This dictionary maps the compression algorithm to Snowflake put copy into command type - # https://docs.snowflake.com/en/sql-reference/sql/copy-into-table.html#type-parquet - compression_map = {"gzip": "auto", "snappy": "snappy"} - if compression not in compression_map.keys(): - raise ProgrammingError( - "Invalid compression '{}', only acceptable values are: {}".format( - compression, compression_map.keys() - ) - ) - if quote_identifiers: - location = ( - (('"' + database + '".') if database else "") - + (('"' + schema + '".') if schema else "") - + ('"' + table_name + '"') - ) - else: - location = ( - (database + "." if database else "") - + (schema + "." if schema else "") - + (table_name) - ) - if chunk_size is None: - chunk_size = len(df) - cursor: SnowflakeCursor = conn.cursor() - stage_name = create_temporary_sfc_stage(cursor) - - with TemporaryDirectory() as tmp_folder: - for i, chunk in chunk_helper(df, chunk_size): - chunk_path = os.path.join(tmp_folder, "file{}.txt".format(i)) - # Dump chunk into parquet file - chunk.to_parquet( - chunk_path, - compression=compression, - use_deprecated_int96_timestamps=True, - ) - # Upload parquet file - upload_sql = ( - "PUT /* Python:snowflake.connector.pandas_tools.write_pandas() */ " - "'file://{path}' @\"{stage_name}\" PARALLEL={parallel}" - ).format( - path=chunk_path.replace("\\", "\\\\").replace("'", "\\'"), - stage_name=stage_name, - parallel=parallel, - ) - logger.debug(f"uploading files with '{upload_sql}'") - cursor.execute(upload_sql, _is_internal=True) - # Remove chunk file - os.remove(chunk_path) - if quote_identifiers: - columns = '"' + '","'.join(list(df.columns)) + '"' - else: - columns = ",".join(list(df.columns)) - - if auto_create_table: - file_format_name = create_file_format(compression, compression_map, cursor) - infer_schema_sql = f"SELECT COLUMN_NAME, TYPE FROM table(infer_schema(location=>'@\"{stage_name}\"', file_format=>'{file_format_name}'))" - logger.debug(f"inferring schema with '{infer_schema_sql}'") - result_cursor = cursor.execute(infer_schema_sql, _is_internal=True) - if result_cursor is None: - raise SnowflakeQueryUnknownError(infer_schema_sql) - result = cast(List[Tuple[str, str]], result_cursor.fetchall()) - column_type_mapping: Dict[str, str] = dict(result) - # Infer schema can return the columns out of order depending on the chunking we do when uploading - # so we have to iterate through the dataframe columns to make sure we create the table with its - # columns in order - quote = '"' if quote_identifiers else "" - create_table_columns = ", ".join( - [f"{quote}{c}{quote} {column_type_mapping[c]}" for c in df.columns] - ) - create_table_sql = ( - f"CREATE {'TEMP ' if create_temp_table else ''}TABLE IF NOT EXISTS {location} " - f"({create_table_columns})" - f" /* Python:snowflake.connector.pandas_tools.write_pandas() */ " - ) - logger.debug(f"auto creating table with '{create_table_sql}'") - cursor.execute(create_table_sql, _is_internal=True) - drop_file_format_sql = f"DROP FILE FORMAT IF EXISTS {file_format_name}" - logger.debug(f"dropping file format with '{drop_file_format_sql}'") - cursor.execute(drop_file_format_sql, _is_internal=True) - - # in Snowflake, all parquet data is stored in a single column, $1, so we must select columns explicitly - # see (https://docs.snowflake.com/en/user-guide/script-data-load-transform-parquet.html) - if quote_identifiers: - parquet_columns = ",".join( - f'TO_BINARY($1:"{c}")' - if c in ["entity_feature_key", "entity_key", "value"] - else f'$1:"{c}"' - for c in df.columns - ) - else: - parquet_columns = ",".join( - f"TO_BINARY($1:{c})" - if c in ["entity_feature_key", "entity_key", "value"] - else f"$1:{c}" - for c in df.columns - ) - - copy_into_sql = ( - "COPY INTO {location} /* Python:snowflake.connector.pandas_tools.write_pandas() */ " - "({columns}) " - 'FROM (SELECT {parquet_columns} FROM @"{stage_name}") ' - "FILE_FORMAT=(TYPE=PARQUET COMPRESSION={compression} BINARY_AS_TEXT = FALSE) " - "PURGE=TRUE ON_ERROR={on_error}" - ).format( - location=location, - columns=columns, - parquet_columns=parquet_columns, - stage_name=stage_name, - compression=compression_map[compression], - on_error=on_error, - ) - logger.debug("copying into with '{}'".format(copy_into_sql)) - # Snowflake returns the original cursor if the query execution succeeded. - result_cursor = cursor.execute(copy_into_sql, _is_internal=True) - if result_cursor is None: - raise SnowflakeQueryUnknownError(copy_into_sql) - result_cursor.close() diff --git a/sdk/python/tests/integration/feature_repos/repo_configuration.py b/sdk/python/tests/integration/feature_repos/repo_configuration.py index 67e585839d..a168f4f028 100644 --- a/sdk/python/tests/integration/feature_repos/repo_configuration.py +++ b/sdk/python/tests/integration/feature_repos/repo_configuration.py @@ -74,17 +74,6 @@ "connection_string": "127.0.0.1:6001,127.0.0.1:6002,127.0.0.1:6003", } -SNOWFLAKE_CONFIG = { - "type": "snowflake.online", - "account": os.environ["SNOWFLAKE_CI_DEPLOYMENT"], - "user": os.environ["SNOWFLAKE_CI_USER"], - "password": os.environ["SNOWFLAKE_CI_PASSWORD"], - "role": os.environ["SNOWFLAKE_CI_ROLE"], - "warehouse": os.environ["SNOWFLAKE_CI_WAREHOUSE"], - "database": "FEAST", - "schema": "ONLINE", -} - OFFLINE_STORE_TO_PROVIDER_CONFIG: Dict[str, DataSourceCreator] = { "file": ("local", FileDataSourceCreator), "bigquery": ("gcp", BigQueryDataSourceCreator), @@ -114,7 +103,6 @@ AVAILABLE_ONLINE_STORES["redis"] = (REDIS_CONFIG, None) AVAILABLE_ONLINE_STORES["dynamodb"] = (DYNAMO_CONFIG, None) AVAILABLE_ONLINE_STORES["datastore"] = ("datastore", None) - AVAILABLE_ONLINE_STORES["snowflake"] = (SNOWFLAKE_CONFIG, None) full_repo_configs_module = os.environ.get(FULL_REPO_CONFIGS_MODULE_ENV_NAME)