Skip to content

Commit

Permalink
refactor(analysis): remove _rewrite_filter() in favor of using repl…
Browse files Browse the repository at this point in the history
…acement patterns
  • Loading branch information
kszucs committed Oct 16, 2023
1 parent e966af8 commit c396cd2
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 87 deletions.
68 changes: 0 additions & 68 deletions ibis/expr/analysis.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from __future__ import annotations

import functools
import operator
from collections import defaultdict
from typing import TYPE_CHECKING

Expand Down Expand Up @@ -644,72 +642,6 @@ def find_subqueries(node: ops.Node, min_dependents=1) -> tuple[ops.Node, ...]:
)


# TODO(kszucs): use substitute instead
@functools.singledispatch
def _rewrite_filter(op, **kwargs):
raise NotImplementedError(type(op))


@_rewrite_filter.register(ops.Reduction)
def _rewrite_filter_reduction(op, name: str | None = None, **kwargs):
"""Turn a reduction inside of a filter into an aggregate."""
# TODO: what about reductions that reference a join that isn't visible at
# this level? Means we probably have the wrong design, but will have to
# revisit when it becomes a problem.
if name is not None:
op = ops.Alias(op, name=name)

agg = op.to_expr().as_table()
return ops.TableArrayView(agg)


@_rewrite_filter.register(ops.Any)
@_rewrite_filter.register(ops.TableColumn)
@_rewrite_filter.register(ops.Literal)
@_rewrite_filter.register(ops.ExistsSubquery)
@_rewrite_filter.register(ops.WindowFunction)
def _rewrite_filter_subqueries(op, **kwargs):
"""Don't rewrite any of these operations in filters."""
return op


@_rewrite_filter.register(ops.Alias)
def _rewrite_filter_alias(op, name: str | None = None, **kwargs):
"""Rewrite filters on aliases."""
return _rewrite_filter(
op.arg,
name=name if name is not None else op.name,
**kwargs,
)


@_rewrite_filter.register(ops.Value)
def _rewrite_filter_value(op, **kwargs):
"""Recursively apply filter rewriting on operations."""

visited = [
_rewrite_filter(arg, **kwargs) if isinstance(arg, ops.Node) else arg
for arg in op.args
]
if all(map(operator.is_, visited, op.args)):
return op

return op.__class__(*visited)


@_rewrite_filter.register(tuple)
def _rewrite_filter_value_list(op, **kwargs):
visited = [
_rewrite_filter(arg, **kwargs) if isinstance(arg, ops.Node) else arg
for arg in op.args
]

if all(map(operator.is_, visited, op.args)):
return op

return op.__class__(*visited)


def find_toplevel_unnest_children(nodes: Iterable[ops.Node]) -> Iterator[ops.Table]:
def finder(node):
return (
Expand Down
29 changes: 13 additions & 16 deletions ibis/expr/types/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2281,11 +2281,7 @@ def filter(
import ibis.expr.analysis as an

resolved_predicates = _resolve_predicates(self, predicates)
predicates = [
an._rewrite_filter(pred.op() if isinstance(pred, Expr) else pred)
for pred in resolved_predicates
]
return an.apply_filter(self.op(), predicates).to_expr()
return an.apply_filter(self.op(), resolved_predicates).to_expr()

def nunique(self, where: ir.BooleanValue | None = None) -> ir.IntegerScalar:
"""Compute the number of unique rows in the table.
Expand Down Expand Up @@ -4158,12 +4154,13 @@ def release(self):
return current_backend._release_cached(self)


# TODO(kszucs): used at a single place along with an.apply_filter(), should be
# consolidated into a single function
def _resolve_predicates(
table: Table, predicates
) -> tuple[list[ir.BooleanValue], list[tuple[ir.BooleanValue, ir.Table]]]:
import ibis.expr.types as ir
from ibis.expr.analysis import p, flatten_predicate
from ibis.common.deferred import _, Attr, Call
from ibis.expr.analysis import p, flatten_predicate, _

# TODO(kszucs): clean this up, too much flattening and resolving happens here
predicates = [
Expand All @@ -4176,15 +4173,15 @@ def _resolve_predicates(
]
predicates = flatten_predicate(predicates)

# _.resolve is actually a non-deferred method, so it won't dispatch to
# the matched UnresolvedExistsSubquery.resolve() method
# TODO(kszucs): remove Deferred.resolve() method
replacement = Call(Attr(_, "resolve"), table.op())
resolved_predicates = [
pred.replace(p.UnresolvedExistsSubquery >> replacement) for pred in predicates
]

return resolved_predicates
rules = (
# turn reductions into table array views so that they can be used as
# WHERE t1.`a` = (SELECT max(t1.`a`) AS `Max(a)`
p.Reduction >> (lambda _: ops.TableArrayView(_.to_expr().as_table()))
|
# resolve exists subqueries
p.UnresolvedExistsSubquery >> (lambda _: _.resolve(table.op()))
)
return [pred.replace(rules, filter=p.Value) for pred in predicates]


def bind_expr(table, expr):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,6 @@ t1 AS (
SELECT t1.*
FROM t1
WHERE t1.`a` = (
SELECT max(t1.`a`) AS `blah`
SELECT max(t1.`a`) AS `Max(a)`
FROM t1
)
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,6 @@ t1 AS (
SELECT t1.*
FROM t1
WHERE t1.`a` = (
SELECT max(t1.`a`) AS `blah`
SELECT max(t1.`a`) AS `Max(a)`
FROM t1
)
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
param = ibis.param("timestamp")
proj = alltypes.select(
[alltypes.float_col, alltypes.timestamp_col, alltypes.int_col, alltypes.string_col]
).filter(alltypes.timestamp_col < param)
).filter(alltypes.timestamp_col < param.name("my_param"))
agg = proj.group_by(proj.string_col).aggregate(proj.float_col.sum().name("foo"))

result = agg.foo.count()

0 comments on commit c396cd2

Please sign in to comment.