Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RedshiftDataSource #1669

Merged
merged 2 commits into from
Jun 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions protos/feast/core/DataSource.proto
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ message DataSource {
BATCH_BIGQUERY = 2;
STREAM_KAFKA = 3;
STREAM_KINESIS = 4;
BATCH_REDSHIFT = 5;
}
SourceType type = 1;

Expand Down Expand Up @@ -100,11 +101,22 @@ message DataSource {
StreamFormat record_format = 3;
}

// Defines options for DataSource that sources features from a Redshift Query
message RedshiftOptions {
// Redshift table name
string table = 1;

// SQL query that returns a table containing feature data. Must contain an event_timestamp column, and respective
// entity columns
string query = 2;
}

// DataSource options.
oneof options {
FileOptions file_options = 11;
BigQueryOptions bigquery_options = 12;
KafkaOptions kafka_options = 13;
KinesisOptions kinesis_options = 14;
RedshiftOptions redshift_options = 15;
}
}
2 changes: 2 additions & 0 deletions sdk/python/feast/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
FileSource,
KafkaSource,
KinesisSource,
RedshiftSource,
SourceType,
)
from .entity import Entity
Expand Down Expand Up @@ -37,6 +38,7 @@
"FileSource",
"KafkaSource",
"KinesisSource",
"RedshiftSource",
"Feature",
"FeatureStore",
"FeatureTable",
Expand Down
266 changes: 260 additions & 6 deletions sdk/python/feast/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,17 @@
from typing import Callable, Dict, Iterable, Optional, Tuple

from pyarrow.parquet import ParquetFile
from tenacity import retry, retry_unless_exception_type, wait_exponential

from feast import type_map
from feast.data_format import FileFormat, StreamFormat
from feast.errors import DataSourceNotFoundException
from feast.errors import (
DataSourceNotFoundException,
RedshiftCredentialsError,
RedshiftQueryError,
)
from feast.protos.feast.core.DataSource_pb2 import DataSource as DataSourceProto
from feast.repo_config import RepoConfig
from feast.value_type import ValueType


