Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(sql): standardize NULL handling of argmin/argmax #10227

Merged
merged 4 commits into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 7 additions & 12 deletions ibis/backends/polars/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1256,20 +1256,15 @@ def execute_hash(op, **kw):


def _arg_min_max(op, func, **kw):
key = op.key
arg = op.arg

if (op_where := op.where) is not None:
key = ops.IfElse(op_where, key, None)
arg = ops.IfElse(op_where, arg, None)
key = translate(op.key, **kw)
arg = translate(op.arg, **kw)

translate_arg = translate(arg, **kw)
translate_key = translate(key, **kw)
if op.where is not None:
where = translate(op.where, **kw)
arg = arg.filter(where)
key = key.filter(where)

not_null_mask = translate_arg.is_not_null() & translate_key.is_not_null()
return translate_arg.filter(not_null_mask).get(
func(translate_key.filter(not_null_mask))
)
return arg.get(func(key))


@translate.register(ops.ArgMax)
Expand Down
2 changes: 0 additions & 2 deletions ibis/backends/sql/compilers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,8 +306,6 @@ class SQLGlotCompiler(abc.ABC):
ops.All: "bool_and",
ops.Any: "bool_or",
ops.ApproxCountDistinct: "approx_distinct",
ops.ArgMax: "max_by",
ops.ArgMin: "min_by",
ops.ArrayContains: "array_contains",
ops.ArrayFlatten: "flatten",
ops.ArrayLength: "array_size",
Expand Down
2 changes: 2 additions & 0 deletions ibis/backends/sql/compilers/bigquery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,8 @@ class BigQueryCompiler(SQLGlotCompiler):
ops.TimeFromHMS: "time_from_parts",
ops.TimestampNow: "current_timestamp",
ops.ExtractHost: "net.host",
ops.ArgMin: "min_by",
ops.ArgMax: "max_by",
}

def to_sqlglot(
Expand Down
14 changes: 12 additions & 2 deletions ibis/backends/sql/compilers/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,6 @@ class ClickHouseCompiler(SQLGlotCompiler):
ops.ApproxCountDistinct: "uniqHLL12",
ops.ApproxMedian: "median",
ops.Arbitrary: "any",
ops.ArgMax: "argMax",
ops.ArgMin: "argMin",
ops.ArrayContains: "has",
ops.ArrayFlatten: "arrayFlatten",
ops.ArrayIntersect: "arrayIntersect",
Expand Down Expand Up @@ -673,6 +671,18 @@ def visit_Last(self, op, *, arg, where, order_by, include_null):
)
return self.agg.anyLast(arg, where=where, order_by=order_by)

def visit_ArgMin(self, op, *, arg, key, where):
return sge.Dot(
this=self.agg.argMin(self.f.tuple(arg), key, where=where),
expression=sge.convert(1),
)

def visit_ArgMax(self, op, *, arg, key, where):
return sge.Dot(
this=self.agg.argMax(self.f.tuple(arg), key, where=where),
expression=sge.convert(1),
)

def visit_CountDistinctStar(
self, op: ops.CountDistinctStar, *, where, **_: Any
) -> str:
Expand Down
10 changes: 8 additions & 2 deletions ibis/backends/sql/compilers/datafusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@ class DataFusionCompiler(SQLGlotCompiler):
post_rewrites = (split_select_distinct_with_order_by,)

