Skip to content

Commit

Permalink
fix: Enforce kw args in datasources (#2567)
Browse files Browse the repository at this point in the history
* Update

Signed-off-by: Kevin Zhang <[email protected]>

* Fix

Signed-off-by: Kevin Zhang <[email protected]>

* Update to keyword args

Signed-off-by: Kevin Zhang <[email protected]>

* Fix lint

Signed-off-by: Kevin Zhang <[email protected]>

* Fix

Signed-off-by: Kevin Zhang <[email protected]>

* Fix

Signed-off-by: Kevin Zhang <[email protected]>

* Change kinesis to optional

Signed-off-by: Kevin Zhang <[email protected]>

* Fix review issues

Signed-off-by: Kevin Zhang <[email protected]>

* Fix lint

Signed-off-by: Kevin Zhang <[email protected]>

* Add unit tests

Signed-off-by: Kevin Zhang <[email protected]>

* Fix

Signed-off-by: Kevin Zhang <[email protected]>

* Fix

Signed-off-by: Kevin Zhang <[email protected]>

* Fix imports

Signed-off-by: Kevin Zhang <[email protected]>

* Fix lint

Signed-off-by: Kevin Zhang <[email protected]>

* Fix

Signed-off-by: Kevin Zhang <[email protected]>
  • Loading branch information
kevjumba authored Apr 19, 2022
1 parent ebb7dfe commit 0b7ec53
Show file tree
Hide file tree
Showing 9 changed files with 271 additions and 42 deletions.
203 changes: 168 additions & 35 deletions sdk/python/feast/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ class DataSource(ABC):

def __init__(
self,
*,
event_timestamp_column: Optional[str] = None,
created_timestamp_column: Optional[str] = None,
field_mapping: Optional[Dict[str, str]] = None,
Expand Down Expand Up @@ -354,11 +355,12 @@ def get_table_column_names_and_types(

def __init__(
self,
name: str,
event_timestamp_column: str,
bootstrap_servers: str,
message_format: StreamFormat,
topic: str,
*args,
name: Optional[str] = None,
event_timestamp_column: Optional[str] = "",
bootstrap_servers: Optional[str] = None,
message_format: Optional[StreamFormat] = None,
topic: Optional[str] = None,
created_timestamp_column: Optional[str] = "",
field_mapping: Optional[Dict[str, str]] = None,
date_partition_column: Optional[str] = "",
Expand All @@ -368,22 +370,62 @@ def __init__(
timestamp_field: Optional[str] = "",
batch_source: Optional[DataSource] = None,
):
positional_attributes = [
"name",
"event_timestamp_column",
"bootstrap_servers",
"message_format",
"topic",
]
_name = name
_event_timestamp_column = event_timestamp_column
_bootstrap_servers = bootstrap_servers or ""
_message_format = message_format
_topic = topic or ""

if args:
warnings.warn(
(
"Kafka parameters should be specified as a keyword argument instead of a positional arg."
"Feast 0.23+ will not support positional arguments to construct Kafka sources"
),
DeprecationWarning,
)
if len(args) > len(positional_attributes):
raise ValueError(
f"Only {', '.join(positional_attributes)} are allowed as positional args when defining "
f"Kafka sources, for backwards compatibility."
)
if len(args) >= 1:
_name = args[0]
if len(args) >= 2:
_event_timestamp_column = args[1]
if len(args) >= 3:
_bootstrap_servers = args[2]
if len(args) >= 4:
_message_format = args[3]
if len(args) >= 5:
_topic = args[4]

if _message_format is None:
raise ValueError("Message format must be specified for Kafka source")
print("Asdfasdf")
super().__init__(
event_timestamp_column=event_timestamp_column,
event_timestamp_column=_event_timestamp_column,
created_timestamp_column=created_timestamp_column,
field_mapping=field_mapping,
date_partition_column=date_partition_column,
description=description,
tags=tags,
owner=owner,
name=name,
name=_name,
timestamp_field=timestamp_field,
)
self.batch_source = batch_source
self.kafka_options = KafkaOptions(
bootstrap_servers=bootstrap_servers,
message_format=message_format,
topic=topic,
bootstrap_servers=_bootstrap_servers,
message_format=_message_format,
topic=_topic,
)

def __eq__(self, other):
Expand Down Expand Up @@ -472,32 +514,56 @@ class RequestSource(DataSource):

def __init__(
self,
name: str,
schema: Union[Dict[str, ValueType], List[Field]],
*args,
name: Optional[str] = None,
schema: Optional[Union[Dict[str, ValueType], List[Field]]] = None,
description: Optional[str] = "",
tags: Optional[Dict[str, str]] = None,
owner: Optional[str] = "",
):
"""Creates a RequestSource object."""
super().__init__(name=name, description=description, tags=tags, owner=owner)
if isinstance(schema, Dict):
positional_attributes = ["name", "schema"]
_name = name
_schema = schema
if args:
warnings.warn(
(
"Request source parameters should be specified as a keyword argument instead of a positional arg."
"Feast 0.23+ will not support positional arguments to construct request sources"
),
DeprecationWarning,
)
if len(args) > len(positional_attributes):
raise ValueError(
f"Only {', '.join(positional_attributes)} are allowed as positional args when defining "
f"feature views, for backwards compatibility."
)
if len(args) >= 1:
_name = args[0]
if len(args) >= 2:
_schema = args[1]

super().__init__(name=_name, description=description, tags=tags, owner=owner)
if not _schema:
raise ValueError("Schema needs to be provided for Request Source")
if isinstance(_schema, Dict):
warnings.warn(
"Schema in RequestSource is changing type. The schema data type Dict[str, ValueType] is being deprecated in Feast 0.23. "
"Please use List[Field] instead for the schema",
DeprecationWarning,
)
schemaList = []
for key, valueType in schema.items():
for key, valueType in _schema.items():
schemaList.append(
Field(name=key, dtype=VALUE_TYPES_TO_FEAST_TYPES[valueType])
)
self.schema = schemaList
elif isinstance(schema, List):
self.schema = schema
elif isinstance(_schema, List):
self.schema = _schema
else:
raise Exception(
"Schema type must be either dictionary or list, not "
+ str(type(schema))
+ str(type(_schema))
)

def validate(self, config: RepoConfig):
Expand Down Expand Up @@ -643,12 +709,13 @@ def get_table_query_string(self) -> str:

def __init__(
self,
name: str,
event_timestamp_column: str,
created_timestamp_column: str,
record_format: StreamFormat,
region: str,
stream_name: str,
*args,
name: Optional[str] = None,
event_timestamp_column: Optional[str] = "",
created_timestamp_column: Optional[str] = "",
record_format: Optional[StreamFormat] = None,
region: Optional[str] = "",
stream_name: Optional[str] = "",
field_mapping: Optional[Dict[str, str]] = None,
date_partition_column: Optional[str] = "",
description: Optional[str] = "",
Expand All @@ -657,10 +724,53 @@ def __init__(
timestamp_field: Optional[str] = "",
batch_source: Optional[DataSource] = None,
):
positional_attributes = [
"name",
"event_timestamp_column",
"created_timestamp_column",
"record_format",
"region",
"stream_name",
]
_name = name
_event_timestamp_column = event_timestamp_column
_created_timestamp_column = created_timestamp_column
_record_format = record_format
_region = region or ""
_stream_name = stream_name or ""
if args:
warnings.warn(
(
"Kinesis parameters should be specified as a keyword argument instead of a positional arg."
"Feast 0.23+ will not support positional arguments to construct kinesis sources"
),
DeprecationWarning,
)
if len(args) > len(positional_attributes):
raise ValueError(
f"Only {', '.join(positional_attributes)} are allowed as positional args when defining "
f"kinesis sources, for backwards compatibility."
)
if len(args) >= 1:
_name = args[0]
if len(args) >= 2:
_event_timestamp_column = args[1]
if len(args) >= 3:
_created_timestamp_column = args[2]
if len(args) >= 4:
_record_format = args[3]
if len(args) >= 5:
_region = args[4]
if len(args) >= 6:
_stream_name = args[5]

if _record_format is None:
raise ValueError("Record format must be specified for kinesis source")

super().__init__(
name=name,
event_timestamp_column=event_timestamp_column,
created_timestamp_column=created_timestamp_column,
name=_name,
event_timestamp_column=_event_timestamp_column,
created_timestamp_column=_created_timestamp_column,
field_mapping=field_mapping,
date_partition_column=date_partition_column,
description=description,
Expand All @@ -670,7 +780,7 @@ def __init__(
)
self.batch_source = batch_source
self.kinesis_options = KinesisOptions(
record_format=record_format, region=region, stream_name=stream_name
record_format=_record_format, region=_region, stream_name=_stream_name
)

def __eq__(self, other):
Expand Down Expand Up @@ -725,9 +835,9 @@ class PushSource(DataSource):

def __init__(
self,
*,
name: str,
batch_source: DataSource,
*args,
name: Optional[str] = None,
batch_source: Optional[DataSource] = None,
description: Optional[str] = "",
tags: Optional[Dict[str, str]] = None,
owner: Optional[str] = "",
Expand All @@ -744,10 +854,33 @@ def __init__(
maintainer.
"""
super().__init__(name=name, description=description, tags=tags, owner=owner)
self.batch_source = batch_source
if not self.batch_source:
raise ValueError(f"batch_source is needed for push source {self.name}")
positional_attributes = ["name", "batch_source"]
_name = name
_batch_source = batch_source
if args:
warnings.warn(
(
"Push source parameters should be specified as a keyword argument instead of a positional arg."
"Feast 0.23+ will not support positional arguments to construct push sources"
),
DeprecationWarning,
)
if len(args) > len(positional_attributes):
raise ValueError(
f"Only {', '.join(positional_attributes)} are allowed as positional args when defining "
f"push sources, for backwards compatibility."
)
if len(args) >= 1:
_name = args[0]
if len(args) >= 2:
_batch_source = args[1]

super().__init__(name=_name, description=description, tags=tags, owner=owner)
if not _batch_source:
raise ValueError(
f"batch_source parameter is needed for push source {self.name}"
)
self.batch_source = _batch_source

def __eq__(self, other):
if not isinstance(other, PushSource):
Expand Down
2 changes: 1 addition & 1 deletion sdk/python/feast/feature_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def __init__(
ValueError: A field mapping conflicts with an Entity or a Feature.
"""

positional_attributes = ["name, entities, ttl"]
positional_attributes = ["name", "entities", "ttl"]

_name = name
_entities = entities
Expand Down
1 change: 1 addition & 0 deletions sdk/python/feast/infra/offline_stores/bigquery_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
class BigQuerySource(DataSource):
def __init__(
self,
*,
event_timestamp_column: Optional[str] = "",
table: Optional[str] = None,
created_timestamp_column: Optional[str] = "",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class SparkSourceFormat(Enum):
class SparkSource(DataSource):
def __init__(
self,
*,
name: Optional[str] = None,
table: Optional[str] = None,
query: Optional[str] = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ def __init__(
table: Optional[str] = None,
created_timestamp_column: Optional[str] = "",
field_mapping: Optional[Dict[str, str]] = None,
date_partition_column: Optional[str] = None,
query: Optional[str] = None,
name: Optional[str] = None,
description: Optional[str] = "",
Expand Down
25 changes: 22 additions & 3 deletions sdk/python/feast/infra/offline_stores/file_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
class FileSource(DataSource):
def __init__(
self,
path: str,
*args,
path: Optional[str] = None,
event_timestamp_column: Optional[str] = "",
file_format: Optional[FileFormat] = None,
created_timestamp_column: Optional[str] = "",
Expand Down Expand Up @@ -58,13 +59,31 @@ def __init__(
>>> from feast import FileSource
>>> file_source = FileSource(path="my_features.parquet", timestamp_field="event_timestamp")
"""
if path is None:
positional_attributes = ["path"]
_path = path
if args:
if args:
warnings.warn(
(
"File Source parameters should be specified as a keyword argument instead of a positional arg."
"Feast 0.23+ will not support positional arguments to construct File sources"
),
DeprecationWarning,
)
if len(args) > len(positional_attributes):
raise ValueError(
f"Only {', '.join(positional_attributes)} are allowed as positional args when defining "
f"File sources, for backwards compatibility."
)
if len(args) >= 1:
_path = args[0]
if _path is None:
raise ValueError(
'No "path" argument provided. Please set "path" to the location of your file source.'
)
self.file_options = FileOptions(
file_format=file_format,
uri=path,
uri=_path,
s3_endpoint_override=s3_endpoint_override,
)

Expand Down
1 change: 1 addition & 0 deletions sdk/python/feast/infra/offline_stores/redshift_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
class RedshiftSource(DataSource):
def __init__(
self,
*,
event_timestamp_column: Optional[str] = "",
table: Optional[str] = None,
schema: Optional[str] = None,
Expand Down
1 change: 1 addition & 0 deletions sdk/python/feast/infra/offline_stores/snowflake_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
class SnowflakeSource(DataSource):
def __init__(
self,
*,
database: Optional[str] = None,
warehouse: Optional[str] = None,
schema: Optional[str] = None,
Expand Down
Loading

0 comments on commit 0b7ec53

Please sign in to comment.