Skip to content

Commit

Permalink
refactor(ir): remove ops.Negatable, ops.NotAny, ops.NotAll, ops.Unres…
Browse files Browse the repository at this point in the history
…olvedNotExistsSubquery
  • Loading branch information
kszucs committed Oct 12, 2023
1 parent f2ae7cc commit 0796d8c
Show file tree
Hide file tree
Showing 16 changed files with 38 additions and 276 deletions.
4 changes: 0 additions & 4 deletions ibis/backends/base/sql/alchemy/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,6 @@ def _exists_subquery(t, op):
sub_ctx = ctx.subcontext()
clause = ctx.compiler.to_sql(filtered, sub_ctx, exists=True)

if isinstance(op, ops.NotExistsSubquery):
clause = sa.not_(clause)

return clause


Expand Down Expand Up @@ -563,7 +560,6 @@ class array_filter(FunctionElement):
ops.TableColumn: _table_column,
ops.TableArrayView: _table_array_view,
ops.ExistsSubquery: _exists_subquery,
ops.NotExistsSubquery: _exists_subquery,
# miscellaneous varargs
ops.Least: varargs(sa.func.least),
ops.Greatest: varargs(sa.func.greatest),
Expand Down
12 changes: 0 additions & 12 deletions ibis/backends/base/sql/compiler/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,23 +336,11 @@ def _any_expand(op):
return ops.Max(op.arg, where=op.where)


@rewrites(ops.NotAny)
def _notany_expand(op):
zero = ops.Literal(0, dtype=op.arg.dtype)
return ops.Min(ops.Equals(op.arg, zero), where=op.where)


@rewrites(ops.All)
def _all_expand(op):
return ops.Min(op.arg, where=op.where)


@rewrites(ops.NotAll)
def _notall_expand(op):
zero = ops.Literal(0, dtype=op.arg.dtype)
return ops.Max(ops.Equals(op.arg, zero), where=op.where)


@rewrites(ops.Cast)
def _rewrite_cast(op):
# TODO(kszucs): avoid the expression roundtrip
Expand Down
4 changes: 1 addition & 3 deletions ibis/backends/base/sql/registry/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,7 @@ def exists_subquery(translator, op):

subquery = ctx.get_compiled_expr(node)

prefix = "NOT " * isinstance(op, ops.NotExistsSubquery)
return f"{prefix}EXISTS (\n{util.indent(subquery, ctx.indent)}\n)"
return f"EXISTS (\n{util.indent(subquery, ctx.indent)}\n)"


# XXX this is not added to operation_registry, but looks like impala is
Expand Down Expand Up @@ -350,7 +349,6 @@ def count_star(translator, op):
ops.TimestampDiff: timestamp.timestamp_diff,
ops.TimestampFromUNIX: timestamp.timestamp_from_unix,
ops.ExistsSubquery: exists_subquery,
ops.NotExistsSubquery: exists_subquery,
# RowNumber, and rank functions starts with 0 in Ibis-land
ops.RowNumber: lambda *_: "row_number()",
ops.DenseRank: lambda *_: "dense_rank()",
Expand Down
15 changes: 0 additions & 15 deletions ibis/backends/clickhouse/compiler/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,18 +101,6 @@ def fn(node, _, **kwargs):
False, dtype="bool"
)

# replace `NotExistsSubquery` with `Not(ExistsSubquery)`
#
# this allows to avoid having another rule to negate ExistsSubquery
replace_notexists_subquery_with_not_exists = p.NotExistsSubquery(...) >> c.Not(
c.ExistsSubquery(...)
)

# clickhouse-specific rewrite to turn notany/notall into equivalent
# already-defined operations
replace_notany_with_min_not = p.NotAny(x, where=y) >> c.Min(c.Not(x), where=y)
replace_notall_with_max_not = p.NotAll(x, where=y) >> c.Max(c.Not(x), where=y)

# subtract one from ranking functions to convert from 1-indexed to 0-indexed
subtract_one_from_ranking_functions = p.WindowFunction(
p.RankBase | p.NTile
Expand All @@ -124,9 +112,6 @@ def fn(node, _, **kwargs):
replace_literals
| replace_in_column_with_table_array_view
| replace_empty_in_values_with_false
| replace_notexists_subquery_with_not_exists
| replace_notany_with_min_not
| replace_notall_with_max_not
| subtract_one_from_ranking_functions
| add_one_to_nth_value_input
)
Expand Down
2 changes: 0 additions & 2 deletions ibis/backends/duckdb/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,6 @@ def compile_array(element, compiler, **kw):