UNSUPPORTED_OPS = (
ops.ArgMax,
ops.ArgMin,
ops.ArrayDistinct,
ops.ArrayFilter,
ops.ArrayMap,
Expand Down Expand Up @@ -457,6 +455,14 @@ def visit_Last(self, op, *, arg, where, order_by, include_null):
where = cond if where is None else sge.And(this=cond, expression=where)
return self.agg.last_value(arg, where=where, order_by=order_by)

def visit_ArgMin(self, op, *, arg, key, where):
return self.agg.first_value(arg, where=where, order_by=[sge.Ordered(this=key)])

def visit_ArgMax(self, op, *, arg, key, where):
return self.agg.first_value(
arg, where=where, order_by=[sge.Ordered(this=key, desc=True)]
)

def visit_Aggregate(self, op, *, parent, groups, metrics):
"""Support `GROUP BY` expressions in `SELECT` since DataFusion does not."""
quoted = self.quoted
Expand Down
2 changes: 0 additions & 2 deletions ibis/backends/sql/compilers/druid.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@ class DruidCompiler(SQLGlotCompiler):

UNSUPPORTED_OPS = (
ops.ApproxMedian,
ops.ArgMax,
ops.ArgMin,
ops.ArrayDistinct,
ops.ArrayFilter,
ops.ArrayFlatten,
Expand Down
8 changes: 8 additions & 0 deletions ibis/backends/sql/compilers/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,14 @@ def visit_Last(self, op, *, arg, where, order_by, include_null):
where = cond if where is None else sge.And(this=cond, expression=where)
return self.agg.last(arg, where=where, order_by=order_by)

def visit_ArgMin(self, op, *, arg, key, where):
return self.agg.first(arg, where=where, order_by=[sge.Ordered(this=key)])

def visit_ArgMax(self, op, *, arg, key, where):
return self.agg.first(
arg, where=where, order_by=[sge.Ordered(this=key, desc=True)]
)

def visit_Quantile(self, op, *, arg, quantile, where):
suffix = "cont" if op.arg.dtype.is_numeric() else "disc"
funcname = f"percentile_{suffix}"
Expand Down
2 changes: 0 additions & 2 deletions ibis/backends/sql/compilers/exasol.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,6 @@ class ExasolCompiler(SQLGlotCompiler):

UNSUPPORTED_OPS = (
ops.AnalyticVectorizedUDF,
ops.ArgMax,
ops.ArgMin,
ops.ArrayDistinct,
ops.ArrayFilter,
ops.ArrayFlatten,
Expand Down
2 changes: 0 additions & 2 deletions ibis/backends/sql/compilers/flink.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,6 @@ class FlinkCompiler(SQLGlotCompiler):
UNSUPPORTED_OPS = (
ops.AnalyticVectorizedUDF,
ops.ApproxMedian,
ops.ArgMax,
ops.ArgMin,
ops.ArrayFlatten,
ops.ArrayStringJoin,
ops.Correlation,
Expand Down
2 changes: 0 additions & 2 deletions ibis/backends/sql/compilers/impala.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@ class ImpalaCompiler(SQLGlotCompiler):
}

UNSUPPORTED_OPS = (
ops.ArgMax,
ops.ArgMin,
ops.ArrayPosition,
ops.Array,
ops.Covariance,
Expand Down
2 changes: 0 additions & 2 deletions ibis/backends/sql/compilers/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,6 @@ class MSSQLCompiler(SQLGlotCompiler):

UNSUPPORTED_OPS = (
ops.ApproxMedian,
ops.ArgMax,
ops.ArgMin,
ops.Array,
ops.ArrayDistinct,
ops.ArrayFlatten,
Expand Down
2 changes: 0 additions & 2 deletions ibis/backends/sql/compilers/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,6 @@ def POS_INF(self):
NEG_INF = POS_INF
UNSUPPORTED_OPS = (
ops.ApproxMedian,
ops.ArgMax,
ops.ArgMin,
ops.Array,
ops.ArrayFlatten,
ops.ArrayMap,
Expand Down
2 changes: 0 additions & 2 deletions ibis/backends/sql/compilers/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,6 @@ class OracleCompiler(SQLGlotCompiler):
}

UNSUPPORTED_OPS = (
ops.ArgMax,
ops.ArgMin,
ops.Array,
ops.ArrayFlatten,
ops.ArrayMap,
Expand Down
14 changes: 6 additions & 8 deletions ibis/backends/sql/compilers/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,23 +192,21 @@ def visit_Mode(self, op, *, arg, where):
expr = sge.Filter(this=expr, expression=sge.Where(this=where))
return expr

def visit_ArgMinMax(self, op, *, arg, key, where, desc: bool):
conditions = [arg.is_(sg.not_(NULL)), key.is_(sg.not_(NULL))]

if where is not None:
conditions.append(where)
def _argminmax(self, op, *, arg, key, where, desc: bool):
cond = key.is_(sg.not_(NULL))
where = cond if where is None else sge.And(this=cond, expression=where)

agg = self.agg.array_agg(
sge.Ordered(this=sge.Order(this=arg, expressions=[key]), desc=desc),
where=sg.and_(*conditions),
where=where,
)
return sge.paren(agg, copy=False)[0]

def visit_ArgMin(self, op, *, arg, key, where):
return self.visit_ArgMinMax(op, arg=arg, key=key, where=where, desc=False)
return self._argminmax(op, arg=arg, key=key, where=where, desc=False)

def visit_ArgMax(self, op, *, arg, key, where):
return self.visit_ArgMinMax(op, arg=arg, key=key, where=where, desc=True)
return self._argminmax(op, arg=arg, key=key, where=where, desc=True)

def visit_Sum(self, op, *, arg, where):
arg = (
Expand Down
2 changes: 2 additions & 0 deletions ibis/backends/sql/compilers/pyspark.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ class PySparkCompiler(SQLGlotCompiler):
}

SIMPLE_OPS = {
ops.ArgMax: "max_by",
ops.ArgMin: "min_by",
ops.ArrayDistinct: "array_distinct",
ops.ArrayFlatten: "flatten",
ops.ArrayIntersect: "array_intersect",
Expand Down
2 changes: 2 additions & 0 deletions ibis/backends/sql/compilers/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ class SnowflakeCompiler(SQLGlotCompiler):
SIMPLE_OPS = {
ops.All: "min",
ops.Any: "max",
ops.ArgMax: "max_by",
ops.ArgMin: "min_by",
ops.ArrayDistinct: "array_distinct",
ops.ArrayFlatten: "array_flatten",
ops.ArrayIndex: "get",
Expand Down
7 changes: 1 addition & 6 deletions ibis/backends/sql/compilers/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,12 +206,7 @@ def visit_ArgMax(self, *args, **kwargs):
return self._visit_arg_reduction("max", *args, **kwargs)

def _visit_arg_reduction(self, func, op, *, arg, key, where):
cond = arg.is_(sg.not_(NULL))

if op.where is not None:
cond = sg.and_(cond, where)

agg = self.agg[func](key, where=cond)
agg = self.agg[func](key, where=where)
return self.f.anon.json_extract(self.f.json_array(arg, agg), "$[0]")

def visit_UnwrapJSONString(self, op, *, arg):
Expand Down
2 changes: 2 additions & 0 deletions ibis/backends/sql/compilers/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ class TrinoCompiler(SQLGlotCompiler):

SIMPLE_OPS = {
ops.Arbitrary: "any_value",
ops.ArgMax: "max_by",
ops.ArgMin: "min_by",
ops.Pi: "pi",
ops.E: "e",
ops.RegexReplace: "regexp_replace",
Expand Down
36 changes: 33 additions & 3 deletions ibis/backends/tests/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,6 @@ def mean_udf(s):
]

argidx_not_grouped_marks = [
"datafusion",
"impala",
"mysql",
"mssql",
Expand Down Expand Up @@ -411,7 +410,6 @@ def mean_and_std(v):
[
"impala",
"mysql",
"datafusion",
"mssql",
"druid",
"oracle",
Expand All @@ -431,7 +429,6 @@ def mean_and_std(v):
[
"impala",
"mysql",
"datafusion",
"mssql",
"druid",
"oracle",
Expand Down Expand Up @@ -689,6 +686,39 @@ def test_first_last_ordered(alltypes, method, filtered, include_null):
assert res == sol


@pytest.mark.notimpl(
[
"druid",
"exasol",
"flink",
"impala",
"mssql",
"mysql",
"oracle",
],
raises=com.OperationNotDefinedError,
)
@pytest.mark.parametrize("method", ["argmin", "argmax"])
@pytest.mark.parametrize("filtered", [True, False], ids=["filtered", "unfiltered"])
@pytest.mark.parametrize("null_result", [True, False], ids=["null", "non-null"])
def test_argmin_argmax(alltypes, method, filtered, null_result):
t = alltypes.mutate(by_col=_.int_col.nullif(0).nullif(9), val_col=10 * _.int_col)

if filtered:
where = _.int_col != (1 if method == "argmin" else 8)
sol = 20 if method == "argmin" else 70
else:
where = None
sol = 10 if method == "argmin" else 80

if null_result:
t = t.mutate(val_col=_.val_col.nullif(sol))

expr = getattr(t.val_col, method)("by_col", where=where)
res = expr.execute()
assert pd.isna(res) if null_result else res == sol


@pytest.mark.notimpl(
[
"impala",
Expand Down
6 changes: 6 additions & 0 deletions ibis/expr/types/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1766,6 +1766,9 @@ def min(self, where: ir.BooleanValue | None = None) -> Scalar:
def argmax(self, key: ir.Value, where: ir.BooleanValue | None = None) -> Scalar:
"""Return the value of `self` that maximizes `key`.

If more than one value maximizes `key`, the returned value is backend
specific. The result may be `NULL`.

Parameters
----------
key
Expand Down Expand Up @@ -1801,6 +1804,9 @@ def argmax(self, key: ir.Value, where: ir.BooleanValue | None = None) -> Scalar:
def argmin(self, key: ir.Value, where: ir.BooleanValue | None = None) -> Scalar:
"""Return the value of `self` that minimizes `key`.

If more than one value minimizes `key`, the returned value is backend
specific. The result may be `NULL`.

Parameters
----------
key
Expand Down