diff --git a/dump.rdb b/dump.rdb new file mode 100644 index 0000000000..a656ac1bba Binary files /dev/null and b/dump.rdb differ diff --git a/sdk/python/feast/data_source.py b/sdk/python/feast/data_source.py index 7480d7fd4f..30b192f6ed 100644 --- a/sdk/python/feast/data_source.py +++ b/sdk/python/feast/data_source.py @@ -17,15 +17,10 @@ from typing import Callable, Dict, Iterable, Optional, Tuple from pyarrow.parquet import ParquetFile -from tenacity import retry, retry_unless_exception_type, wait_exponential from feast import type_map from feast.data_format import FileFormat, StreamFormat -from feast.errors import ( - DataSourceNotFoundException, - RedshiftCredentialsError, - RedshiftQueryError, -) +from feast.errors import DataSourceNotFoundException, RedshiftCredentialsError from feast.protos.feast.core.DataSource_pb2 import DataSource as DataSourceProto from feast.repo_config import RepoConfig from feast.value_type import ValueType @@ -1062,7 +1057,7 @@ def validate(self, config: RepoConfig): def get_table_query_string(self) -> str: """Returns a string that can directly be used to reference this table in SQL""" if self.table: - return f"`{self.table}`" + return f'"{self.table}"' else: return f"({self.query})" @@ -1073,62 +1068,43 @@ def source_datatype_to_feast_value_type() -> Callable[[str], ValueType]: def get_table_column_names_and_types( self, config: RepoConfig ) -> Iterable[Tuple[str, str]]: - import boto3 - from botocore.config import Config from botocore.exceptions import ClientError from feast.infra.offline_stores.redshift import RedshiftOfflineStoreConfig + from feast.infra.utils import aws_utils assert isinstance(config.offline_store, RedshiftOfflineStoreConfig) - client = boto3.client( - "redshift-data", config=Config(region_name=config.offline_store.region) - ) + client = aws_utils.get_redshift_data_client(config.offline_store.region) - try: - if self.table is not None: + if self.table is not None: + try: table = client.describe_table( ClusterIdentifier=config.offline_store.cluster_id, Database=config.offline_store.database, DbUser=config.offline_store.user, Table=self.table, ) - # The API returns valid JSON with empty column list when the table doesn't exist - if len(table["ColumnList"]) == 0: - raise DataSourceNotFoundException(self.table) + except ClientError as e: + if e.response["Error"]["Code"] == "ValidationException": + raise RedshiftCredentialsError() from e + raise - columns = table["ColumnList"] - else: - statement = client.execute_statement( - ClusterIdentifier=config.offline_store.cluster_id, - Database=config.offline_store.database, - DbUser=config.offline_store.user, - Sql=f"SELECT * FROM ({self.query}) LIMIT 1", - ) + # The API returns valid JSON with empty column list when the table doesn't exist + if len(table["ColumnList"]) == 0: + raise DataSourceNotFoundException(self.table) - # Need to retry client.describe_statement(...) until the task is finished. We don't want to bombard - # Redshift with queries, and neither do we want to wait for a long time on the initial call. - # The solution is exponential backoff. The backoff starts with 0.1 seconds and doubles exponentially - # until reaching 30 seconds, at which point the backoff is fixed. - @retry( - wait=wait_exponential(multiplier=0.1, max=30), - retry=retry_unless_exception_type(RedshiftQueryError), - ) - def wait_for_statement(): - desc = client.describe_statement(Id=statement["Id"]) - if desc["Status"] in ("SUBMITTED", "STARTED", "PICKED"): - raise Exception # Retry - if desc["Status"] != "FINISHED": - raise RedshiftQueryError(desc) # Don't retry. Raise exception. - - wait_for_statement() - - result = client.get_statement_result(Id=statement["Id"]) - - columns = result["ColumnMetadata"] - except ClientError as e: - if e.response["Error"]["Code"] == "ValidationException": - raise RedshiftCredentialsError() from e - raise + columns = table["ColumnList"] + else: + statement_id = aws_utils.execute_redshift_statement( + client, + config.offline_store.cluster_id, + config.offline_store.database, + config.offline_store.user, + f"SELECT * FROM ({self.query}) LIMIT 1", + ) + columns = aws_utils.get_redshift_statement_result(client, statement_id)[ + "ColumnMetadata" + ] return [(column["name"], column["typeName"].upper()) for column in columns] diff --git a/sdk/python/feast/infra/offline_stores/redshift.py b/sdk/python/feast/infra/offline_stores/redshift.py index 1850958111..c97389ceaa 100644 --- a/sdk/python/feast/infra/offline_stores/redshift.py +++ b/sdk/python/feast/infra/offline_stores/redshift.py @@ -1,13 +1,16 @@ +import uuid from datetime import datetime from typing import List, Optional, Union import pandas as pd +import pyarrow as pa from pydantic import StrictStr from pydantic.typing import Literal -from feast.data_source import DataSource +from feast.data_source import DataSource, RedshiftSource from feast.feature_view import FeatureView from feast.infra.offline_stores.offline_store import OfflineStore, RetrievalJob +from feast.infra.utils import aws_utils from feast.registry import Registry from feast.repo_config import FeastConfigBaseModel, RepoConfig @@ -30,9 +33,12 @@ class RedshiftOfflineStoreConfig(FeastConfigBaseModel): database: StrictStr """ Redshift database name """ - s3_path: StrictStr + s3_staging_location: StrictStr """ S3 path for importing & exporting data to Redshift """ + iam_role: StrictStr + """ IAM Role for Redshift, granting it access to S3 """ + class RedshiftOfflineStore(OfflineStore): @staticmethod @@ -46,7 +52,45 @@ def pull_latest_from_table_or_query( start_date: datetime, end_date: datetime, ) -> RetrievalJob: - pass + assert isinstance(data_source, RedshiftSource) + assert isinstance(config.offline_store, RedshiftOfflineStoreConfig) + + from_expression = data_source.get_table_query_string() + + partition_by_join_key_string = ", ".join(join_key_columns) + if partition_by_join_key_string != "": + partition_by_join_key_string = ( + "PARTITION BY " + partition_by_join_key_string + ) + timestamp_columns = [event_timestamp_column] + if created_timestamp_column: + timestamp_columns.append(created_timestamp_column) + timestamp_desc_string = " DESC, ".join(timestamp_columns) + " DESC" + field_string = ", ".join( + join_key_columns + feature_name_columns + timestamp_columns + ) + + redshift_client = aws_utils.get_redshift_data_client( + config.offline_store.region + ) + s3_resource = aws_utils.get_s3_resource(config.offline_store.region) + + query = f""" + SELECT {field_string} + FROM ( + SELECT {field_string}, + ROW_NUMBER() OVER({partition_by_join_key_string} ORDER BY {timestamp_desc_string}) AS _feast_row + FROM {from_expression} + WHERE {event_timestamp_column} BETWEEN TIMESTAMP '{start_date}' AND TIMESTAMP '{end_date}' + ) + WHERE _feast_row = 1 + """ + return RedshiftRetrievalJob( + query=query, + redshift_client=redshift_client, + s3_resource=s3_resource, + config=config, + ) @staticmethod def get_historical_features( @@ -59,3 +103,71 @@ def get_historical_features( full_feature_names: bool = False, ) -> RetrievalJob: pass + + +class RedshiftRetrievalJob(RetrievalJob): + def __init__(self, query: str, redshift_client, s3_resource, config: RepoConfig): + """Initialize RedshiftRetrievalJob object. + + Args: + query: Redshift SQL query to execute. + redshift_client: boto3 redshift-data client + s3_resource: boto3 s3 resource object + config: Feast repo config + """ + self.query = query + self._redshift_client = redshift_client + self._s3_resource = s3_resource + self._config = config + self._s3_path = ( + self._config.offline_store.s3_staging_location + + "/unload/" + + str(uuid.uuid4()) + ) + + def to_df(self) -> pd.DataFrame: + return aws_utils.unload_redshift_query_to_df( + self._redshift_client, + self._config.offline_store.cluster_id, + self._config.offline_store.database, + self._config.offline_store.user, + self._s3_resource, + self._s3_path, + self._config.offline_store.iam_role, + self.query, + ) + + def to_arrow(self) -> pa.Table: + return aws_utils.unload_redshift_query_to_pa( + self._redshift_client, + self._config.offline_store.cluster_id, + self._config.offline_store.database, + self._config.offline_store.user, + self._s3_resource, + self._s3_path, + self._config.offline_store.iam_role, + self.query, + ) + + def to_s3(self) -> str: + """ Export dataset to S3 in Parquet format and return path """ + aws_utils.execute_redshift_query_and_unload_to_s3( + self._redshift_client, + self._config.offline_store.cluster_id, + self._config.offline_store.database, + self._config.offline_store.user, + self._s3_path, + self._config.offline_store.iam_role, + self.query, + ) + return self._s3_path + + def to_redshift(self, table_name: str) -> None: + """ Save dataset as a new Redshift table """ + aws_utils.execute_redshift_statement( + self._redshift_client, + self._config.offline_store.cluster_id, + self._config.offline_store.database, + self._config.offline_store.user, + f'CREATE TABLE "{table_name}" AS ({self.query})', + ) diff --git a/sdk/python/feast/infra/utils/__init__.py b/sdk/python/feast/infra/utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/sdk/python/feast/infra/utils/aws_utils.py b/sdk/python/feast/infra/utils/aws_utils.py new file mode 100644 index 0000000000..235f427b76 --- /dev/null +++ b/sdk/python/feast/infra/utils/aws_utils.py @@ -0,0 +1,296 @@ +import os +import tempfile +import uuid +from typing import Tuple + +import pandas as pd +import pyarrow as pa +import pyarrow.parquet as pq +from tenacity import retry, retry_if_exception_type, wait_exponential + +from feast.errors import RedshiftCredentialsError, RedshiftQueryError +from feast.type_map import pa_to_redshift_value_type + +try: + import boto3 + from botocore.config import Config + from botocore.exceptions import ClientError +except ImportError as e: + from feast.errors import FeastExtrasDependencyImportError + + raise FeastExtrasDependencyImportError("aws", str(e)) + + +def get_redshift_data_client(aws_region: str): + """ + Get the Redshift Data API Service client for the given AWS region. + """ + return boto3.client("redshift-data", config=Config(region_name=aws_region)) + + +def get_s3_resource(aws_region: str): + """ + Get the S3 resource for the given AWS region. + """ + return boto3.resource("s3", config=Config(region_name=aws_region)) + + +def get_bucket_and_key(s3_path: str) -> Tuple[str, str]: + """ + Get the S3 bucket and key given the full path. + + For example get_bucket_and_key("s3://foo/bar/test.file") returns ("foo", "bar/test.file") + + If the s3_path doesn't start with "s3://", it throws ValueError. + """ + assert s3_path.startswith("s3://") + s3_path = s3_path.replace("s3://", "") + bucket, key = s3_path.split("/", 1) + return bucket, key + + +def execute_redshift_statement_async( + redshift_data_client, cluster_id: str, database: str, user: str, query: str +) -> dict: + """Execute Redshift statement asynchronously. Does not wait for the query to finish. + + Raises RedshiftCredentialsError if the statement couldn't be executed due to the validation error. + + Args: + redshift_data_client: Redshift Data API Service client + cluster_id: Redshift Cluster Identifier + database: Redshift Database Name + user: Redshift username + query: The SQL query to execute + + Returns: JSON response + + """ + try: + return redshift_data_client.execute_statement( + ClusterIdentifier=cluster_id, Database=database, DbUser=user, Sql=query, + ) + except ClientError as e: + if e.response["Error"]["Code"] == "ValidationException": + raise RedshiftCredentialsError() from e + raise + + +class RedshiftStatementNotFinishedError(Exception): + pass + + +@retry( + wait=wait_exponential(multiplier=0.1, max=30), + retry=retry_if_exception_type(RedshiftStatementNotFinishedError), +) +def wait_for_redshift_statement(redshift_data_client, statement: dict) -> None: + """Waits for the Redshift statement to finish. Raises RedshiftQueryError if the statement didn't succeed. + + We use exponential backoff for checking the query state until it's not running. The backoff starts with + 0.1 seconds and doubles exponentially until reaching 30 seconds, at which point the backoff is fixed. + + Args: + redshift_data_client: Redshift Data API Service client + statement: The redshift statement to wait for (result of execute_redshift_statement) + + Returns: None + + """ + desc = redshift_data_client.describe_statement(Id=statement["Id"]) + if desc["Status"] in ("SUBMITTED", "STARTED", "PICKED"): + raise RedshiftStatementNotFinishedError # Retry + if desc["Status"] != "FINISHED": + raise RedshiftQueryError(desc) # Don't retry. Raise exception. + + +def execute_redshift_statement( + redshift_data_client, cluster_id: str, database: str, user: str, query: str +) -> str: + """Execute Redshift statement synchronously. Waits for the query to finish. + + Raises RedshiftCredentialsError if the statement couldn't be executed due to the validation error. + Raises RedshiftQueryError if the query runs but finishes with errors. + + + Args: + redshift_data_client: Redshift Data API Service client + cluster_id: Redshift Cluster Identifier + database: Redshift Database Name + user: Redshift username + query: The SQL query to execute + + Returns: Statement ID + + """ + statement = execute_redshift_statement_async( + redshift_data_client, cluster_id, database, user, query + ) + wait_for_redshift_statement(redshift_data_client, statement) + return statement["Id"] + + +def get_redshift_statement_result(redshift_data_client, statement_id: str) -> dict: + """ Get the Redshift statement result """ + return redshift_data_client.get_statement_result(Id=statement_id) + + +def upload_df_to_redshift( + redshift_data_client, + cluster_id: str, + database: str, + user: str, + s3_resource, + s3_path: str, + iam_role: str, + table_name: str, + df: pd.DataFrame, +) -> None: + """Uploads a Pandas DataFrame to Redshift as a new table. + + The caller is responsible for deleting the table when no longer necessary. + + Here's how the upload process works: + 1. Pandas DataFrame is converted to PyArrow Table + 2. PyArrow Table is serialized into a Parquet format on local disk + 3. The Parquet file is uploaded to S3 + 4. The S3 file is uploaded to Redshift as a new table through COPY command + 5. The local disk & s3 paths are cleaned up + + Args: + redshift_data_client: Redshift Data API Service client + cluster_id: Redshift Cluster Identifier + database: Redshift Database Name + user: Redshift username + s3_resource: S3 Resource object + s3_path: S3 path where the Parquet file is temporarily uploaded + iam_role: IAM Role for Redshift to assume during the COPY command. + The role must grant permission to read the S3 location. + table_name: The name of the new Redshift table where we copy the dataframe + df: The Pandas DataFrame to upload + + Returns: None + + """ + bucket, key = get_bucket_and_key(s3_path) + + # Convert Pandas DataFrame into PyArrow table and compile the Redshift table schema + table = pa.Table.from_pandas(df) + column_names, column_types = [], [] + for field in table.schema: + column_names.append(field.name) + column_types.append(pa_to_redshift_value_type(str(field.type))) + column_query_list = ", ".join( + [ + f"{column_name} {column_type}" + for column_name, column_type in zip(column_names, column_types) + ] + ) + + # Write the PyArrow Table on disk in Parquet format and upload it to S3 + with tempfile.TemporaryDirectory() as temp_dir: + file_path = f"{temp_dir}/{uuid.uuid4()}.parquet" + pq.write_table(table, file_path) + s3_resource.Object(bucket, key).put(Body=open(file_path, "rb")) + + # Create the table with the desired schema and + # copy the Parquet file contents to the Redshift table + create_and_copy_query = ( + f"CREATE TABLE {table_name}({column_query_list}); " + + f"COPY {table_name} FROM '{s3_path}' IAM_ROLE '{iam_role}' FORMAT AS PARQUET" + ) + execute_redshift_statement( + redshift_data_client, cluster_id, database, user, create_and_copy_query + ) + + # Clean up S3 temporary data + s3_resource.Object(bucket, key).delete() + + +def download_s3_directory(s3_resource, bucket: str, key: str, local_dir: str): + """ Download the S3 directory to a local disk """ + bucket_obj = s3_resource.Bucket(bucket) + if key != "" and not key.endswith("/"): + key = key + "/" + for obj in bucket_obj.objects.filter(Prefix=key): + local_file_path = local_dir + "/" + obj.key[len(key) :] + local_file_dir = os.path.dirname(local_file_path) + os.makedirs(local_file_dir, exist_ok=True) + bucket_obj.download_file(obj.key, local_file_path) + + +def delete_s3_directory(s3_resource, bucket: str, key: str): + """ Delete S3 directory recursively """ + bucket_obj = s3_resource.Bucket(bucket) + if key != "" and not key.endswith("/"): + key = key + "/" + for obj in bucket_obj.objects.filter(Prefix=key): + obj.delete() + + +def execute_redshift_query_and_unload_to_s3( + redshift_data_client, + cluster_id: str, + database: str, + user: str, + s3_path: str, + iam_role: str, + query: str, +) -> None: + """ Unload Redshift Query results to S3 """ + # Run the query, unload the results to S3 + unique_table_name = "_" + str(uuid.uuid4()).replace("-", "") + unload_query = f""" + CREATE TEMPORARY TABLE {unique_table_name} AS ({query}); + UNLOAD ('SELECT * FROM {unique_table_name}') TO '{s3_path}/' IAM_ROLE '{iam_role}' PARQUET + """ + execute_redshift_statement( + redshift_data_client, cluster_id, database, user, unload_query + ) + + +def unload_redshift_query_to_pa( + redshift_data_client, + cluster_id: str, + database: str, + user: str, + s3_resource, + s3_path: str, + iam_role: str, + query: str, +) -> pa.Table: + """ Unload Redshift Query results to S3 and get the results in PyArrow Table format """ + bucket, key = get_bucket_and_key(s3_path) + + execute_redshift_query_and_unload_to_s3( + redshift_data_client, cluster_id, database, user, s3_path, iam_role, query + ) + + with tempfile.TemporaryDirectory() as temp_dir: + download_s3_directory(s3_resource, bucket, key, temp_dir) + delete_s3_directory(s3_resource, bucket, key) + return pq.read_table(temp_dir) + + +def unload_redshift_query_to_df( + redshift_data_client, + cluster_id: str, + database: str, + user: str, + s3_resource, + s3_path: str, + iam_role: str, + query: str, +) -> pd.DataFrame: + """ Unload Redshift Query results to S3 and get the results in Pandas DataFrame format """ + table = unload_redshift_query_to_pa( + redshift_data_client, + cluster_id, + database, + user, + s3_resource, + s3_path, + iam_role, + query, + ) + return table.to_pandas() diff --git a/sdk/python/feast/type_map.py b/sdk/python/feast/type_map.py index 3514a0e7f5..53ab1183c2 100644 --- a/sdk/python/feast/type_map.py +++ b/sdk/python/feast/type_map.py @@ -13,15 +13,11 @@ # limitations under the License. import re -from datetime import datetime, timezone -from typing import Any, Dict, List, Union +from typing import Any, Dict, Union import numpy as np import pandas as pd -import pyarrow as pa from google.protobuf.json_format import MessageToDict -from google.protobuf.timestamp_pb2 import Timestamp -from pyarrow.lib import TimestampType from feast.protos.feast.types.Value_pb2 import ( BoolList, @@ -33,7 +29,6 @@ StringList, ) from feast.protos.feast.types.Value_pb2 import Value as ProtoValue -from feast.protos.feast.types.Value_pb2 import ValueType as ProtoValueType from feast.value_type import ValueType @@ -161,32 +156,6 @@ def python_type_to_feast_value_type( return type_map[value.dtype.__str__()] -def _pd_datetime_to_timestamp_proto(dtype, value) -> Timestamp: - """ - Converts a Pandas datetime to a Timestamp Proto - - Args: - dtype: Pandas datatype - value: Value of datetime - - Returns: - Timestamp protobuf value - """ - - if type(value) in [np.float64, np.float32, np.int32, np.int64]: - return Timestamp(seconds=int(value)) - if dtype.__str__() == "datetime64[ns]": - # If timestamp does not contain timezone, we assume it is of local - # timezone and adjust it to UTC - local_timezone = datetime.now(timezone.utc).astimezone().tzinfo - value = value.tz_localize(local_timezone).tz_convert("UTC").tz_localize(None) - return Timestamp(seconds=int(value.timestamp())) - if dtype.__str__() == "datetime64[ns, UTC]": - return Timestamp(seconds=int(value.timestamp())) - else: - return Timestamp(seconds=np.datetime64(value).astype("int64") // 1000000) # type: ignore - - def _type_err(item, dtype): raise ValueError(f'Value "{item}" is of type {type(item)} not of type {dtype}') @@ -369,81 +338,8 @@ def _proto_str_to_value_type(proto_str: str) -> ValueType: return type_map[proto_str] -def pa_to_feast_value_attr(pa_type: object): - """ - Returns the equivalent Feast ValueType string for the given pa.lib type. - - Args: - pa_type (object): - PyArrow type. - - Returns: - str: - Feast attribute name in Feast ValueType string-ed representation. - """ - # Mapping of PyArrow type to attribute name in Feast ValueType strings - type_map = { - "timestamp[ms]": "int64_val", - "int32": "int32_val", - "int64": "int64_val", - "double": "double_val", - "float": "float_val", - "string": "string_val", - "binary": "bytes_val", - "bool": "bool_val", - "list": "int32_list_val", - "list": "int64_list_val", - "list": "double_list_val", - "list": "float_list_val", - "list": "string_list_val", - "list": "bytes_list_val", - "list": "bool_list_val", - } - - return type_map[pa_type.__str__()] - - -def pa_to_value_type(pa_type: object): - """ - Returns the equivalent Feast ValueType for the given pa.lib type. - - Args: - pa_type (object): - PyArrow type. - - Returns: - feast.types.Value_pb2.ValueType: - Feast ValueType. - - """ - - # Mapping of PyArrow to attribute name in Feast ValueType - type_map = { - "timestamp[ms]": ProtoValueType.INT64, - "int32": ProtoValueType.INT32, - "int64": ProtoValueType.INT64, - "double": ProtoValueType.DOUBLE, - "float": ProtoValueType.FLOAT, - "string": ProtoValueType.STRING, - "binary": ProtoValueType.BYTES, - "bool": ProtoValueType.BOOL, - "list": ProtoValueType.INT32_LIST, - "list": ProtoValueType.INT64_LIST, - "list": ProtoValueType.DOUBLE_LIST, - "list": ProtoValueType.FLOAT_LIST, - "list": ProtoValueType.STRING_LIST, - "list": ProtoValueType.BYTES_LIST, - "list": ProtoValueType.BOOL_LIST, - } - return type_map[pa_type.__str__()] - - -def pa_to_feast_value_type(value: Union[pa.lib.ChunkedArray, str]) -> ValueType: - value_type = ( - value.type.__str__() if isinstance(value, pa.lib.ChunkedArray) else value - ) - - if re.match(r"^timestamp", value_type): +def pa_to_feast_value_type(pa_type_as_str: str) -> ValueType: + if re.match(r"^timestamp", pa_type_as_str): return ValueType.INT64 type_map = { @@ -462,51 +358,7 @@ def pa_to_feast_value_type(value: Union[pa.lib.ChunkedArray, str]) -> ValueType: "list": ValueType.BYTES_LIST, "list": ValueType.BOOL_LIST, } - return type_map[value_type] - - -def pa_column_to_timestamp_proto_column(column: pa.lib.ChunkedArray) -> List[Timestamp]: - if not isinstance(column.type, TimestampType): - raise Exception("Only TimestampType columns are allowed") - - proto_column = [] - for val in column: - timestamp = Timestamp() - timestamp.FromMicroseconds(micros=int(val.as_py().timestamp() * 1_000_000)) - proto_column.append(timestamp) - return proto_column - - -def pa_column_to_proto_column( - feast_value_type: ValueType, column: pa.lib.ChunkedArray -) -> List[ProtoValue]: - type_map: Dict[ValueType, Union[str, Dict[str, Any]]] = { - ValueType.INT32: "int32_val", - ValueType.INT64: "int64_val", - ValueType.FLOAT: "float_val", - ValueType.DOUBLE: "double_val", - ValueType.STRING: "string_val", - ValueType.BYTES: "bytes_val", - ValueType.BOOL: "bool_val", - ValueType.BOOL_LIST: {"bool_list_val": BoolList}, - ValueType.BYTES_LIST: {"bytes_list_val": BytesList}, - ValueType.STRING_LIST: {"string_list_val": StringList}, - ValueType.FLOAT_LIST: {"float_list_val": FloatList}, - ValueType.DOUBLE_LIST: {"double_list_val": DoubleList}, - ValueType.INT32_LIST: {"int32_list_val": Int32List}, - ValueType.INT64_LIST: {"int64_list_val": Int64List}, - } - - value: Union[str, Dict[str, Any]] = type_map[feast_value_type] - # Process list types - if isinstance(value, dict): - list_param_name = list(value.keys())[0] - return [ - ProtoValue(**{list_param_name: value[list_param_name](val=x.as_py())}) - for x in column - ] - else: - return [ProtoValue(**{value: x.as_py()}) for x in column] + return type_map[pa_type_as_str] def bq_to_feast_value_type(bq_type_as_str): @@ -530,25 +382,61 @@ def bq_to_feast_value_type(bq_type_as_str): return type_map[bq_type_as_str] -def redshift_to_feast_value_type(redshift_type_as_str): +def redshift_to_feast_value_type(redshift_type_as_str: str) -> ValueType: # Type names from https://docs.aws.amazon.com/redshift/latest/dg/c_Supported_data_types.html - type_map: Dict[ValueType, Union[str, Dict[str, Any]]] = { - "INT": ValueType.INT32, - "INT4": ValueType.INT32, - "INT8": ValueType.INT64, - "FLOAT4": ValueType.FLOAT, - "FLOAT8": ValueType.DOUBLE, - "FLOAT": ValueType.DOUBLE, - "NUMERIC": ValueType.DOUBLE, - "BOOL": ValueType.BOOL, - "CHARACTER": ValueType.STRING, - "NCHAR": ValueType.STRING, - "BPCHAR": ValueType.STRING, - "CHARACTER VARYING": ValueType.STRING, - "NVARCHAR": ValueType.STRING, - "TEXT": ValueType.STRING, - "TIMESTAMP WITHOUT TIME ZONE": ValueType.UNIX_TIMESTAMP, - "TIMESTAMP WITH TIME ZONE": ValueType.UNIX_TIMESTAMP, + type_map = { + "int2": ValueType.INT32, + "int4": ValueType.INT32, + "int8": ValueType.INT64, + "numeric": ValueType.DOUBLE, + "float4": ValueType.FLOAT, + "float8": ValueType.DOUBLE, + "bool": ValueType.BOOL, + "character": ValueType.STRING, + "varchar": ValueType.STRING, + "timestamp": ValueType.UNIX_TIMESTAMP, + "timestamptz": ValueType.UNIX_TIMESTAMP, + # skip date, geometry, hllsketch, time, timetz + } + + return type_map[redshift_type_as_str.lower()] + + +def pa_to_redshift_value_type(pa_type_as_str: str) -> str: + # PyArrow types: https://arrow.apache.org/docs/python/api/datatypes.html + # Redshift type: https://docs.aws.amazon.com/redshift/latest/dg/c_Supported_data_types.html + pa_type_as_str = pa_type_as_str.lower() + if pa_type_as_str.startswith("timestamp"): + if "tz=" in pa_type_as_str: + return "timestamptz" + else: + return "timestamp" + + if pa_type_as_str.startswith("date"): + return "date" + + if pa_type_as_str.startswith("decimal"): + # PyArrow decimal types (e.g. "decimal(38,37)") luckily directly map to the Redshift type. + return pa_type_as_str + + # We have to take into account how arrow types map to parquet types as well. + # For example, null type maps to int32 in parquet, so we have to use int4 in Redshift. + # Other mappings have also been adjusted accordingly. + type_map = { + "null": "int4", + "bool": "bool", + "int8": "int4", + "int16": "int4", + "int32": "int4", + "int64": "int8", + "uint8": "int4", + "uint16": "int4", + "uint32": "int8", + "uint64": "int8", + "float": "float4", + "double": "float8", + "binary": "varchar", + "string": "varchar", } - return type_map[redshift_type_as_str.upper()] + return type_map[pa_type_as_str] diff --git a/sdk/python/tensorflow_metadata/proto/v0/path_pb2.py b/sdk/python/tensorflow_metadata/proto/v0/path_pb2.py index d732119ead..4b6dec828c 100644 --- a/sdk/python/tensorflow_metadata/proto/v0/path_pb2.py +++ b/sdk/python/tensorflow_metadata/proto/v0/path_pb2.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: tensorflow_metadata/proto/v0/path.proto - +"""Generated protocol buffer code.""" from google.protobuf import descriptor as _descriptor from google.protobuf import message as _message from google.protobuf import reflection as _reflection diff --git a/sdk/python/tensorflow_metadata/proto/v0/schema_pb2.py b/sdk/python/tensorflow_metadata/proto/v0/schema_pb2.py index 78fda8003d..d3bfc50616 100644 --- a/sdk/python/tensorflow_metadata/proto/v0/schema_pb2.py +++ b/sdk/python/tensorflow_metadata/proto/v0/schema_pb2.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: tensorflow_metadata/proto/v0/schema.proto - +"""Generated protocol buffer code.""" from google.protobuf.internal import enum_type_wrapper from google.protobuf import descriptor as _descriptor from google.protobuf import message as _message diff --git a/sdk/python/tensorflow_metadata/proto/v0/statistics_pb2.py b/sdk/python/tensorflow_metadata/proto/v0/statistics_pb2.py index d8e12bd120..21473adc75 100644 --- a/sdk/python/tensorflow_metadata/proto/v0/statistics_pb2.py +++ b/sdk/python/tensorflow_metadata/proto/v0/statistics_pb2.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: tensorflow_metadata/proto/v0/statistics.proto - +"""Generated protocol buffer code.""" from google.protobuf import descriptor as _descriptor from google.protobuf import message as _message from google.protobuf import reflection as _reflection diff --git a/sdk/python/tests/test_entity.py b/sdk/python/tests/test_entity.py index 334ab6e5d4..b8381451fd 100644 --- a/sdk/python/tests/test_entity.py +++ b/sdk/python/tests/test_entity.py @@ -12,61 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import socket -from concurrent import futures -from contextlib import closing - -import grpc -import pytest - -from feast.client import Client from feast.entity import Entity -from feast.protos.feast.core import CoreService_pb2_grpc as Core from feast.value_type import ValueType -from feast_core_server import CoreServicer - - -def find_free_port(): - with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: - s.bind(("", 0)) - s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - return s.getsockname()[1] - - -free_port = find_free_port() - - -class TestEntity: - @pytest.fixture(scope="function") - def server(self): - server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) - Core.add_CoreServiceServicer_to_server(CoreServicer(), server) - server.add_insecure_port(f"[::]:{free_port}") - server.start() - yield server - server.stop(0) - - @pytest.fixture - def client(self, server): - return Client(core_url=f"localhost:{free_port}") - - def test_entity_import_export_yaml(self): - - test_entity = Entity( - name="car_driver_entity", - description="Driver entity for car rides", - value_type=ValueType.STRING, - labels={"team": "matchmaking"}, - ) - - # Create a string YAML representation of the entity - string_yaml = test_entity.to_yaml() - - # Create a new entity object from the YAML string - actual_entity_from_string = Entity.from_yaml(string_yaml) - - # Ensure equality is upheld to original entity - assert test_entity == actual_entity_from_string def test_join_key_default(): diff --git a/sdk/python/tests/test_offline_online_store_consistency.py b/sdk/python/tests/test_offline_online_store_consistency.py index 3e1cb1a9c7..803b7dc261 100644 --- a/sdk/python/tests/test_offline_online_store_consistency.py +++ b/sdk/python/tests/test_offline_online_store_consistency.py @@ -5,7 +5,7 @@ import uuid from datetime import datetime, timedelta from pathlib import Path -from typing import Iterator, Optional, Tuple, Union +from typing import Iterator, Optional, Tuple import pandas as pd import pytest @@ -13,16 +13,18 @@ from pytz import timezone, utc from feast.data_format import ParquetFormat -from feast.data_source import BigQuerySource, FileSource +from feast.data_source import BigQuerySource, DataSource, FileSource, RedshiftSource from feast.entity import Entity from feast.feature import Feature from feast.feature_store import FeatureStore from feast.feature_view import FeatureView from feast.infra.offline_stores.file import FileOfflineStoreConfig +from feast.infra.offline_stores.redshift import RedshiftOfflineStoreConfig from feast.infra.online_stores.datastore import DatastoreOnlineStoreConfig from feast.infra.online_stores.dynamodb import DynamoDBOnlineStoreConfig from feast.infra.online_stores.redis import RedisOnlineStoreConfig, RedisType from feast.infra.online_stores.sqlite import SqliteOnlineStoreConfig +from feast.infra.utils import aws_utils from feast.repo_config import RepoConfig from feast.value_type import ValueType @@ -50,7 +52,7 @@ def create_dataset() -> pd.DataFrame: return pd.DataFrame.from_dict(data) -def get_feature_view(data_source: Union[FileSource, BigQuerySource]) -> FeatureView: +def get_feature_view(data_source: DataSource) -> FeatureView: return FeatureView( name="test_bq_correctness", entities=["driver"], @@ -112,6 +114,79 @@ def prep_bq_fs_and_fv( yield fs, fv +@contextlib.contextmanager +def prep_redshift_fs_and_fv( + source_type: str, +) -> Iterator[Tuple[FeatureStore, FeatureView]]: + client = aws_utils.get_redshift_data_client("us-west-2") + s3 = aws_utils.get_s3_resource("us-west-2") + + df = create_dataset() + + table_name = f"test_ingestion_{source_type}_correctness_{int(time.time())}" + + offline_store = RedshiftOfflineStoreConfig( + cluster_id="feast-integration-tests", + region="us-west-2", + user="admin", + database="feast", + s3_staging_location="s3://feast-integration-tests/redshift/tests/ingestion", + iam_role="arn:aws:iam::402087665549:role/redshift_s3_access_role", + ) + + aws_utils.upload_df_to_redshift( + client, + offline_store.cluster_id, + offline_store.database, + offline_store.user, + s3, + f"{offline_store.s3_staging_location}/copy/{table_name}.parquet", + offline_store.iam_role, + table_name, + df, + ) + + redshift_source = RedshiftSource( + table=table_name if source_type == "table" else None, + query=f"SELECT * FROM {table_name}" if source_type == "query" else None, + event_timestamp_column="ts", + created_timestamp_column="created_ts", + date_partition_column="", + field_mapping={"ts_1": "ts", "id": "driver_id"}, + ) + + fv = get_feature_view(redshift_source) + e = Entity( + name="driver", + description="id for driver", + join_key="driver_id", + value_type=ValueType.INT32, + ) + with tempfile.TemporaryDirectory() as repo_dir_name, tempfile.TemporaryDirectory() as data_dir_name: + config = RepoConfig( + registry=str(Path(repo_dir_name) / "registry.db"), + project=f"test_bq_correctness_{str(uuid.uuid4()).replace('-', '')}", + provider="local", + online_store=SqliteOnlineStoreConfig( + path=str(Path(data_dir_name) / "online_store.db") + ), + offline_store=offline_store, + ) + fs = FeatureStore(config=config) + fs.apply([fv, e]) + + yield fs, fv + + # Clean up the uploaded Redshift table + aws_utils.execute_redshift_statement( + client, + offline_store.cluster_id, + offline_store.database, + offline_store.user, + f"DROP TABLE {table_name}", + ) + + @contextlib.contextmanager def prep_local_fs_and_fv() -> Iterator[Tuple[FeatureStore, FeatureView]]: with tempfile.NamedTemporaryFile(suffix=".parquet") as f: @@ -229,6 +304,7 @@ def check_offline_and_online_features( event_timestamp: datetime, expected_value: Optional[float], full_feature_names: bool, + check_offline_store: bool = True, ) -> None: # Check online store response_dict = fs.get_online_features( @@ -249,28 +325,32 @@ def check_offline_and_online_features( assert response_dict["value"][0] is None # Check offline store - df = fs.get_historical_features( - entity_df=pd.DataFrame.from_dict( - {"driver_id": [driver_id], "event_timestamp": [event_timestamp]} - ), - feature_refs=[f"{fv.name}:value"], - full_feature_names=full_feature_names, - ).to_df() - - if full_feature_names: - if expected_value: - assert abs(df.to_dict()[f"{fv.name}__value"][0] - expected_value) < 1e-6 + if check_offline_store: + df = fs.get_historical_features( + entity_df=pd.DataFrame.from_dict( + {"driver_id": [driver_id], "event_timestamp": [event_timestamp]} + ), + feature_refs=[f"{fv.name}:value"], + full_feature_names=full_feature_names, + ).to_df() + + if full_feature_names: + if expected_value: + assert abs(df.to_dict()[f"{fv.name}__value"][0] - expected_value) < 1e-6 + else: + assert math.isnan(df.to_dict()[f"{fv.name}__value"][0]) else: - assert math.isnan(df.to_dict()[f"{fv.name}__value"][0]) - else: - if expected_value: - assert abs(df.to_dict()["value"][0] - expected_value) < 1e-6 - else: - assert math.isnan(df.to_dict()["value"][0]) + if expected_value: + assert abs(df.to_dict()["value"][0] - expected_value) < 1e-6 + else: + assert math.isnan(df.to_dict()["value"][0]) def run_offline_online_store_consistency_test( - fs: FeatureStore, fv: FeatureView, full_feature_names: bool + fs: FeatureStore, + fv: FeatureView, + full_feature_names: bool, + check_offline_store: bool = True, ) -> None: now = datetime.utcnow() # Run materialize() @@ -287,6 +367,7 @@ def run_offline_online_store_consistency_test( event_timestamp=end_date, expected_value=0.3, full_feature_names=full_feature_names, + check_offline_store=check_offline_store, ) check_offline_and_online_features( @@ -296,6 +377,7 @@ def run_offline_online_store_consistency_test( event_timestamp=end_date, expected_value=None, full_feature_names=full_feature_names, + check_offline_store=check_offline_store, ) # check prior value for materialize_incremental() @@ -306,6 +388,7 @@ def run_offline_online_store_consistency_test( event_timestamp=end_date, expected_value=4, full_feature_names=full_feature_names, + check_offline_store=check_offline_store, ) # run materialize_incremental() @@ -319,6 +402,7 @@ def run_offline_online_store_consistency_test( event_timestamp=now, expected_value=5, full_feature_names=full_feature_names, + check_offline_store=check_offline_store, ) @@ -348,6 +432,19 @@ def test_dynamodb_offline_online_store_consistency(full_feature_names: bool): run_offline_online_store_consistency_test(fs, fv, full_feature_names) +@pytest.mark.integration +@pytest.mark.parametrize( + "source_type", ["query", "table"], +) +@pytest.mark.parametrize("full_feature_names", [True, False]) +def test_redshift_offline_online_store_consistency( + source_type: str, full_feature_names: bool +): + with prep_redshift_fs_and_fv(source_type) as (fs, fv): + # TODO: remove check_offline_store parameter once Redshift's get_historical_features is implemented + run_offline_online_store_consistency_test(fs, fv, full_feature_names, False) + + @pytest.mark.parametrize("full_feature_names", [True, False]) def test_local_offline_online_store_consistency(full_feature_names: bool): with prep_local_fs_and_fv() as (fs, fv):