diff --git a/ibis/expr/analysis.py b/ibis/expr/analysis.py index acad1bc8930c..3728bae23d28 100644 --- a/ibis/expr/analysis.py +++ b/ibis/expr/analysis.py @@ -1,7 +1,5 @@ from __future__ import annotations -import functools -import operator from collections import defaultdict from typing import TYPE_CHECKING @@ -644,72 +642,6 @@ def find_subqueries(node: ops.Node, min_dependents=1) -> tuple[ops.Node, ...]: ) -# TODO(kszucs): use substitute instead -@functools.singledispatch -def _rewrite_filter(op, **kwargs): - raise NotImplementedError(type(op)) - - -@_rewrite_filter.register(ops.Reduction) -def _rewrite_filter_reduction(op, name: str | None = None, **kwargs): - """Turn a reduction inside of a filter into an aggregate.""" - # TODO: what about reductions that reference a join that isn't visible at - # this level? Means we probably have the wrong design, but will have to - # revisit when it becomes a problem. - if name is not None: - op = ops.Alias(op, name=name) - - agg = op.to_expr().as_table() - return ops.TableArrayView(agg) - - -@_rewrite_filter.register(ops.Any) -@_rewrite_filter.register(ops.TableColumn) -@_rewrite_filter.register(ops.Literal) -@_rewrite_filter.register(ops.ExistsSubquery) -@_rewrite_filter.register(ops.WindowFunction) -def _rewrite_filter_subqueries(op, **kwargs): - """Don't rewrite any of these operations in filters.""" - return op - - -@_rewrite_filter.register(ops.Alias) -def _rewrite_filter_alias(op, name: str | None = None, **kwargs): - """Rewrite filters on aliases.""" - return _rewrite_filter( - op.arg, - name=name if name is not None else op.name, - **kwargs, - ) - - -@_rewrite_filter.register(ops.Value) -def _rewrite_filter_value(op, **kwargs): - """Recursively apply filter rewriting on operations.""" - - visited = [ - _rewrite_filter(arg, **kwargs) if isinstance(arg, ops.Node) else arg - for arg in op.args - ] - if all(map(operator.is_, visited, op.args)): - return op - - return op.__class__(*visited) - - -@_rewrite_filter.register(tuple) -def _rewrite_filter_value_list(op, **kwargs): - visited = [ - _rewrite_filter(arg, **kwargs) if isinstance(arg, ops.Node) else arg - for arg in op.args - ] - - if all(map(operator.is_, visited, op.args)): - return op - - return op.__class__(*visited) - - def find_toplevel_unnest_children(nodes: Iterable[ops.Node]) -> Iterator[ops.Table]: def finder(node): return ( diff --git a/ibis/expr/types/relations.py b/ibis/expr/types/relations.py index 85f5350b463e..0e7239832768 100644 --- a/ibis/expr/types/relations.py +++ b/ibis/expr/types/relations.py @@ -2281,11 +2281,7 @@ def filter( import ibis.expr.analysis as an resolved_predicates = _resolve_predicates(self, predicates) - predicates = [ - an._rewrite_filter(pred.op() if isinstance(pred, Expr) else pred) - for pred in resolved_predicates - ] - return an.apply_filter(self.op(), predicates).to_expr() + return an.apply_filter(self.op(), resolved_predicates).to_expr() def nunique(self, where: ir.BooleanValue | None = None) -> ir.IntegerScalar: """Compute the number of unique rows in the table. @@ -4158,12 +4154,13 @@ def release(self): return current_backend._release_cached(self) +# TODO(kszucs): used at a single place along with an.apply_filter(), should be +# consolidated into a single function def _resolve_predicates( table: Table, predicates ) -> tuple[list[ir.BooleanValue], list[tuple[ir.BooleanValue, ir.Table]]]: import ibis.expr.types as ir - from ibis.expr.analysis import p, flatten_predicate - from ibis.common.deferred import _, Attr, Call + from ibis.expr.analysis import p, flatten_predicate, _ # TODO(kszucs): clean this up, too much flattening and resolving happens here predicates = [ @@ -4176,15 +4173,15 @@ def _resolve_predicates( ] predicates = flatten_predicate(predicates) - # _.resolve is actually a non-deferred method, so it won't dispatch to - # the matched UnresolvedExistsSubquery.resolve() method - # TODO(kszucs): remove Deferred.resolve() method - replacement = Call(Attr(_, "resolve"), table.op()) - resolved_predicates = [ - pred.replace(p.UnresolvedExistsSubquery >> replacement) for pred in predicates - ] - - return resolved_predicates + rules = ( + # turn reductions into table array views so that they can be used as + # WHERE t1.`a` = (SELECT max(t1.`a`) AS `Max(a)` + p.Reduction >> (lambda _: ops.TableArrayView(_.to_expr().as_table())) + | + # resolve exists subqueries + p.UnresolvedExistsSubquery >> (lambda _: _.resolve(table.op())) + ) + return [pred.replace(rules, filter=p.Value) for pred in predicates] def bind_expr(table, expr): diff --git a/ibis/tests/sql/snapshots/test_compiler/test_agg_filter/out.sql b/ibis/tests/sql/snapshots/test_compiler/test_agg_filter/out.sql index 6953da9d317e..dccd920d55cb 100644 --- a/ibis/tests/sql/snapshots/test_compiler/test_agg_filter/out.sql +++ b/ibis/tests/sql/snapshots/test_compiler/test_agg_filter/out.sql @@ -10,6 +10,6 @@ t1 AS ( SELECT t1.* FROM t1 WHERE t1.`a` = ( - SELECT max(t1.`a`) AS `blah` + SELECT max(t1.`a`) AS `Max(a)` FROM t1 ) \ No newline at end of file diff --git a/ibis/tests/sql/snapshots/test_compiler/test_agg_filter_with_alias/out.sql b/ibis/tests/sql/snapshots/test_compiler/test_agg_filter_with_alias/out.sql index 6953da9d317e..dccd920d55cb 100644 --- a/ibis/tests/sql/snapshots/test_compiler/test_agg_filter_with_alias/out.sql +++ b/ibis/tests/sql/snapshots/test_compiler/test_agg_filter_with_alias/out.sql @@ -10,6 +10,6 @@ t1 AS ( SELECT t1.* FROM t1 WHERE t1.`a` = ( - SELECT max(t1.`a`) AS `blah` + SELECT max(t1.`a`) AS `Max(a)` FROM t1 ) \ No newline at end of file diff --git a/ibis/tests/sql/snapshots/test_compiler/test_subquery_where_location/decompiled.py b/ibis/tests/sql/snapshots/test_compiler/test_subquery_where_location/decompiled.py index 9b9df516bb17..a18437fa4d10 100644 --- a/ibis/tests/sql/snapshots/test_compiler/test_subquery_where_location/decompiled.py +++ b/ibis/tests/sql/snapshots/test_compiler/test_subquery_where_location/decompiled.py @@ -13,7 +13,7 @@ param = ibis.param("timestamp") proj = alltypes.select( [alltypes.float_col, alltypes.timestamp_col, alltypes.int_col, alltypes.string_col] -).filter(alltypes.timestamp_col < param) +).filter(alltypes.timestamp_col < param.name("my_param")) agg = proj.group_by(proj.string_col).aggregate(proj.float_col.sum().name("foo")) result = agg.foo.count()