diff --git a/ibis/backends/dask/executor.py b/ibis/backends/dask/executor.py index 64797cf8885e..bdcc55b6e9f5 100644 --- a/ibis/backends/dask/executor.py +++ b/ibis/backends/dask/executor.py @@ -22,6 +22,8 @@ PandasLimit, PandasResetIndex, PandasScalarSubquery, + PandasWindowFrame, + PandasWindowFunction, plan, ) from ibis.common.exceptions import UnboundExpressionError @@ -271,7 +273,7 @@ def agg(df, order_keys): ############################ Window functions ############################# @classmethod - def visit(cls, op: ops.WindowFrame, table, start, end, **kwargs): + def visit(cls, op: PandasWindowFrame, table, start, end, **kwargs): table = table.compute() if isinstance(start, dd.Series): start = start.compute() @@ -280,7 +282,7 @@ def visit(cls, op: ops.WindowFrame, table, start, end, **kwargs): return super().visit(op, table=table, start=start, end=end, **kwargs) @classmethod - def visit(cls, op: ops.WindowFunction, func, frame): + def visit(cls, op: PandasWindowFunction, func, frame): result = super().visit(op, func=func, frame=frame) return cls.asseries(result) diff --git a/ibis/backends/mssql/compiler.py b/ibis/backends/mssql/compiler.py index d4d71b666d47..c2c89124a31d 100644 --- a/ibis/backends/mssql/compiler.py +++ b/ibis/backends/mssql/compiler.py @@ -46,14 +46,10 @@ # * Boolean expressions MUST be used in a WHERE clause, i.e., SELECT * FROM t WHERE 1 is not allowed -@replace( - p.WindowFunction( - p.Reduction & ~p.ReductionVectorizedUDF, frame=y @ p.WindowFrame(order_by=()) - ) -) -def rewrite_rows_range_order_by_window(_, y, **kwargs): +@replace(p.WindowFunction(p.Reduction & ~p.ReductionVectorizedUDF, order_by=())) +def rewrite_rows_range_order_by_window(_, **kwargs): # MSSQL requires an order by in a window frame that has either ROWS or RANGE - return _.copy(frame=y.copy(order_by=(_.func.arg,))) + return _.copy(order_by=(_.func.arg,)) @public diff --git a/ibis/backends/oracle/compiler.py b/ibis/backends/oracle/compiler.py index 1f4ecf27b450..4645467ace92 100644 --- a/ibis/backends/oracle/compiler.py +++ b/ibis/backends/oracle/compiler.py @@ -368,7 +368,7 @@ def visit_DateTruncate(self, op, *, arg, unit): visit_TimestampTruncate = visit_DateTruncate - def visit_Window(self, op, *, how, func, start, end, group_by, order_by): + def visit_WindowFunction(self, op, *, how, func, start, end, group_by, order_by): # Oracle has two (more?) types of analytic functions you can use inside OVER. # # The first group accepts an "analytic clause" which is decomposed into the diff --git a/ibis/backends/pandas/executor.py b/ibis/backends/pandas/executor.py index 7a422ef4dbcb..a97aa8512813 100644 --- a/ibis/backends/pandas/executor.py +++ b/ibis/backends/pandas/executor.py @@ -24,6 +24,8 @@ PandasRename, PandasResetIndex, PandasScalarSubquery, + PandasWindowFrame, + PandasWindowFunction, plan, ) from ibis.common.dispatch import Dispatched @@ -468,9 +470,7 @@ def visit(cls, op: ops.WindowBoundary, value, preceding): return value @classmethod - def visit( - cls, op: ops.WindowFrame, table, start, end, group_by, order_by, **kwargs - ): + def visit(cls, op: PandasWindowFrame, table, how, start, end, group_by, order_by): if start is not None and op.start.preceding: start = -start if end is not None and op.end.preceding: @@ -494,19 +494,19 @@ def visit( if start is None and end is None: return frame - elif op.how == "rows": + elif how == "rows": return RowsFrame(parent=frame) - elif op.how == "range": + elif how == "range": if len(order_keys) != 1: raise NotImplementedError( "Only single column order by is supported for range window frames" ) return RangeFrame(parent=frame, order_key=order_keys[0]) else: - raise NotImplementedError(f"Unsupported window frame type: {op.how}") + raise NotImplementedError(f"Unsupported window frame type: {how}") @classmethod - def visit(cls, op: ops.WindowFunction, func, frame): + def visit(cls, op: PandasWindowFunction, func, frame): if isinstance(op.func, ops.Analytic): order_keys = [key.name for key in op.frame.order_by] return frame.apply_analytic(func, order_keys=order_keys) diff --git a/ibis/backends/pandas/rewrites.py b/ibis/backends/pandas/rewrites.py index 84cd1050a305..53f2b6a813e3 100644 --- a/ibis/backends/pandas/rewrites.py +++ b/ibis/backends/pandas/rewrites.py @@ -1,5 +1,7 @@ from __future__ import annotations +from typing import Optional + from public import public import ibis @@ -8,9 +10,9 @@ import ibis.expr.operations as ops from ibis.common.annotations import attribute from ibis.common.collections import FrozenDict -from ibis.common.patterns import replace +from ibis.common.patterns import InstanceOf, replace from ibis.common.typing import VarTuple # noqa: TCH001 -from ibis.expr.rewrites import replace_parameter, rewrite_stringslice +from ibis.expr.rewrites import p, replace_parameter, rewrite_stringslice from ibis.expr.schema import Schema from ibis.util import gen_name @@ -19,10 +21,6 @@ class PandasRelation(ops.Relation): pass -class PandasValue(ops.Value): - pass - - @public class PandasRename(PandasRelation): parent: ops.Relation @@ -114,7 +112,7 @@ def schema(self): @public -class PandasScalarSubquery(PandasValue): +class PandasScalarSubquery(ops.Value): # variant with no integrity checks rel: ops.Relation @@ -125,10 +123,43 @@ def dtype(self): return self.rel.schema.types[0] +@public +class PandasWindowFrame(ops.Node): + table: ops.Relation + how: str + start: Optional[ops.Value] + end: Optional[ops.Value] + group_by: VarTuple[ops.Column] + order_by: VarTuple[ops.SortKey] + + +@public +class PandasWindowFunction(ops.Value): + func: ops.Value + frame: PandasWindowFrame + + shape = ds.columnar + + @property + def dtype(self): + return self.func.dtype + + def is_columnar(node): return isinstance(node, ops.Value) and node.shape.is_columnar() +computable_column = p.Value(shape=ds.columnar) & ~InstanceOf( + ( + ops.Reduction, + ops.Analytic, + ops.SortKey, + ops.WindowFunction, + ops.WindowBoundary, + ) +) + + @replace(ops.Project) def rewrite_project(_, **kwargs): unnests = [] @@ -143,17 +174,9 @@ def rewrite_project(_, **kwargs): selects = {ops.Field(_.parent, k): k for k in _.parent.schema} for node in winfuncs: # add computed values from the window function - values = list(node.func.__args__) - # add computed values from the window frame - values += node.frame.group_by - values += [key.expr for key in node.frame.order_by] - if node.frame.start is not None: - values.append(node.frame.start.value) - if node.frame.end is not None: - values.append(node.frame.end.value) - - for v in values: - if is_columnar(v) and v not in selects: + columns = node.find(computable_column, filter=ops.Value) + for v in columns: + if v not in selects: selects[v] = gen_name("value") # STEP 1: construct the pre-projection @@ -163,15 +186,16 @@ def rewrite_project(_, **kwargs): # STEP 2: construct new window function nodes metrics = {} for node in winfuncs: - frame = node.frame - start = None if frame.start is None else frame.start.replace(subs) - end = None if frame.end is None else frame.end.replace(subs) - order_by = [key.replace(subs) for key in frame.order_by] - group_by = [key.replace(subs) for key in frame.group_by] - frame = frame.__class__( - proj, start=start, end=end, group_by=group_by, order_by=order_by + subbed = node.replace(subs, filter=ops.Value) + frame = PandasWindowFrame( + table=proj, + how=subbed.how, + start=subbed.start, + end=subbed.end, + group_by=subbed.group_by, + order_by=subbed.order_by, ) - metrics[node] = ops.WindowFunction(node.func.replace(subs), frame) + metrics[node] = PandasWindowFunction(subbed.func, frame) # STEP 3: reconstruct the current projection with the window functions subs.update(metrics) @@ -190,9 +214,9 @@ def rewrite_aggregate(_, **kwargs): reductions = {} for v in _.metrics.values(): - for reduction in v.find_topmost(ops.Reduction): - for arg in reduction.__args__: - if is_columnar(arg) and arg not in selects: + for reduction in v.find(ops.Reduction, filter=ops.Value): + for arg in reduction.find(computable_column, filter=ops.Value): + if arg not in selects: selects[arg] = gen_name("value") if reduction not in reductions: reductions[reduction] = gen_name("reduction") diff --git a/ibis/backends/pyspark/compiler.py b/ibis/backends/pyspark/compiler.py index ce8a262123de..c9fa54b87bc8 100644 --- a/ibis/backends/pyspark/compiler.py +++ b/ibis/backends/pyspark/compiler.py @@ -405,7 +405,7 @@ def visit_JSONGetItem(self, op, *, arg, index): path = self.f.format_string(fmt, index) return self.f.get_json_object(arg, path) - def visit_Window(self, op, *, func, group_by, order_by, **kwargs): + def visit_WindowFunction(self, op, *, func, group_by, order_by, **kwargs): if isinstance(op.func, ops.Analytic) and not isinstance( op.func, (FirstValue, LastValue) ): @@ -417,7 +417,7 @@ def visit_Window(self, op, *, func, group_by, order_by, **kwargs): order = sge.Order(expressions=[NULL]) return sge.Window(this=func, partition_by=group_by, order=order) else: - return super().visit_Window( + return super().visit_WindowFunction( op, func=func, group_by=group_by, order_by=order_by, **kwargs ) diff --git a/ibis/backends/snowflake/compiler.py b/ibis/backends/snowflake/compiler.py index 2362096c1d24..00353ce42235 100644 --- a/ibis/backends/snowflake/compiler.py +++ b/ibis/backends/snowflake/compiler.py @@ -480,7 +480,7 @@ def visit_Xor(self, op, *, left, right): # boolxor accepts numerics ... and returns a boolean? wtf? return self.f.boolxor(self.cast(left, dt.int8), self.cast(right, dt.int8)) - def visit_Window(self, op, *, how, func, start, end, group_by, order_by): + def visit_WindowFunction(self, op, *, how, func, start, end, group_by, order_by): if start is None: start = {} if end is None: diff --git a/ibis/backends/sql/compiler.py b/ibis/backends/sql/compiler.py index 443ac1811324..b231c1c3e76f 100644 --- a/ibis/backends/sql/compiler.py +++ b/ibis/backends/sql/compiler.py @@ -177,14 +177,7 @@ class SQLGlotCompiler(abc.ABC): """A sequence of rewrites to apply to the expression tree before compilation.""" extra_supported_ops: frozenset = frozenset( - ( - ops.Project, - ops.Filter, - ops.Sort, - ops.WindowFunction, - ops.RowsWindowFrame, - ops.RangeWindowFrame, - ) + (ops.Project, ops.Filter, ops.Sort, ops.WindowFunction) ) """A frozenset of ops classes that are supported, but don't have explicit `visit_*` methods (usually due to being handled by rewrite rules). Used by @@ -975,7 +968,7 @@ def visit_WindowBoundary(self, op, *, value, preceding): # that corresponds to _only_ this information return {"value": value, "side": "preceding" if preceding else "following"} - def visit_Window(self, op, *, how, func, start, end, group_by, order_by): + def visit_WindowFunction(self, op, *, how, func, start, end, group_by, order_by): if start is None: start = {} if end is None: diff --git a/ibis/backends/sql/rewrites.py b/ibis/backends/sql/rewrites.py index 8a082923968e..914059075106 100644 --- a/ibis/backends/sql/rewrites.py +++ b/ibis/backends/sql/rewrites.py @@ -4,13 +4,12 @@ import operator from functools import reduce -from typing import TYPE_CHECKING, Any, Literal, Optional +from typing import TYPE_CHECKING, Any import toolz from public import public import ibis.common.exceptions as com -import ibis.expr.datashape as ds import ibis.expr.datatypes as dt import ibis.expr.operations as ops from ibis.common.annotations import attribute @@ -84,24 +83,6 @@ def dtype(self): return self.arg.dtype -@public -class Window(ops.Value): - """Window modelled after SQL's window statements.""" - - how: Literal["rows", "range"] - func: ops.Reduction | ops.Analytic - start: Optional[ops.WindowBoundary] = None - end: Optional[ops.WindowBoundary] = None - group_by: VarTuple[ops.Column] = () - order_by: VarTuple[ops.SortKey] = () - - shape = ds.columnar - - @attribute - def dtype(self): - return self.func.dtype - - # TODO(kszucs): there is a better strategy to rewrite the relational operations # to Select nodes by wrapping the leaf nodes in a Select node and then merging # Project, Filter, Sort, etc. incrementally into the Select node. This way we @@ -126,30 +107,16 @@ def sort_to_select(_, **kwargs): return Select(_.parent, selections=_.values, sort_keys=_.keys) -@replace(p.WindowFunction) -def window_function_to_window(_, **kwargs): - """Convert a WindowFunction node to a Window node. - - Also rewrites first -> first_value, last -> last_value. - """ - func = _.func - if isinstance(func, (ops.First, ops.Last)): - if func.where is not None: - raise com.UnsupportedOperationError( - f"`{type(func).__name__.lower()}` with `where` is unsupported " - "in a window function" - ) - cls = FirstValue if isinstance(func, ops.First) else LastValue - func = cls(func.arg) - - return Window( - how=_.frame.how, - func=func, - start=_.frame.start, - end=_.frame.end, - group_by=_.frame.group_by, - order_by=_.frame.order_by, - ) +@replace(p.WindowFunction(p.First | p.Last)) +def first_to_firstvalue(_, **kwargs): + """Convert a First or Last node to a FirstValue or LastValue node.""" + if _.func.where is not None: + raise com.UnsupportedOperationError( + f"`{type(_.func).__name__.lower()}` with `where` is unsupported " + "in a window function" + ) + klass = FirstValue if isinstance(_.func, ops.First) else LastValue + return _.copy(func=klass(_.func.arg)) @replace(Object(Select, Object(Select))) @@ -162,10 +129,10 @@ def merge_select_select(_, **kwargs): """ # don't merge if either the outer or the inner select has window functions for v in _.selections.values(): - if v.find(Window, filter=ops.Value): + if v.find(ops.WindowFunction, filter=ops.Value): return _ for v in _.parent.selections.values(): - if v.find((Window, ops.Unnest), filter=ops.Value): + if v.find((ops.WindowFunction, ops.Unnest), filter=ops.Value): return _ for v in _.predicates: if v.find((ops.ExistsSubquery, ops.InSubquery), filter=ops.Value): @@ -240,7 +207,7 @@ def sqlize( | project_to_select | filter_to_select | sort_to_select - | window_function_to_window, + | first_to_firstvalue, context=context, ) @@ -268,10 +235,12 @@ def wrap(node, _, **kwargs): """Add an ORDER BY clause to rank window functions that don't have one.""" -add_order_by_to_empty_ranking_window_functions = p.WindowFunction( - func=p.NTile(y), - frame=p.WindowFrame(order_by=()) >> _.copy(order_by=(y,)), -) + + +@replace(p.WindowFunction(func=p.NTile(y), order_by=())) +def add_order_by_to_empty_ranking_window_functions(_, **kwargs): + return _.copy(order_by=(y,)) + """Replace checks against an empty right side with `False`.""" empty_in_values_right_side = p.InValues(options=()) >> d.Literal(False, dtype=dt.bool) @@ -322,28 +291,27 @@ def rewrite_sample_as_filter(_, **kwargs): return ops.Filter(_.parent, (ops.LessEqual(ops.RandomScalar(), _.fraction),)) -@replace(p.WindowFunction(frame=y @ p.WindowFrame(order_by=()))) -def rewrite_empty_order_by_window(_, y, **kwargs): - return _.copy(frame=y.copy(order_by=(ops.NULL,))) +@replace(p.WindowFunction(order_by=())) +def rewrite_empty_order_by_window(_, **kwargs): + return _.copy(order_by=(ops.NULL,)) -@replace(p.WindowFunction(p.RowNumber | p.NTile, y)) -def exclude_unsupported_window_frame_from_row_number(_, y): - return ops.Subtract(_.copy(frame=y.copy(start=None, end=0)), 1) +@replace(p.WindowFunction(p.RowNumber | p.NTile)) +def exclude_unsupported_window_frame_from_row_number(_, **kwargs): + return ops.Subtract(_.copy(start=None, end=0), 1) -@replace(p.WindowFunction(p.MinRank | p.DenseRank, y @ p.WindowFrame(start=None))) -def exclude_unsupported_window_frame_from_rank(_, y): +@replace(p.WindowFunction(p.MinRank | p.DenseRank, start=None)) +def exclude_unsupported_window_frame_from_rank(_, **kwargs): return ops.Subtract( - _.copy(frame=y.copy(start=None, end=0, order_by=y.order_by or (ops.NULL,))), 1 + _.copy(start=None, end=0, order_by=_.order_by or (ops.NULL,)), 1 ) @replace( p.WindowFunction( - p.Lag | p.Lead | p.PercentRank | p.CumeDist | p.Any | p.All, - y @ p.WindowFrame(start=None), + p.Lag | p.Lead | p.PercentRank | p.CumeDist | p.Any | p.All, start=None ) ) -def exclude_unsupported_window_frame_from_ops(_, y, **kwargs): - return _.copy(frame=y.copy(start=None, end=0, order_by=y.order_by or (ops.NULL,))) +def exclude_unsupported_window_frame_from_ops(_, **kwargs): + return _.copy(start=None, end=0, order_by=_.order_by or (ops.NULL,)) diff --git a/ibis/expr/builders.py b/ibis/expr/builders.py index b164f877bdf2..b98459927021 100644 --- a/ibis/expr/builders.py +++ b/ibis/expr/builders.py @@ -146,7 +146,7 @@ class WindowBuilder(Builder): start: Optional[RangeWindowBoundary] = None end: Optional[RangeWindowBoundary] = None groupings: VarTuple[Union[str, Resolver, ops.Value]] = () - orderings: VarTuple[Union[str, Resolver, ops.Value]] = () + orderings: VarTuple[Union[str, Resolver, ops.SortKey]] = () @attribute def _table(self): @@ -225,42 +225,18 @@ def group_by(self, expr) -> Self: def order_by(self, expr) -> Self: return self.copy(orderings=self.orderings + util.promote_tuple(expr)) - @annotated - def bind(self, table: Optional[ops.Relation]): - table = table or self._table - if table is None: - raise IbisInputError("Unable to bind window frame to a table") + def bind(self, table): + from ibis.expr.types.relations import bind - table = table.to_expr() - - def bind_value(value): - if isinstance(value, str): - return table._get_column(value) - elif isinstance(value, Resolver): - return value.resolve({"_": table}) + if table is None: + if self._table is None: + raise IbisInputError("Cannot bind window frame without a table") else: - return value - - groupings = map(bind_value, self.groupings) - orderings = map(bind_value, self.orderings) - if self.how == "rows": - return ops.RowsWindowFrame( - table=table, - start=self.start, - end=self.end, - group_by=groupings, - order_by=orderings, - ) - elif self.how == "range": - return ops.RangeWindowFrame( - table=table, - start=self.start, - end=self.end, - group_by=groupings, - order_by=orderings, - ) - else: - raise ValueError(f"Unsupported `{self.how}` window type") + table = self._table.to_expr() + + grouping = bind(table, self.groupings) + orderings = bind(table, self.orderings) + return self.copy(groupings=grouping, orderings=orderings) class LegacyWindowBuilder(WindowBuilder): diff --git a/ibis/expr/operations/window.py b/ibis/expr/operations/window.py index 724f4335007a..aed700d058e8 100644 --- a/ibis/expr/operations/window.py +++ b/ibis/expr/operations/window.py @@ -1,6 +1,6 @@ from __future__ import annotations -from abc import abstractmethod +from typing import Literal as LiteralType from typing import Optional from public import public @@ -17,7 +17,6 @@ from ibis.expr.operations.generic import Literal from ibis.expr.operations.numeric import Negate from ibis.expr.operations.reductions import Reduction # noqa: TCH001 -from ibis.expr.operations.relations import Relation # noqa: TCH001 from ibis.expr.operations.sortkeys import SortKey # noqa: TCH001 T = TypeVar("T", bound=dt.Numeric | dt.Interval, covariant=True) @@ -61,66 +60,44 @@ def __coerce__(cls, value, **kwargs): @public -class WindowFrame(Value): - """A window frame operation bound to a table.""" - - table: Relation +class WindowFunction(Value): + func: Analytic | Reduction + how: LiteralType["rows", "range"] = "rows" + start: Optional[WindowBoundary[dt.Numeric | dt.Interval]] = None + end: Optional[WindowBoundary[dt.Numeric | dt.Interval]] = None group_by: VarTuple[Column] = () order_by: VarTuple[SortKey] = () + dtype = rlz.dtype_like("func") shape = ds.columnar - def __init__(self, start, end, **kwargs): - if start and end: - if not ( - (start.dtype.is_interval() and end.dtype.is_interval()) - or (start.dtype.is_numeric() and end.dtype.is_numeric()) + def __init__(self, how, start, end, **kwargs): + if how == "rows": + if start and not start.dtype.is_integer(): + raise com.IbisTypeError( + "Row-based window frame start boundary must be an integer" + ) + if end and not end.dtype.is_integer(): + raise com.IbisTypeError( + "Row-based window frame end boundary must be an integer" + ) + elif how == "range": + if ( + start + and end + and not ( + (start.dtype.is_interval() and end.dtype.is_interval()) + or (start.dtype.is_numeric() and end.dtype.is_numeric()) + ) ): raise com.IbisTypeError( "Window frame start and end boundaries must have the same datatype" ) - super().__init__(start=start, end=end, **kwargs) - - def dtype(self) -> dt.DataType: - return dt.Array(dt.Struct.from_tuples(self.table.schema.items())) - - @property - @abstractmethod - def start(self): ... - - @property - @abstractmethod - def end(self): ... - - -@public -class RowsWindowFrame(WindowFrame): - how = "rows" - start: Optional[WindowBoundary[dt.Integer]] = None - end: Optional[WindowBoundary] = None - - -@public -class RangeWindowFrame(WindowFrame): - how = "range" - start: Optional[WindowBoundary[dt.Numeric | dt.Interval]] = None - end: Optional[WindowBoundary[dt.Numeric | dt.Interval]] = None - - -@public -class WindowFunction(Value): - func: Analytic | Reduction - frame: WindowFrame - - dtype = rlz.dtype_like("func") - shape = ds.columnar - - def __init__(self, func, frame): - if func.relations and frame.table not in func.relations: - raise com.RelationError( - "The reduction has different parent relation than the window" + else: + raise com.IbisTypeError( + f"Window frame type must be either 'rows' or 'range', got {how}" ) - super().__init__(func=func, frame=frame) + super().__init__(how=how, start=start, end=end, **kwargs) @property def name(self): diff --git a/ibis/expr/rewrites.py b/ibis/expr/rewrites.py index 25e9b9c9a3b0..5b03b024d89e 100644 --- a/ibis/expr/rewrites.py +++ b/ibis/expr/rewrites.py @@ -105,7 +105,7 @@ def rewrite_stringslice(_, **kwargs): @replace(p.Analytic) def project_wrap_analytic(_, rel): # Wrap analytic functions in a window function - return ops.WindowFunction(_, ops.RowsWindowFrame(rel)) + return ops.WindowFunction(_) @replace(p.Reduction) @@ -114,7 +114,7 @@ def project_wrap_reduction(_, rel): if _.relations == {rel}: # The reduction is fully originating from the `rel`, so turn # it into a window function of `rel` - return ops.WindowFunction(_, ops.RowsWindowFrame(rel)) + return ops.WindowFunction(_) else: # 1. The reduction doesn't depend on any table, constructed from # scalar values, so turn it into a scalar subquery. @@ -156,37 +156,47 @@ def rewrite_filter_input(value): @replace(p.Analytic | p.Reduction) -def window_wrap_reduction(_, frame): +def window_wrap_reduction(_, window): # Wrap analytic and reduction functions in a window function. Used in the # value.over() API. - return ops.WindowFunction(_, frame) + return ops.WindowFunction( + _, + how=window.how, + start=window.start, + end=window.end, + group_by=window.groupings, + order_by=window.orderings, + ) @replace(p.WindowFunction) -def window_merge_frames(_, frame): +def window_merge_frames(_, window): # Merge window frames, used in the value.over() and groupby.select() APIs. - if _.frame.start and frame.start and _.frame.start != frame.start: + if _.how != window.how: + raise ExpressionError( + f"Unable to merge {_.how} window with {window.how} window" + ) + elif _.start and window.start and _.start != window.start: raise ExpressionError( "Unable to merge windows with conflicting `start` boundary" ) - if _.frame.end and frame.end and _.frame.end != frame.end: + elif _.end and window.end and _.end != window.end: raise ExpressionError("Unable to merge windows with conflicting `end` boundary") - start = _.frame.start or frame.start - end = _.frame.end or frame.end - group_by = tuple(toolz.unique(_.frame.group_by + frame.group_by)) + start = _.start or window.start + end = _.end or window.end + group_by = tuple(toolz.unique(_.group_by + window.groupings)) - order_by = {} - for sort_key in frame.order_by + _.frame.order_by: - order_by[sort_key.expr] = sort_key.ascending - order_by = tuple(ops.SortKey(k, v) for k, v in order_by.items()) + order_keys = {} + for sort_key in window.orderings + _.order_by: + order_keys[sort_key.expr] = sort_key.ascending + order_by = (ops.SortKey(k, v) for k, v in order_keys.items()) - frame = _.frame.copy(start=start, end=end, group_by=group_by, order_by=order_by) - return ops.WindowFunction(_.func, frame) + return _.copy(start=start, end=end, group_by=group_by, order_by=order_by) -def rewrite_window_input(value, frame): - context = {"frame": frame} +def rewrite_window_input(value, window): + context = {"window": window} # if self is a reduction or analytic function, wrap it in a window function node = value.replace( window_wrap_reduction, diff --git a/ibis/expr/tests/snapshots/test_format/test_window_group_by/repr.txt b/ibis/expr/tests/snapshots/test_format/test_window_group_by/repr.txt index a062b06c486d..b79ee8160531 100644 --- a/ibis/expr/tests/snapshots/test_format/test_window_group_by/repr.txt +++ b/ibis/expr/tests/snapshots/test_format/test_window_group_by/repr.txt @@ -2,4 +2,4 @@ r0 := UnboundTable: t a int64 b string -Mean(a): WindowFunction(func=Mean(r0.a), frame=RowsWindowFrame(table=r0, group_by=[r0.b])) \ No newline at end of file +Mean(a): WindowFunction(func=Mean(r0.a), how='rows', group_by=[r0.b]) \ No newline at end of file diff --git a/ibis/expr/tests/snapshots/test_format/test_window_no_group_by/repr.txt b/ibis/expr/tests/snapshots/test_format/test_window_no_group_by/repr.txt index b46b19f10b7a..acd8865e7735 100644 --- a/ibis/expr/tests/snapshots/test_format/test_window_no_group_by/repr.txt +++ b/ibis/expr/tests/snapshots/test_format/test_window_no_group_by/repr.txt @@ -2,4 +2,4 @@ r0 := UnboundTable: t a int64 b string -Mean(a): WindowFunction(func=Mean(r0.a), frame=RowsWindowFrame(table=r0, start=WindowBoundary(value=0, preceding=True))) \ No newline at end of file +Mean(a): WindowFunction(func=Mean(r0.a), how='rows', start=WindowBoundary(value=0, preceding=True)) \ No newline at end of file diff --git a/ibis/expr/types/generic.py b/ibis/expr/types/generic.py index 7da5f63c98b6..6f359fa2c9a1 100644 --- a/ibis/expr/types/generic.py +++ b/ibis/expr/types/generic.py @@ -723,7 +723,7 @@ def over( A window function expression """ - node = self.op() + if window is None: window = ibis.window( rows=rows, @@ -734,17 +734,18 @@ def over( elif not isinstance(window, bl.WindowBuilder): raise com.IbisTypeError("Unexpected window type: {window!r}") + node = self.op() if len(node.relations) == 0: table = None elif len(node.relations) == 1: (table,) = node.relations + table = table.to_expr() else: raise com.RelationError("Cannot use window with multiple tables") @deferrable def bind(table): - frame = window.bind(table) - winfunc = rewrite_window_input(node, frame) + winfunc = rewrite_window_input(node, window.bind(table)) if winfunc == node: raise com.IbisTypeError( "No reduction or analytic function found to construct a window expression" diff --git a/ibis/expr/types/groupby.py b/ibis/expr/types/groupby.py index 8f3fba1bab60..32bccaa20d7e 100644 --- a/ibis/expr/types/groupby.py +++ b/ibis/expr/types/groupby.py @@ -201,13 +201,9 @@ def _selectables(self, *exprs, **kwexprs): [`GroupedTable.mutate`](#ibis.expr.types.groupby.GroupedTable.mutate) """ table = self.table.to_expr() - frame = ops.RowsWindowFrame( - table=self.table, - group_by=self.groupings, - order_by=self.orderings, - ) values = bind(table, (exprs, kwexprs)) - return [rewrite_window_input(expr.op(), frame).to_expr() for expr in values] + window = ibis.window(group_by=self.groupings, order_by=self.orderings) + return [rewrite_window_input(expr.op(), window).to_expr() for expr in values] projection = select diff --git a/ibis/expr/types/relations.py b/ibis/expr/types/relations.py index 716d38d4f956..3bd2090eb833 100644 --- a/ibis/expr/types/relations.py +++ b/ibis/expr/types/relations.py @@ -21,7 +21,7 @@ import ibis.expr.operations as ops import ibis.expr.schema as sch from ibis import util -from ibis.common.deferred import Deferred +from ibis.common.deferred import Deferred, Resolver from ibis.expr.types.core import Expr, _FixedTextJupyterMixin from ibis.expr.types.generic import ValueExpr, literal from ibis.selectors import Selector @@ -106,6 +106,8 @@ def bind(table: Table, value: Any) -> Iterator[ir.Value]: yield value._get_column(name) elif isinstance(value, Deferred): yield value.resolve(table) + elif isinstance(value, Resolver): + yield value.resolve({"_": table}) elif isinstance(value, Selector): yield from value.expand(table) elif isinstance(value, Mapping): diff --git a/ibis/expr/visualize.py b/ibis/expr/visualize.py index 2dd8faf367d2..7b75dcbb3d8b 100644 --- a/ibis/expr/visualize.py +++ b/ibis/expr/visualize.py @@ -60,23 +60,17 @@ def get_label(node): ops.Field, ops.Alias, ops.PhysicalTable, - ops.window.RangeWindowFrame, ), ) else None ) if nodename is not None: - # [TODO] Don't show nodename because it's too long and ruins the image - if isinstance(node, ops.window.RangeWindowFrame): - label_fmt = "<{}>" - label = label_fmt.format(escape(name)) + if isinstance(node, ops.Relation): + label_fmt = "<{}: {}{}>" else: - if isinstance(node, ops.Relation): - label_fmt = "<{}: {}{}>" - else: - label_fmt = '<{}: {}
:: {}>' - # typename is already escaped - label = label_fmt.format(escape(nodename), escape(name), typename) + label_fmt = '<{}: {}
:: {}>' + # typename is already escaped + label = label_fmt.format(escape(nodename), escape(name), typename) else: if isinstance(node, ops.Relation): label_fmt = "<{}{}>" diff --git a/ibis/tests/expr/snapshots/test_format_sql_operations/test_memoize_insert_sort_key/repr.txt b/ibis/tests/expr/snapshots/test_format_sql_operations/test_memoize_insert_sort_key/repr.txt index 3514fa501a73..cc2dc4dddffe 100644 --- a/ibis/tests/expr/snapshots/test_format_sql_operations/test_memoize_insert_sort_key/repr.txt +++ b/ibis/tests/expr/snapshots/test_format_sql_operations/test_memoize_insert_sort_key/repr.txt @@ -36,8 +36,8 @@ r1 := Project[r0] r2 := Project[r1] arrdelay: r1.arrdelay dest: r1.dest - dest_avg: WindowFunction(func=Mean(r1.arrdelay), frame=RowsWindowFrame(table=r1, group_by=[r1.dest])) - dev: r1.arrdelay - WindowFunction(func=Mean(r1.arrdelay), frame=RowsWindowFrame(table=r1, group_by=[r1.dest])) + dest_avg: WindowFunction(func=Mean(r1.arrdelay), how='rows', group_by=[r1.dest]) + dev: r1.arrdelay - WindowFunction(func=Mean(r1.arrdelay), how='rows', group_by=[r1.dest]) r3 := Filter[r2] NotNull(r2.dev) diff --git a/ibis/tests/expr/test_window_frames.py b/ibis/tests/expr/test_window_frames.py index 967d28b20844..f44b739e1cc8 100644 --- a/ibis/tests/expr/test_window_frames.py +++ b/ibis/tests/expr/test_window_frames.py @@ -14,6 +14,11 @@ from ibis.common.patterns import NoMatch, Pattern +@pytest.fixture +def t(alltypes): + return alltypes + + def test_window_boundary(): # the boundary value must be either numeric or interval b = ops.WindowBoundary(5, preceding=False) @@ -214,39 +219,44 @@ def test_window_builder_between(): assert w9.how == "range" -def test_window_api_supports_value_expressions(alltypes): - t = alltypes - +def test_window_api_supports_value_expressions(t): w = ibis.window(between=(t.d, t.d + 1), group_by=t.b, order_by=t.c) - assert w.bind(t) == ops.RowsWindowFrame( - table=t, + func = t.a.sum().over(w).op() + expected = ops.WindowFunction( + func=t.a.sum(), + how="rows", start=ops.WindowBoundary(t.d, preceding=False), end=ops.WindowBoundary(t.d + 1, preceding=False), group_by=(t.b,), order_by=(t.c,), ) + assert func == expected -def test_window_api_supports_scalar_order_by(alltypes): - t = alltypes - - w = ibis.window(order_by=ibis.NA) - assert w.bind(t) == ops.RowsWindowFrame( - table=t, +def test_window_api_supports_scalar_order_by(t): + window = ibis.window(order_by=ibis.NA) + expr = t.a.sum().over(window).op() + expected = ops.WindowFunction( + t.a.sum(), + how="rows", start=None, end=None, group_by=(), order_by=(ibis.NA.op(),), ) + assert expr == expected - w = ibis.window(order_by=ibis.random()) - assert w.bind(t) == ops.RowsWindowFrame( - table=t, + window = ibis.window(order_by=ibis.random()) + expr = t.a.sum().over(window).op() + expected = ops.WindowFunction( + t.a.sum(), + how="rows", start=None, end=None, group_by=(), - order_by=(ibis.random().op(),), + order_by=[ibis.random()], ) + assert expr == expected def test_window_api_properly_determines_how(): @@ -281,8 +291,7 @@ def test_window_api_mutually_exclusive_options(): ibis.window(range=(None, 5), between=(None, 5)) -def test_window_builder_methods(alltypes): - t = alltypes +def test_window_builder_methods(t): w1 = ibis.window(preceding=5, following=1, group_by=t.a, order_by=t.b) w2 = w1.group_by(t.c) @@ -327,8 +336,21 @@ def test_window_api_preceding_following(method, is_preceding): def test_window_api_trailing_range(): t = ibis.table([("col", "int64")], name="t") - w = ibis.trailing_range_window(ibis.interval(days=1), order_by="col") - w.bind(t) + start = ibis.interval(days=1) + end = ibis.literal(0).cast(start.type()) + + window = ibis.trailing_range_window(start, order_by="col") + expr = t.col.sum().over(window).op() + + expected = ops.WindowFunction( + t.col.sum(), + how="range", + start=ops.WindowBoundary(start, preceding=True), + end=ops.WindowBoundary(end, preceding=False), + group_by=(), + order_by=[t.col], + ) + assert expr == expected @pytest.mark.parametrize( @@ -405,34 +427,39 @@ def test_window_api_preceding_following_invalid_tuple(kind, begin, end): ibis.window(**kwargs) -def test_window_bind_to_table(alltypes): - t = alltypes +def test_window_bind_to_table(t): spec = ibis.window(group_by="g", order_by=ibis.desc("f")) - frame = spec.bind(t) - expected = ops.RowsWindowFrame(table=t, group_by=[t.g], order_by=[t.f.desc()]) - - assert frame == expected + window = t.a.sum().over(spec).op() + expected = ops.WindowFunction( + t.a.sum(), + how="rows", + start=None, + end=None, + group_by=[t.g], + order_by=[t.f.desc()], + ) + assert window == expected -def test_window_bind_value_expression_using_over(alltypes): +def test_window_bind_value_expression_using_over(t): # GH #542 - t = alltypes - w = ibis.window(group_by="g", order_by="f") - expr = t.f.lag().over(w) - - frame = expr.op().frame - expected = ops.RowsWindowFrame(table=t, group_by=[t.g], order_by=[t.f.asc()]) - - assert frame == expected + window = t.f.lag().over(w).op() + expected = ops.WindowFunction( + t.f.lag(), + how="rows", + start=None, + end=None, + group_by=[t.g], + order_by=[t.f.asc()], + ) + assert window == expected -def test_window_analysis_propagate_nested_windows(alltypes): +def test_window_analysis_propagate_nested_windows(t): # GH #469 - t = alltypes - w = ibis.window(group_by=t.g, order_by=t.f) col = (t.f - t.f.lag()).lag() @@ -442,8 +469,7 @@ def test_window_analysis_propagate_nested_windows(alltypes): assert result.equals(expected) -def test_window_analysis_combine_group_by(alltypes): - t = alltypes +def test_window_analysis_combine_group_by(t): w = ibis.window(group_by=t.g, order_by=t.f) diff = t.d - t.d.lag() @@ -468,7 +494,7 @@ def test_window_analysis_combine_preserves_existing_window(): ) w = ibis.cumulative_window(order_by=t.one) mut = t.group_by(t.three).mutate(four=t.two.sum().over(w)) - assert mut.op().values["four"].frame.start is None + assert mut.op().values["four"].start is None def test_window_analysis_auto_windowize_bug(): @@ -496,20 +522,24 @@ def metric(x): assert enriched.equals(expected) -def test_windowization_wraps_reduction_inside_a_nested_value_expression(alltypes): - t = alltypes +def test_windowization_wraps_reduction_inside_a_nested_value_expression(t): win = ibis.window( following=0, group_by=[t.g], order_by=[t.a], ) expr = (t.f == 0).notany().over(win) - assert expr.op() == ops.Not( + expected = ops.Not( ops.WindowFunction( func=ops.Any(t.f == 0), - frame=ops.RowsWindowFrame(table=t, end=0, group_by=[t.g], order_by=[t.a]), + how="rows", + start=None, + end=0, + group_by=[t.g], + order_by=[t.a], ) ) + assert expr.op() == expected def test_group_by_with_window_function_preserves_range(alltypes): @@ -525,9 +555,11 @@ def test_group_by_with_window_function_preserves_range(alltypes): "three": t.three, "four": ops.WindowFunction( func=ops.Sum(t.two), - frame=ops.RowsWindowFrame( - table=t, end=0, group_by=[t.three], order_by=[t.one] - ), + how="rows", + start=None, + end=0, + group_by=[t.three], + order_by=[t.one], ), }, )