From 4b7e2bd87827f941211d8aae4185904a2873689b Mon Sep 17 00:00:00 2001 From: Itamar Turner-Trauring Date: Tue, 16 Jul 2024 09:57:54 -0400 Subject: [PATCH] fix(python): Support duplicate expression names when calling ufuncs (#17641) Co-authored-by: Itamar Turner-Trauring --- py-polars/polars/expr/expr.py | 29 +++++++++++++------ .../unit/interop/numpy/test_ufunc_expr.py | 8 +++++ 2 files changed, 28 insertions(+), 9 deletions(-) diff --git a/py-polars/polars/expr/expr.py b/py-polars/polars/expr/expr.py index 9a7a6c1a68e2..410774d1973a 100644 --- a/py-polars/polars/expr/expr.py +++ b/py-polars/polars/expr/expr.py @@ -292,20 +292,33 @@ def __array_ufunc__( is_custom_ufunc = getattr(ufunc, "signature") is not None # noqa: B009 num_expr = sum(isinstance(inp, Expr) for inp in inputs) exprs = [ - (inp, Expr, i) if isinstance(inp, Expr) else (inp, None, i) + (inp, True, i) if isinstance(inp, Expr) else (inp, False, i) for i, inp in enumerate(inputs) ] + if num_expr == 1: - root_expr = next(expr[0] for expr in exprs if expr[1] == Expr) + root_expr = next(expr[0] for expr in exprs if expr[1]) else: - root_expr = F.struct(expr[0] for expr in exprs if expr[1] == Expr) + # We rename all but the first expression in case someone did e.g. + # np.divide(pl.col("a"), pl.col("a")); we'll be creating a struct + # below, and structs can't have duplicate names. + first_renamable_expr = True + actual_exprs = [] + for inp, is_actual_expr, index in exprs: + if is_actual_expr: + if first_renamable_expr: + first_renamable_expr = False + else: + inp = inp.alias(f"argument_{index}") + actual_exprs.append(inp) + root_expr = F.struct(actual_exprs) def function(s: Series) -> Series: # pragma: no cover args = [] for i, expr in enumerate(exprs): - if expr[1] == Expr and num_expr > 1: + if expr[1] and num_expr > 1: args.append(s.struct[i]) - elif expr[1] == Expr: + elif expr[1]: args.append(s) else: args.append(expr[0]) @@ -323,10 +336,8 @@ def function(s: Series) -> Series: # pragma: no cover CustomUFuncWarning, stacklevel=find_stacklevel(), ) - return root_expr.map_batches( - function, is_elementwise=False - ).meta.undo_aliases() - return root_expr.map_batches(function, is_elementwise=True).meta.undo_aliases() + return root_expr.map_batches(function, is_elementwise=False) + return root_expr.map_batches(function, is_elementwise=True) @classmethod def deserialize( diff --git a/py-polars/tests/unit/interop/numpy/test_ufunc_expr.py b/py-polars/tests/unit/interop/numpy/test_ufunc_expr.py index e3516bb58cc1..bba5c72b6fc7 100644 --- a/py-polars/tests/unit/interop/numpy/test_ufunc_expr.py +++ b/py-polars/tests/unit/interop/numpy/test_ufunc_expr.py @@ -120,6 +120,14 @@ def test_ufunc_multiple_expressions() -> None: assert_series_equal(expected, result) # type: ignore[arg-type] +def test_repeated_name_ufunc_17472() -> None: + """If a ufunc takes multiple inputs has a repeating name, this works.""" + df = pl.DataFrame({"a": [6.0]}) + result = df.select(np.divide(pl.col("a"), pl.col("a"))) # type: ignore[call-overload] + expected = pl.DataFrame({"a": [1.0]}) + assert_frame_equal(expected, result) + + def test_grouped_ufunc() -> None: df = pl.DataFrame({"id": ["a", "a", "b", "b"], "values": [0.1, 0.1, -0.1, -0.1]}) df.group_by("id").agg(pl.col("values").log1p().sum().pipe(np.expm1))