@rewrites(ops.Any)
@rewrites(ops.All)
@rewrites(ops.NotAny)
@rewrites(ops.NotAll)
@rewrites(ops.StringContains)
def _no_op(expr):
return expr
Expand Down
32 changes: 0 additions & 32 deletions ibis/backends/pandas/execution/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,38 +847,6 @@ def execute_any_all_series_group_by(op, data, mask, aggcontext=None, **kwargs):
return result


@execute_node.register((ops.NotAny, ops.NotAll), pd.Series, (pd.Series, type(None)))
def execute_notany_notall_series(op, data, mask, aggcontext=None, **kwargs):
name = type(op).__name__.lower()[len("Not") :]
if mask is not None:
data = data.loc[mask]
if isinstance(aggcontext, (agg_ctx.Summarize, agg_ctx.Transform)):
result = ~aggcontext.agg(data, name)
else:
method = operator.methodcaller(name)
result = aggcontext.agg(data, lambda data: ~method(data))
try:
return result.astype(bool)
except TypeError:
return result


@execute_node.register((ops.NotAny, ops.NotAll), SeriesGroupBy, type(None))
def execute_notany_notall_series_group_by(op, data, mask, aggcontext=None, **kwargs):
name = type(op).__name__.lower()[len("Not") :]
if mask is not None:
data = data.obj.loc[mask].groupby(get_grouping(data.grouper.groupings))
if isinstance(aggcontext, (agg_ctx.Summarize, agg_ctx.Transform)):
result = ~aggcontext.agg(data, name)
else:
method = operator.methodcaller(name)
result = aggcontext.agg(data, lambda data: ~method(data))
try:
return result.astype(bool)
except TypeError:
return result


@execute_node.register(ops.CountStar, pd.DataFrame, type(None))
def execute_count_star_frame(op, data, _, **kwargs):
return len(data)
Expand Down
18 changes: 0 additions & 18 deletions ibis/backends/polars/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1082,24 +1082,6 @@ def execute_hash(op, **kw):
return translate(op.arg, **kw).hash()


@translate.register(ops.NotAll)
def execute_not_all(op, **kw):
arg = op.arg
if (op_where := op.where) is not None:
arg = ops.IfElse(op_where, arg, None)

return translate(arg, **kw).all().not_()


@translate.register(ops.NotAny)
def execute_not_any(op, **kw):
arg = op.arg
if (op_where := op.where) is not None:
arg = ops.IfElse(op_where, arg, None)

return translate(arg, **kw).any().not_()


def _arg_min_max(op, func, **kw):
key = op.key
arg = op.arg
Expand Down
2 changes: 0 additions & 2 deletions ibis/backends/postgres/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ class PostgreSQLExprTranslator(AlchemyExprTranslator):

@rewrites(ops.Any)
@rewrites(ops.All)
@rewrites(ops.NotAny)
@rewrites(ops.NotAll)
def _any_all_no_op(expr):
return expr

Expand Down
2 changes: 0 additions & 2 deletions ibis/backends/postgres/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,8 +608,6 @@ def _array_filter(t, op):
# boolean reductions
ops.Any: reduction(sa.func.bool_or),
ops.All: reduction(sa.func.bool_and),
ops.NotAny: reduction(lambda x: sa.func.bool_and(~x)),
ops.NotAll: reduction(lambda x: sa.func.bool_or(~x)),
# strings
ops.GroupConcat: _string_agg,
ops.Capitalize: unary(sa.func.initcap),
Expand Down
29 changes: 0 additions & 29 deletions ibis/expr/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,12 +625,6 @@ def predicate(node):
return any(g.traverse(predicate, node))


_ANY_OP_MAPPING = {
ops.Any: ops.UnresolvedExistsSubquery,
ops.NotAny: ops.UnresolvedNotExistsSubquery,
}


def find_predicates(node, flatten=True):
# TODO(kszucs): consider to remove flatten argument and compose with
# flatten_predicates instead
Expand Down Expand Up @@ -663,28 +657,6 @@ def find_subqueries(node: ops.Node, min_dependents=1) -> tuple[ops.Node, ...]:
)


# TODO(kszucs): move to types/logical.py
def _make_any(
expr,
any_op_class: type[ops.Any] | type[ops.NotAny],
*,
where: ir.BooleanValue | None = None,
):
assert isinstance(expr, ir.Expr), type(expr)

tables = find_immediate_parent_tables(expr.op())
predicates = find_predicates(expr.op(), flatten=True)

if len(tables) > 1:
op = _ANY_OP_MAPPING[any_op_class](
tables=[t.to_expr() for t in tables],
predicates=predicates,
)
else:
op = any_op_class(expr, where=expr._bind_reduction_filter(where))
return op.to_expr()


