diff --git a/ibis/backends/polars/compiler.py b/ibis/backends/polars/compiler.py index 31d7a6ca12c7b..f11b7a6420a04 100644 --- a/ibis/backends/polars/compiler.py +++ b/ibis/backends/polars/compiler.py @@ -1170,7 +1170,16 @@ def execute_count_distinct_star(op, **kw): @translate.register(ops.ScalarUDF) def execute_scalar_udf(op, **kw): if op.__input_type__ == InputType.BUILTIN: - first, *rest = map(translate, op.args) + first, *rest = map(partial(translate, **kw), op.args) return getattr(first, op.__func_name__)(*rest) else: raise NotImplementedError("Only builtin scalar UDFs are supported for polars") + + +@translate.register(ops.AggUDF) +def execute_agg_udf(op, **kw): + args = (arg for name, arg in zip(op.argnames, op.args) if name != "where") + first, *rest = map(partial(translate, **kw), args) + if (where := op.where) is not None: + first = first.filter(translate(where, **kw)) + return getattr(first, op.__func_name__)(*rest) diff --git a/ibis/backends/polars/tests/test_udf.py b/ibis/backends/polars/tests/test_udf.py index 8693c2a147977..9c1c0324f893c 100644 --- a/ibis/backends/polars/tests/test_udf.py +++ b/ibis/backends/polars/tests/test_udf.py @@ -54,7 +54,7 @@ def test_multiple_argument_udf(alltypes): @pytest.mark.parametrize( ("value", "expected"), [(8, 2), (27, 3), (7, 7 ** (1.0 / 3.0))] ) -def test_builtin(con, value, expected): +def test_builtin_scalar_udf(con, value, expected): @udf.scalar.builtin def cbrt(a: float) -> float: ... @@ -62,3 +62,18 @@ def cbrt(a: float) -> float: expr = cbrt(value) result = con.execute(expr) assert pytest.approx(result) == expected + + +def test_builtin_agg_udf(con): + @udf.agg.builtin + def approx_n_unique(a, where: bool = True) -> int: + ... + + ft = con.tables.functional_alltypes + expr = approx_n_unique(ft.string_col) + result = con.execute(expr) + assert result == 10 + + expr = approx_n_unique(ft.string_col, where=ft.string_col == "1") + result = con.execute(expr) + assert result == 1