Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Udf in stream feature view UI shows pickled data #3268

Merged
merged 5 commits into from
Oct 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions sdk/python/feast/infra/registry/base_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,9 +638,12 @@ def to_dict(self, project: str) -> Dict[str, List[Any]]:
self.list_stream_feature_views(project=project),
key=lambda stream_feature_view: stream_feature_view.name,
):
registry_dict["streamFeatureViews"].append(
self._message_to_sorted_dict(stream_feature_view.to_proto())
)
sfv_dict = self._message_to_sorted_dict(stream_feature_view.to_proto())

sfv_dict["spec"]["userDefinedFunction"][
"body"
] = stream_feature_view.udf_string
registry_dict["streamFeatureViews"].append(sfv_dict)
for saved_dataset in sorted(
self.list_saved_datasets(project=project), key=lambda item: item.name
):
Expand Down
13 changes: 13 additions & 0 deletions sdk/python/feast/stream_feature_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ class StreamFeatureView(FeatureView):
timestamp_field: str
materialization_intervals: List[Tuple[datetime, datetime]]
udf: Optional[FunctionType]
udf_string: Optional[str]

def __init__(
self,
Expand All @@ -88,6 +89,7 @@ def __init__(
mode: Optional[str] = "spark",
timestamp_field: Optional[str] = "",
udf: Optional[FunctionType] = None,
udf_string: Optional[str] = "",
):
if not flags_helper.is_test():
warnings.warn(
Expand All @@ -114,6 +116,7 @@ def __init__(
self.mode = mode or ""
self.timestamp_field = timestamp_field or ""
self.udf = udf
self.udf_string = udf_string

super().__init__(
name=name,
Expand Down Expand Up @@ -143,6 +146,7 @@ def __eq__(self, other):
self.mode != other.mode
or self.timestamp_field != other.timestamp_field
or self.udf.__code__.co_code != other.udf.__code__.co_code
or self.udf_string != other.udf_string
or self.aggregations != other.aggregations
):
return False
Expand Down Expand Up @@ -171,6 +175,7 @@ def to_proto(self):
udf_proto = UserDefinedFunctionProto(
name=self.udf.__name__,
body=dill.dumps(self.udf, recurse=True),
body_text=self.udf_string,
)
spec = StreamFeatureViewSpecProto(
name=self.name,
Expand Down Expand Up @@ -209,6 +214,11 @@ def from_proto(cls, sfv_proto):
if sfv_proto.spec.HasField("user_defined_function")
else None
)
udf_string = (
sfv_proto.spec.user_defined_function.body_text
if sfv_proto.spec.HasField("user_defined_function")
else None
)
stream_feature_view = cls(
name=sfv_proto.spec.name,
description=sfv_proto.spec.description,
Expand All @@ -226,6 +236,7 @@ def from_proto(cls, sfv_proto):
source=stream_source,
mode=sfv_proto.spec.mode,
udf=udf,
udf_string=udf_string,
aggregations=[
Aggregation.from_proto(agg_proto)
for agg_proto in sfv_proto.spec.aggregations
Expand Down Expand Up @@ -315,6 +326,7 @@ def mainify(obj):
obj.__module__ = "__main__"

def decorator(user_function):
udf_string = dill.source.getsource(user_function)
mainify(user_function)
stream_feature_view_obj = StreamFeatureView(
name=user_function.__name__,
Expand All @@ -323,6 +335,7 @@ def decorator(user_function):
source=source,
schema=schema,
udf=user_function,
udf_string=udf_string,
description=description,
tags=tags,
online=online,
Expand Down