# TODO(kszucs): use substitute instead
@functools.singledispatch
def _rewrite_filter(op, **kwargs):
Expand All @@ -708,7 +680,6 @@ def _rewrite_filter_reduction(op, name: str | None = None, **kwargs):
@_rewrite_filter.register(ops.TableColumn)
@_rewrite_filter.register(ops.Literal)
@_rewrite_filter.register(ops.ExistsSubquery)
@_rewrite_filter.register(ops.NotExistsSubquery)
@_rewrite_filter.register(ops.WindowFunction)
def _rewrite_filter_subqueries(op, **kwargs):
"""Don't rewrite any of these operations in filters."""
Expand Down
8 changes: 0 additions & 8 deletions ibis/expr/operations/generic.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import abc
import itertools
from typing import Annotated, Any, Optional, Union
from typing import Literal as LiteralType
Expand All @@ -13,7 +12,6 @@
import ibis.expr.datatypes as dt
import ibis.expr.rules as rlz
from ibis.common.annotations import attribute
from ibis.common.bases import Abstract
from ibis.common.deferred import Deferred # noqa: TCH001
from ibis.common.grounds import Singleton
from ibis.common.patterns import InstanceOf, Length # noqa: TCH001
Expand Down Expand Up @@ -313,9 +311,3 @@ def shape(self):
def dtype(self):
exprs = [*self.results, self.default]
return rlz.highest_precedence_dtype(exprs)


class _Negatable(Abstract):
@abc.abstractmethod
def negate(self): # pragma: no cover
...
53 changes: 5 additions & 48 deletions ibis/expr/operations/logical.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from __future__ import annotations

import abc

from public import public

import ibis.expr.datashape as ds
Expand All @@ -11,7 +9,6 @@
from ibis.common.exceptions import IbisTypeError
from ibis.common.typing import VarTuple # noqa: TCH001
from ibis.expr.operations.core import Binary, Column, Unary, Value
from ibis.expr.operations.generic import _Negatable
from ibis.expr.operations.relations import Relation # noqa: TCH001


Expand Down Expand Up @@ -170,30 +167,16 @@ def dtype(self):


@public
class ExistsSubquery(Value, _Negatable):
class ExistsSubquery(Value):
foreign_table: Relation
predicates: VarTuple[Value[dt.Boolean]]

dtype = dt.boolean
shape = ds.columnar

def negate(self) -> NotExistsSubquery:
return NotExistsSubquery(*self.args)


@public
class NotExistsSubquery(Value, _Negatable):
foreign_table: Relation
predicates: VarTuple[Value[dt.Boolean]]

dtype = dt.boolean
shape = ds.columnar

def negate(self) -> ExistsSubquery:
return ExistsSubquery(*self.args)


class _UnresolvedSubquery(Value, _Negatable):
class UnresolvedExistsSubquery(Value):
"""An exists subquery whose outer leaf table is unknown.
Notes
Expand Down Expand Up @@ -231,8 +214,8 @@ class _UnresolvedSubquery(Value, _Negatable):
Notably the correlated subquery cannot stand on its own.
The purpose of `_UnresolvedSubquery` is to capture enough information about
an exists predicate such that it can be resolved when predicates are
The purpose of `UnresolvedExistsSubquery` is to capture enough information
about an exists predicate such that it can be resolved when predicates are
resolved against the outer leaf table when `Selection`s are constructed.
"""

Expand All @@ -242,36 +225,10 @@ class _UnresolvedSubquery(Value, _Negatable):
dtype = dt.boolean
shape = ds.columnar

@abc.abstractmethod
def _resolve(
self, table
) -> type[ExistsSubquery] | type[NotExistsSubquery]: # pragma: no cover
...


@public
class UnresolvedExistsSubquery(_UnresolvedSubquery):
def negate(self) -> UnresolvedNotExistsSubquery:
return UnresolvedNotExistsSubquery(*self.args)

def _resolve(self, table) -> ExistsSubquery:
def resolve(self, table) -> ExistsSubquery:
from ibis.expr.operations.relations import TableNode

assert isinstance(table, TableNode)

(foreign_table,) = (t for t in self.tables if t != table)
return ExistsSubquery(foreign_table, self.predicates).to_expr()


@public
class UnresolvedNotExistsSubquery(_UnresolvedSubquery):
def negate(self) -> UnresolvedExistsSubquery:
return UnresolvedExistsSubquery(*self.args)

def _resolve(self, table) -> NotExistsSubquery:
from ibis.expr.operations.relations import TableNode

assert isinstance(table, TableNode)

(foreign_table,) = (t for t in self.tables if t != table)
return NotExistsSubquery(foreign_table, self.predicates).to_expr()
Loading

0 comments on commit 0796d8c

Please sign in to comment.