Skip to content

Commit

Permalink
chore: clean up UDF field construction
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Sep 12, 2023
1 parent 99e531d commit 78714d0
Showing 1 changed file with 18 additions and 20 deletions.
38 changes: 18 additions & 20 deletions ibis/expr/operations/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,26 +293,24 @@ def _make_node(
if (return_annotation := annotations.pop("return", None)) is None:
raise exc.MissingReturnAnnotationError(fn)

fields = {}

func_name = name or fn.__name__

for arg_name, param in inspect.signature(fn).parameters.items():
if (raw_dtype := annotations.get(arg_name)) is not None:
dtype = dt.dtype(raw_dtype)
else:
dtype = raw_dtype
arg = rlz.ValueOf(dtype)
fields[arg_name] = Argument(pattern=arg, default=param.default)

fields["dtype"] = dt.dtype(return_annotation)
fields["__input_type__"] = input_type
# can't be just `fn` otherwise `fn` is assumed to be a method
fields["__func__"] = property(fget=lambda _, fn=fn: fn)
fields["__config__"] = FrozenDict(args=args, kwargs=FrozenDict(**kwargs))
fields["__udf_namespace__"] = schema
fields["__module__"] = fn.__module__
fields["__func_name__"] = func_name
func_name = name if name is not None else fn.__name__

fields = {
arg_name: Argument(
pattern=rlz.ValueOf(annotations.get(arg_name)), default=param.default
)
for arg_name, param in inspect.signature(fn).parameters.items()
} | {
"dtype": dt.dtype(return_annotation),
"__input_type__": input_type,
# must wrap `fn` in a `property` otherwise `fn` is assumed to be a
# method
"__func__": property(fget=lambda _, fn=fn: fn),
"__config__": FrozenDict(args=args, kwargs=FrozenDict(**kwargs)),
"__udf_namespace__": schema,
"__module__": fn.__module__,
"__func_name__": func_name,
}

return type(func_name, (ScalarUDF,), fields)

Expand Down

0 comments on commit 78714d0

Please sign in to comment.