Skip to content

Commit

Permalink
fix: drop nulls in .collect() aggregation
Browse files Browse the repository at this point in the history
  • Loading branch information
jcrist committed Jun 7, 2024
1 parent 10112bd commit b6e0c31
Show file tree
Hide file tree
Showing 10 changed files with 49 additions and 33 deletions.
3 changes: 3 additions & 0 deletions ibis/backends/datafusion/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from ibis.backends.sql.compiler import FALSE, NULL, STAR, AggGen, SQLGlotCompiler
from ibis.backends.sql.datatypes import DataFusionType
from ibis.backends.sql.dialects import DataFusion
from ibis.backends.sql.rewrites import exclude_nulls_from_array_collect
from ibis.common.temporal import IntervalUnit, TimestampUnit
from ibis.expr.operations.udf import InputType
from ibis.formats.pyarrow import PyArrowType
Expand All @@ -25,6 +26,8 @@ class DataFusionCompiler(SQLGlotCompiler):
dialect = DataFusion
type_mapper = DataFusionType

rewrites = (exclude_nulls_from_array_collect, *SQLGlotCompiler.rewrites)

agg = AggGen(supports_filter=True)

UNSUPPORTED_OPS = (
Expand Down
6 changes: 6 additions & 0 deletions ibis/backends/duckdb/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import ibis.expr.operations as ops
from ibis.backends.sql.compiler import NULL, STAR, AggGen, SQLGlotCompiler
from ibis.backends.sql.datatypes import DuckDBType
from ibis.backends.sql.rewrites import exclude_nulls_from_array_collect

_INTERVAL_SUFFIXES = {
"ms": "milliseconds",
Expand All @@ -35,6 +36,11 @@ class DuckDBCompiler(SQLGlotCompiler):

agg = AggGen(supports_filter=True)

rewrites = (
exclude_nulls_from_array_collect,
*SQLGlotCompiler.rewrites,
)

LOWERED_OPS = {
ops.Sample: None,
ops.StringSlice: None,
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/pandas/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ def arbitrary(arg):
ops.Arbitrary: arbitrary,
ops.CountDistinct: lambda x: x.nunique(),
ops.ApproxCountDistinct: lambda x: x.nunique(),
ops.ArrayCollect: lambda x: x.tolist(),
ops.ArrayCollect: lambda x: x.dropna().tolist(),
}


Expand Down
22 changes: 15 additions & 7 deletions ibis/backends/polars/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,21 +275,24 @@ def aggregation(op, **kw):

if op.groups:
# project first to handle computed group by columns
lf = (
func = (
lf.with_columns(
[translate(arg, **kw).alias(name) for name, arg in op.groups.items()]
)
.group_by(list(op.groups.keys()))
.agg
)
else:
lf = lf.select
func = lf.select

if op.metrics:
metrics = [translate(arg, **kw).alias(name) for name, arg in op.metrics.items()]
lf = lf(metrics)
metrics = [
translate(arg, in_group_by=bool(op.groups), **kw).alias(name)
for name, arg in op.metrics.items()
]
return func(metrics)

return lf
return func()


@translate.register(PandasRename)
Expand Down Expand Up @@ -988,11 +991,16 @@ def array_column(op, **kw):


@translate.register(ops.ArrayCollect)
def array_collect(op, **kw):
def array_collect(op, in_group_by=False, **kw):
arg = translate(op.arg, **kw)
if (where := op.where) is not None:
arg = arg.filter(translate(where, **kw))
return arg
out = arg.drop_nulls()
if not in_group_by:
# Polars' behavior changes for `implode` within a `group_by` currently.
# See https://github.com/pola-rs/polars/issues/16756
out = out.implode()
return out


@translate.register(ops.ArrayFlatten)
Expand Down
3 changes: 3 additions & 0 deletions ibis/backends/postgres/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from ibis.backends.sql.compiler import NULL, STAR, AggGen, SQLGlotCompiler
from ibis.backends.sql.datatypes import PostgresType
from ibis.backends.sql.dialects import Postgres
from ibis.backends.sql.rewrites import exclude_nulls_from_array_collect


class PostgresUDFNode(ops.Value):
Expand All @@ -27,6 +28,8 @@ class PostgresCompiler(SQLGlotCompiler):
dialect = Postgres
type_mapper = PostgresType

rewrites = (exclude_nulls_from_array_collect, *SQLGlotCompiler.rewrites)

agg = AggGen(supports_filter=True)

NAN = sge.Literal.number("'NaN'::double precision")
Expand Down
5 changes: 0 additions & 5 deletions ibis/backends/snowflake/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,11 +306,6 @@ def visit_ArrayUnion(self, op, *, left, right):
def visit_ArrayContains(self, op, *, arg, other):
return self.f.array_contains(arg, self.f.to_variant(other))

def visit_ArrayCollect(self, op, *, arg, where):
return self.agg.array_agg(
self.f.ifnull(arg, self.f.parse_json("null")), where=where
)

def visit_ArrayConcat(self, op, *, arg):
# array_cat only accepts two arguments
return self.f.array_flatten(self.f.array(*arg))
Expand Down
8 changes: 8 additions & 0 deletions ibis/backends/sql/rewrites.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,14 @@ def exclude_unsupported_window_frame_from_ops(_, **kwargs):
return _.copy(start=None, end=0, order_by=_.order_by or (ops.NULL,))


@replace(p.ArrayCollect)
def exclude_nulls_from_array_collect(_, **kwargs):
where = ops.NotNull(_.arg)
if _.where is not None:
where = ops.And(where, _.where)
return _.copy(where=where)


# Rewrite rules for lowering a high-level operation into one composed of more
# primitive operations.

Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/tests/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,8 +559,8 @@ def mean_and_std(v):
id="count_star",
),
param(
lambda t, where: t.string_col.collect(where=where),
lambda t, where: t.string_col[where].tolist(),
lambda t, where: t.string_col.nullif("3").collect(where=where),
lambda t, where: t.string_col[t.string_col != "3"][where].tolist(),
id="collect",
marks=[
pytest.mark.notimpl(
Expand Down
23 changes: 6 additions & 17 deletions ibis/backends/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,18 +303,7 @@ def test_unnest_complex(backend):


@builtin_array
@pytest.mark.never(
"pyspark", reason="pyspark throws away nulls in collect_list", raises=AssertionError
)
@pytest.mark.never(
"clickhouse",
reason="clickhouse throws away nulls in groupArray",
raises=AssertionError,
)
@pytest.mark.notimpl(["datafusion", "flink"], raises=com.OperationNotDefinedError)
@pytest.mark.broken(
"dask", reason="DataFrame.index are different", raises=AssertionError
)
def test_unnest_idempotent(backend):
array_types = backend.array_types
df = array_types.execute()
Expand All @@ -326,18 +315,18 @@ def test_unnest_idempotent(backend):
.aggregate(x=lambda t: t.x.collect())
.order_by("scalar_column")
)
result = expr.execute()
result = expr.execute().reset_index(drop=True)
expected = (
df[["scalar_column", "x"]].sort_values("scalar_column").reset_index(drop=True)
df[["scalar_column", "x"]]
.assign(x=df.x.map(lambda arr: [i for i in arr if not pd.isna(i)]))
.sort_values("scalar_column")
.reset_index(drop=True)
)
tm.assert_frame_equal(result, expected)


@builtin_array
@pytest.mark.notimpl(["datafusion", "flink"], raises=com.OperationNotDefinedError)
@pytest.mark.broken(
"dask", reason="DataFrame.index are different", raises=AssertionError
)
def test_unnest_no_nulls(backend):
array_types = backend.array_types
df = array_types.execute()
Expand All @@ -350,7 +339,7 @@ def test_unnest_no_nulls(backend):
.aggregate(x=lambda t: t.y.collect())
.order_by("scalar_column")
)
result = expr.execute()
result = expr.execute().reset_index(drop=True)
expected = (
df[["scalar_column", "x"]]
.explode("x")
Expand Down
6 changes: 5 additions & 1 deletion ibis/backends/trino/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
from ibis.backends.sql.compiler import FALSE, NULL, STAR, AggGen, SQLGlotCompiler
from ibis.backends.sql.datatypes import TrinoType
from ibis.backends.sql.dialects import Trino
from ibis.backends.sql.rewrites import exclude_unsupported_window_frame_from_ops
from ibis.backends.sql.rewrites import (
exclude_nulls_from_array_collect,
exclude_unsupported_window_frame_from_ops,
)


class TrinoCompiler(SQLGlotCompiler):
Expand All @@ -25,6 +28,7 @@ class TrinoCompiler(SQLGlotCompiler):
agg = AggGen(supports_filter=True)

rewrites = (
exclude_nulls_from_array_collect,
exclude_unsupported_window_frame_from_ops,
*SQLGlotCompiler.rewrites,
)
Expand Down

0 comments on commit b6e0c31

Please sign in to comment.