Skip to content

Commit

Permalink
refactor(ir): merge ops.WindowFrame node into ops.WindowFunction (#…
Browse files Browse the repository at this point in the history
…8779)

`WindowFrame` was only used by `WindowFunction` adding an unnecessary
indirection to the window abstraction. While merging them is useful from
the `IR` perspective, backends like `pandas` must extract the frame part
so that the executor can reuse the window frame computation.
  • Loading branch information
kszucs authored Apr 4, 2024
1 parent bffaaa5 commit 3cd5a1a
Show file tree
Hide file tree
Showing 20 changed files with 270 additions and 299 deletions.
6 changes: 4 additions & 2 deletions ibis/backends/dask/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
PandasLimit,
PandasResetIndex,
PandasScalarSubquery,
PandasWindowFrame,
PandasWindowFunction,
plan,
)
from ibis.common.exceptions import UnboundExpressionError
Expand Down Expand Up @@ -271,7 +273,7 @@ def agg(df, order_keys):
############################ Window functions #############################

@classmethod
def visit(cls, op: ops.WindowFrame, table, start, end, **kwargs):
def visit(cls, op: PandasWindowFrame, table, start, end, **kwargs):
table = table.compute()
if isinstance(start, dd.Series):
start = start.compute()
Expand All @@ -280,7 +282,7 @@ def visit(cls, op: ops.WindowFrame, table, start, end, **kwargs):
return super().visit(op, table=table, start=start, end=end, **kwargs)

@classmethod
def visit(cls, op: ops.WindowFunction, func, frame):
def visit(cls, op: PandasWindowFunction, func, frame):
result = super().visit(op, func=func, frame=frame)
return cls.asseries(result)

Expand Down
10 changes: 3 additions & 7 deletions ibis/backends/mssql/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,10 @@
# * Boolean expressions MUST be used in a WHERE clause, i.e., SELECT * FROM t WHERE 1 is not allowed


@replace(
p.WindowFunction(
p.Reduction & ~p.ReductionVectorizedUDF, frame=y @ p.WindowFrame(order_by=())
)
)
def rewrite_rows_range_order_by_window(_, y, **kwargs):
@replace(p.WindowFunction(p.Reduction & ~p.ReductionVectorizedUDF, order_by=()))
def rewrite_rows_range_order_by_window(_, **kwargs):
# MSSQL requires an order by in a window frame that has either ROWS or RANGE
return _.copy(frame=y.copy(order_by=(_.func.arg,)))
return _.copy(order_by=(_.func.arg,))


@public
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/oracle/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ def visit_DateTruncate(self, op, *, arg, unit):

visit_TimestampTruncate = visit_DateTruncate

def visit_Window(self, op, *, how, func, start, end, group_by, order_by):
def visit_WindowFunction(self, op, *, how, func, start, end, group_by, order_by):
# Oracle has two (more?) types of analytic functions you can use inside OVER.
#
# The first group accepts an "analytic clause" which is decomposed into the
Expand Down
14 changes: 7 additions & 7 deletions ibis/backends/pandas/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
PandasRename,
PandasResetIndex,
PandasScalarSubquery,
PandasWindowFrame,
PandasWindowFunction,
plan,
)
from ibis.common.dispatch import Dispatched
Expand Down Expand Up @@ -468,9 +470,7 @@ def visit(cls, op: ops.WindowBoundary, value, preceding):
return value

