From 66fd69c9d64b6cd31b26eff2b62dafa88b6cc893 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Sun, 17 Dec 2023 23:16:37 +0100 Subject: [PATCH] refactor(analysis): always merge frames during windowization --- .../test_cumulative_functions/max/out1.sql | 2 +- .../test_cumulative_functions/mean/out1.sql | 2 +- .../test_cumulative_functions/min/out1.sql | 2 +- .../test_cumulative_functions/sum/out1.sql | 2 +- ibis/backends/tests/test_window.py | 19 ++++++++++++-- ibis/expr/analysis.py | 25 ++++++++++++++----- ibis/expr/types/core.py | 3 ++- ibis/expr/types/generic.py | 2 +- ibis/expr/types/groupby.py | 4 +-- ibis/expr/types/relations.py | 3 ++- ibis/tests/expr/test_window_functions.py | 18 ++++++++++--- 11 files changed, 61 insertions(+), 21 deletions(-) diff --git a/ibis/backends/impala/tests/snapshots/test_window/test_cumulative_functions/max/out1.sql b/ibis/backends/impala/tests/snapshots/test_window/test_cumulative_functions/max/out1.sql index d480705d545b..e4d6a93a43cb 100644 --- a/ibis/backends/impala/tests/snapshots/test_window/test_cumulative_functions/max/out1.sql +++ b/ibis/backends/impala/tests/snapshots/test_window/test_cumulative_functions/max/out1.sql @@ -1,2 +1,2 @@ -SELECT max(t0.`f`) OVER (ORDER BY t0.`d` ASC) AS `foo` +SELECT max(t0.`f`) OVER (ORDER BY t0.`d` ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS `foo` FROM `alltypes` t0 \ No newline at end of file diff --git a/ibis/backends/impala/tests/snapshots/test_window/test_cumulative_functions/mean/out1.sql b/ibis/backends/impala/tests/snapshots/test_window/test_cumulative_functions/mean/out1.sql index b557f2cd0e52..8df3c5d8d98f 100644 --- a/ibis/backends/impala/tests/snapshots/test_window/test_cumulative_functions/mean/out1.sql +++ b/ibis/backends/impala/tests/snapshots/test_window/test_cumulative_functions/mean/out1.sql @@ -1,2 +1,2 @@ -SELECT avg(t0.`f`) OVER (ORDER BY t0.`d` ASC) AS `foo` +SELECT avg(t0.`f`) OVER (ORDER BY t0.`d` ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS `foo` FROM `alltypes` t0 \ No newline at end of file diff --git a/ibis/backends/impala/tests/snapshots/test_window/test_cumulative_functions/min/out1.sql b/ibis/backends/impala/tests/snapshots/test_window/test_cumulative_functions/min/out1.sql index af549112f077..debaa216f1c1 100644 --- a/ibis/backends/impala/tests/snapshots/test_window/test_cumulative_functions/min/out1.sql +++ b/ibis/backends/impala/tests/snapshots/test_window/test_cumulative_functions/min/out1.sql @@ -1,2 +1,2 @@ -SELECT min(t0.`f`) OVER (ORDER BY t0.`d` ASC) AS `foo` +SELECT min(t0.`f`) OVER (ORDER BY t0.`d` ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS `foo` FROM `alltypes` t0 \ No newline at end of file diff --git a/ibis/backends/impala/tests/snapshots/test_window/test_cumulative_functions/sum/out1.sql b/ibis/backends/impala/tests/snapshots/test_window/test_cumulative_functions/sum/out1.sql index f03e5db98c3c..4ac5a6cc4bd8 100644 --- a/ibis/backends/impala/tests/snapshots/test_window/test_cumulative_functions/sum/out1.sql +++ b/ibis/backends/impala/tests/snapshots/test_window/test_cumulative_functions/sum/out1.sql @@ -1,2 +1,2 @@ -SELECT sum(t0.`f`) OVER (ORDER BY t0.`d` ASC) AS `foo` +SELECT sum(t0.`f`) OVER (ORDER BY t0.`d` ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS `foo` FROM `alltypes` t0 \ No newline at end of file diff --git a/ibis/backends/tests/test_window.py b/ibis/backends/tests/test_window.py index f9278807e5d4..c0ecb83a4cd4 100644 --- a/ibis/backends/tests/test_window.py +++ b/ibis/backends/tests/test_window.py @@ -207,8 +207,8 @@ def calc_zscore(s): ), pytest.mark.broken( ["flink"], - raises=Py4JJavaError, - reason="CalciteContextException: Argument to function 'NTILE' must be a literal", + raises=com.UnsupportedOperationError, + reason="Windows in Flink can only be ordered by a single time column", ), ], ), @@ -1250,6 +1250,21 @@ def test_range_expression_bounds(backend): assert len(result) == con.execute(t.count()) +@pytest.mark.notimpl(["polars", "dask"], raises=com.OperationNotDefinedError) +@pytest.mark.notyet( + ["clickhouse"], + reason="clickhouse doesn't implement percent_rank", + raises=com.OperationNotDefinedError, +) +@pytest.mark.broken( + ["pandas"], reason="missing column during execution", raises=KeyError +) +@pytest.mark.broken( + ["mssql"], reason="lack of support for booleans", raises=sa.exc.OperationalError +) +@pytest.mark.broken( + ["pyspark"], reason="pyspark requires CURRENT ROW", raises=AnalysisException +) def test_rank_followed_by_over_call_merge_frames(backend, alltypes, df): # GH #7631 t = alltypes diff --git a/ibis/expr/analysis.py b/ibis/expr/analysis.py index 17b9beaa5506..d086a7c827e3 100644 --- a/ibis/expr/analysis.py +++ b/ibis/expr/analysis.py @@ -11,7 +11,7 @@ import ibis.expr.types as ir from ibis import util from ibis.common.deferred import deferred, var -from ibis.common.exceptions import IbisTypeError, IntegrityError +from ibis.common.exceptions import ExpressionError, IbisTypeError, IntegrityError from ibis.common.patterns import Eq, In, pattern, replace from ibis.util import Namespace @@ -170,17 +170,30 @@ def wrap_analytic(_, default_frame): @replace(p.WindowFunction) def merge_windows(_, default_frame): + if _.frame.start and default_frame.start and _.frame.start != default_frame.start: + raise ExpressionError( + "Unable to merge windows with conflicting `start` boundary" + ) + if _.frame.end and default_frame.end and _.frame.end != default_frame.end: + raise ExpressionError("Unable to merge windows with conflicting `end` boundary") + + start = _.frame.start or default_frame.start + end = _.frame.end or default_frame.end group_by = tuple(toolz.unique(_.frame.group_by + default_frame.group_by)) - order_by = tuple(toolz.unique(_.frame.order_by + default_frame.order_by)) - frame = _.frame.copy(group_by=group_by, order_by=order_by) + + order_by = {} + for sort_key in _.frame.order_by + default_frame.order_by: + order_by[sort_key.expr] = sort_key.ascending + order_by = tuple(ops.SortKey(k, v) for k, v in order_by.items()) + + frame = _.frame.copy(start=start, end=end, group_by=group_by, order_by=order_by) return ops.WindowFunction(_.func, frame) -def windowize_function(expr, default_frame, merge_frames=False): +def windowize_function(expr, default_frame): ctx = {"default_frame": default_frame} node = expr.op() - if merge_frames: - node = node.replace(merge_windows, filter=p.Value, context=ctx) + node = node.replace(merge_windows, filter=p.Value, context=ctx) node = node.replace(wrap_analytic, filter=p.Value & ~p.WindowFunction, context=ctx) return node.to_expr() diff --git a/ibis/expr/types/core.py b/ibis/expr/types/core.py index 5e8246080de6..3167cbee373a 100644 --- a/ibis/expr/types/core.py +++ b/ibis/expr/types/core.py @@ -589,7 +589,8 @@ def to_torch( def unbind(self) -> ir.Table: """Return an expression built on `UnboundTable` instead of backend-specific objects.""" - from ibis.expr.analysis import p, c, _ + from ibis.expr.analysis import p, c + from ibis.common.deferred import _ rule = p.DatabaseTable >> c.UnboundTable(name=_.name, schema=_.schema) return self.op().replace(rule).to_expr() diff --git a/ibis/expr/types/generic.py b/ibis/expr/types/generic.py index f0d573f323f8..ea54a4d5c116 100644 --- a/ibis/expr/types/generic.py +++ b/ibis/expr/types/generic.py @@ -758,7 +758,7 @@ def over( def bind(table): frame = window.bind(table) - expr = an.windowize_function(self, frame, merge_frames=True) + expr = an.windowize_function(self, frame) if expr.equals(self): raise com.IbisTypeError( "No reduction or analytic function found to construct a window expression" diff --git a/ibis/expr/types/groupby.py b/ibis/expr/types/groupby.py index 8505fbf8bc04..123b1ca8b71f 100644 --- a/ibis/expr/types/groupby.py +++ b/ibis/expr/types/groupby.py @@ -254,12 +254,12 @@ def _selectables(self, *exprs, **kwexprs): order_by=bind_expr(self.table, self._order_by), ) return [ - an.windowize_function(e2, default_frame, merge_frames=True) + an.windowize_function(e2, default_frame) for expr in exprs for e1 in util.promote_list(expr) for e2 in util.promote_list(table._ensure_expr(e1)) ] + [ - an.windowize_function(e, default_frame, merge_frames=True).name(k) + an.windowize_function(e, default_frame).name(k) for k, expr in kwexprs.items() for e in util.promote_list(table._ensure_expr(expr)) ] diff --git a/ibis/expr/types/relations.py b/ibis/expr/types/relations.py index a41210403eca..030442361f0a 100644 --- a/ibis/expr/types/relations.py +++ b/ibis/expr/types/relations.py @@ -4338,7 +4338,8 @@ def _resolve_predicates( table: Table, predicates ) -> tuple[list[ir.BooleanValue], list[tuple[ir.BooleanValue, ir.Table]]]: import ibis.expr.types as ir - from ibis.expr.analysis import _, flatten_predicate, p + from ibis.common.deferred import _ + from ibis.expr.analysis import flatten_predicate, p # TODO(kszucs): clean this up, too much flattening and resolving happens here predicates = [ diff --git a/ibis/tests/expr/test_window_functions.py b/ibis/tests/expr/test_window_functions.py index 8ee6c5617ed3..1c2fd6110468 100644 --- a/ibis/tests/expr/test_window_functions.py +++ b/ibis/tests/expr/test_window_functions.py @@ -1,7 +1,10 @@ from __future__ import annotations +import pytest + import ibis import ibis.expr.operations as ops +from ibis.common.exceptions import ExpressionError def test_mutate_with_analytic_functions(alltypes): @@ -48,15 +51,22 @@ def test_value_over_api(alltypes): w1 = ibis.window(rows=(0, 1), group_by=t.g, order_by=[t.f, t.h]) w2 = ibis.window(range=(-1, 1), group_by=[t.g, t.a], order_by=[t.f]) - expr = t.f.cumsum().over(rows=(0, 1), group_by=t.g, order_by=[t.f, t.h]) - expected = t.f.cumsum().over(w1) + expr = t.f.sum().over(rows=(0, 1), group_by=t.g, order_by=[t.f, t.h]) + expected = t.f.sum().over(w1) assert expr.equals(expected) - expr = t.f.cumsum().over(range=(-1, 1), group_by=[t.g, t.a], order_by=[t.f]) - expected = t.f.cumsum().over(w2) + expr = t.f.sum().over(range=(-1, 1), group_by=[t.g, t.a], order_by=[t.f]) + expected = t.f.sum().over(w2) assert expr.equals(expected) +def test_conflicting_window_boundaries(alltypes): + t = alltypes + + with pytest.raises(ExpressionError, match="Unable to merge windows"): + t.f.cumsum().over(rows=(0, 1)) + + def test_rank_followed_by_over_call_merge_frames(alltypes): t = alltypes expr1 = t.f.percent_rank().over(ibis.window(group_by=t.f.notnull()))