Skip to content

Commit

Permalink
fix(ir): merge window frames for bound analytic window functions with…
Browse files Browse the repository at this point in the history
… a subsequent over call
  • Loading branch information
kszucs committed Dec 18, 2023
1 parent f20e34e commit 5dcd218
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 22 deletions.
18 changes: 18 additions & 0 deletions ibis/backends/tests/test_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -1248,3 +1248,21 @@ def test_range_expression_bounds(backend):

assert not result.empty
assert len(result) == con.execute(t.count())


def test_rank_followed_by_over_call_merge_frames(backend, alltypes, df):
# GH #7631
t = alltypes
expr = t.int_col.percent_rank().over(ibis.window(group_by=t.int_col.notnull()))
result = expr.execute()

expected = (
df.sort_values("int_col")
.groupby(df["int_col"].notnull())
.apply(lambda df: (df.int_col.rank(method="min").sub(1).div(len(df) - 1)))
.T.reset_index(drop=True)
.iloc[:, 0]
.rename(expr.get_name())
)

backend.assert_series_equal(result, expected)
36 changes: 18 additions & 18 deletions ibis/expr/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
import ibis.expr.operations.relations as rels
import ibis.expr.types as ir
from ibis import util
from ibis.common.deferred import _, deferred, var
from ibis.common.deferred import deferred, var
from ibis.common.exceptions import IbisTypeError, IntegrityError
from ibis.common.patterns import Eq, In, pattern
from ibis.common.patterns import Eq, In, pattern, replace
from ibis.util import Namespace

if TYPE_CHECKING:
Expand Down Expand Up @@ -163,25 +163,25 @@ def pushdown_selection_filters(parent, predicates):
return parent.copy(predicates=parent.predicates + tuple(simplified))


def windowize_function(expr, default_frame, merge_frames=False):
func, frame = var("func"), var("frame")

wrap_analytic = (p.Analytic | p.Reduction) >> c.WindowFunction(_, default_frame)
merge_windows = p.WindowFunction(func, frame) >> c.WindowFunction(
func,
frame.copy(
order_by=frame.order_by + default_frame.order_by,
group_by=frame.group_by + default_frame.group_by,
),
)
@replace(p.Analytic | p.Reduction)
def wrap_analytic(_, default_frame):
return ops.WindowFunction(_, default_frame)


@replace(p.WindowFunction)
def merge_windows(_, default_frame):
group_by = tuple(toolz.unique(_.frame.group_by + default_frame.group_by))
order_by = tuple(toolz.unique(_.frame.order_by + default_frame.order_by))
frame = _.frame.copy(group_by=group_by, order_by=order_by)
return ops.WindowFunction(_.func, frame)


def windowize_function(expr, default_frame, merge_frames=False):
ctx = {"default_frame": default_frame}
node = expr.op()
if merge_frames:
# it only happens in ibis.expr.groupby.GroupedTable, but the projector
# changes the windowization order to put everything here
node = node.replace(merge_windows, filter=p.Value)
node = node.replace(wrap_analytic, filter=p.Value & ~p.WindowFunction)

node = node.replace(merge_windows, filter=p.Value, context=ctx)
node = node.replace(wrap_analytic, filter=p.Value & ~p.WindowFunction, context=ctx)
return node.to_expr()


Expand Down
6 changes: 2 additions & 4 deletions ibis/expr/types/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,17 +758,15 @@ def over(

def bind(table):
frame = window.bind(table)
expr = an.windowize_function(self, frame)
expr = an.windowize_function(self, frame, merge_frames=True)
if expr.equals(self):
raise com.IbisTypeError(
"No reduction or analytic function found to construct a window expression"
)
return expr

op = self.op()
if isinstance(op, ops.WindowFunction):
return op.func.to_expr().over(window)
elif isinstance(window, bl.WindowBuilder):
if isinstance(window, bl.WindowBuilder):
if table := an.find_first_base_table(self.op()):
return bind(table)
else:
Expand Down
7 changes: 7 additions & 0 deletions ibis/tests/expr/test_window_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,10 @@ def test_value_over_api(alltypes):
expr = t.f.cumsum().over(range=(-1, 1), group_by=[t.g, t.a], order_by=[t.f])
expected = t.f.cumsum().over(w2)
assert expr.equals(expected)


def test_rank_followed_by_over_call_merge_frames(alltypes):
t = alltypes
expr1 = t.f.percent_rank().over(ibis.window(group_by=t.f.notnull()))
expr2 = ibis.percent_rank().over(group_by=t.f.notnull(), order_by=t.f).resolve(t)
assert expr1.equals(expr2)

0 comments on commit 5dcd218

Please sign in to comment.