@classmethod
def visit(
cls, op: ops.WindowFrame, table, start, end, group_by, order_by, **kwargs
):
def visit(cls, op: PandasWindowFrame, table, how, start, end, group_by, order_by):
if start is not None and op.start.preceding:
start = -start
if end is not None and op.end.preceding:
Expand All @@ -494,19 +494,19 @@ def visit(

if start is None and end is None:
return frame
elif op.how == "rows":
elif how == "rows":
return RowsFrame(parent=frame)
elif op.how == "range":
elif how == "range":
if len(order_keys) != 1:
raise NotImplementedError(
"Only single column order by is supported for range window frames"
)
return RangeFrame(parent=frame, order_key=order_keys[0])
else:
raise NotImplementedError(f"Unsupported window frame type: {op.how}")
raise NotImplementedError(f"Unsupported window frame type: {how}")

@classmethod
def visit(cls, op: ops.WindowFunction, func, frame):
def visit(cls, op: PandasWindowFunction, func, frame):
if isinstance(op.func, ops.Analytic):
order_keys = [key.name for key in op.frame.order_by]
return frame.apply_analytic(func, order_keys=order_keys)
Expand Down
82 changes: 53 additions & 29 deletions ibis/backends/pandas/rewrites.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from typing import Optional

from public import public

import ibis
Expand All @@ -8,9 +10,9 @@
import ibis.expr.operations as ops
from ibis.common.annotations import attribute
from ibis.common.collections import FrozenDict
from ibis.common.patterns import replace
from ibis.common.patterns import InstanceOf, replace
from ibis.common.typing import VarTuple # noqa: TCH001
from ibis.expr.rewrites import replace_parameter, rewrite_stringslice
from ibis.expr.rewrites import p, replace_parameter, rewrite_stringslice
from ibis.expr.schema import Schema
from ibis.util import gen_name

Expand All @@ -19,10 +21,6 @@ class PandasRelation(ops.Relation):
pass


class PandasValue(ops.Value):
pass


@public
class PandasRename(PandasRelation):
parent: ops.Relation
Expand Down Expand Up @@ -114,7 +112,7 @@ def schema(self):


@public
class PandasScalarSubquery(PandasValue):
class PandasScalarSubquery(ops.Value):
# variant with no integrity checks
rel: ops.Relation

Expand All @@ -125,10 +123,43 @@ def dtype(self):
return self.rel.schema.types[0]


@public
class PandasWindowFrame(ops.Node):
table: ops.Relation
how: str
start: Optional[ops.Value]
end: Optional[ops.Value]
group_by: VarTuple[ops.Column]
order_by: VarTuple[ops.SortKey]


@public
class PandasWindowFunction(ops.Value):
func: ops.Value
frame: PandasWindowFrame

shape = ds.columnar

@property
def dtype(self):
return self.func.dtype


def is_columnar(node):
return isinstance(node, ops.Value) and node.shape.is_columnar()


computable_column = p.Value(shape=ds.columnar) & ~InstanceOf(
(
ops.Reduction,
ops.Analytic,
ops.SortKey,
ops.WindowFunction,
ops.WindowBoundary,
)
)


@replace(ops.Project)
def rewrite_project(_, **kwargs):
unnests = []
Expand All @@ -143,17 +174,9 @@ def rewrite_project(_, **kwargs):
selects = {ops.Field(_.parent, k): k for k in _.parent.schema}
for node in winfuncs:
# add computed values from the window function
values = list(node.func.__args__)
# add computed values from the window frame
values += node.frame.group_by
values += [key.expr for key in node.frame.order_by]
if node.frame.start is not None:
values.append(node.frame.start.value)
if node.frame.end is not None:
values.append(node.frame.end.value)

for v in values:
if is_columnar(v) and v not in selects:
columns = node.find(computable_column, filter=ops.Value)
for v in columns:
if v not in selects:
selects[v] = gen_name("value")

# STEP 1: construct the pre-projection
Expand All @@ -163,15 +186,16 @@ def rewrite_project(_, **kwargs):
# STEP 2: construct new window function nodes
metrics = {}
for node in winfuncs:
frame = node.frame
start = None if frame.start is None else frame.start.replace(subs)
end = None if frame.end is None else frame.end.replace(subs)
order_by = [key.replace(subs) for key in frame.order_by]
group_by = [key.replace(subs) for key in frame.group_by]
frame = frame.__class__(
proj, start=start, end=end, group_by=group_by, order_by=order_by
subbed = node.replace(subs, filter=ops.Value)
frame = PandasWindowFrame(
table=proj,
how=subbed.how,
start=subbed.start,
end=subbed.end,
group_by=subbed.group_by,
order_by=subbed.order_by,
)
metrics[node] = ops.WindowFunction(node.func.replace(subs), frame)
metrics[node] = PandasWindowFunction(subbed.func, frame)

# STEP 3: reconstruct the current projection with the window functions
subs.update(metrics)
Expand All @@ -190,9 +214,9 @@ def rewrite_aggregate(_, **kwargs):

reductions = {}
for v in _.metrics.values():
for reduction in v.find_topmost(ops.Reduction):
for arg in reduction.__args__:
if is_columnar(arg) and arg not in selects:
for reduction in v.find(ops.Reduction, filter=ops.Value):
for arg in reduction.find(computable_column, filter=ops.Value):
if arg not in selects:
selects[arg] = gen_name("value")
if reduction not in reductions:
reductions[reduction] = gen_name("reduction")
Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/pyspark/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ def visit_JSONGetItem(self, op, *, arg, index):
path = self.f.format_string(fmt, index)
return self.f.get_json_object(arg, path)

def visit_Window(self, op, *, func, group_by, order_by, **kwargs):
def visit_WindowFunction(self, op, *, func, group_by, order_by, **kwargs):
if isinstance(op.func, ops.Analytic) and not isinstance(
op.func, (FirstValue, LastValue)
):
Expand All @@ -417,7 +417,7 @@ def visit_Window(self, op, *, func, group_by, order_by, **kwargs):
order = sge.Order(expressions=[NULL])
return sge.Window(this=func, partition_by=group_by, order=order)
else:
return super().visit_Window(
return super().visit_WindowFunction(
op, func=func, group_by=group_by, order_by=order_by, **kwargs
)

Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/snowflake/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,7 @@ def visit_Xor(self, op, *, left, right):
# boolxor accepts numerics ... and returns a boolean? wtf?
return self.f.boolxor(self.cast(left, dt.int8), self.cast(right, dt.int8))

def visit_Window(self, op, *, how, func, start, end, group_by, order_by):
def visit_WindowFunction(self, op, *, how, func, start, end, group_by, order_by):
if start is None:
start = {}
if end is None:
Expand Down
11 changes: 2 additions & 9 deletions ibis/backends/sql/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,14 +177,7 @@ class SQLGlotCompiler(abc.ABC):
"""A sequence of rewrites to apply to the expression tree before compilation."""

extra_supported_ops: frozenset = frozenset(
(
ops.Project,
ops.Filter,
ops.Sort,
ops.WindowFunction,
ops.RowsWindowFrame,
ops.RangeWindowFrame,
)
(ops.Project, ops.Filter, ops.Sort, ops.WindowFunction)
)
"""A frozenset of ops classes that are supported, but don't have explicit
`visit_*` methods (usually due to being handled by rewrite rules). Used by
Expand Down Expand Up @@ -975,7 +968,7 @@ def visit_WindowBoundary(self, op, *, value, preceding):
# that corresponds to _only_ this information
return {"value": value, "side": "preceding" if preceding else "following"}

def visit_Window(self, op, *, how, func, start, end, group_by, order_by):
def visit_WindowFunction(self, op, *, how, func, start, end, group_by, order_by):
if start is None:
start = {}
if end is None:
Expand Down
Loading

0 comments on commit 3cd5a1a

Please sign in to comment.