Skip to content

Commit

Permalink
fix(python): Support duplicate expression names when calling ufuncs (#…
Browse files Browse the repository at this point in the history
…17641)

Co-authored-by: Itamar Turner-Trauring <[email protected]>
  • Loading branch information
itamarst and pythonspeed authored Jul 16, 2024
1 parent 3897a37 commit 4b7e2bd
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 9 deletions.
29 changes: 20 additions & 9 deletions py-polars/polars/expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,20 +292,33 @@ def __array_ufunc__(
is_custom_ufunc = getattr(ufunc, "signature") is not None # noqa: B009
num_expr = sum(isinstance(inp, Expr) for inp in inputs)
exprs = [
(inp, Expr, i) if isinstance(inp, Expr) else (inp, None, i)
(inp, True, i) if isinstance(inp, Expr) else (inp, False, i)
for i, inp in enumerate(inputs)
]

if num_expr == 1:
root_expr = next(expr[0] for expr in exprs if expr[1] == Expr)
root_expr = next(expr[0] for expr in exprs if expr[1])
else:
root_expr = F.struct(expr[0] for expr in exprs if expr[1] == Expr)
# We rename all but the first expression in case someone did e.g.
# np.divide(pl.col("a"), pl.col("a")); we'll be creating a struct
# below, and structs can't have duplicate names.
first_renamable_expr = True
actual_exprs = []
for inp, is_actual_expr, index in exprs:
if is_actual_expr:
if first_renamable_expr:
first_renamable_expr = False
else:
inp = inp.alias(f"argument_{index}")
actual_exprs.append(inp)
root_expr = F.struct(actual_exprs)

def function(s: Series) -> Series: # pragma: no cover
args = []
for i, expr in enumerate(exprs):
if expr[1] == Expr and num_expr > 1:
if expr[1] and num_expr > 1:
args.append(s.struct[i])
elif expr[1] == Expr:
elif expr[1]:
args.append(s)
else:
args.append(expr[0])
Expand All @@ -323,10 +336,8 @@ def function(s: Series) -> Series: # pragma: no cover
CustomUFuncWarning,
stacklevel=find_stacklevel(),
)
return root_expr.map_batches(
function, is_elementwise=False
).meta.undo_aliases()
return root_expr.map_batches(function, is_elementwise=True).meta.undo_aliases()
return root_expr.map_batches(function, is_elementwise=False)
return root_expr.map_batches(function, is_elementwise=True)

@classmethod
def deserialize(
Expand Down
8 changes: 8 additions & 0 deletions py-polars/tests/unit/interop/numpy/test_ufunc_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,14 @@ def test_ufunc_multiple_expressions() -> None:
assert_series_equal(expected, result) # type: ignore[arg-type]


def test_repeated_name_ufunc_17472() -> None:
"""If a ufunc takes multiple inputs has a repeating name, this works."""
df = pl.DataFrame({"a": [6.0]})
result = df.select(np.divide(pl.col("a"), pl.col("a"))) # type: ignore[call-overload]
expected = pl.DataFrame({"a": [1.0]})
assert_frame_equal(expected, result)


def test_grouped_ufunc() -> None:
df = pl.DataFrame({"id": ["a", "a", "b", "b"], "values": [0.1, 0.1, -0.1, -0.1]})
df.group_by("id").agg(pl.col("values").log1p().sum().pipe(np.expm1))
Expand Down

0 comments on commit 4b7e2bd

Please sign in to comment.