Skip to content

Commit

Permalink
fix(sql): standardize NULL handling of argmin/argmax (#10227)
Browse files Browse the repository at this point in the history
Co-authored-by: Phillip Cloud <[email protected]>
  • Loading branch information
jcrist and cpcloud authored Sep 26, 2024
1 parent 428d1a3 commit 51335ed
Show file tree
Hide file tree
Showing 20 changed files with 89 additions and 49 deletions.
19 changes: 7 additions & 12 deletions ibis/backends/polars/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1256,20 +1256,15 @@ def execute_hash(op, **kw):


def _arg_min_max(op, func, **kw):
key = op.key
arg = op.arg

if (op_where := op.where) is not None:
key = ops.IfElse(op_where, key, None)
arg = ops.IfElse(op_where, arg, None)
key = translate(op.key, **kw)
arg = translate(op.arg, **kw)

translate_arg = translate(arg, **kw)
translate_key = translate(key, **kw)
if op.where is not None:
where = translate(op.where, **kw)
arg = arg.filter(where)
key = key.filter(where)

not_null_mask = translate_arg.is_not_null() & translate_key.is_not_null()
return translate_arg.filter(not_null_mask).get(
func(translate_key.filter(not_null_mask))
)
return arg.get(func(key))


@translate.register(ops.ArgMax)
Expand Down
2 changes: 0 additions & 2 deletions ibis/backends/sql/compilers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,8 +306,6 @@ class SQLGlotCompiler(abc.ABC):
ops.All: "bool_and",
ops.Any: "bool_or",
ops.ApproxCountDistinct: "approx_distinct",
ops.ArgMax: "max_by",
ops.ArgMin: "min_by",
ops.ArrayContains: "array_contains",
ops.ArrayFlatten: "flatten",
ops.ArrayLength: "array_size",
Expand Down
2 changes: 2 additions & 0 deletions ibis/backends/sql/compilers/bigquery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,8 @@ class BigQueryCompiler(SQLGlotCompiler):
ops.TimeFromHMS: "time_from_parts",
ops.TimestampNow: "current_timestamp",
ops.ExtractHost: "net.host",
ops.ArgMin: "min_by",
ops.ArgMax: "max_by",
}

def to_sqlglot(
Expand Down
14 changes: 12 additions & 2 deletions ibis/backends/sql/compilers/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,6 @@ class ClickHouseCompiler(SQLGlotCompiler):
ops.ApproxCountDistinct: "uniqHLL12",
ops.ApproxMedian: "median",
ops.Arbitrary: "any",
ops.ArgMax: "argMax",
ops.ArgMin: "argMin",
ops.ArrayContains: "has",
ops.ArrayFlatten: "arrayFlatten",
ops.ArrayIntersect: "arrayIntersect",
Expand Down Expand Up @@ -673,6 +671,18 @@ def visit_Last(self, op, *, arg, where, order_by, include_null):
)
return self.agg.anyLast(arg, where=where, order_by=order_by)

def visit_ArgMin(self, op, *, arg, key, where):
return sge.Dot(
this=self.agg.argMin(self.f.tuple(arg), key, where=where),
expression=sge.convert(1),
)

def visit_ArgMax(self, op, *, arg, key, where):
return sge.Dot(
this=self.agg.argMax(self.f.tuple(arg), key, where=where),
expression=sge.convert(1),
)

def visit_CountDistinctStar(
self, op: ops.CountDistinctStar, *, where, **_: Any
) -> str:
Expand Down
10 changes: 8 additions & 2 deletions ibis/backends/sql/compilers/datafusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@ class DataFusionCompiler(SQLGlotCompiler):
post_rewrites = (split_select_distinct_with_order_by,)