Expand Down Expand Up @@ -477,6 +483,15 @@ def from_proto(data_source):
date_partition_column=data_source.date_partition_column,
query=data_source.bigquery_options.query,
)
elif data_source.redshift_options.table or data_source.redshift_options.query:
data_source_obj = RedshiftSource(
field_mapping=data_source.field_mapping,
table=data_source.redshift_options.table,
event_timestamp_column=data_source.event_timestamp_column,
created_timestamp_column=data_source.created_timestamp_column,
date_partition_column=data_source.date_partition_column,
query=data_source.redshift_options.query,
)
elif (
data_source.kafka_options.bootstrap_servers
and data_source.kafka_options.topic
Expand Down Expand Up @@ -520,12 +535,27 @@ def to_proto(self) -> DataSourceProto:
"""
raise NotImplementedError

def validate(self):
def validate(self, config: RepoConfig):
"""
Validates the underlying data source.
"""
raise NotImplementedError

@staticmethod
def source_datatype_to_feast_value_type() -> Callable[[str], ValueType]:
"""
Get the callable method that returns Feast type given the raw column type
"""
raise NotImplementedError

def get_table_column_names_and_types(
self, config: RepoConfig
) -> Iterable[Tuple[str, str]]:
"""
Get the list of column names and raw column types
"""
raise NotImplementedError


class FileSource(DataSource):
def __init__(
Expand Down Expand Up @@ -622,15 +652,17 @@ def to_proto(self) -> DataSourceProto:

return data_source_proto

def validate(self):
def validate(self, config: RepoConfig):
# TODO: validate a FileSource
pass

@staticmethod
def source_datatype_to_feast_value_type() -> Callable[[str], ValueType]:
return type_map.pa_to_feast_value_type

def get_table_column_names_and_types(self) -> Iterable[Tuple[str, str]]:
def get_table_column_names_and_types(
self, config: RepoConfig
) -> Iterable[Tuple[str, str]]:
schema = ParquetFile(self.path).schema_arrow
return zip(schema.names, map(str, schema.types))

Expand Down Expand Up @@ -703,7 +735,7 @@ def to_proto(self) -> DataSourceProto:

return data_source_proto

def validate(self):
def validate(self, config: RepoConfig):
if not self.query:
from google.api_core.exceptions import NotFound
from google.cloud import bigquery
Expand All @@ -725,7 +757,9 @@ def get_table_query_string(self) -> str:
def source_datatype_to_feast_value_type() -> Callable[[str], ValueType]:
return type_map.bq_to_feast_value_type

def get_table_column_names_and_types(self) -> Iterable[Tuple[str, str]]:
def get_table_column_names_and_types(
self, config: RepoConfig
) -> Iterable[Tuple[str, str]]:
from google.cloud import bigquery

client = bigquery.Client()
Expand Down Expand Up @@ -875,3 +909,223 @@ def to_proto(self) -> DataSourceProto:
data_source_proto.date_partition_column = self.date_partition_column

return data_source_proto


class RedshiftOptions:
"""
DataSource Redshift options used to source features from Redshift query
"""

def __init__(self, table: Optional[str], query: Optional[str]):
self._table = table
self._query = query

@property
def query(self):
woop marked this conversation as resolved.
Show resolved Hide resolved
"""
Returns the Redshift SQL query referenced by this source
"""
return self._query

@query.setter
def query(self, query):
"""
Sets the Redshift SQL query referenced by this source
"""
self._query = query

@property
def table(self):
"""
Returns the table name of this Redshift table
"""
return self._table

@table.setter
def table(self, table_name):
"""
Sets the table ref of this Redshift table
"""
self._table = table_name

@classmethod
def from_proto(cls, redshift_options_proto: DataSourceProto.RedshiftOptions):
"""
Creates a RedshiftOptions from a protobuf representation of a Redshift option

Args:
redshift_options_proto: A protobuf representation of a DataSource

Returns:
Returns a RedshiftOptions object based on the redshift_options protobuf
"""

redshift_options = cls(
table=redshift_options_proto.table, query=redshift_options_proto.query,
)

return redshift_options

def to_proto(self) -> DataSourceProto.RedshiftOptions:
"""
Converts an RedshiftOptionsProto object to its protobuf representation.

Returns:
RedshiftOptionsProto protobuf
"""

redshift_options_proto = DataSourceProto.RedshiftOptions(
table=self.table, query=self.query,
)

return redshift_options_proto


class RedshiftSource(DataSource):
def __init__(
self,
event_timestamp_column: Optional[str] = "",
table: Optional[str] = None,
created_timestamp_column: Optional[str] = "",
field_mapping: Optional[Dict[str, str]] = None,
date_partition_column: Optional[str] = "",
query: Optional[str] = None,
):
super().__init__(
tsotnet marked this conversation as resolved.
Show resolved Hide resolved
event_timestamp_column,
created_timestamp_column,
field_mapping,
date_partition_column,
)

self._redshift_options = RedshiftOptions(table=table, query=query)

def __eq__(self, other):
if not isinstance(other, RedshiftSource):
raise TypeError(
"Comparisons should only involve RedshiftSource class objects."
)

return (
self.redshift_options.table == other.redshift_options.table
and self.redshift_options.query == other.redshift_options.query
and self.event_timestamp_column == other.event_timestamp_column
and self.created_timestamp_column == other.created_timestamp_column
and self.field_mapping == other.field_mapping
)

@property
def table(self):
return self._redshift_options.table

@property
def query(self):
return self._redshift_options.query

@property
def redshift_options(self):
"""
Returns the Redshift options of this data source
"""
return self._redshift_options

@redshift_options.setter
def redshift_options(self, _redshift_options):
"""
Sets the Redshift options of this data source
"""
self._redshift_options = _redshift_options

def to_proto(self) -> DataSourceProto:
data_source_proto = DataSourceProto(
type=DataSourceProto.BATCH_REDSHIFT,
field_mapping=self.field_mapping,
redshift_options=self.redshift_options.to_proto(),
)

data_source_proto.event_timestamp_column = self.event_timestamp_column
data_source_proto.created_timestamp_column = self.created_timestamp_column
data_source_proto.date_partition_column = self.date_partition_column

return data_source_proto

def validate(self, config: RepoConfig):
# As long as the query gets successfully executed, or the table exists,
# the data source is validated. We don't need the results though.
# TODO: uncomment this
# self.get_table_column_names_and_types(config)
print("Validate", self.get_table_column_names_and_types(config))

def get_table_query_string(self) -> str:
"""Returns a string that can directly be used to reference this table in SQL"""
if self.table:
return f"`{self.table}`"
else:
return f"({self.query})"

@staticmethod
def source_datatype_to_feast_value_type() -> Callable[[str], ValueType]:
return type_map.redshift_to_feast_value_type

def get_table_column_names_and_types(
self, config: RepoConfig
) -> Iterable[Tuple[str, str]]:
import boto3
from botocore.config import Config
from botocore.exceptions import ClientError

from feast.infra.offline_stores.redshift import RedshiftOfflineStoreConfig

assert isinstance(config.offline_store, RedshiftOfflineStoreConfig)

client = boto3.client(
"redshift-data", config=Config(region_name=config.offline_store.region)
)

try:
if self.table is not None:
table = client.describe_table(
ClusterIdentifier=config.offline_store.cluster_id,
Database=config.offline_store.database,
DbUser=config.offline_store.user,
Table=self.table,
)
# The API returns valid JSON with empty column list when the table doesn't exist
if len(table["ColumnList"]) == 0:
raise DataSourceNotFoundException(self.table)

columns = table["ColumnList"]
else:
statement = client.execute_statement(
ClusterIdentifier=config.offline_store.cluster_id,
Database=config.offline_store.database,
DbUser=config.offline_store.user,
Sql=f"SELECT * FROM ({self.query}) LIMIT 1",
tsotnet marked this conversation as resolved.
Show resolved Hide resolved
)

# Need to retry client.describe_statement(...) until the task is finished. We don't want to bombard
# Redshift with queries, and neither do we want to wait for a long time on the initial call.
# The solution is exponential backoff. The backoff starts with 0.1 seconds and doubles exponentially
# until reaching 30 seconds, at which point the backoff is fixed.
@retry(
wait=wait_exponential(multiplier=0.1, max=30),
retry=retry_unless_exception_type(RedshiftQueryError),
)
def wait_for_statement():
desc = client.describe_statement(Id=statement["Id"])
if desc["Status"] in ("SUBMITTED", "STARTED", "PICKED"):
raise Exception # Retry
if desc["Status"] != "FINISHED":
raise RedshiftQueryError(desc) # Don't retry. Raise exception.

wait_for_statement()

result = client.get_statement_result(Id=statement["Id"])

columns = result["ColumnMetadata"]
except ClientError as e:
if e.response["Error"]["Code"] == "ValidationException":
raise RedshiftCredentialsError() from e
raise

return [(column["name"], column["typeName"].upper()) for column in columns]
10 changes: 10 additions & 0 deletions sdk/python/feast/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,13 @@ def __init__(self, repo_obj_type: str, specific_issue: str):
f"Inference to fill in missing information for {repo_obj_type} failed. {specific_issue}. "
"Try filling the information explicitly."
)


class RedshiftCredentialsError(Exception):
def __init__(self):
super().__init__("Redshift API failed due to incorrect credentials")


class RedshiftQueryError(Exception):
def __init__(self, details):
super().__init__(f"Redshift SQL Query failed to finish. Details: {details}")
Loading