From 0796d8cce9367f861108f0c436a16a5557311693 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Wed, 27 Sep 2023 14:57:03 +0200 Subject: [PATCH] refactor(ir): remove ops.Negatable, ops.NotAny, ops.NotAll, ops.UnresolvedNotExistsSubquery --- ibis/backends/base/sql/alchemy/registry.py | 4 -- ibis/backends/base/sql/compiler/translator.py | 12 ---- ibis/backends/base/sql/registry/main.py | 4 +- ibis/backends/clickhouse/compiler/core.py | 15 ----- ibis/backends/duckdb/compiler.py | 2 - ibis/backends/pandas/execution/generic.py | 32 ---------- ibis/backends/polars/compiler.py | 18 ------ ibis/backends/postgres/compiler.py | 2 - ibis/backends/postgres/registry.py | 2 - ibis/expr/analysis.py | 29 --------- ibis/expr/operations/generic.py | 8 --- ibis/expr/operations/logical.py | 53 ++------------- ibis/expr/operations/reductions.py | 31 +-------- ibis/expr/types/logical.py | 18 ++++-- ibis/expr/types/relations.py | 20 +++--- ibis/tests/expr/test_table.py | 64 ++----------------- 16 files changed, 38 insertions(+), 276 deletions(-) diff --git a/ibis/backends/base/sql/alchemy/registry.py b/ibis/backends/base/sql/alchemy/registry.py index cbd17837cf829..307172b4970fa 100644 --- a/ibis/backends/base/sql/alchemy/registry.py +++ b/ibis/backends/base/sql/alchemy/registry.py @@ -158,9 +158,6 @@ def _exists_subquery(t, op): sub_ctx = ctx.subcontext() clause = ctx.compiler.to_sql(filtered, sub_ctx, exists=True) - if isinstance(op, ops.NotExistsSubquery): - clause = sa.not_(clause) - return clause @@ -563,7 +560,6 @@ class array_filter(FunctionElement): ops.TableColumn: _table_column, ops.TableArrayView: _table_array_view, ops.ExistsSubquery: _exists_subquery, - ops.NotExistsSubquery: _exists_subquery, # miscellaneous varargs ops.Least: varargs(sa.func.least), ops.Greatest: varargs(sa.func.greatest), diff --git a/ibis/backends/base/sql/compiler/translator.py b/ibis/backends/base/sql/compiler/translator.py index 0dfba2c519d08..fbde625b24f6b 100644 --- a/ibis/backends/base/sql/compiler/translator.py +++ b/ibis/backends/base/sql/compiler/translator.py @@ -336,23 +336,11 @@ def _any_expand(op): return ops.Max(op.arg, where=op.where) -@rewrites(ops.NotAny) -def _notany_expand(op): - zero = ops.Literal(0, dtype=op.arg.dtype) - return ops.Min(ops.Equals(op.arg, zero), where=op.where) - - @rewrites(ops.All) def _all_expand(op): return ops.Min(op.arg, where=op.where) -@rewrites(ops.NotAll) -def _notall_expand(op): - zero = ops.Literal(0, dtype=op.arg.dtype) - return ops.Max(ops.Equals(op.arg, zero), where=op.where) - - @rewrites(ops.Cast) def _rewrite_cast(op): # TODO(kszucs): avoid the expression roundtrip diff --git a/ibis/backends/base/sql/registry/main.py b/ibis/backends/base/sql/registry/main.py index 92e935ef4db70..1192acc7dacac 100644 --- a/ibis/backends/base/sql/registry/main.py +++ b/ibis/backends/base/sql/registry/main.py @@ -161,8 +161,7 @@ def exists_subquery(translator, op): subquery = ctx.get_compiled_expr(node) - prefix = "NOT " * isinstance(op, ops.NotExistsSubquery) - return f"{prefix}EXISTS (\n{util.indent(subquery, ctx.indent)}\n)" + return f"EXISTS (\n{util.indent(subquery, ctx.indent)}\n)" # XXX this is not added to operation_registry, but looks like impala is @@ -350,7 +349,6 @@ def count_star(translator, op): ops.TimestampDiff: timestamp.timestamp_diff, ops.TimestampFromUNIX: timestamp.timestamp_from_unix, ops.ExistsSubquery: exists_subquery, - ops.NotExistsSubquery: exists_subquery, # RowNumber, and rank functions starts with 0 in Ibis-land ops.RowNumber: lambda *_: "row_number()", ops.DenseRank: lambda *_: "dense_rank()", diff --git a/ibis/backends/clickhouse/compiler/core.py b/ibis/backends/clickhouse/compiler/core.py index fa5fa162c854d..b919c8b20ea4a 100644 --- a/ibis/backends/clickhouse/compiler/core.py +++ b/ibis/backends/clickhouse/compiler/core.py @@ -101,18 +101,6 @@ def fn(node, _, **kwargs): False, dtype="bool" ) - # replace `NotExistsSubquery` with `Not(ExistsSubquery)` - # - # this allows to avoid having another rule to negate ExistsSubquery - replace_notexists_subquery_with_not_exists = p.NotExistsSubquery(...) >> c.Not( - c.ExistsSubquery(...) - ) - - # clickhouse-specific rewrite to turn notany/notall into equivalent - # already-defined operations - replace_notany_with_min_not = p.NotAny(x, where=y) >> c.Min(c.Not(x), where=y) - replace_notall_with_max_not = p.NotAll(x, where=y) >> c.Max(c.Not(x), where=y) - # subtract one from ranking functions to convert from 1-indexed to 0-indexed subtract_one_from_ranking_functions = p.WindowFunction( p.RankBase | p.NTile @@ -124,9 +112,6 @@ def fn(node, _, **kwargs): replace_literals | replace_in_column_with_table_array_view | replace_empty_in_values_with_false - | replace_notexists_subquery_with_not_exists - | replace_notany_with_min_not - | replace_notall_with_max_not | subtract_one_from_ranking_functions | add_one_to_nth_value_input ) diff --git a/ibis/backends/duckdb/compiler.py b/ibis/backends/duckdb/compiler.py index caf2d6c266a6a..a74136935059b 100644 --- a/ibis/backends/duckdb/compiler.py +++ b/ibis/backends/duckdb/compiler.py @@ -55,8 +55,6 @@ def compile_array(element, compiler, **kw): @rewrites(ops.Any) @rewrites(ops.All) -@rewrites(ops.NotAny) -@rewrites(ops.NotAll) @rewrites(ops.StringContains) def _no_op(expr): return expr diff --git a/ibis/backends/pandas/execution/generic.py b/ibis/backends/pandas/execution/generic.py index 969f2819e9d4a..e23ecdd1aa8ea 100644 --- a/ibis/backends/pandas/execution/generic.py +++ b/ibis/backends/pandas/execution/generic.py @@ -847,38 +847,6 @@ def execute_any_all_series_group_by(op, data, mask, aggcontext=None, **kwargs): return result -@execute_node.register((ops.NotAny, ops.NotAll), pd.Series, (pd.Series, type(None))) -def execute_notany_notall_series(op, data, mask, aggcontext=None, **kwargs): - name = type(op).__name__.lower()[len("Not") :] - if mask is not None: - data = data.loc[mask] - if isinstance(aggcontext, (agg_ctx.Summarize, agg_ctx.Transform)): - result = ~aggcontext.agg(data, name) - else: - method = operator.methodcaller(name) - result = aggcontext.agg(data, lambda data: ~method(data)) - try: - return result.astype(bool) - except TypeError: - return result - - -@execute_node.register((ops.NotAny, ops.NotAll), SeriesGroupBy, type(None)) -def execute_notany_notall_series_group_by(op, data, mask, aggcontext=None, **kwargs): - name = type(op).__name__.lower()[len("Not") :] - if mask is not None: - data = data.obj.loc[mask].groupby(get_grouping(data.grouper.groupings)) - if isinstance(aggcontext, (agg_ctx.Summarize, agg_ctx.Transform)): - result = ~aggcontext.agg(data, name) - else: - method = operator.methodcaller(name) - result = aggcontext.agg(data, lambda data: ~method(data)) - try: - return result.astype(bool) - except TypeError: - return result - - @execute_node.register(ops.CountStar, pd.DataFrame, type(None)) def execute_count_star_frame(op, data, _, **kwargs): return len(data) diff --git a/ibis/backends/polars/compiler.py b/ibis/backends/polars/compiler.py index 50fe0755c79b9..2e6e2dd6ece80 100644 --- a/ibis/backends/polars/compiler.py +++ b/ibis/backends/polars/compiler.py @@ -1082,24 +1082,6 @@ def execute_hash(op, **kw): return translate(op.arg, **kw).hash() -@translate.register(ops.NotAll) -def execute_not_all(op, **kw): - arg = op.arg - if (op_where := op.where) is not None: - arg = ops.IfElse(op_where, arg, None) - - return translate(arg, **kw).all().not_() - - -@translate.register(ops.NotAny) -def execute_not_any(op, **kw): - arg = op.arg - if (op_where := op.where) is not None: - arg = ops.IfElse(op_where, arg, None) - - return translate(arg, **kw).any().not_() - - def _arg_min_max(op, func, **kw): key = op.key arg = op.arg diff --git a/ibis/backends/postgres/compiler.py b/ibis/backends/postgres/compiler.py index 11252f5d488cc..e5efc76fa7d29 100644 --- a/ibis/backends/postgres/compiler.py +++ b/ibis/backends/postgres/compiler.py @@ -29,8 +29,6 @@ class PostgreSQLExprTranslator(AlchemyExprTranslator): @rewrites(ops.Any) @rewrites(ops.All) -@rewrites(ops.NotAny) -@rewrites(ops.NotAll) def _any_all_no_op(expr): return expr diff --git a/ibis/backends/postgres/registry.py b/ibis/backends/postgres/registry.py index 02bb75fa1aa23..6c3fd2b38fe83 100644 --- a/ibis/backends/postgres/registry.py +++ b/ibis/backends/postgres/registry.py @@ -608,8 +608,6 @@ def _array_filter(t, op): # boolean reductions ops.Any: reduction(sa.func.bool_or), ops.All: reduction(sa.func.bool_and), - ops.NotAny: reduction(lambda x: sa.func.bool_and(~x)), - ops.NotAll: reduction(lambda x: sa.func.bool_or(~x)), # strings ops.GroupConcat: _string_agg, ops.Capitalize: unary(sa.func.initcap), diff --git a/ibis/expr/analysis.py b/ibis/expr/analysis.py index b8c1c28f6d654..210baca9ec777 100644 --- a/ibis/expr/analysis.py +++ b/ibis/expr/analysis.py @@ -625,12 +625,6 @@ def predicate(node): return any(g.traverse(predicate, node)) -_ANY_OP_MAPPING = { - ops.Any: ops.UnresolvedExistsSubquery, - ops.NotAny: ops.UnresolvedNotExistsSubquery, -} - - def find_predicates(node, flatten=True): # TODO(kszucs): consider to remove flatten argument and compose with # flatten_predicates instead @@ -663,28 +657,6 @@ def find_subqueries(node: ops.Node, min_dependents=1) -> tuple[ops.Node, ...]: ) -# TODO(kszucs): move to types/logical.py -def _make_any( - expr, - any_op_class: type[ops.Any] | type[ops.NotAny], - *, - where: ir.BooleanValue | None = None, -): - assert isinstance(expr, ir.Expr), type(expr) - - tables = find_immediate_parent_tables(expr.op()) - predicates = find_predicates(expr.op(), flatten=True) - - if len(tables) > 1: - op = _ANY_OP_MAPPING[any_op_class]( - tables=[t.to_expr() for t in tables], - predicates=predicates, - ) - else: - op = any_op_class(expr, where=expr._bind_reduction_filter(where)) - return op.to_expr() - - # TODO(kszucs): use substitute instead @functools.singledispatch def _rewrite_filter(op, **kwargs): @@ -708,7 +680,6 @@ def _rewrite_filter_reduction(op, name: str | None = None, **kwargs): @_rewrite_filter.register(ops.TableColumn) @_rewrite_filter.register(ops.Literal) @_rewrite_filter.register(ops.ExistsSubquery) -@_rewrite_filter.register(ops.NotExistsSubquery) @_rewrite_filter.register(ops.WindowFunction) def _rewrite_filter_subqueries(op, **kwargs): """Don't rewrite any of these operations in filters.""" diff --git a/ibis/expr/operations/generic.py b/ibis/expr/operations/generic.py index 02d25ba65b2cb..eb3398ef7cdfe 100644 --- a/ibis/expr/operations/generic.py +++ b/ibis/expr/operations/generic.py @@ -1,6 +1,5 @@ from __future__ import annotations -import abc import itertools from typing import Annotated, Any, Optional, Union from typing import Literal as LiteralType @@ -13,7 +12,6 @@ import ibis.expr.datatypes as dt import ibis.expr.rules as rlz from ibis.common.annotations import attribute -from ibis.common.bases import Abstract from ibis.common.deferred import Deferred # noqa: TCH001 from ibis.common.grounds import Singleton from ibis.common.patterns import InstanceOf, Length # noqa: TCH001 @@ -313,9 +311,3 @@ def shape(self): def dtype(self): exprs = [*self.results, self.default] return rlz.highest_precedence_dtype(exprs) - - -class _Negatable(Abstract): - @abc.abstractmethod - def negate(self): # pragma: no cover - ... diff --git a/ibis/expr/operations/logical.py b/ibis/expr/operations/logical.py index 7e66956466fda..5330e4dd36a99 100644 --- a/ibis/expr/operations/logical.py +++ b/ibis/expr/operations/logical.py @@ -1,7 +1,5 @@ from __future__ import annotations -import abc - from public import public import ibis.expr.datashape as ds @@ -11,7 +9,6 @@ from ibis.common.exceptions import IbisTypeError from ibis.common.typing import VarTuple # noqa: TCH001 from ibis.expr.operations.core import Binary, Column, Unary, Value -from ibis.expr.operations.generic import _Negatable from ibis.expr.operations.relations import Relation # noqa: TCH001 @@ -170,30 +167,16 @@ def dtype(self): @public -class ExistsSubquery(Value, _Negatable): +class ExistsSubquery(Value): foreign_table: Relation predicates: VarTuple[Value[dt.Boolean]] dtype = dt.boolean shape = ds.columnar - def negate(self) -> NotExistsSubquery: - return NotExistsSubquery(*self.args) - @public -class NotExistsSubquery(Value, _Negatable): - foreign_table: Relation - predicates: VarTuple[Value[dt.Boolean]] - - dtype = dt.boolean - shape = ds.columnar - - def negate(self) -> ExistsSubquery: - return ExistsSubquery(*self.args) - - -class _UnresolvedSubquery(Value, _Negatable): +class UnresolvedExistsSubquery(Value): """An exists subquery whose outer leaf table is unknown. Notes @@ -231,8 +214,8 @@ class _UnresolvedSubquery(Value, _Negatable): Notably the correlated subquery cannot stand on its own. - The purpose of `_UnresolvedSubquery` is to capture enough information about - an exists predicate such that it can be resolved when predicates are + The purpose of `UnresolvedExistsSubquery` is to capture enough information + about an exists predicate such that it can be resolved when predicates are resolved against the outer leaf table when `Selection`s are constructed. """ @@ -242,36 +225,10 @@ class _UnresolvedSubquery(Value, _Negatable): dtype = dt.boolean shape = ds.columnar - @abc.abstractmethod - def _resolve( - self, table - ) -> type[ExistsSubquery] | type[NotExistsSubquery]: # pragma: no cover - ... - - -@public -class UnresolvedExistsSubquery(_UnresolvedSubquery): - def negate(self) -> UnresolvedNotExistsSubquery: - return UnresolvedNotExistsSubquery(*self.args) - - def _resolve(self, table) -> ExistsSubquery: + def resolve(self, table) -> ExistsSubquery: from ibis.expr.operations.relations import TableNode assert isinstance(table, TableNode) (foreign_table,) = (t for t in self.tables if t != table) return ExistsSubquery(foreign_table, self.predicates).to_expr() - - -@public -class UnresolvedNotExistsSubquery(_UnresolvedSubquery): - def negate(self) -> UnresolvedExistsSubquery: - return UnresolvedExistsSubquery(*self.args) - - def _resolve(self, table) -> NotExistsSubquery: - from ibis.expr.operations.relations import TableNode - - assert isinstance(table, TableNode) - - (foreign_table,) = (t for t in self.tables if t != table) - return NotExistsSubquery(foreign_table, self.predicates).to_expr() diff --git a/ibis/expr/operations/reductions.py b/ibis/expr/operations/reductions.py index 747fe8bbd9815..d6370b9a41941 100644 --- a/ibis/expr/operations/reductions.py +++ b/ibis/expr/operations/reductions.py @@ -10,7 +10,6 @@ import ibis.expr.rules as rlz from ibis.common.annotations import attribute from ibis.expr.operations.core import Column, Value -from ibis.expr.operations.generic import _Negatable from ibis.expr.operations.relations import Relation # noqa: TCH001 @@ -327,40 +326,14 @@ def dtype(self): @public -class All(Filterable, Reduction, _Negatable): +class All(Filterable, Reduction): arg: Column[dt.Boolean] dtype = dt.boolean - def negate(self): - return NotAll(self.arg) - - -@public -class NotAll(Filterable, Reduction, _Negatable): - arg: Column[dt.Boolean] - - dtype = dt.boolean - - def negate(self) -> Any: - return All(*self.args) - @public -class Any(Filterable, Reduction, _Negatable): +class Any(Filterable, Reduction): arg: Column[dt.Boolean] dtype = dt.boolean - - def negate(self) -> NotAny: - return NotAny(*self.args) - - -@public -class NotAny(Filterable, Reduction, _Negatable): - arg: Column[dt.Boolean] - - dtype = dt.boolean - - def negate(self) -> Any: - return Any(*self.args) diff --git a/ibis/expr/types/logical.py b/ibis/expr/types/logical.py index 4f79ca210d36a..1aed08b290655 100644 --- a/ibis/expr/types/logical.py +++ b/ibis/expr/types/logical.py @@ -269,7 +269,17 @@ def any(self, where: BooleanValue | None = None) -> BooleanValue: """ import ibis.expr.analysis as an - return an._make_any(self, ops.Any, where=where) + tables = an.find_immediate_parent_tables(self.op()) + + if len(tables) > 1: + op = ops.UnresolvedExistsSubquery( + tables=[t.to_expr() for t in tables], + predicates=an.find_predicates(self.op(), flatten=True), + ) + else: + op = ops.Any(self, where=where) + + return op.to_expr() def notany(self, where: BooleanValue | None = None) -> BooleanValue: """Return whether no elements are `True`. @@ -297,9 +307,7 @@ def notany(self, where: BooleanValue | None = None) -> BooleanValue: >>> (t.arr == None).notany(where=t.arr != None) True """ - import ibis.expr.analysis as an - - return an._make_any(self, ops.NotAny, where=where) + return ~self.any(where=where) def all(self, where: BooleanValue | None = None) -> BooleanScalar: """Return whether all elements are `True`. @@ -358,7 +366,7 @@ def notall(self, where: BooleanValue | None = None) -> BooleanScalar: >>> (t.arr == 2).notall(where=t.arr >= 2) True """ - return ops.NotAll(self, where=self._bind_reduction_filter(where)).to_expr() + return ~self.all(where=where) def cumany(self, *, where=None, group_by=None, order_by=None) -> BooleanColumn: """Accumulate the `any` aggregate. diff --git a/ibis/expr/types/relations.py b/ibis/expr/types/relations.py index 02656575875d0..85f5350b463ee 100644 --- a/ibis/expr/types/relations.py +++ b/ibis/expr/types/relations.py @@ -4161,8 +4161,9 @@ def release(self): def _resolve_predicates( table: Table, predicates ) -> tuple[list[ir.BooleanValue], list[tuple[ir.BooleanValue, ir.Table]]]: - import ibis.expr.analysis as an import ibis.expr.types as ir + from ibis.expr.analysis import p, flatten_predicate + from ibis.common.deferred import _, Attr, Call # TODO(kszucs): clean this up, too much flattening and resolving happens here predicates = [ @@ -4173,14 +4174,15 @@ def _resolve_predicates( ) for pred in util.promote_list(preds) ] - predicates = an.flatten_predicate(predicates) - - resolved_predicates = [] - for pred in predicates: - if isinstance(pred, ops.logical._UnresolvedSubquery): - resolved_predicates.append(pred._resolve(table.op())) - else: - resolved_predicates.append(pred) + predicates = flatten_predicate(predicates) + + # _.resolve is actually a non-deferred method, so it won't dispatch to + # the matched UnresolvedExistsSubquery.resolve() method + # TODO(kszucs): remove Deferred.resolve() method + replacement = Call(Attr(_, "resolve"), table.op()) + resolved_predicates = [ + pred.replace(p.UnresolvedExistsSubquery >> replacement) for pred in predicates + ] return resolved_predicates diff --git a/ibis/tests/expr/test_table.py b/ibis/tests/expr/test_table.py index 2366792966dcf..83d377ee7a32d 100644 --- a/ibis/tests/expr/test_table.py +++ b/ibis/tests/expr/test_table.py @@ -1260,71 +1260,19 @@ def t2(): return ibis.table([("key1", "string"), ("key2", "string")], "bar") -@pytest.mark.parametrize( - ("func", "expected_type"), - [ - param( - lambda t1, t2: (t1.key1 == t2.key1).any(), - ops.UnresolvedExistsSubquery, - id="exists", - ), - param( - lambda t1, t2: -(t1.key1 == t2.key1).any(), - ops.UnresolvedNotExistsSubquery, - id="not_exists", - ), - param( - lambda t1, t2: -(-(t1.key1 == t2.key1).any()), # noqa: B002 - ops.UnresolvedExistsSubquery, - id="not_not_exists", - ), - ], -) -def test_unresolved_existence_predicate(t1, t2, func, expected_type): - expr = func(t1, t2) +def test_unresolved_existence_predicate(t1, t2): + expr = (t1.key1 == t2.key1).any() assert isinstance(expr, ir.BooleanColumn) - - op = expr.op() - assert isinstance(op, expected_type) + assert isinstance(expr.op(), ops.UnresolvedExistsSubquery) -@pytest.mark.parametrize( - ("func", "expected_type", "expected_negated_type"), - [ - param( - lambda t1, t2: t1[(t1.key1 == t2.key1).any()], - ops.ExistsSubquery, - ops.NotExistsSubquery, - id="exists", - ), - param( - lambda t1, t2: t1[-(t1.key1 == t2.key1).any()], - ops.NotExistsSubquery, - ops.ExistsSubquery, - id="not_exists", - ), - param( - lambda t1, t2: t1[-(-(t1.key1 == t2.key1).any())], # noqa: B002 - ops.ExistsSubquery, - ops.NotExistsSubquery, - id="not_not_exists", - ), - ], -) -def test_resolve_existence_predicate( - t1, - t2, - func, - expected_type, - expected_negated_type, -): - expr = func(t1, t2) +def test_resolve_existence_predicate(t1, t2): + expr = t1[(t1.key1 == t2.key1).any()] op = expr.op() assert isinstance(op, ops.Selection) pred = op.predicates[0].to_expr() - assert isinstance(pred.op(), expected_type) - assert isinstance((-pred).op(), expected_negated_type) + assert isinstance(pred.op(), ops.ExistsSubquery) def test_aggregate_metrics(table):