Skip to content

Commit

Permalink
feat(polars): implement support for builtin aggregate udfs
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Sep 13, 2023
1 parent 7dcb444 commit 1398acd
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 2 deletions.
11 changes: 10 additions & 1 deletion ibis/backends/polars/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1170,7 +1170,16 @@ def execute_count_distinct_star(op, **kw):
@translate.register(ops.ScalarUDF)
def execute_scalar_udf(op, **kw):
if op.__input_type__ == InputType.BUILTIN:
first, *rest = map(translate, op.args)
first, *rest = map(partial(translate, **kw), op.args)
return getattr(first, op.__func_name__)(*rest)
else:
raise NotImplementedError("Only builtin scalar UDFs are supported for polars")


@translate.register(ops.AggUDF)
def execute_agg_udf(op, **kw):
args = (arg for name, arg in zip(op.argnames, op.args) if name != "where")
first, *rest = map(partial(translate, **kw), args)
if (where := op.where) is not None:
first = first.filter(translate(where, **kw))
return getattr(first, op.__func_name__)(*rest)
17 changes: 16 additions & 1 deletion ibis/backends/polars/tests/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,26 @@ def test_multiple_argument_udf(alltypes):
@pytest.mark.parametrize(
("value", "expected"), [(8, 2), (27, 3), (7, 7 ** (1.0 / 3.0))]
)
def test_builtin(con, value, expected):
def test_builtin_scalar_udf(con, value, expected):
@udf.scalar.builtin
def cbrt(a: float) -> float:
...

expr = cbrt(value)
result = con.execute(expr)
assert pytest.approx(result) == expected


def test_builtin_agg_udf(con):
@udf.agg.builtin
def approx_n_unique(a, where: bool = True) -> int:
...

ft = con.tables.functional_alltypes
expr = approx_n_unique(ft.string_col)
result = con.execute(expr)
assert result == 10

expr = approx_n_unique(ft.string_col, where=ft.string_col == "1")
result = con.execute(expr)
assert result == 1

0 comments on commit 1398acd

Please sign in to comment.