Skip to content

Commit

Permalink
refactor(analysis): vastly simplify windowize_function
Browse files Browse the repository at this point in the history
  • Loading branch information
kszucs committed Oct 12, 2023
1 parent 3edd8f7 commit 998bbaa
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 93 deletions.
3 changes: 1 addition & 2 deletions ibis/backends/dask/tests/execution/test_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,7 @@ def test_batting_avg_change_in_games_per_year(players, players_df):


@pytest.mark.xfail(
raises=AssertionError,
reason="Dask doesn't support the `rank` method on SeriesGroupBy",
raises=AttributeError, reason="'Series' object has no attribute 'rank'"
)
def test_batting_most_hits(players, players_df):
expr = players.mutate(
Expand Down
35 changes: 14 additions & 21 deletions ibis/common/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,8 @@ def map(self, fn: Callable, filter: Optional[Any] = None) -> dict[Node, Any]:
the results as the second and the results of the children as keyword
arguments.
filter
Pattern-like object to filter out nodes from the traversal. Essentially
the traversal will only visit nodes that match the given pattern and
stop otherwise.
Pattern-like object to filter out nodes from the traversal. The traversal
will only visit nodes that match the given pattern and stop otherwise.
Returns
-------
Expand All @@ -171,9 +170,8 @@ def find(self, type: type | tuple[type], filter: Optional[Any] = None) -> set[No
type
Type or tuple of types to find.
filter
Pattern-like object to filter out nodes from the traversal. Essentially
the traversal will only visit nodes that match the given pattern and
stop otherwise.
Pattern-like object to filter out nodes from the traversal. The traversal
will only visit nodes that match the given pattern and stop otherwise.
Returns
-------
Expand All @@ -197,9 +195,8 @@ def match(
Pattern to match. `ibis.common.pattern()` function is used to coerce the
input value into a pattern. See the pattern module for more details.
filter
Pattern-like object to filter out nodes from the traversal. Essentially
the traversal will only visit nodes that match the given pattern and
stop otherwise.
Pattern-like object to filter out nodes from the traversal. The traversal
will only visit nodes that match the given pattern and stop otherwise.
context
Optional context to use for the pattern matching.
Expand Down Expand Up @@ -288,9 +285,8 @@ def from_bfs(cls, root: Node, filter: Optional[Any] = None) -> Self:
root
Root node of the graph.
filter
Pattern-like object to filter out nodes from the traversal. Essentially
the traversal will only visit nodes that match the given pattern and
stop otherwise.
Pattern-like object to filter out nodes from the traversal. The traversal
will only visit nodes that match the given pattern and stop otherwise.
Returns
-------
Expand Down Expand Up @@ -338,9 +334,8 @@ def from_dfs(cls, root: Node, filter: Optional[Any] = None) -> Self:
root
Root node of the graph.
filter
Pattern-like object to filter out nodes from the traversal. Essentially
the traversal will only visit nodes that match the given pattern and
stop otherwise.
Pattern-like object to filter out nodes from the traversal. The traversal
will only visit nodes that match the given pattern and stop otherwise.
Returns
-------
Expand Down Expand Up @@ -441,9 +436,8 @@ def bfs(node: Node, filter: Optional[Any] = None) -> Graph:
node
Root node of the graph.
filter
Pattern-like object to filter out nodes from the traversal. Essentially
the traversal will only visit nodes that match the given pattern and
stop otherwise.
Pattern-like object to filter out nodes from the traversal. The traversal
will only visit nodes that match the given pattern and stop otherwise.
Returns
-------
Expand All @@ -460,9 +454,8 @@ def dfs(node: Node, filter: Optional[Any] = None) -> Graph:
node
Root node of the graph.
filter
Pattern-like object to filter out nodes from the traversal. Essentially
the traversal will only visit nodes that match the given pattern and
stop otherwise.
Pattern-like object to filter out nodes from the traversal. The traversal
will only visit nodes that match the given pattern and stop otherwise.
Returns
-------
Expand Down
11 changes: 6 additions & 5 deletions ibis/common/tests/test_graph_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from typing import Optional

import pytest
from typing_extensions import Self # noqa: TCH002

from ibis.common.collections import frozendict
Expand Down Expand Up @@ -32,11 +33,11 @@ def generate_node(depth):
)


def test_generate_node():
for depth in [0, 1, 2, 10, 100]:
n = generate_node(depth)
assert isinstance(n, MyNode)
assert len(Graph.from_bfs(n).nodes()) == depth + 1
@pytest.mark.parametrize("depth", [0, 1, 10])
def test_generate_node(depth):
n = generate_node(depth)
assert isinstance(n, MyNode)
assert len(Graph.from_bfs(n).nodes()) == depth + 1


def test_bfs(benchmark):
Expand Down
59 changes: 17 additions & 42 deletions ibis/expr/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import ibis.expr.types as ir
from ibis import util
from ibis.common.annotations import ValidationError
from ibis.common.deferred import deferred, var
from ibis.common.deferred import _, deferred, var
from ibis.common.exceptions import IbisTypeError, IntegrityError
from ibis.common.patterns import pattern
from ibis.util import Namespace
Expand Down Expand Up @@ -338,48 +338,24 @@ def propagate_down_window(func: ops.Value, frame: ops.WindowFrame):
return type(func)(*clean_args)


# TODO(kszucs): rewrite to receive and return an ops.Node
def windowize_function(expr, frame):
assert isinstance(expr, ir.Expr), type(expr)
assert isinstance(frame, ops.WindowFrame)

def _windowize(op, frame):
if isinstance(op, ops.WindowFunction):
walked_child = _walk(op.func, frame)
walked = walked_child.to_expr().over(op.frame).op()
elif isinstance(op, ops.Value):
walked = _walk(op, frame)
else:
walked = op

if isinstance(walked, (ops.Analytic, ops.Reduction)):
return op.to_expr().over(frame).op()
elif isinstance(walked, ops.WindowFunction):
if frame is not None:
frame = walked.frame.copy(
group_by=frame.group_by + walked.frame.group_by,
order_by=frame.order_by + walked.frame.order_by,
)
return walked.to_expr().over(frame).op()
else:
return walked
else:
return walked

def _walk(op, frame):
# TODO(kszucs): rewrite to use the substitute utility
windowed_args = []
for arg in op.args:
if isinstance(arg, ops.Value):
arg = _windowize(arg, frame)
elif isinstance(arg, tuple):
arg = tuple(_windowize(x, frame) for x in arg)
def windowize_function(expr, default_frame):
func = var("func")
frame = var("frame")

windowed_args.append(arg)
wrap_analytic = (p.Analytic | p.Reduction) >> c.WindowFunction(_, default_frame)
merge_frames = p.WindowFunction(func, frame) >> c.WindowFunction(
func,
frame.copy(
order_by=frame.order_by + default_frame.order_by,
group_by=frame.group_by + default_frame.group_by,
),
)

return type(op)(*windowed_args)
node = expr.op()
node = node.replace(merge_frames, filter=p.Value)
node = node.replace(wrap_analytic, filter=p.Value & ~p.WindowFunction)

return _windowize(expr.op(), frame).to_expr()
return node.to_expr()


def contains_first_or_last_agg(exprs):
Expand Down Expand Up @@ -458,8 +434,7 @@ def __init__(self, parent, proj_exprs):

default_frame = ops.RowsWindowFrame(table=parent)
self.clean_exprs = [
windowize_function(expr, frame=default_frame)
for expr in self.resolved_exprs
windowize_function(expr, default_frame) for expr in self.resolved_exprs
]

def get_result(self):
Expand Down
35 changes: 12 additions & 23 deletions ibis/expr/types/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from ibis.selectors import Selector
from ibis.expr.types.relations import bind_expr
import ibis.common.exceptions as com
from public import public

_function_types = tuple(
filter(
Expand Down Expand Up @@ -60,13 +61,11 @@ def _get_group_by_key(table, value):
yield value


# TODO(kszucs): make a builder class for this
@public
class GroupedTable:
"""An intermediate table expression to hold grouping information."""

def __init__(
self, table, by, having=None, order_by=None, window=None, **expressions
):
def __init__(self, table, by, having=None, order_by=None, **expressions):
self.table = table
self.by = list(
itertools.chain(
Expand All @@ -86,7 +85,6 @@ def __init__(

self._order_by = order_by or []
self._having = having or []
self._window = window

def __getitem__(self, args):
# Shortcut for projection with window functions
Expand Down Expand Up @@ -133,7 +131,6 @@ def having(self, expr: ir.BooleanScalar) -> GroupedTable:
self.by,
having=self._having + util.promote_list(expr),
order_by=self._order_by,
window=self._window,
)

def order_by(self, expr: ir.Value | Iterable[ir.Value]) -> GroupedTable:
Expand All @@ -158,7 +155,6 @@ def order_by(self, expr: ir.Value | Iterable[ir.Value]) -> GroupedTable:
self.by,
having=self._having,
order_by=self._order_by + util.promote_list(expr),
window=self._window,
)

def mutate(
Expand Down Expand Up @@ -250,33 +246,24 @@ def _selectables(self, *exprs, **kwexprs):
[`GroupedTable.mutate`](#ibis.expr.types.groupby.GroupedTable.mutate)
"""
table = self.table
default_frame = self._get_window()
default_frame = ops.RowsWindowFrame(
table=self.table,
group_by=bind_expr(self.table, self.by),
order_by=bind_expr(self.table, self._order_by),
)
return [
an.windowize_function(e2, frame=default_frame)
an.windowize_function(e2, default_frame)
for expr in exprs
for e1 in util.promote_list(expr)
for e2 in util.promote_list(table._ensure_expr(e1))
] + [
an.windowize_function(e, frame=default_frame).name(k)
an.windowize_function(e, default_frame).name(k)
for k, expr in kwexprs.items()
for e in util.promote_list(table._ensure_expr(expr))
]

projection = select

def _get_window(self):
if self._window is None:
return ops.RowsWindowFrame(
table=self.table,
group_by=self.by,
order_by=bind_expr(self.table, self._order_by),
)
else:
return self._window.copy(
groupy_by=bind_expr(self.table, self._window.group_by + self.by),
order_by=bind_expr(self.table, self._window.order_by + self._order_by),
)

def over(
self,
window=None,
Expand Down Expand Up @@ -347,6 +334,7 @@ def wrapper(self, *args, **kwargs):
return wrapper


@public
class GroupedArray:
def __init__(self, arr, parent):
self.arr = arr
Expand All @@ -361,6 +349,7 @@ def __init__(self, arr, parent):
group_concat = _group_agg_dispatch("group_concat")


@public
class GroupedNumbers(GroupedArray):
mean = _group_agg_dispatch("mean")
sum = _group_agg_dispatch("sum")
1 change: 1 addition & 0 deletions ibis/tests/sql/test_select_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ def test_bug_duplicated_where(airlines, snapshot):
expr = t.group_by("dest").mutate(
dest_avg=t.arrdelay.mean(), dev=t.arrdelay - t.arrdelay.mean()
)

tmp1 = expr[expr.dev.notnull()]
tmp2 = tmp1.order_by(ibis.desc("dev"))
expr = tmp2.limit(10)
Expand Down

0 comments on commit 998bbaa

Please sign in to comment.