diff --git a/ibis/backends/dask/executor.py b/ibis/backends/dask/executor.py
index 64797cf8885e..bdcc55b6e9f5 100644
--- a/ibis/backends/dask/executor.py
+++ b/ibis/backends/dask/executor.py
@@ -22,6 +22,8 @@
PandasLimit,
PandasResetIndex,
PandasScalarSubquery,
+ PandasWindowFrame,
+ PandasWindowFunction,
plan,
)
from ibis.common.exceptions import UnboundExpressionError
@@ -271,7 +273,7 @@ def agg(df, order_keys):
############################ Window functions #############################
@classmethod
- def visit(cls, op: ops.WindowFrame, table, start, end, **kwargs):
+ def visit(cls, op: PandasWindowFrame, table, start, end, **kwargs):
table = table.compute()
if isinstance(start, dd.Series):
start = start.compute()
@@ -280,7 +282,7 @@ def visit(cls, op: ops.WindowFrame, table, start, end, **kwargs):
return super().visit(op, table=table, start=start, end=end, **kwargs)
@classmethod
- def visit(cls, op: ops.WindowFunction, func, frame):
+ def visit(cls, op: PandasWindowFunction, func, frame):
result = super().visit(op, func=func, frame=frame)
return cls.asseries(result)
diff --git a/ibis/backends/mssql/compiler.py b/ibis/backends/mssql/compiler.py
index d4d71b666d47..c2c89124a31d 100644
--- a/ibis/backends/mssql/compiler.py
+++ b/ibis/backends/mssql/compiler.py
@@ -46,14 +46,10 @@
# * Boolean expressions MUST be used in a WHERE clause, i.e., SELECT * FROM t WHERE 1 is not allowed
-@replace(
- p.WindowFunction(
- p.Reduction & ~p.ReductionVectorizedUDF, frame=y @ p.WindowFrame(order_by=())
- )
-)
-def rewrite_rows_range_order_by_window(_, y, **kwargs):
+@replace(p.WindowFunction(p.Reduction & ~p.ReductionVectorizedUDF, order_by=()))
+def rewrite_rows_range_order_by_window(_, **kwargs):
# MSSQL requires an order by in a window frame that has either ROWS or RANGE
- return _.copy(frame=y.copy(order_by=(_.func.arg,)))
+ return _.copy(order_by=(_.func.arg,))
@public
diff --git a/ibis/backends/oracle/compiler.py b/ibis/backends/oracle/compiler.py
index 1f4ecf27b450..4645467ace92 100644
--- a/ibis/backends/oracle/compiler.py
+++ b/ibis/backends/oracle/compiler.py
@@ -368,7 +368,7 @@ def visit_DateTruncate(self, op, *, arg, unit):
visit_TimestampTruncate = visit_DateTruncate
- def visit_Window(self, op, *, how, func, start, end, group_by, order_by):
+ def visit_WindowFunction(self, op, *, how, func, start, end, group_by, order_by):
# Oracle has two (more?) types of analytic functions you can use inside OVER.
#
# The first group accepts an "analytic clause" which is decomposed into the
diff --git a/ibis/backends/pandas/executor.py b/ibis/backends/pandas/executor.py
index 7a422ef4dbcb..a97aa8512813 100644
--- a/ibis/backends/pandas/executor.py
+++ b/ibis/backends/pandas/executor.py
@@ -24,6 +24,8 @@
PandasRename,
PandasResetIndex,
PandasScalarSubquery,
+ PandasWindowFrame,
+ PandasWindowFunction,
plan,
)
from ibis.common.dispatch import Dispatched
@@ -468,9 +470,7 @@ def visit(cls, op: ops.WindowBoundary, value, preceding):
return value
@classmethod
- def visit(
- cls, op: ops.WindowFrame, table, start, end, group_by, order_by, **kwargs
- ):
+ def visit(cls, op: PandasWindowFrame, table, how, start, end, group_by, order_by):
if start is not None and op.start.preceding:
start = -start
if end is not None and op.end.preceding:
@@ -494,19 +494,19 @@ def visit(
if start is None and end is None:
return frame
- elif op.how == "rows":
+ elif how == "rows":
return RowsFrame(parent=frame)
- elif op.how == "range":
+ elif how == "range":
if len(order_keys) != 1:
raise NotImplementedError(
"Only single column order by is supported for range window frames"
)
return RangeFrame(parent=frame, order_key=order_keys[0])
else:
- raise NotImplementedError(f"Unsupported window frame type: {op.how}")
+ raise NotImplementedError(f"Unsupported window frame type: {how}")
@classmethod
- def visit(cls, op: ops.WindowFunction, func, frame):
+ def visit(cls, op: PandasWindowFunction, func, frame):
if isinstance(op.func, ops.Analytic):
order_keys = [key.name for key in op.frame.order_by]
return frame.apply_analytic(func, order_keys=order_keys)
diff --git a/ibis/backends/pandas/rewrites.py b/ibis/backends/pandas/rewrites.py
index 84cd1050a305..53f2b6a813e3 100644
--- a/ibis/backends/pandas/rewrites.py
+++ b/ibis/backends/pandas/rewrites.py
@@ -1,5 +1,7 @@
from __future__ import annotations
+from typing import Optional
+
from public import public
import ibis
@@ -8,9 +10,9 @@
import ibis.expr.operations as ops
from ibis.common.annotations import attribute
from ibis.common.collections import FrozenDict
-from ibis.common.patterns import replace
+from ibis.common.patterns import InstanceOf, replace
from ibis.common.typing import VarTuple # noqa: TCH001
-from ibis.expr.rewrites import replace_parameter, rewrite_stringslice
+from ibis.expr.rewrites import p, replace_parameter, rewrite_stringslice
from ibis.expr.schema import Schema
from ibis.util import gen_name
@@ -19,10 +21,6 @@ class PandasRelation(ops.Relation):
pass
-class PandasValue(ops.Value):
- pass
-
-
@public
class PandasRename(PandasRelation):
parent: ops.Relation
@@ -114,7 +112,7 @@ def schema(self):
@public
-class PandasScalarSubquery(PandasValue):
+class PandasScalarSubquery(ops.Value):
# variant with no integrity checks
rel: ops.Relation
@@ -125,10 +123,43 @@ def dtype(self):
return self.rel.schema.types[0]
+@public
+class PandasWindowFrame(ops.Node):
+ table: ops.Relation
+ how: str
+ start: Optional[ops.Value]
+ end: Optional[ops.Value]
+ group_by: VarTuple[ops.Column]
+ order_by: VarTuple[ops.SortKey]
+
+
+@public
+class PandasWindowFunction(ops.Value):
+ func: ops.Value
+ frame: PandasWindowFrame
+
+ shape = ds.columnar
+
+ @property
+ def dtype(self):
+ return self.func.dtype
+
+
def is_columnar(node):
return isinstance(node, ops.Value) and node.shape.is_columnar()
+computable_column = p.Value(shape=ds.columnar) & ~InstanceOf(
+ (
+ ops.Reduction,
+ ops.Analytic,
+ ops.SortKey,
+ ops.WindowFunction,
+ ops.WindowBoundary,
+ )
+)
+
+
@replace(ops.Project)
def rewrite_project(_, **kwargs):
unnests = []
@@ -143,17 +174,9 @@ def rewrite_project(_, **kwargs):
selects = {ops.Field(_.parent, k): k for k in _.parent.schema}
for node in winfuncs:
# add computed values from the window function
- values = list(node.func.__args__)
- # add computed values from the window frame
- values += node.frame.group_by
- values += [key.expr for key in node.frame.order_by]
- if node.frame.start is not None:
- values.append(node.frame.start.value)
- if node.frame.end is not None:
- values.append(node.frame.end.value)
-
- for v in values:
- if is_columnar(v) and v not in selects:
+ columns = node.find(computable_column, filter=ops.Value)
+ for v in columns:
+ if v not in selects:
selects[v] = gen_name("value")
# STEP 1: construct the pre-projection
@@ -163,15 +186,16 @@ def rewrite_project(_, **kwargs):
# STEP 2: construct new window function nodes
metrics = {}
for node in winfuncs:
- frame = node.frame
- start = None if frame.start is None else frame.start.replace(subs)
- end = None if frame.end is None else frame.end.replace(subs)
- order_by = [key.replace(subs) for key in frame.order_by]
- group_by = [key.replace(subs) for key in frame.group_by]
- frame = frame.__class__(
- proj, start=start, end=end, group_by=group_by, order_by=order_by
+ subbed = node.replace(subs, filter=ops.Value)
+ frame = PandasWindowFrame(
+ table=proj,
+ how=subbed.how,
+ start=subbed.start,
+ end=subbed.end,
+ group_by=subbed.group_by,
+ order_by=subbed.order_by,
)
- metrics[node] = ops.WindowFunction(node.func.replace(subs), frame)
+ metrics[node] = PandasWindowFunction(subbed.func, frame)
# STEP 3: reconstruct the current projection with the window functions
subs.update(metrics)
@@ -190,9 +214,9 @@ def rewrite_aggregate(_, **kwargs):
reductions = {}
for v in _.metrics.values():
- for reduction in v.find_topmost(ops.Reduction):
- for arg in reduction.__args__:
- if is_columnar(arg) and arg not in selects:
+ for reduction in v.find(ops.Reduction, filter=ops.Value):
+ for arg in reduction.find(computable_column, filter=ops.Value):
+ if arg not in selects:
selects[arg] = gen_name("value")
if reduction not in reductions:
reductions[reduction] = gen_name("reduction")
diff --git a/ibis/backends/pyspark/compiler.py b/ibis/backends/pyspark/compiler.py
index ce8a262123de..c9fa54b87bc8 100644
--- a/ibis/backends/pyspark/compiler.py
+++ b/ibis/backends/pyspark/compiler.py
@@ -405,7 +405,7 @@ def visit_JSONGetItem(self, op, *, arg, index):
path = self.f.format_string(fmt, index)
return self.f.get_json_object(arg, path)
- def visit_Window(self, op, *, func, group_by, order_by, **kwargs):
+ def visit_WindowFunction(self, op, *, func, group_by, order_by, **kwargs):
if isinstance(op.func, ops.Analytic) and not isinstance(
op.func, (FirstValue, LastValue)
):
@@ -417,7 +417,7 @@ def visit_Window(self, op, *, func, group_by, order_by, **kwargs):
order = sge.Order(expressions=[NULL])
return sge.Window(this=func, partition_by=group_by, order=order)
else:
- return super().visit_Window(
+ return super().visit_WindowFunction(
op, func=func, group_by=group_by, order_by=order_by, **kwargs
)
diff --git a/ibis/backends/snowflake/compiler.py b/ibis/backends/snowflake/compiler.py
index 2362096c1d24..00353ce42235 100644
--- a/ibis/backends/snowflake/compiler.py
+++ b/ibis/backends/snowflake/compiler.py
@@ -480,7 +480,7 @@ def visit_Xor(self, op, *, left, right):
# boolxor accepts numerics ... and returns a boolean? wtf?
return self.f.boolxor(self.cast(left, dt.int8), self.cast(right, dt.int8))
- def visit_Window(self, op, *, how, func, start, end, group_by, order_by):
+ def visit_WindowFunction(self, op, *, how, func, start, end, group_by, order_by):
if start is None:
start = {}
if end is None:
diff --git a/ibis/backends/sql/compiler.py b/ibis/backends/sql/compiler.py
index 443ac1811324..b231c1c3e76f 100644
--- a/ibis/backends/sql/compiler.py
+++ b/ibis/backends/sql/compiler.py
@@ -177,14 +177,7 @@ class SQLGlotCompiler(abc.ABC):
"""A sequence of rewrites to apply to the expression tree before compilation."""
extra_supported_ops: frozenset = frozenset(
- (
- ops.Project,
- ops.Filter,
- ops.Sort,
- ops.WindowFunction,
- ops.RowsWindowFrame,
- ops.RangeWindowFrame,
- )
+ (ops.Project, ops.Filter, ops.Sort, ops.WindowFunction)
)
"""A frozenset of ops classes that are supported, but don't have explicit
`visit_*` methods (usually due to being handled by rewrite rules). Used by
@@ -975,7 +968,7 @@ def visit_WindowBoundary(self, op, *, value, preceding):
# that corresponds to _only_ this information
return {"value": value, "side": "preceding" if preceding else "following"}
- def visit_Window(self, op, *, how, func, start, end, group_by, order_by):
+ def visit_WindowFunction(self, op, *, how, func, start, end, group_by, order_by):
if start is None:
start = {}
if end is None:
diff --git a/ibis/backends/sql/rewrites.py b/ibis/backends/sql/rewrites.py
index 8a082923968e..914059075106 100644
--- a/ibis/backends/sql/rewrites.py
+++ b/ibis/backends/sql/rewrites.py
@@ -4,13 +4,12 @@
import operator
from functools import reduce
-from typing import TYPE_CHECKING, Any, Literal, Optional
+from typing import TYPE_CHECKING, Any
import toolz
from public import public
import ibis.common.exceptions as com
-import ibis.expr.datashape as ds
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis.common.annotations import attribute
@@ -84,24 +83,6 @@ def dtype(self):
return self.arg.dtype
-@public
-class Window(ops.Value):
- """Window modelled after SQL's window statements."""
-
- how: Literal["rows", "range"]
- func: ops.Reduction | ops.Analytic
- start: Optional[ops.WindowBoundary] = None
- end: Optional[ops.WindowBoundary] = None
- group_by: VarTuple[ops.Column] = ()
- order_by: VarTuple[ops.SortKey] = ()
-
- shape = ds.columnar
-
- @attribute
- def dtype(self):
- return self.func.dtype
-
-
# TODO(kszucs): there is a better strategy to rewrite the relational operations
# to Select nodes by wrapping the leaf nodes in a Select node and then merging
# Project, Filter, Sort, etc. incrementally into the Select node. This way we
@@ -126,30 +107,16 @@ def sort_to_select(_, **kwargs):
return Select(_.parent, selections=_.values, sort_keys=_.keys)
-@replace(p.WindowFunction)
-def window_function_to_window(_, **kwargs):
- """Convert a WindowFunction node to a Window node.
-
- Also rewrites first -> first_value, last -> last_value.
- """
- func = _.func
- if isinstance(func, (ops.First, ops.Last)):
- if func.where is not None:
- raise com.UnsupportedOperationError(
- f"`{type(func).__name__.lower()}` with `where` is unsupported "
- "in a window function"
- )
- cls = FirstValue if isinstance(func, ops.First) else LastValue
- func = cls(func.arg)
-
- return Window(
- how=_.frame.how,
- func=func,
- start=_.frame.start,
- end=_.frame.end,
- group_by=_.frame.group_by,
- order_by=_.frame.order_by,
- )
+@replace(p.WindowFunction(p.First | p.Last))
+def first_to_firstvalue(_, **kwargs):
+ """Convert a First or Last node to a FirstValue or LastValue node."""
+ if _.func.where is not None:
+ raise com.UnsupportedOperationError(
+ f"`{type(_.func).__name__.lower()}` with `where` is unsupported "
+ "in a window function"
+ )
+ klass = FirstValue if isinstance(_.func, ops.First) else LastValue
+ return _.copy(func=klass(_.func.arg))
@replace(Object(Select, Object(Select)))
@@ -162,10 +129,10 @@ def merge_select_select(_, **kwargs):
"""
# don't merge if either the outer or the inner select has window functions
for v in _.selections.values():
- if v.find(Window, filter=ops.Value):
+ if v.find(ops.WindowFunction, filter=ops.Value):
return _
for v in _.parent.selections.values():
- if v.find((Window, ops.Unnest), filter=ops.Value):
+ if v.find((ops.WindowFunction, ops.Unnest), filter=ops.Value):
return _
for v in _.predicates:
if v.find((ops.ExistsSubquery, ops.InSubquery), filter=ops.Value):
@@ -240,7 +207,7 @@ def sqlize(
| project_to_select
| filter_to_select
| sort_to_select
- | window_function_to_window,
+ | first_to_firstvalue,
context=context,
)
@@ -268,10 +235,12 @@ def wrap(node, _, **kwargs):
"""Add an ORDER BY clause to rank window functions that don't have one."""
-add_order_by_to_empty_ranking_window_functions = p.WindowFunction(
- func=p.NTile(y),
- frame=p.WindowFrame(order_by=()) >> _.copy(order_by=(y,)),
-)
+
+
+@replace(p.WindowFunction(func=p.NTile(y), order_by=()))
+def add_order_by_to_empty_ranking_window_functions(_, **kwargs):
+ return _.copy(order_by=(y,))
+
"""Replace checks against an empty right side with `False`."""
empty_in_values_right_side = p.InValues(options=()) >> d.Literal(False, dtype=dt.bool)
@@ -322,28 +291,27 @@ def rewrite_sample_as_filter(_, **kwargs):
return ops.Filter(_.parent, (ops.LessEqual(ops.RandomScalar(), _.fraction),))
-@replace(p.WindowFunction(frame=y @ p.WindowFrame(order_by=())))
-def rewrite_empty_order_by_window(_, y, **kwargs):
- return _.copy(frame=y.copy(order_by=(ops.NULL,)))
+@replace(p.WindowFunction(order_by=()))
+def rewrite_empty_order_by_window(_, **kwargs):
+ return _.copy(order_by=(ops.NULL,))
-@replace(p.WindowFunction(p.RowNumber | p.NTile, y))
-def exclude_unsupported_window_frame_from_row_number(_, y):
- return ops.Subtract(_.copy(frame=y.copy(start=None, end=0)), 1)
+@replace(p.WindowFunction(p.RowNumber | p.NTile))
+def exclude_unsupported_window_frame_from_row_number(_, **kwargs):
+ return ops.Subtract(_.copy(start=None, end=0), 1)
-@replace(p.WindowFunction(p.MinRank | p.DenseRank, y @ p.WindowFrame(start=None)))
-def exclude_unsupported_window_frame_from_rank(_, y):
+@replace(p.WindowFunction(p.MinRank | p.DenseRank, start=None))
+def exclude_unsupported_window_frame_from_rank(_, **kwargs):
return ops.Subtract(
- _.copy(frame=y.copy(start=None, end=0, order_by=y.order_by or (ops.NULL,))), 1
+ _.copy(start=None, end=0, order_by=_.order_by or (ops.NULL,)), 1
)
@replace(
p.WindowFunction(
- p.Lag | p.Lead | p.PercentRank | p.CumeDist | p.Any | p.All,
- y @ p.WindowFrame(start=None),
+ p.Lag | p.Lead | p.PercentRank | p.CumeDist | p.Any | p.All, start=None
)
)
-def exclude_unsupported_window_frame_from_ops(_, y, **kwargs):
- return _.copy(frame=y.copy(start=None, end=0, order_by=y.order_by or (ops.NULL,)))
+def exclude_unsupported_window_frame_from_ops(_, **kwargs):
+ return _.copy(start=None, end=0, order_by=_.order_by or (ops.NULL,))
diff --git a/ibis/expr/builders.py b/ibis/expr/builders.py
index b164f877bdf2..b98459927021 100644
--- a/ibis/expr/builders.py
+++ b/ibis/expr/builders.py
@@ -146,7 +146,7 @@ class WindowBuilder(Builder):
start: Optional[RangeWindowBoundary] = None
end: Optional[RangeWindowBoundary] = None
groupings: VarTuple[Union[str, Resolver, ops.Value]] = ()
- orderings: VarTuple[Union[str, Resolver, ops.Value]] = ()
+ orderings: VarTuple[Union[str, Resolver, ops.SortKey]] = ()
@attribute
def _table(self):
@@ -225,42 +225,18 @@ def group_by(self, expr) -> Self:
def order_by(self, expr) -> Self:
return self.copy(orderings=self.orderings + util.promote_tuple(expr))
- @annotated
- def bind(self, table: Optional[ops.Relation]):
- table = table or self._table
- if table is None:
- raise IbisInputError("Unable to bind window frame to a table")
+ def bind(self, table):
+ from ibis.expr.types.relations import bind
- table = table.to_expr()
-
- def bind_value(value):
- if isinstance(value, str):
- return table._get_column(value)
- elif isinstance(value, Resolver):
- return value.resolve({"_": table})
+ if table is None:
+ if self._table is None:
+ raise IbisInputError("Cannot bind window frame without a table")
else:
- return value
-
- groupings = map(bind_value, self.groupings)
- orderings = map(bind_value, self.orderings)
- if self.how == "rows":
- return ops.RowsWindowFrame(
- table=table,
- start=self.start,
- end=self.end,
- group_by=groupings,
- order_by=orderings,
- )
- elif self.how == "range":
- return ops.RangeWindowFrame(
- table=table,
- start=self.start,
- end=self.end,
- group_by=groupings,
- order_by=orderings,
- )
- else:
- raise ValueError(f"Unsupported `{self.how}` window type")
+ table = self._table.to_expr()
+
+ grouping = bind(table, self.groupings)
+ orderings = bind(table, self.orderings)
+ return self.copy(groupings=grouping, orderings=orderings)
class LegacyWindowBuilder(WindowBuilder):
diff --git a/ibis/expr/operations/window.py b/ibis/expr/operations/window.py
index 724f4335007a..aed700d058e8 100644
--- a/ibis/expr/operations/window.py
+++ b/ibis/expr/operations/window.py
@@ -1,6 +1,6 @@
from __future__ import annotations
-from abc import abstractmethod
+from typing import Literal as LiteralType
from typing import Optional
from public import public
@@ -17,7 +17,6 @@
from ibis.expr.operations.generic import Literal
from ibis.expr.operations.numeric import Negate
from ibis.expr.operations.reductions import Reduction # noqa: TCH001
-from ibis.expr.operations.relations import Relation # noqa: TCH001
from ibis.expr.operations.sortkeys import SortKey # noqa: TCH001
T = TypeVar("T", bound=dt.Numeric | dt.Interval, covariant=True)
@@ -61,66 +60,44 @@ def __coerce__(cls, value, **kwargs):
@public
-class WindowFrame(Value):
- """A window frame operation bound to a table."""
-
- table: Relation
+class WindowFunction(Value):
+ func: Analytic | Reduction
+ how: LiteralType["rows", "range"] = "rows"
+ start: Optional[WindowBoundary[dt.Numeric | dt.Interval]] = None
+ end: Optional[WindowBoundary[dt.Numeric | dt.Interval]] = None
group_by: VarTuple[Column] = ()
order_by: VarTuple[SortKey] = ()
+ dtype = rlz.dtype_like("func")
shape = ds.columnar
- def __init__(self, start, end, **kwargs):
- if start and end:
- if not (
- (start.dtype.is_interval() and end.dtype.is_interval())
- or (start.dtype.is_numeric() and end.dtype.is_numeric())
+ def __init__(self, how, start, end, **kwargs):
+ if how == "rows":
+ if start and not start.dtype.is_integer():
+ raise com.IbisTypeError(
+ "Row-based window frame start boundary must be an integer"
+ )
+ if end and not end.dtype.is_integer():
+ raise com.IbisTypeError(
+ "Row-based window frame end boundary must be an integer"
+ )
+ elif how == "range":
+ if (
+ start
+ and end
+ and not (
+ (start.dtype.is_interval() and end.dtype.is_interval())
+ or (start.dtype.is_numeric() and end.dtype.is_numeric())
+ )
):
raise com.IbisTypeError(
"Window frame start and end boundaries must have the same datatype"
)
- super().__init__(start=start, end=end, **kwargs)
-
- def dtype(self) -> dt.DataType:
- return dt.Array(dt.Struct.from_tuples(self.table.schema.items()))
-
- @property
- @abstractmethod
- def start(self): ...
-
- @property
- @abstractmethod
- def end(self): ...
-
-
-@public
-class RowsWindowFrame(WindowFrame):
- how = "rows"
- start: Optional[WindowBoundary[dt.Integer]] = None
- end: Optional[WindowBoundary] = None
-
-
-@public
-class RangeWindowFrame(WindowFrame):
- how = "range"
- start: Optional[WindowBoundary[dt.Numeric | dt.Interval]] = None
- end: Optional[WindowBoundary[dt.Numeric | dt.Interval]] = None
-
-
-@public
-class WindowFunction(Value):
- func: Analytic | Reduction
- frame: WindowFrame
-
- dtype = rlz.dtype_like("func")
- shape = ds.columnar
-
- def __init__(self, func, frame):
- if func.relations and frame.table not in func.relations:
- raise com.RelationError(
- "The reduction has different parent relation than the window"
+ else:
+ raise com.IbisTypeError(
+ f"Window frame type must be either 'rows' or 'range', got {how}"
)
- super().__init__(func=func, frame=frame)
+ super().__init__(how=how, start=start, end=end, **kwargs)
@property
def name(self):
diff --git a/ibis/expr/rewrites.py b/ibis/expr/rewrites.py
index 25e9b9c9a3b0..5b03b024d89e 100644
--- a/ibis/expr/rewrites.py
+++ b/ibis/expr/rewrites.py
@@ -105,7 +105,7 @@ def rewrite_stringslice(_, **kwargs):
@replace(p.Analytic)
def project_wrap_analytic(_, rel):
# Wrap analytic functions in a window function
- return ops.WindowFunction(_, ops.RowsWindowFrame(rel))
+ return ops.WindowFunction(_)
@replace(p.Reduction)
@@ -114,7 +114,7 @@ def project_wrap_reduction(_, rel):
if _.relations == {rel}:
# The reduction is fully originating from the `rel`, so turn
# it into a window function of `rel`
- return ops.WindowFunction(_, ops.RowsWindowFrame(rel))
+ return ops.WindowFunction(_)
else:
# 1. The reduction doesn't depend on any table, constructed from
# scalar values, so turn it into a scalar subquery.
@@ -156,37 +156,47 @@ def rewrite_filter_input(value):
@replace(p.Analytic | p.Reduction)
-def window_wrap_reduction(_, frame):
+def window_wrap_reduction(_, window):
# Wrap analytic and reduction functions in a window function. Used in the
# value.over() API.
- return ops.WindowFunction(_, frame)
+ return ops.WindowFunction(
+ _,
+ how=window.how,
+ start=window.start,
+ end=window.end,
+ group_by=window.groupings,
+ order_by=window.orderings,
+ )
@replace(p.WindowFunction)
-def window_merge_frames(_, frame):
+def window_merge_frames(_, window):
# Merge window frames, used in the value.over() and groupby.select() APIs.
- if _.frame.start and frame.start and _.frame.start != frame.start:
+ if _.how != window.how:
+ raise ExpressionError(
+ f"Unable to merge {_.how} window with {window.how} window"
+ )
+ elif _.start and window.start and _.start != window.start:
raise ExpressionError(
"Unable to merge windows with conflicting `start` boundary"
)
- if _.frame.end and frame.end and _.frame.end != frame.end:
+ elif _.end and window.end and _.end != window.end:
raise ExpressionError("Unable to merge windows with conflicting `end` boundary")
- start = _.frame.start or frame.start
- end = _.frame.end or frame.end
- group_by = tuple(toolz.unique(_.frame.group_by + frame.group_by))
+ start = _.start or window.start
+ end = _.end or window.end
+ group_by = tuple(toolz.unique(_.group_by + window.groupings))
- order_by = {}
- for sort_key in frame.order_by + _.frame.order_by:
- order_by[sort_key.expr] = sort_key.ascending
- order_by = tuple(ops.SortKey(k, v) for k, v in order_by.items())
+ order_keys = {}
+ for sort_key in window.orderings + _.order_by:
+ order_keys[sort_key.expr] = sort_key.ascending
+ order_by = (ops.SortKey(k, v) for k, v in order_keys.items())
- frame = _.frame.copy(start=start, end=end, group_by=group_by, order_by=order_by)
- return ops.WindowFunction(_.func, frame)
+ return _.copy(start=start, end=end, group_by=group_by, order_by=order_by)
-def rewrite_window_input(value, frame):
- context = {"frame": frame}
+def rewrite_window_input(value, window):
+ context = {"window": window}
# if self is a reduction or analytic function, wrap it in a window function
node = value.replace(
window_wrap_reduction,
diff --git a/ibis/expr/tests/snapshots/test_format/test_window_group_by/repr.txt b/ibis/expr/tests/snapshots/test_format/test_window_group_by/repr.txt
index a062b06c486d..b79ee8160531 100644
--- a/ibis/expr/tests/snapshots/test_format/test_window_group_by/repr.txt
+++ b/ibis/expr/tests/snapshots/test_format/test_window_group_by/repr.txt
@@ -2,4 +2,4 @@ r0 := UnboundTable: t
a int64
b string
-Mean(a): WindowFunction(func=Mean(r0.a), frame=RowsWindowFrame(table=r0, group_by=[r0.b]))
\ No newline at end of file
+Mean(a): WindowFunction(func=Mean(r0.a), how='rows', group_by=[r0.b])
\ No newline at end of file
diff --git a/ibis/expr/tests/snapshots/test_format/test_window_no_group_by/repr.txt b/ibis/expr/tests/snapshots/test_format/test_window_no_group_by/repr.txt
index b46b19f10b7a..acd8865e7735 100644
--- a/ibis/expr/tests/snapshots/test_format/test_window_no_group_by/repr.txt
+++ b/ibis/expr/tests/snapshots/test_format/test_window_no_group_by/repr.txt
@@ -2,4 +2,4 @@ r0 := UnboundTable: t
a int64
b string
-Mean(a): WindowFunction(func=Mean(r0.a), frame=RowsWindowFrame(table=r0, start=WindowBoundary(value=0, preceding=True)))
\ No newline at end of file
+Mean(a): WindowFunction(func=Mean(r0.a), how='rows', start=WindowBoundary(value=0, preceding=True))
\ No newline at end of file
diff --git a/ibis/expr/types/generic.py b/ibis/expr/types/generic.py
index 7da5f63c98b6..6f359fa2c9a1 100644
--- a/ibis/expr/types/generic.py
+++ b/ibis/expr/types/generic.py
@@ -723,7 +723,7 @@ def over(
A window function expression
"""
- node = self.op()
+
if window is None:
window = ibis.window(
rows=rows,
@@ -734,17 +734,18 @@ def over(
elif not isinstance(window, bl.WindowBuilder):
raise com.IbisTypeError("Unexpected window type: {window!r}")
+ node = self.op()
if len(node.relations) == 0:
table = None
elif len(node.relations) == 1:
(table,) = node.relations
+ table = table.to_expr()
else:
raise com.RelationError("Cannot use window with multiple tables")
@deferrable
def bind(table):
- frame = window.bind(table)
- winfunc = rewrite_window_input(node, frame)
+ winfunc = rewrite_window_input(node, window.bind(table))
if winfunc == node:
raise com.IbisTypeError(
"No reduction or analytic function found to construct a window expression"
diff --git a/ibis/expr/types/groupby.py b/ibis/expr/types/groupby.py
index 8f3fba1bab60..32bccaa20d7e 100644
--- a/ibis/expr/types/groupby.py
+++ b/ibis/expr/types/groupby.py
@@ -201,13 +201,9 @@ def _selectables(self, *exprs, **kwexprs):
[`GroupedTable.mutate`](#ibis.expr.types.groupby.GroupedTable.mutate)
"""
table = self.table.to_expr()
- frame = ops.RowsWindowFrame(
- table=self.table,
- group_by=self.groupings,
- order_by=self.orderings,
- )
values = bind(table, (exprs, kwexprs))
- return [rewrite_window_input(expr.op(), frame).to_expr() for expr in values]
+ window = ibis.window(group_by=self.groupings, order_by=self.orderings)
+ return [rewrite_window_input(expr.op(), window).to_expr() for expr in values]
projection = select
diff --git a/ibis/expr/types/relations.py b/ibis/expr/types/relations.py
index 716d38d4f956..3bd2090eb833 100644
--- a/ibis/expr/types/relations.py
+++ b/ibis/expr/types/relations.py
@@ -21,7 +21,7 @@
import ibis.expr.operations as ops
import ibis.expr.schema as sch
from ibis import util
-from ibis.common.deferred import Deferred
+from ibis.common.deferred import Deferred, Resolver
from ibis.expr.types.core import Expr, _FixedTextJupyterMixin
from ibis.expr.types.generic import ValueExpr, literal
from ibis.selectors import Selector
@@ -106,6 +106,8 @@ def bind(table: Table, value: Any) -> Iterator[ir.Value]:
yield value._get_column(name)
elif isinstance(value, Deferred):
yield value.resolve(table)
+ elif isinstance(value, Resolver):
+ yield value.resolve({"_": table})
elif isinstance(value, Selector):
yield from value.expand(table)
elif isinstance(value, Mapping):
diff --git a/ibis/expr/visualize.py b/ibis/expr/visualize.py
index 2dd8faf367d2..7b75dcbb3d8b 100644
--- a/ibis/expr/visualize.py
+++ b/ibis/expr/visualize.py
@@ -60,23 +60,17 @@ def get_label(node):
ops.Field,
ops.Alias,
ops.PhysicalTable,
- ops.window.RangeWindowFrame,
),
)
else None
)
if nodename is not None:
- # [TODO] Don't show nodename because it's too long and ruins the image
- if isinstance(node, ops.window.RangeWindowFrame):
- label_fmt = "<{}>"
- label = label_fmt.format(escape(name))
+ if isinstance(node, ops.Relation):
+ label_fmt = "<{}: {}{}>"
else:
- if isinstance(node, ops.Relation):
- label_fmt = "<{}: {}{}>"
- else:
- label_fmt = '<{}: {}
:: {}>'
- # typename is already escaped
- label = label_fmt.format(escape(nodename), escape(name), typename)
+ label_fmt = '<{}: {}
:: {}>'
+ # typename is already escaped
+ label = label_fmt.format(escape(nodename), escape(name), typename)
else:
if isinstance(node, ops.Relation):
label_fmt = "<{}{}>"
diff --git a/ibis/tests/expr/snapshots/test_format_sql_operations/test_memoize_insert_sort_key/repr.txt b/ibis/tests/expr/snapshots/test_format_sql_operations/test_memoize_insert_sort_key/repr.txt
index 3514fa501a73..cc2dc4dddffe 100644
--- a/ibis/tests/expr/snapshots/test_format_sql_operations/test_memoize_insert_sort_key/repr.txt
+++ b/ibis/tests/expr/snapshots/test_format_sql_operations/test_memoize_insert_sort_key/repr.txt
@@ -36,8 +36,8 @@ r1 := Project[r0]
r2 := Project[r1]
arrdelay: r1.arrdelay
dest: r1.dest
- dest_avg: WindowFunction(func=Mean(r1.arrdelay), frame=RowsWindowFrame(table=r1, group_by=[r1.dest]))
- dev: r1.arrdelay - WindowFunction(func=Mean(r1.arrdelay), frame=RowsWindowFrame(table=r1, group_by=[r1.dest]))
+ dest_avg: WindowFunction(func=Mean(r1.arrdelay), how='rows', group_by=[r1.dest])
+ dev: r1.arrdelay - WindowFunction(func=Mean(r1.arrdelay), how='rows', group_by=[r1.dest])
r3 := Filter[r2]
NotNull(r2.dev)
diff --git a/ibis/tests/expr/test_window_frames.py b/ibis/tests/expr/test_window_frames.py
index 967d28b20844..f44b739e1cc8 100644
--- a/ibis/tests/expr/test_window_frames.py
+++ b/ibis/tests/expr/test_window_frames.py
@@ -14,6 +14,11 @@
from ibis.common.patterns import NoMatch, Pattern
+@pytest.fixture
+def t(alltypes):
+ return alltypes
+
+
def test_window_boundary():
# the boundary value must be either numeric or interval
b = ops.WindowBoundary(5, preceding=False)
@@ -214,39 +219,44 @@ def test_window_builder_between():
assert w9.how == "range"
-def test_window_api_supports_value_expressions(alltypes):
- t = alltypes
-
+def test_window_api_supports_value_expressions(t):
w = ibis.window(between=(t.d, t.d + 1), group_by=t.b, order_by=t.c)
- assert w.bind(t) == ops.RowsWindowFrame(
- table=t,
+ func = t.a.sum().over(w).op()
+ expected = ops.WindowFunction(
+ func=t.a.sum(),
+ how="rows",
start=ops.WindowBoundary(t.d, preceding=False),
end=ops.WindowBoundary(t.d + 1, preceding=False),
group_by=(t.b,),
order_by=(t.c,),
)
+ assert func == expected
-def test_window_api_supports_scalar_order_by(alltypes):
- t = alltypes
-
- w = ibis.window(order_by=ibis.NA)
- assert w.bind(t) == ops.RowsWindowFrame(
- table=t,
+def test_window_api_supports_scalar_order_by(t):
+ window = ibis.window(order_by=ibis.NA)
+ expr = t.a.sum().over(window).op()
+ expected = ops.WindowFunction(
+ t.a.sum(),
+ how="rows",
start=None,
end=None,
group_by=(),
order_by=(ibis.NA.op(),),
)
+ assert expr == expected
- w = ibis.window(order_by=ibis.random())
- assert w.bind(t) == ops.RowsWindowFrame(
- table=t,
+ window = ibis.window(order_by=ibis.random())
+ expr = t.a.sum().over(window).op()
+ expected = ops.WindowFunction(
+ t.a.sum(),
+ how="rows",
start=None,
end=None,
group_by=(),
- order_by=(ibis.random().op(),),
+ order_by=[ibis.random()],
)
+ assert expr == expected
def test_window_api_properly_determines_how():
@@ -281,8 +291,7 @@ def test_window_api_mutually_exclusive_options():
ibis.window(range=(None, 5), between=(None, 5))
-def test_window_builder_methods(alltypes):
- t = alltypes
+def test_window_builder_methods(t):
w1 = ibis.window(preceding=5, following=1, group_by=t.a, order_by=t.b)
w2 = w1.group_by(t.c)
@@ -327,8 +336,21 @@ def test_window_api_preceding_following(method, is_preceding):
def test_window_api_trailing_range():
t = ibis.table([("col", "int64")], name="t")
- w = ibis.trailing_range_window(ibis.interval(days=1), order_by="col")
- w.bind(t)
+ start = ibis.interval(days=1)
+ end = ibis.literal(0).cast(start.type())
+
+ window = ibis.trailing_range_window(start, order_by="col")
+ expr = t.col.sum().over(window).op()
+
+ expected = ops.WindowFunction(
+ t.col.sum(),
+ how="range",
+ start=ops.WindowBoundary(start, preceding=True),
+ end=ops.WindowBoundary(end, preceding=False),
+ group_by=(),
+ order_by=[t.col],
+ )
+ assert expr == expected
@pytest.mark.parametrize(
@@ -405,34 +427,39 @@ def test_window_api_preceding_following_invalid_tuple(kind, begin, end):
ibis.window(**kwargs)
-def test_window_bind_to_table(alltypes):
- t = alltypes
+def test_window_bind_to_table(t):
spec = ibis.window(group_by="g", order_by=ibis.desc("f"))
- frame = spec.bind(t)
- expected = ops.RowsWindowFrame(table=t, group_by=[t.g], order_by=[t.f.desc()])
-
- assert frame == expected
+ window = t.a.sum().over(spec).op()
+ expected = ops.WindowFunction(
+ t.a.sum(),
+ how="rows",
+ start=None,
+ end=None,
+ group_by=[t.g],
+ order_by=[t.f.desc()],
+ )
+ assert window == expected
-def test_window_bind_value_expression_using_over(alltypes):
+def test_window_bind_value_expression_using_over(t):
# GH #542
- t = alltypes
-
w = ibis.window(group_by="g", order_by="f")
- expr = t.f.lag().over(w)
-
- frame = expr.op().frame
- expected = ops.RowsWindowFrame(table=t, group_by=[t.g], order_by=[t.f.asc()])
-
- assert frame == expected
+ window = t.f.lag().over(w).op()
+ expected = ops.WindowFunction(
+ t.f.lag(),
+ how="rows",
+ start=None,
+ end=None,
+ group_by=[t.g],
+ order_by=[t.f.asc()],
+ )
+ assert window == expected
-def test_window_analysis_propagate_nested_windows(alltypes):
+def test_window_analysis_propagate_nested_windows(t):
# GH #469
- t = alltypes
-
w = ibis.window(group_by=t.g, order_by=t.f)
col = (t.f - t.f.lag()).lag()
@@ -442,8 +469,7 @@ def test_window_analysis_propagate_nested_windows(alltypes):
assert result.equals(expected)
-def test_window_analysis_combine_group_by(alltypes):
- t = alltypes
+def test_window_analysis_combine_group_by(t):
w = ibis.window(group_by=t.g, order_by=t.f)
diff = t.d - t.d.lag()
@@ -468,7 +494,7 @@ def test_window_analysis_combine_preserves_existing_window():
)
w = ibis.cumulative_window(order_by=t.one)
mut = t.group_by(t.three).mutate(four=t.two.sum().over(w))
- assert mut.op().values["four"].frame.start is None
+ assert mut.op().values["four"].start is None
def test_window_analysis_auto_windowize_bug():
@@ -496,20 +522,24 @@ def metric(x):
assert enriched.equals(expected)
-def test_windowization_wraps_reduction_inside_a_nested_value_expression(alltypes):
- t = alltypes
+def test_windowization_wraps_reduction_inside_a_nested_value_expression(t):
win = ibis.window(
following=0,
group_by=[t.g],
order_by=[t.a],
)
expr = (t.f == 0).notany().over(win)
- assert expr.op() == ops.Not(
+ expected = ops.Not(
ops.WindowFunction(
func=ops.Any(t.f == 0),
- frame=ops.RowsWindowFrame(table=t, end=0, group_by=[t.g], order_by=[t.a]),
+ how="rows",
+ start=None,
+ end=0,
+ group_by=[t.g],
+ order_by=[t.a],
)
)
+ assert expr.op() == expected
def test_group_by_with_window_function_preserves_range(alltypes):
@@ -525,9 +555,11 @@ def test_group_by_with_window_function_preserves_range(alltypes):
"three": t.three,
"four": ops.WindowFunction(
func=ops.Sum(t.two),
- frame=ops.RowsWindowFrame(
- table=t, end=0, group_by=[t.three], order_by=[t.one]
- ),
+ how="rows",
+ start=None,
+ end=0,
+ group_by=[t.three],
+ order_by=[t.one],
),
},
)