Skip to content

Commit

Permalink
refactor(analysis): always merge frames during windowization
Browse files Browse the repository at this point in the history
  • Loading branch information
kszucs authored and cpcloud committed Dec 18, 2023
1 parent e12ce8d commit 66fd69c
Show file tree
Hide file tree
Showing 11 changed files with 61 additions and 21 deletions.
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
SELECT max(t0.`f`) OVER (ORDER BY t0.`d` ASC) AS `foo`
SELECT max(t0.`f`) OVER (ORDER BY t0.`d` ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS `foo`
FROM `alltypes` t0
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
SELECT avg(t0.`f`) OVER (ORDER BY t0.`d` ASC) AS `foo`
SELECT avg(t0.`f`) OVER (ORDER BY t0.`d` ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS `foo`
FROM `alltypes` t0
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
SELECT min(t0.`f`) OVER (ORDER BY t0.`d` ASC) AS `foo`
SELECT min(t0.`f`) OVER (ORDER BY t0.`d` ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS `foo`
FROM `alltypes` t0
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
SELECT sum(t0.`f`) OVER (ORDER BY t0.`d` ASC) AS `foo`
SELECT sum(t0.`f`) OVER (ORDER BY t0.`d` ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS `foo`
FROM `alltypes` t0
19 changes: 17 additions & 2 deletions ibis/backends/tests/test_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,8 @@ def calc_zscore(s):
),
pytest.mark.broken(
["flink"],
raises=Py4JJavaError,
reason="CalciteContextException: Argument to function 'NTILE' must be a literal",
raises=com.UnsupportedOperationError,
reason="Windows in Flink can only be ordered by a single time column",
),
],
),
Expand Down Expand Up @@ -1250,6 +1250,21 @@ def test_range_expression_bounds(backend):
assert len(result) == con.execute(t.count())


@pytest.mark.notimpl(["polars", "dask"], raises=com.OperationNotDefinedError)
@pytest.mark.notyet(
["clickhouse"],
reason="clickhouse doesn't implement percent_rank",
raises=com.OperationNotDefinedError,
)
@pytest.mark.broken(
["pandas"], reason="missing column during execution", raises=KeyError
)
@pytest.mark.broken(
["mssql"], reason="lack of support for booleans", raises=sa.exc.OperationalError
)
@pytest.mark.broken(
["pyspark"], reason="pyspark requires CURRENT ROW", raises=AnalysisException
)
def test_rank_followed_by_over_call_merge_frames(backend, alltypes, df):
# GH #7631
t = alltypes
Expand Down
25 changes: 19 additions & 6 deletions ibis/expr/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import ibis.expr.types as ir
from ibis import util
from ibis.common.deferred import deferred, var
from ibis.common.exceptions import IbisTypeError, IntegrityError
from ibis.common.exceptions import ExpressionError, IbisTypeError, IntegrityError
from ibis.common.patterns import Eq, In, pattern, replace
from ibis.util import Namespace

Expand Down Expand Up @@ -170,17 +170,30 @@ def wrap_analytic(_, default_frame):

@replace(p.WindowFunction)
def merge_windows(_, default_frame):
if _.frame.start and default_frame.start and _.frame.start != default_frame.start:
raise ExpressionError(
"Unable to merge windows with conflicting `start` boundary"
)
if _.frame.end and default_frame.end and _.frame.end != default_frame.end:
raise ExpressionError("Unable to merge windows with conflicting `end` boundary")

start = _.frame.start or default_frame.start
end = _.frame.end or default_frame.end
group_by = tuple(toolz.unique(_.frame.group_by + default_frame.group_by))
order_by = tuple(toolz.unique(_.frame.order_by + default_frame.order_by))
frame = _.frame.copy(group_by=group_by, order_by=order_by)

order_by = {}
for sort_key in _.frame.order_by + default_frame.order_by:
order_by[sort_key.expr] = sort_key.ascending
order_by = tuple(ops.SortKey(k, v) for k, v in order_by.items())

frame = _.frame.copy(start=start, end=end, group_by=group_by, order_by=order_by)
return ops.WindowFunction(_.func, frame)


def windowize_function(expr, default_frame, merge_frames=False):
def windowize_function(expr, default_frame):
ctx = {"default_frame": default_frame}
node = expr.op()
if merge_frames:
node = node.replace(merge_windows, filter=p.Value, context=ctx)
node = node.replace(merge_windows, filter=p.Value, context=ctx)
node = node.replace(wrap_analytic, filter=p.Value & ~p.WindowFunction, context=ctx)
return node.to_expr()

Expand Down
3 changes: 2 additions & 1 deletion ibis/expr/types/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,8 @@ def to_torch(

def unbind(self) -> ir.Table:
"""Return an expression built on `UnboundTable` instead of backend-specific objects."""
from ibis.expr.analysis import p, c, _
from ibis.expr.analysis import p, c
from ibis.common.deferred import _

rule = p.DatabaseTable >> c.UnboundTable(name=_.name, schema=_.schema)
return self.op().replace(rule).to_expr()
Expand Down
2 changes: 1 addition & 1 deletion ibis/expr/types/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,7 +758,7 @@ def over(

def bind(table):
frame = window.bind(table)
expr = an.windowize_function(self, frame, merge_frames=True)
expr = an.windowize_function(self, frame)
if expr.equals(self):
raise com.IbisTypeError(
"No reduction or analytic function found to construct a window expression"
Expand Down
4 changes: 2 additions & 2 deletions ibis/expr/types/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,12 +254,12 @@ def _selectables(self, *exprs, **kwexprs):
order_by=bind_expr(self.table, self._order_by),
)
return [
an.windowize_function(e2, default_frame, merge_frames=True)
an.windowize_function(e2, default_frame)
for expr in exprs
for e1 in util.promote_list(expr)
for e2 in util.promote_list(table._ensure_expr(e1))
] + [
an.windowize_function(e, default_frame, merge_frames=True).name(k)
an.windowize_function(e, default_frame).name(k)
for k, expr in kwexprs.items()
for e in util.promote_list(table._ensure_expr(expr))
]
Expand Down
3 changes: 2 additions & 1 deletion ibis/expr/types/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4338,7 +4338,8 @@ def _resolve_predicates(
table: Table, predicates
) -> tuple[list[ir.BooleanValue], list[tuple[ir.BooleanValue, ir.Table]]]:
import ibis.expr.types as ir
from ibis.expr.analysis import _, flatten_predicate, p
from ibis.common.deferred import _
from ibis.expr.analysis import flatten_predicate, p

# TODO(kszucs): clean this up, too much flattening and resolving happens here
predicates = [
Expand Down
18 changes: 14 additions & 4 deletions ibis/tests/expr/test_window_functions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from __future__ import annotations

import pytest

import ibis
import ibis.expr.operations as ops
from ibis.common.exceptions import ExpressionError


def test_mutate_with_analytic_functions(alltypes):
Expand Down Expand Up @@ -48,15 +51,22 @@ def test_value_over_api(alltypes):
w1 = ibis.window(rows=(0, 1), group_by=t.g, order_by=[t.f, t.h])
w2 = ibis.window(range=(-1, 1), group_by=[t.g, t.a], order_by=[t.f])

expr = t.f.cumsum().over(rows=(0, 1), group_by=t.g, order_by=[t.f, t.h])
expected = t.f.cumsum().over(w1)
expr = t.f.sum().over(rows=(0, 1), group_by=t.g, order_by=[t.f, t.h])
expected = t.f.sum().over(w1)
assert expr.equals(expected)

expr = t.f.cumsum().over(range=(-1, 1), group_by=[t.g, t.a], order_by=[t.f])
expected = t.f.cumsum().over(w2)
expr = t.f.sum().over(range=(-1, 1), group_by=[t.g, t.a], order_by=[t.f])
expected = t.f.sum().over(w2)
assert expr.equals(expected)


def test_conflicting_window_boundaries(alltypes):
t = alltypes

with pytest.raises(ExpressionError, match="Unable to merge windows"):
t.f.cumsum().over(rows=(0, 1))


def test_rank_followed_by_over_call_merge_frames(alltypes):
t = alltypes
expr1 = t.f.percent_rank().over(ibis.window(group_by=t.f.notnull()))
Expand Down

0 comments on commit 66fd69c

Please sign in to comment.