UNSUPPORTED_OPS = (
ops.ArgMax,
ops.ArgMin,
ops.ArrayDistinct,
ops.ArrayFilter,
ops.ArrayMap,
Expand Down Expand Up @@ -457,6 +455,14 @@ def visit_Last(self, op, *, arg, where, order_by, include_null):
where = cond if where is None else sge.And(this=cond, expression=where)
return self.agg.last_value(arg, where=where, order_by=order_by)

def visit_ArgMin(self, op, *, arg, key, where):
return self.agg.first_value(arg, where=where, order_by=[sge.Ordered(this=key)])

def visit_ArgMax(self, op, *, arg, key, where):
return self.agg.first_value(
arg, where=where, order_by=[sge.Ordered(this=key, desc=True)]
)

def visit_Aggregate(self, op, *, parent, groups, metrics):
"""Support `GROUP BY` expressions in `SELECT` since DataFusion does not."""
quoted = self.quoted
Expand Down
2 changes: 0 additions & 2 deletions ibis/backends/sql/compilers/druid.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@ class DruidCompiler(SQLGlotCompiler):

UNSUPPORTED_OPS = (
ops.ApproxMedian,
ops.ArgMax,
ops.ArgMin,
ops.ArrayDistinct,
ops.ArrayFilter,
ops.ArrayFlatten,
Expand Down
8 changes: 8 additions & 0 deletions ibis/backends/sql/compilers/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,14 @@ def visit_Last(self, op, *, arg, where, order_by, include_null):
where = cond if where is None else sge.And(this=cond, expression=where)
return self.agg.last(arg, where=where, order_by=order_by)

def visit_ArgMin(self, op, *, arg, key, where):
return self.agg.first(arg, where=where, order_by=[sge.Ordered(this=key)])

def visit_ArgMax(self, op, *, arg, key, where):
return self.agg.first(
arg, where=where, order_by=[sge.Ordered(this=key, desc=True)]
)

def visit_Quantile(self, op, *, arg, quantile, where):
suffix = "cont" if op.arg.dtype.is_numeric() else "disc"
funcname = f"percentile_{suffix}"
Expand Down
2 changes: 0 additions & 2 deletions ibis/backends/sql/compilers/exasol.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,6 @@ class ExasolCompiler(SQLGlotCompiler):

UNSUPPORTED_OPS = (
ops.AnalyticVectorizedUDF,
ops.ArgMax,
ops.ArgMin,
ops.ArrayDistinct,
ops.ArrayFilter,
ops.ArrayFlatten,
Expand Down
2 changes: 0 additions & 2 deletions ibis/backends/sql/compilers/flink.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,6 @@ class FlinkCompiler(SQLGlotCompiler):
UNSUPPORTED_OPS = (
ops.AnalyticVectorizedUDF,
ops.ApproxMedian,
ops.ArgMax,
ops.ArgMin,
ops.ArrayFlatten,
ops.ArrayStringJoin,
ops.Correlation,
Expand Down
2 changes: 0 additions & 2 deletions ibis/backends/sql/compilers/impala.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@ class ImpalaCompiler(SQLGlotCompiler):
}

UNSUPPORTED_OPS = (
ops.ArgMax,
ops.ArgMin,
ops.ArrayPosition,
ops.Array,
ops.Covariance,
Expand Down
2 changes: 0 additions & 2 deletions ibis/backends/sql/compilers/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,6 @@ class MSSQLCompiler(SQLGlotCompiler):

UNSUPPORTED_OPS = (
ops.ApproxMedian,
ops.ArgMax,
ops.ArgMin,
ops.Array,
ops.ArrayDistinct,
ops.ArrayFlatten,
Expand Down
2 changes: 0 additions & 2 deletions ibis/backends/sql/compilers/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,6 @@ def POS_INF(self):
NEG_INF = POS_INF
UNSUPPORTED_OPS = (
ops.ApproxMedian,
ops.ArgMax,
ops.ArgMin,
ops.Array,
ops.ArrayFlatten,
ops.ArrayMap,
Expand Down
2 changes: 0 additions & 2 deletions ibis/backends/sql/compilers/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,6 @@ class OracleCompiler(SQLGlotCompiler):
}

UNSUPPORTED_OPS = (
ops.ArgMax,
ops.ArgMin,
ops.Array,
ops.ArrayFlatten,
ops.ArrayMap,
Expand Down
14 changes: 6 additions & 8 deletions ibis/backends/sql/compilers/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,23 +192,21 @@ def visit_Mode(self, op, *, arg, where):
expr = sge.Filter(this=expr, expression=sge.Where(this=where))
return expr

def visit_ArgMinMax(self, op, *, arg, key, where, desc: bool):
conditions = [arg.is_(sg.not_(NULL)), key.is_(sg.not_(NULL))]

if where is not None:
conditions.append(where)
def _argminmax(self, op, *, arg, key, where, desc: bool):
cond = key.is_(sg.not_(NULL))
where = cond if where is None else sge.And(this=cond, expression=where)

agg = self.agg.array_agg(
sge.Ordered(this=sge.Order(this=arg, expressions=[key]), desc=desc),
where=sg.and_(*conditions),
where=where,
)
return sge.paren(agg, copy=False)[0]

def visit_ArgMin(self, op, *, arg, key, where):
return self.visit_ArgMinMax(op, arg=arg, key=key, where=where, desc=False)
return self._argminmax(op, arg=arg, key=key, where=where, desc=False)

def visit_ArgMax(self, op, *, arg, key, where):
return self.visit_ArgMinMax(op, arg=arg, key=key, where=where, desc=True)
return self._argminmax(op, arg=arg, key=key, where=where, desc=True)

def visit_Sum(self, op, *, arg, where):
arg = (
Expand Down
2 changes: 2 additions & 0 deletions ibis/backends/sql/compilers/pyspark.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ class PySparkCompiler(SQLGlotCompiler):
}

SIMPLE_OPS = {
ops.ArgMax: "max_by",
ops.ArgMin: "min_by",
ops.ArrayDistinct: "array_distinct",
ops.ArrayFlatten: "flatten",
ops.ArrayIntersect: "array_intersect",
Expand Down
2 changes: 2 additions & 0 deletions ibis/backends/sql/compilers/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ class SnowflakeCompiler(SQLGlotCompiler):
SIMPLE_OPS = {
ops.All: "min",
ops.Any: "max",
ops.ArgMax: "max_by",
ops.ArgMin: "min_by",
ops.ArrayDistinct: "array_distinct",
ops.ArrayFlatten: "array_flatten",
ops.ArrayIndex: "get",
Expand Down
7 changes: 1 addition & 6 deletions ibis/backends/sql/compilers/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,12 +206,7 @@ def visit_ArgMax(self, *args, **kwargs):
return self._visit_arg_reduction("max", *args, **kwargs)

def _visit_arg_reduction(self, func, op, *, arg, key, where):
cond = arg.is_(sg.not_(NULL))

if op.where is not None:
cond = sg.and_(cond, where)

agg = self.agg[func](key, where=cond)
agg = self.agg[func](key, where=where)
return self.f.anon.json_extract(self.f.json_array(arg, agg), "$[0]")

def visit_UnwrapJSONString(self, op, *, arg):
Expand Down
2 changes: 2 additions & 0 deletions ibis/backends/sql/compilers/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ class TrinoCompiler(SQLGlotCompiler):

SIMPLE_OPS = {
ops.Arbitrary: "any_value",
ops.ArgMax: "max_by",
ops.ArgMin: "min_by",
ops.Pi: "pi",
ops.E: "e",
ops.RegexReplace: "regexp_replace",
Expand Down
36 changes: 33 additions & 3 deletions ibis/backends/tests/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,6 @@ def mean_udf(s):
]

argidx_not_grouped_marks = [
"datafusion",
"impala",
"mysql",
"mssql",
Expand Down Expand Up @@ -411,7 +410,6 @@ def mean_and_std(v):
[
"impala",
"mysql",
"datafusion",
"mssql",
"druid",
"oracle",
Expand All @@ -431,7 +429,6 @@ def mean_and_std(v):
[
"impala",
"mysql",
"datafusion",
"mssql",
"druid",
"oracle",
Expand Down Expand Up @@ -689,6 +686,39 @@ def test_first_last_ordered(alltypes, method, filtered, include_null):
assert res == sol


@pytest.mark.notimpl(
[
"druid",
"exasol",
"flink",
"impala",
"mssql",
"mysql",
"oracle",
],
raises=com.OperationNotDefinedError,
)
@pytest.mark.parametrize("method", ["argmin", "argmax"])
@pytest.mark.parametrize("filtered", [True, False], ids=["filtered", "unfiltered"])
@pytest.mark.parametrize("null_result", [True, False], ids=["null", "non-null"])
def test_argmin_argmax(alltypes, method, filtered, null_result):
t = alltypes.mutate(by_col=_.int_col.nullif(0).nullif(9), val_col=10 * _.int_col)

if filtered:
where = _.int_col != (1 if method == "argmin" else 8)
sol = 20 if method == "argmin" else 70
else:
where = None
sol = 10 if method == "argmin" else 80

if null_result:
t = t.mutate(val_col=_.val_col.nullif(sol))

expr = getattr(t.val_col, method)("by_col", where=where)
res = expr.execute()
assert pd.isna(res) if null_result else res == sol


@pytest.mark.notimpl(
[
"impala",
Expand Down
6 changes: 6 additions & 0 deletions ibis/expr/types/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1766,6 +1766,9 @@ def min(self, where: ir.BooleanValue | None = None) -> Scalar:
def argmax(self, key: ir.Value, where: ir.BooleanValue | None = None) -> Scalar:
"""Return the value of `self` that maximizes `key`.
If more than one value maximizes `key`, the returned value is backend
specific. The result may be `NULL`.
Parameters
----------
key
Expand Down Expand Up @@ -1801,6 +1804,9 @@ def argmax(self, key: ir.Value, where: ir.BooleanValue | None = None) -> Scalar:
def argmin(self, key: ir.Value, where: ir.BooleanValue | None = None) -> Scalar:
"""Return the value of `self` that minimizes `key`.
If more than one value minimizes `key`, the returned value is backend
specific. The result may be `NULL`.
Parameters
----------
key
Expand Down

0 comments on commit 51335ed

Please sign in to comment.