diff --git a/ibis/backends/clickhouse/tests/snapshots/test_select/test_join_self_reference/out.sql b/ibis/backends/clickhouse/tests/snapshots/test_select/test_join_self_reference/out.sql index 66518891122e..059d27e85c4a 100644 --- a/ibis/backends/clickhouse/tests/snapshots/test_select/test_join_self_reference/out.sql +++ b/ibis/backends/clickhouse/tests/snapshots/test_select/test_join_self_reference/out.sql @@ -13,5 +13,5 @@ SELECT "t1"."year", "t1"."month" FROM "functional_alltypes" AS "t1" -INNER JOIN "functional_alltypes" AS "t3" - ON "t1"."id" = "t3"."id" \ No newline at end of file +INNER JOIN "functional_alltypes" AS "t2" + ON "t1"."id" = "t2"."id" \ No newline at end of file diff --git a/ibis/backends/sql/compiler.py b/ibis/backends/sql/compiler.py index a7c8fc7b5b6c..171359b48550 100644 --- a/ibis/backends/sql/compiler.py +++ b/ibis/backends/sql/compiler.py @@ -1155,6 +1155,8 @@ def visit_DatabaseTable( def visit_SelfReference(self, op, *, parent, identifier): return parent + visit_JoinReference = visit_SelfReference + def visit_JoinChain(self, op, *, first, rest, values): result = sg.select(*self._cleanup_names(values), copy=False).from_( first, copy=False @@ -1385,9 +1387,6 @@ def visit_SQLStringView(self, op, *, query: str, child, schema): def visit_SQLQueryResult(self, op, *, query, schema, source): return sg.parse_one(query, dialect=self.dialect).subquery(copy=False) - def visit_JoinTable(self, op, *, parent, index): - return parent - def visit_RegexExtract(self, op, *, arg, pattern, index): return self.f.regexp_extract(arg, pattern, index, dialect=self.dialect) diff --git a/ibis/backends/tests/sql/snapshots/test_select_sql/test_limit_cte_extract/out.sql b/ibis/backends/tests/sql/snapshots/test_select_sql/test_limit_cte_extract/out.sql index ca9b89c74630..e94ea21a070d 100644 --- a/ibis/backends/tests/sql/snapshots/test_select_sql/test_limit_cte_extract/out.sql +++ b/ibis/backends/tests/sql/snapshots/test_select_sql/test_limit_cte_extract/out.sql @@ -19,5 +19,5 @@ SELECT "t3"."year", "t3"."month" FROM "t1" AS "t3" -INNER JOIN "t1" AS "t5" +INNER JOIN "t1" AS "t4" ON TRUE \ No newline at end of file diff --git a/ibis/backends/tests/sql/snapshots/test_select_sql/test_limit_with_self_join/out.sql b/ibis/backends/tests/sql/snapshots/test_select_sql/test_limit_with_self_join/out.sql index 2cda81f9a69c..bd5a9e2ba9db 100644 --- a/ibis/backends/tests/sql/snapshots/test_select_sql/test_limit_with_self_join/out.sql +++ b/ibis/backends/tests/sql/snapshots/test_select_sql/test_limit_with_self_join/out.sql @@ -15,20 +15,20 @@ FROM ( "t1"."timestamp_col", "t1"."year", "t1"."month", - "t3"."id" AS "id_right", - "t3"."bool_col" AS "bool_col_right", - "t3"."tinyint_col" AS "tinyint_col_right", - "t3"."smallint_col" AS "smallint_col_right", - "t3"."int_col" AS "int_col_right", - "t3"."bigint_col" AS "bigint_col_right", - "t3"."float_col" AS "float_col_right", - "t3"."double_col" AS "double_col_right", - "t3"."date_string_col" AS "date_string_col_right", - "t3"."string_col" AS "string_col_right", - "t3"."timestamp_col" AS "timestamp_col_right", - "t3"."year" AS "year_right", - "t3"."month" AS "month_right" + "t2"."id" AS "id_right", + "t2"."bool_col" AS "bool_col_right", + "t2"."tinyint_col" AS "tinyint_col_right", + "t2"."smallint_col" AS "smallint_col_right", + "t2"."int_col" AS "int_col_right", + "t2"."bigint_col" AS "bigint_col_right", + "t2"."float_col" AS "float_col_right", + "t2"."double_col" AS "double_col_right", + "t2"."date_string_col" AS "date_string_col_right", + "t2"."string_col" AS "string_col_right", + "t2"."timestamp_col" AS "timestamp_col_right", + "t2"."year" AS "year_right", + "t2"."month" AS "month_right" FROM "functional_alltypes" AS "t1" - INNER JOIN "functional_alltypes" AS "t3" - ON "t1"."tinyint_col" < EXTRACT(minute FROM "t3"."timestamp_col") -) AS "t4" \ No newline at end of file + INNER JOIN "functional_alltypes" AS "t2" + ON "t1"."tinyint_col" < EXTRACT(minute FROM "t2"."timestamp_col") +) AS "t3" \ No newline at end of file diff --git a/ibis/backends/tests/sql/snapshots/test_select_sql/test_self_join_subquery_distinct_equal/out.sql b/ibis/backends/tests/sql/snapshots/test_select_sql/test_self_join_subquery_distinct_equal/out.sql index 52fe23dbf9e4..321a92d13028 100644 --- a/ibis/backends/tests/sql/snapshots/test_select_sql/test_self_join_subquery_distinct_equal/out.sql +++ b/ibis/backends/tests/sql/snapshots/test_select_sql/test_self_join_subquery_distinct_equal/out.sql @@ -1,6 +1,6 @@ SELECT "t2"."r_name", - "t6"."n_name" + "t5"."n_name" FROM "tpch_region" AS "t2" INNER JOIN "tpch_nation" AS "t3" ON "t2"."r_regionkey" = "t3"."n_regionkey" @@ -16,5 +16,5 @@ INNER JOIN ( FROM "tpch_region" AS "t2" INNER JOIN "tpch_nation" AS "t3" ON "t2"."r_regionkey" = "t3"."n_regionkey" -) AS "t6" - ON "t2"."r_regionkey" = "t6"."r_regionkey" \ No newline at end of file +) AS "t5" + ON "t2"."r_regionkey" = "t5"."r_regionkey" \ No newline at end of file diff --git a/ibis/backends/tests/sql/snapshots/test_select_sql/test_subquery_in_union/out.sql b/ibis/backends/tests/sql/snapshots/test_select_sql/test_subquery_in_union/out.sql index 0d28860d3c25..6c4beaaccf15 100644 --- a/ibis/backends/tests/sql/snapshots/test_select_sql/test_subquery_in_union/out.sql +++ b/ibis/backends/tests/sql/snapshots/test_select_sql/test_subquery_in_union/out.sql @@ -7,25 +7,25 @@ WITH "t1" AS ( GROUP BY 1, 2 -), "t6" AS ( +), "t5" AS ( SELECT "t3"."a", "t3"."g", "t3"."metric" FROM "t1" AS "t3" - INNER JOIN "t1" AS "t5" - ON "t3"."g" = "t5"."g" + INNER JOIN "t1" AS "t4" + ON "t3"."g" = "t4"."g" ) SELECT - "t9"."a", - "t9"."g", - "t9"."metric" + "t8"."a", + "t8"."g", + "t8"."metric" FROM ( SELECT * - FROM "t6" AS "t7" + FROM "t5" AS "t6" UNION ALL SELECT * - FROM "t6" AS "t8" -) AS "t9" \ No newline at end of file + FROM "t5" AS "t7" +) AS "t8" \ No newline at end of file diff --git a/ibis/backends/tests/sql/snapshots/test_select_sql/test_subquery_used_for_self_join/out.sql b/ibis/backends/tests/sql/snapshots/test_select_sql/test_subquery_used_for_self_join/out.sql index 4039868d3eac..a9b27450205b 100644 --- a/ibis/backends/tests/sql/snapshots/test_select_sql/test_subquery_used_for_self_join/out.sql +++ b/ibis/backends/tests/sql/snapshots/test_select_sql/test_subquery_used_for_self_join/out.sql @@ -11,21 +11,21 @@ WITH "t1" AS ( 3 ) SELECT - "t6"."g", - MAX("t6"."total" - "t6"."total_right") AS "metric" + "t5"."g", + MAX("t5"."total" - "t5"."total_right") AS "metric" FROM ( SELECT "t3"."g", "t3"."a", "t3"."b", "t3"."total", - "t5"."g" AS "g_right", - "t5"."a" AS "a_right", - "t5"."b" AS "b_right", - "t5"."total" AS "total_right" + "t4"."g" AS "g_right", + "t4"."a" AS "a_right", + "t4"."b" AS "b_right", + "t4"."total" AS "total_right" FROM "t1" AS "t3" - INNER JOIN "t1" AS "t5" - ON "t3"."a" = "t5"."b" -) AS "t6" + INNER JOIN "t1" AS "t4" + ON "t3"."a" = "t4"."b" +) AS "t5" GROUP BY 1 \ No newline at end of file diff --git a/ibis/backends/tests/sql/snapshots/test_select_sql/test_tpch_self_join_failure/out.sql b/ibis/backends/tests/sql/snapshots/test_select_sql/test_tpch_self_join_failure/out.sql index bf5316341bdf..18f3ee847814 100644 --- a/ibis/backends/tests/sql/snapshots/test_select_sql/test_tpch_self_join_failure/out.sql +++ b/ibis/backends/tests/sql/snapshots/test_select_sql/test_tpch_self_join_failure/out.sql @@ -24,9 +24,9 @@ WITH "t9" AS ( SELECT "t11"."region", "t11"."year", - "t11"."total" - "t13"."total" AS "yoy_change" + "t11"."total" - "t12"."total" AS "yoy_change" FROM "t9" AS "t11" -INNER JOIN "t9" AS "t13" +INNER JOIN "t9" AS "t12" ON "t11"."year" = ( - "t13"."year" - CAST(1 AS TINYINT) + "t12"."year" - CAST(1 AS TINYINT) ) \ No newline at end of file diff --git a/ibis/backends/tests/sql/snapshots/test_sql/test_cte_factor_distinct_but_equal/out.sql b/ibis/backends/tests/sql/snapshots/test_sql/test_cte_factor_distinct_but_equal/out.sql index 47a9c20dfbed..d4c24d4f2192 100644 --- a/ibis/backends/tests/sql/snapshots/test_sql/test_cte_factor_distinct_but_equal/out.sql +++ b/ibis/backends/tests/sql/snapshots/test_sql/test_cte_factor_distinct_but_equal/out.sql @@ -16,5 +16,5 @@ INNER JOIN ( FROM "alltypes" AS "t1" GROUP BY 1 -) AS "t6" - ON "t3"."g" = "t6"."g" \ No newline at end of file +) AS "t5" + ON "t3"."g" = "t5"."g" \ No newline at end of file diff --git a/ibis/backends/tests/sql/snapshots/test_sql/test_self_reference_join/out.sql b/ibis/backends/tests/sql/snapshots/test_sql/test_self_reference_join/out.sql index ac92348404fd..927c6c569ef6 100644 --- a/ibis/backends/tests/sql/snapshots/test_sql/test_self_reference_join/out.sql +++ b/ibis/backends/tests/sql/snapshots/test_sql/test_self_reference_join/out.sql @@ -4,5 +4,5 @@ SELECT "t1"."foo_id", "t1"."bar_id" FROM "star1" AS "t1" -INNER JOIN "star1" AS "t3" - ON "t1"."foo_id" = "t3"."bar_id" \ No newline at end of file +INNER JOIN "star1" AS "t2" + ON "t1"."foo_id" = "t2"."bar_id" \ No newline at end of file diff --git a/ibis/backends/tests/tpch/snapshots/test_h07/test_tpc_h07/duckdb/h07.sql b/ibis/backends/tests/tpch/snapshots/test_h07/test_tpc_h07/duckdb/h07.sql index fca7ba9abc97..8a5689ab25a2 100644 --- a/ibis/backends/tests/tpch/snapshots/test_h07/test_tpc_h07/duckdb/h07.sql +++ b/ibis/backends/tests/tpch/snapshots/test_h07/test_tpc_h07/duckdb/h07.sql @@ -1,27 +1,27 @@ SELECT - "t14"."supp_nation", - "t14"."cust_nation", - "t14"."l_year", - "t14"."revenue" + "t13"."supp_nation", + "t13"."cust_nation", + "t13"."l_year", + "t13"."revenue" FROM ( SELECT - "t13"."supp_nation", - "t13"."cust_nation", - "t13"."l_year", - SUM("t13"."volume") AS "revenue" + "t12"."supp_nation", + "t12"."cust_nation", + "t12"."l_year", + SUM("t12"."volume") AS "revenue" FROM ( SELECT - "t12"."supp_nation", - "t12"."cust_nation", - "t12"."l_shipdate", - "t12"."l_extendedprice", - "t12"."l_discount", - "t12"."l_year", - "t12"."volume" + "t11"."supp_nation", + "t11"."cust_nation", + "t11"."l_shipdate", + "t11"."l_extendedprice", + "t11"."l_discount", + "t11"."l_year", + "t11"."volume" FROM ( SELECT "t9"."n_name" AS "supp_nation", - "t11"."n_name" AS "cust_nation", + "t10"."n_name" AS "cust_nation", "t6"."l_shipdate", "t6"."l_extendedprice", "t6"."l_discount", @@ -38,34 +38,34 @@ FROM ( ON "t8"."c_custkey" = "t7"."o_custkey" INNER JOIN "nation" AS "t9" ON "t5"."s_nationkey" = "t9"."n_nationkey" - INNER JOIN "nation" AS "t11" - ON "t8"."c_nationkey" = "t11"."n_nationkey" - ) AS "t12" + INNER JOIN "nation" AS "t10" + ON "t8"."c_nationkey" = "t10"."n_nationkey" + ) AS "t11" WHERE ( ( ( - "t12"."cust_nation" = 'FRANCE' + "t11"."cust_nation" = 'FRANCE' ) AND ( - "t12"."supp_nation" = 'GERMANY' + "t11"."supp_nation" = 'GERMANY' ) ) OR ( ( - "t12"."cust_nation" = 'GERMANY' + "t11"."cust_nation" = 'GERMANY' ) AND ( - "t12"."supp_nation" = 'FRANCE' + "t11"."supp_nation" = 'FRANCE' ) ) ) - AND "t12"."l_shipdate" BETWEEN MAKE_DATE(1995, 1, 1) AND MAKE_DATE(1996, 12, 31) - ) AS "t13" + AND "t11"."l_shipdate" BETWEEN MAKE_DATE(1995, 1, 1) AND MAKE_DATE(1996, 12, 31) + ) AS "t12" GROUP BY 1, 2, 3 -) AS "t14" +) AS "t13" ORDER BY - "t14"."supp_nation" ASC, - "t14"."cust_nation" ASC, - "t14"."l_year" ASC \ No newline at end of file + "t13"."supp_nation" ASC, + "t13"."cust_nation" ASC, + "t13"."l_year" ASC \ No newline at end of file diff --git a/ibis/backends/tests/tpch/snapshots/test_h08/test_tpc_h08/duckdb/h08.sql b/ibis/backends/tests/tpch/snapshots/test_h08/test_tpc_h08/duckdb/h08.sql index e2a61a509543..3fe20e69b03d 100644 --- a/ibis/backends/tests/tpch/snapshots/test_h08/test_tpc_h08/duckdb/h08.sql +++ b/ibis/backends/tests/tpch/snapshots/test_h08/test_tpc_h08/duckdb/h08.sql @@ -1,26 +1,26 @@ SELECT - "t18"."o_year", - "t18"."mkt_share" + "t17"."o_year", + "t17"."mkt_share" FROM ( SELECT - "t17"."o_year", - SUM("t17"."nation_volume") / SUM("t17"."volume") AS "mkt_share" + "t16"."o_year", + SUM("t16"."nation_volume") / SUM("t16"."volume") AS "mkt_share" FROM ( SELECT - "t16"."o_year", - "t16"."volume", - "t16"."nation", - "t16"."r_name", - "t16"."o_orderdate", - "t16"."p_type", - CASE WHEN "t16"."nation" = 'BRAZIL' THEN "t16"."volume" ELSE CAST(0 AS TINYINT) END AS "nation_volume" + "t15"."o_year", + "t15"."volume", + "t15"."nation", + "t15"."r_name", + "t15"."o_orderdate", + "t15"."p_type", + CASE WHEN "t15"."nation" = 'BRAZIL' THEN "t15"."volume" ELSE CAST(0 AS TINYINT) END AS "nation_volume" FROM ( SELECT EXTRACT(year FROM "t10"."o_orderdate") AS "o_year", "t8"."l_extendedprice" * ( CAST(1 AS TINYINT) - "t8"."l_discount" ) AS "volume", - "t15"."n_name" AS "nation", + "t13"."n_name" AS "nation", "t14"."r_name", "t10"."o_orderdate", "t7"."p_type" @@ -37,16 +37,16 @@ FROM ( ON "t11"."c_nationkey" = "t12"."n_nationkey" INNER JOIN "region" AS "t14" ON "t12"."n_regionkey" = "t14"."r_regionkey" - INNER JOIN "nation" AS "t15" - ON "t9"."s_nationkey" = "t15"."n_nationkey" - ) AS "t16" + INNER JOIN "nation" AS "t13" + ON "t9"."s_nationkey" = "t13"."n_nationkey" + ) AS "t15" WHERE - "t16"."r_name" = 'AMERICA' - AND "t16"."o_orderdate" BETWEEN MAKE_DATE(1995, 1, 1) AND MAKE_DATE(1996, 12, 31) - AND "t16"."p_type" = 'ECONOMY ANODIZED STEEL' - ) AS "t17" + "t15"."r_name" = 'AMERICA' + AND "t15"."o_orderdate" BETWEEN MAKE_DATE(1995, 1, 1) AND MAKE_DATE(1996, 12, 31) + AND "t15"."p_type" = 'ECONOMY ANODIZED STEEL' + ) AS "t16" GROUP BY 1 -) AS "t18" +) AS "t17" ORDER BY - "t18"."o_year" ASC \ No newline at end of file + "t17"."o_year" ASC \ No newline at end of file diff --git a/ibis/common/graph.py b/ibis/common/graph.py index 7654e4b54f39..2535a7308611 100644 --- a/ibis/common/graph.py +++ b/ibis/common/graph.py @@ -2,13 +2,13 @@ from __future__ import annotations +import itertools from abc import abstractmethod from collections import deque from collections.abc import Iterable, Iterator, KeysView, Mapping, Sequence from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union from ibis.common.bases import Hashable -from ibis.common.collections import frozendict from ibis.common.patterns import NoMatch, Pattern from ibis.common.typing import _ClassInfo from ibis.util import experimental, promote_list @@ -66,8 +66,9 @@ def _flatten_collections(node: Any) -> Iterator[N]: yield item elif isinstance(item, (tuple, list)): yield from _flatten_collections(item) - elif isinstance(item, (dict, frozendict)): - yield from _flatten_collections(item.values()) + elif isinstance(item, dict): + items = itertools.chain.from_iterable(item.items()) + yield from _flatten_collections(items) def _recursive_lookup(obj: Any, dct: dict) -> Any: @@ -117,8 +118,10 @@ def _recursive_lookup(obj: Any, dct: dict) -> Any: return dct.get(obj, obj) elif isinstance(obj, (tuple, list)): return tuple(_recursive_lookup(o, dct) for o in obj) - elif isinstance(obj, (dict, frozendict)): - return {k: _recursive_lookup(v, dct) for k, v in obj.items()} + elif isinstance(obj, dict): + return { + _recursive_lookup(k, dct): _recursive_lookup(v, dct) for k, v in obj.items() + } else: return obj diff --git a/ibis/expr/decompile.py b/ibis/expr/decompile.py index 43452fa0c09a..c7ea1015975a 100644 --- a/ibis/expr/decompile.py +++ b/ibis/expr/decompile.py @@ -192,8 +192,8 @@ def self_reference(op, parent, identifier): return f"{parent}.view()" -@translate.register(ops.JoinTable) -def join_table(op, parent, index): +@translate.register(ops.JoinReference) +def join_reference(op, parent, identifier): return parent @@ -353,7 +353,7 @@ class CodeContext: ) always_ignore = ( - ops.JoinTable, + ops.JoinReference, ops.Field, dt.Primitive, dt.Variadic, diff --git a/ibis/expr/format.py b/ibis/expr/format.py index 973a0fd1d847..fca4e0dbda26 100644 --- a/ibis/expr/format.py +++ b/ibis/expr/format.py @@ -15,6 +15,7 @@ import ibis.expr.operations as ops import ibis.expr.types as ir from ibis import util +from ibis.common.graph import Node _infix_ops = { # comparison operations @@ -151,7 +152,8 @@ def get_defining_frame(expr): """Locate the outermost frame where `expr` is defined.""" for frame_info in inspect.stack()[::-1]: for var in frame_info.frame.f_locals.values(): - if isinstance(var, ir.Expr) and expr.equals(var): + # if isinstance(var, ir.Expr) and expr.equals(var): + if var is expr: return frame_info.frame raise ValueError(f"No defining frame found for {expr}") @@ -191,7 +193,7 @@ def pretty(expr: ops.Node | ir.Expr, scope: Optional[dict[str, ir.Expr]] = None) """ if isinstance(expr, ir.Expr): node = expr.op() - elif isinstance(expr, ops.Node): + elif isinstance(expr, Node): node = expr else: raise TypeError(f"Expected an expression or a node, got {type(expr)}") @@ -208,7 +210,7 @@ def mapper(op, _, **kwargs): if var := variables.get(op): refs[op] = result result = var - elif isinstance(op, ops.Relation) and not isinstance(op, ops.JoinTable): + elif isinstance(op, ops.Relation) and not isinstance(op, ops.JoinReference): refs[op] = result result = f"r{next(refcnt)}" return Rendered(result) @@ -391,8 +393,8 @@ def _self_reference(op, parent, **kwargs): return f"{op.__class__.__name__}[{parent}]" -@fmt.register(ops.JoinTable) -def _join_table(op, parent, index): +@fmt.register(ops.JoinReference) +def _join_reference(op, parent, **kwargs): return parent diff --git a/ibis/expr/operations/relations.py b/ibis/expr/operations/relations.py index 4e38359520e8..4618de7a4d52 100644 --- a/ibis/expr/operations/relations.py +++ b/ibis/expr/operations/relations.py @@ -133,11 +133,10 @@ def schema(self): return self.parent.schema -# TODO(kszucs): remove in favor of View @public -class SelfReference(Simple): +class Reference(Relation): _uid_counter = itertools.count() - + parent: Relation identifier: Optional[int] = None def __init__(self, parent, identifier): @@ -145,9 +144,22 @@ def __init__(self, parent, identifier): identifier = next(self._uid_counter) super().__init__(parent=parent, identifier=identifier) + @attribute + def schema(self): + return self.parent.schema + + +# TODO(kszucs): remove in favor of View +@public +class SelfReference(Reference): + values = FrozenDict() + + +@public +class JoinReference(Reference): @attribute def values(self): - return FrozenDict() + return self.parent.fields JoinKind = Literal[ @@ -164,21 +176,16 @@ def values(self): ] -@public -class JoinTable(Simple): - index: int - - @public class JoinLink(Node): how: JoinKind - table: JoinTable + table: Reference predicates: VarTuple[Value[dt.Boolean]] @public class JoinChain(Relation): - first: JoinTable + first: Reference rest: VarTuple[JoinLink] values: FrozenDict[str, Unaliased[Value]] @@ -194,6 +201,10 @@ def __init__(self, first, rest, values): _check_integrity(values.values(), allowed_parents) super().__init__(first=first, rest=rest, values=values) + @property + def tables(self): + return [self.first] + [link.table for link in self.rest] + @property def length(self): return len(self.rest) + 1 diff --git a/ibis/expr/rewrites.py b/ibis/expr/rewrites.py index 5b03b024d89e..787fd63c6870 100644 --- a/ibis/expr/rewrites.py +++ b/ibis/expr/rewrites.py @@ -3,16 +3,21 @@ from __future__ import annotations import functools +from collections import defaultdict from collections.abc import Mapping import toolz import ibis.expr.datatypes as dt import ibis.expr.operations as ops +from ibis.common.collections import FrozenDict # noqa: TCH001 from ibis.common.deferred import Item, _, deferred, var -from ibis.common.exceptions import ExpressionError +from ibis.common.exceptions import ExpressionError, IbisInputError +from ibis.common.graph import Node as Traversable +from ibis.common.grounds import Concrete from ibis.common.patterns import Check, pattern, replace -from ibis.util import Namespace +from ibis.common.typing import VarTuple # noqa: TCH001 +from ibis.util import Namespace, promote_list p = Namespace(pattern, module=ops) d = Namespace(deferred, module=ops) @@ -23,6 +28,55 @@ name = var("name") +class DerefMap(Concrete, Traversable): + rels: VarTuple[ops.Relation] + subs: FrozenDict[ops.Value, ops.Value] + ambigs: FrozenDict[ops.Value, VarTuple[ops.Value]] + + @classmethod + def from_targets(cls, rels, extra=None): + rels = promote_list(rels) + mapping = defaultdict(dict) + for rel in rels: + for field in rel.fields.values(): + for value, distance in cls.backtrack(field): + mapping[value][field] = distance + + subs, ambigs = {}, {} + for from_, to in mapping.items(): + mindist = min(to.values()) + minkeys = [k for k, v in to.items() if v == mindist] + if len(minkeys) == 1: + subs[from_] = minkeys[0] + else: + ambigs[from_] = minkeys + + if extra is not None: + subs.update(extra) + + return cls(rels, subs, ambigs) + + @classmethod + def backtrack(cls, value): + distance = 0 + # track down the field in the hierarchy until no modification + # is made so only follow ops.Field nodes not arbitrary values; + while isinstance(value, ops.Field): + yield value, distance + value = value.rel.values.get(value.name) + distance += 1 + if value is not None: + yield value, distance + + def dereference(self, value): + ambigs = value.find(lambda x: x in self.ambigs, filter=ops.Value) + if ambigs: + raise IbisInputError( + f"Ambiguous field reference {ambigs!r} in expression {value!r}" + ) + return value.replace(self.subs, filter=ops.Value) + + @replace(p.Field(p.JoinChain)) def peel_join_field(_): return _.rel.values[_.name] diff --git a/ibis/expr/tests/snapshots/test_format/test_asof_join/repr.txt b/ibis/expr/tests/snapshots/test_format/test_asof_join/repr.txt index a29fc083cf3a..3e39b232a162 100644 --- a/ibis/expr/tests/snapshots/test_format/test_asof_join/repr.txt +++ b/ibis/expr/tests/snapshots/test_format/test_asof_join/repr.txt @@ -6,15 +6,17 @@ r1 := UnboundTable: right time2 int32 value2 float64 +r2 := SelfReference[r1] + JoinChain[r0] JoinLink[asof, r1] r0.time1 >= r1.time2 - JoinLink[inner, r1] - r0.value == r1.value2 + JoinLink[inner, r2] + r0.value == r2.value2 values: time1: r0.time1 value: r0.value time2: r1.time2 value2: r1.value2 - time2_right: r1.time2 - value2_right: r1.value2 \ No newline at end of file + time2_right: r2.time2 + value2_right: r2.value2 \ No newline at end of file diff --git a/ibis/expr/tests/snapshots/test_format/test_two_inner_joins/repr.txt b/ibis/expr/tests/snapshots/test_format/test_two_inner_joins/repr.txt index 672faadf9ba2..0facaef4d746 100644 --- a/ibis/expr/tests/snapshots/test_format/test_two_inner_joins/repr.txt +++ b/ibis/expr/tests/snapshots/test_format/test_two_inner_joins/repr.txt @@ -8,11 +8,13 @@ r1 := UnboundTable: right value2 float64 b string +r2 := SelfReference[r1] + JoinChain[r0] JoinLink[inner, r1] r0.a == r1.b - JoinLink[inner, r1] - r0.value == r1.value2 + JoinLink[inner, r2] + r0.value == r2.value2 values: time1: r0.time1 value: r0.value @@ -20,6 +22,6 @@ JoinChain[r0] time2: r1.time2 value2: r1.value2 b: r1.b - time2_right: r1.time2 - value2_right: r1.value2 - b_right: r1.b \ No newline at end of file + time2_right: r2.time2 + value2_right: r2.value2 + b_right: r2.b \ No newline at end of file diff --git a/ibis/expr/tests/test_dereference.py b/ibis/expr/tests/test_dereference.py index e19234f92084..4396c0ac5fb0 100644 --- a/ibis/expr/tests/test_dereference.py +++ b/ibis/expr/tests/test_dereference.py @@ -1,7 +1,7 @@ from __future__ import annotations import ibis -from ibis.expr.types.relations import dereference_mapping +from ibis.expr.types.relations import DerefMap t = ibis.table( [ @@ -20,7 +20,7 @@ def dereference_expect(expected): def test_dereference_project(): p = t.select([t.int_col, t.double_col]) - mapping = dereference_mapping([p.op()]) + mapping = DerefMap.from_targets([p.op()]) expected = dereference_expect( { p.int_col: p.int_col, @@ -29,13 +29,13 @@ def test_dereference_project(): t.double_col: p.double_col, } ) - assert mapping == expected + assert mapping.subs == expected def test_dereference_mapping_self_reference(): v = t.view() - mapping = dereference_mapping([v.op()]) + mapping = DerefMap.from_targets([v.op()]) expected = dereference_expect( { v.int_col: v.int_col, @@ -43,4 +43,4 @@ def test_dereference_mapping_self_reference(): v.string_col: v.string_col, } ) - assert mapping == expected + assert mapping.subs == expected diff --git a/ibis/expr/tests/test_format.py b/ibis/expr/tests/test_format.py index a9a03efd9877..cce86be6ea8a 100644 --- a/ibis/expr/tests/test_format.py +++ b/ibis/expr/tests/test_format.py @@ -308,8 +308,9 @@ def test_fillna(snapshot): def test_asof_join(snapshot): left = ibis.table([("time1", "int32"), ("value", "double")], name="left") right = ibis.table([("time2", "int32"), ("value2", "double")], name="right") + right_ = right.view() joined = left.asof_join(right, ("time1", "time2")).inner_join( - right, left.value == right.value2 + right_, left.value == right_.value2 ) result = repr(joined) @@ -323,8 +324,9 @@ def test_two_inner_joins(snapshot): right = ibis.table( [("time2", "int32"), ("value2", "double"), ("b", "string")], name="right" ) + right_ = right.view() joined = left.inner_join(right, left.a == right.b).inner_join( - right, left.value == right.value2 + right_, left.value == right_.value2 ) result = repr(joined) diff --git a/ibis/expr/tests/test_newrels.py b/ibis/expr/tests/test_newrels.py index 9635459d4314..ef25a6a65bce 100644 --- a/ibis/expr/tests/test_newrels.py +++ b/ibis/expr/tests/test_newrels.py @@ -35,8 +35,8 @@ @contextlib.contextmanager -def join_tables(*tables): - yield tuple(ops.JoinTable(t, i).to_expr() for i, t in enumerate(tables)) +def join_tables(table): + yield [t.to_expr() for t in table.op().tables] def test_field(): @@ -565,7 +565,7 @@ def test_join(): assert isinstance(joined.op(), JoinChain) assert isinstance(joined.op().to_expr(), ir.Join) - with join_tables(t1, t2) as (t1, t2): + with join_tables(joined) as (t1, t2): assert result.op() == JoinChain( first=t1, rest=[ @@ -584,16 +584,16 @@ def test_join_integrity_checks(): t1 = ibis.table(name="t1", schema={"a": "int64", "b": "string"}) # correct example - r1 = ops.JoinTable(t1, 10) - r2 = ops.JoinTable(t1, 20) + r1 = ops.JoinReference(t1, 10) + r2 = ops.JoinReference(t1, 20) assert r1 != r2 assert hash(r1) != hash(r2) chain = ops.JoinChain(r1, [ops.JoinLink("inner", r2, [True])], values={}) assert isinstance(chain, JoinChain) # not unique tables - r1 = ops.JoinTable(t1, 10) - r2 = ops.JoinTable(t1, 10) + r1 = ops.JoinReference(t1, 10) + r2 = ops.JoinReference(t1, 10) assert r1 == r2 assert hash(r1) == hash(r2) with pytest.raises(IntegrityError): @@ -609,7 +609,7 @@ def test_join_unambiguous_select(): expr2 = join.select("a_int", "b_int") assert expr1.equals(expr2) - with join_tables(a, b) as (r1, r2): + with join_tables(join) as (r1, r2): assert expr1.op() == JoinChain( first=r1, rest=[JoinLink("inner", r2, [r1.a_int == r2.b_int])], @@ -627,7 +627,7 @@ def test_join_with_subsequent_projection(): # a single computed value is pulled to a subsequent projection joined = t1.join(t2, [t1.a == t2.c]) expr = joined.select(t1.a, t1.b, col=t2.c + 1) - with join_tables(t1, t2) as (r1, r2): + with join_tables(joined) as (r1, r2): expected = JoinChain( first=r1, rest=[JoinLink("inner", r2, [r1.a == r2.c])], @@ -645,7 +645,7 @@ def test_join_with_subsequent_projection(): baz=t2.d.name("bar") + "3", baz2=(t2.c + t1.a).name("foo"), ) - with join_tables(t1, t2) as (r1, r2): + with join_tables(joined) as (r1, r2): expected = JoinChain( first=r1, rest=[JoinLink("inner", r2, [r1.a == r2.c])], @@ -674,7 +674,7 @@ def test_join_with_subsequent_projection_colliding_names(): foo=t2.a + 1, bar=t1.a + t2.a, ) - with join_tables(t1, t2) as (r1, r2): + with join_tables(expr) as (r1, r2): expected = JoinChain( first=r1, rest=[JoinLink("inner", r2, [r1.a == r2.a])], @@ -695,7 +695,7 @@ def test_chained_join(): joined = a.join(b, [a.a == b.c]).join(c, [a.a == c.e]) result = joined._finish() - with join_tables(a, b, c) as (r1, r2, r3): + with join_tables(joined) as (r1, r2, r3): assert result.op() == JoinChain( first=r1, rest=[ @@ -715,7 +715,7 @@ def test_chained_join(): joined = a.join(b, [a.a == b.c]).join(c, [b.c == c.e]) result = joined.select(a.a, b.d, c.f) - with join_tables(a, b, c) as (r1, r2, r3): + with join_tables(joined) as (r1, r2, r3): assert result.op() == JoinChain( first=r1, rest=[ @@ -739,7 +739,7 @@ def test_chained_join_referencing_intermediate_table(): abc = ab.join(c, [ab.a == c.e]) result = abc._finish() - with join_tables(a, b, c) as (r1, r2, r3): + with join_tables(abc) as (r1, r2, r3): assert result.op() == JoinChain( first=r1, rest=[ @@ -772,7 +772,7 @@ def test_join_predicate_dereferencing(): # dereference table.foo_id to filtered.foo_id j1 = filtered.left_join(table2, table["foo_id"] == table2["foo_id"]) - with join_tables(filtered, table2) as (r1, r2): + with join_tables(j1) as (r1, r2): expected = ops.JoinChain( first=r1, rest=[ @@ -793,7 +793,7 @@ def test_join_predicate_dereferencing(): j1 = filtered.left_join(table2, table["foo_id"] == table2["foo_id"]) j2 = j1.inner_join(table3, filtered["bar_id"] == table3["bar_id"]) view = j2[[filtered, table2["value1"], table3["value2"]]] - with join_tables(filtered, table2, table3) as (r1, r2, r3): + with join_tables(j2) as (r1, r2, r3): expected = ops.JoinChain( first=r1, rest=[ @@ -835,7 +835,7 @@ def test_join_predicate_dereferencing_using_tuple_syntax(): j1 = ibis.join(t2, t3, [(t2.x, t3.x)]) j2 = ibis.join(t2, t4, [(t2.x, t4.x)]) - with join_tables(t2, t3) as (r1, r2): + with join_tables(j1) as (r1, r2): expected = ops.JoinChain( first=r1, rest=[ @@ -849,7 +849,7 @@ def test_join_predicate_dereferencing_using_tuple_syntax(): ) assert j1.op() == expected - with join_tables(t2, t4) as (r1, r2): + with join_tables(j2) as (r1, r2): expected = ops.JoinChain( first=r1, rest=[ @@ -864,6 +864,44 @@ def test_join_predicate_dereferencing_using_tuple_syntax(): assert j2.op() == expected +# +def test_join_rhs_dereferencing(): +# + t1 = ibis.table(name="t1", schema={"a": "int64", "b": "string"}) +# + t2 = ibis.table(name="t2", schema={"c": "int64", "d": "string"}) +# + +# + t3 = t2.mutate(e=t2.c + 1) +# + joined = t1.join(t3, [t1.a == t2.c]) +# + expected = JoinChain( +# + first=t1, +# + rest=[ +# + JoinLink("inner", t3, [t1.a == t3.c]), +# + ], +# + values={ +# + "a": t1.a, +# + "b": t1.b, +# + "c": t3.c, +# + "d": t3.d, +# + "e": t3.e, +# + }, +# + ) +# + assert joined.op() == expected +# + +# + joined = t1.join(t3, [t1.a == (t2.c + 1)]) +# + expected = JoinChain( +# + first=t1, +# + rest=[ +# + JoinLink("inner", t3, [t1.a == t3.e]), +# + ], +# + values={ +# + "a": t1.a, +# + "b": t1.b, +# + "c": t3.c, +# + "d": t3.d, +# + "e": t3.e, +# + }, +# + ) +# + assert joined.op() == expected + + def test_aggregate(): agg = t.aggregate(by=[t.bool_col], metrics=[t.int_col.sum()]) expected = Aggregate( @@ -1063,7 +1101,7 @@ def test_self_join(): t3 = t2.join(t2, ["key"]) t4 = t3.join(t3, ["key"]) - with join_tables(t2, t2, t3) as (r1, r2, r3): + with join_tables(t4) as (r1, r2, r3): expected = ops.JoinChain( first=r1, rest=[ @@ -1089,7 +1127,7 @@ def test_self_join_view(): t_view = t.view() expr = t.join(t_view, t.x == t_view.y).select("x", "y", "z", "z_right") - with join_tables(t, t_view) as (r1, r2): + with join_tables(expr) as (r1, r2): expected = ops.JoinChain( first=r1, rest=[ @@ -1105,7 +1143,7 @@ def test_self_join_with_view_projection(): t2 = t1.view() expr = t1.inner_join(t2, ["x"])[[t1]] - with join_tables(t1, t2) as (r1, r2): + with join_tables(expr) as (r1, r2): expected = ops.JoinChain( first=r1, rest=[ @@ -1120,10 +1158,14 @@ def test_joining_same_table_twice(): left = ibis.table(name="left", schema={"time1": int, "value": float, "a": str}) right = ibis.table(name="right", schema={"time2": int, "value2": float, "b": str}) - joined = left.inner_join(right, left.a == right.b).inner_join( - right, left.value == right.value2 - ) - with join_tables(left, right, right) as (r1, r2, r3): + first = left.inner_join(right, left.a == right.b) + + with pytest.raises(IbisInputError, match="Ambiguous field reference"): + first.inner_join(right, left.value == right.value2) + + right_ = right.view() + second = first.inner_join(right_, left.value == right_.value2) + with join_tables(second) as (r1, r2, r3): expected = ops.JoinChain( first=r1, rest=[ @@ -1142,7 +1184,7 @@ def test_joining_same_table_twice(): "b_right": r3.b, }, ) - assert joined.op() == expected + assert second.op() == expected def test_join_chain_gets_reused_and_continued_after_a_select(): @@ -1153,7 +1195,7 @@ def test_join_chain_gets_reused_and_continued_after_a_select(): ab = a.join(b, [a.a == b.c]) abc = ab[a.b, b.d].join(c, [a.a == c.e]) - with join_tables(a, b, c) as (r1, r2, r3): + with join_tables(abc) as (r1, r2, r3): expected = ops.JoinChain( first=r1, rest=[ @@ -1173,10 +1215,28 @@ def test_join_chain_gets_reused_and_continued_after_a_select(): def test_self_join_extensive(): a = ibis.table(name="a", schema={"a": "int64", "b": "string"}) - aa = a.join(a, [a.a == a.a]) + with pytest.raises(IbisInputError, match="Ambiguous field reference"): + a.join(a, [a.a == a.a]) + + a_ = a.view() + aa = a.join(a_, [a.a == a_.a]) + with join_tables(aa) as (r1, r2): + expected = ops.JoinChain( + first=r1, + rest=[ + ops.JoinLink("inner", r2, [r1.a == r2.a]), + ], + values={ + "a": r1.a, + "b": r1.b, + "b_right": r2.b, + }, + ) + assert aa.op() == expected + aa1 = a.join(a, "a") aa2 = a.join(a, [("a", "a")]) - with join_tables(a, a) as (r1, r2): + with join_tables(aa1) as (r1, r2): expected = ops.JoinChain( first=r1, rest=[ @@ -1188,15 +1248,13 @@ def test_self_join_extensive(): "b_right": r2.b, }, ) - assert aa.op() == expected assert aa1.op() == expected assert aa2.op() == expected - aaa = a.join(a, [a.a == a.a]).join(a, [a.a == a.a]) - aaa1 = aa.join(a, [aa.a == a.a]) - aaa2 = aa.join(a, "a") - aaa3 = aa.join(a, [("a", "a")]) - with join_tables(a, a, a) as (r1, r2, r3): + aaa = a.join(a, "a").join(a, "a") + aaa1 = aa1.join(a, "a") + aaa2 = aa1.join(a, [("a", "a")]) + with join_tables(aaa) as (r1, r2, r3): expected = ops.JoinChain( first=r1, rest=[ @@ -1212,14 +1270,32 @@ def test_self_join_extensive(): assert aaa.op() == expected assert aaa1.op() == expected assert aaa2.op() == expected - assert aaa3.op() == expected def test_self_join_with_intermediate_selection(): a = ibis.table(name="a", schema={"a": "int64", "b": "string"}) proj = a[["b", "a"]] + # the predicate only references the original table, unless we enforce + # that the predicates must contain both sides of the join, we can't + # do much with this, perhaps raise a warning join = proj.join(a, [a.a == a.a]) - with join_tables(proj, a) as (r1, r2): + with join_tables(join) as (r1, r2): + expected = ops.JoinChain( + first=r1, + rest=[ + ops.JoinLink("inner", r2, [r2.a == r2.a]), + ], + values={ + "b": r1.b, + "a": r1.a, + "a_right": r2.a, + "b_right": r2.b, + }, + ) + assert join.op() == expected + + join = proj.join(a, [proj.a == a.a]) + with join_tables(join) as (r1, r2): expected = ops.JoinChain( first=r1, rest=[ @@ -1233,9 +1309,11 @@ def test_self_join_with_intermediate_selection(): ) assert join.op() == expected - aa = a.join(a, [a.a == a.a])["a", "b_right"] - aaa = aa.join(a, [aa.a == a.a]) - with join_tables(a, a, a) as (r1, r2, r3): + a_ = a.view() + a__ = a.view() + aa = a.join(a_, [a.a == a_.a])["a", "b_right"] + aaa = aa.join(a__, [aa.a == a__.a]) + with join_tables(aaa) as (r1, r2, r3): expected = ops.JoinChain( first=r1, rest=[ @@ -1306,7 +1384,7 @@ def test_self_view_join_followed_by_aggregate_correctly_dereference_fields(): join = agged.inner_join(view, [agged.a == view.b]) agg = join.aggregate(metrics, by=[agged.g]) - with join_tables(agged, view) as (r1, r2): + with join_tables(join) as (r1, r2): expected_join = ops.JoinChain( first=r1, rest=[ @@ -1367,7 +1445,7 @@ def test_join_between_joins(): exprs = [left, right.value3, right.value4] expr = joined.select(exprs) - with join_tables(t1, t2, right) as (r1, r2, r3): + with join_tables(expr) as (r1, r2, r3): expected = ops.JoinChain( first=r1, rest=[ @@ -1393,7 +1471,7 @@ def test_join_with_filtered_join_of_left(): joined = t1.left_join(t2, [t1.a == t2.a]).filter(t1.a < 5) expr = t1.left_join(joined, [t1.a == joined.a]).select(t1) - with join_tables(t1, joined) as (r1, r2): + with join_tables(expr) as (r1, r2): expected = ops.JoinChain( first=r1, rest=[ @@ -1451,7 +1529,7 @@ def test_join_with_compound_predicate(): ], ) expr = joined[t1] - with join_tables(t1, t2) as (r1, r2): + with join_tables(joined) as (r1, r2): expected = ops.JoinChain( first=r1, rest=[ @@ -1486,7 +1564,7 @@ def test_inner_join_convenience(): t5 = ibis.table(name="t5", schema={"a": "int64", "f": "string"}) first_join = t1.inner_join(t2, [t1.a == t2.a]) - with join_tables(t1, t2) as (r1, r2): + with join_tables(first_join) as (r1, r2): expected = ops.JoinChain( first=r1, rest=[ @@ -1504,7 +1582,7 @@ def test_inner_join_convenience(): # note that we are joining on r2.a which isn't among the values second_join = first_join.inner_join(t3, [r2.a == t3.a]) - with join_tables(t1, t2, t3) as (r1, r2, r3): + with join_tables(second_join) as (r1, r2, r3): expected = ops.JoinChain( first=r1, rest=[ @@ -1523,7 +1601,7 @@ def test_inner_join_convenience(): assert result == expected third_join = second_join.left_join(t4, [r3.a == t4.a]) - with join_tables(t1, t2, t3, t4) as (r1, r2, r3, r4): + with join_tables(third_join) as (r1, r2, r3, r4): expected = ops.JoinChain( first=r1, rest=[ @@ -1545,7 +1623,7 @@ def test_inner_join_convenience(): assert result == expected fourth_join = third_join.inner_join(t5, [r3.a == t5.a], rname="{name}_") - with join_tables(t1, t2, t3, t4, t5) as (r1, r2, r3, r4, r5): + with join_tables(fourth_join) as (r1, r2, r3, r4, r5): # equality groups are being reset expected = ops.JoinChain( first=r1, @@ -1575,7 +1653,7 @@ def test_inner_join_convenience(): third_join.inner_join(t5, [r4.a == t5.a])._finish() fifth_join = third_join.inner_join(t5, [r4.a == t5.a], rname="{name}_") - with join_tables(t1, t2, t3, t4, t5) as (r1, r2, r3, r4, r5): + with join_tables(fifth_join) as (r1, r2, r3, r4, r5): # equality groups are being reset expected = ops.JoinChain( first=r1, diff --git a/ibis/expr/types/joins.py b/ibis/expr/types/joins.py index 71a8b11ab061..83559847fd89 100644 --- a/ibis/expr/types/joins.py +++ b/ibis/expr/types/joins.py @@ -20,9 +20,9 @@ from ibis.expr.rewrites import peel_join_field from ibis.expr.types.generic import Value from ibis.expr.types.relations import ( + DerefMap, Table, bind, - dereference_mapping, unwrap_aliases, ) @@ -112,47 +112,18 @@ def disambiguate_fields( return fields, collisions, equalities -def dereference_mapping_left(chain): - # construct the list of join table we wish to dereference fields to - rels = [chain.first] - for link in chain.rest: - if link.how not in ("semi", "anti"): - rels.append(link.table) - - # create the dereference mapping suitable to disambiguate field references - # from earlier in the relation hierarchy to one of the join tables - subs = dereference_mapping(rels) - - # also allow to dereference fields of the join chain itself - for k, v in chain.values.items(): - subs[ops.Field(chain, k)] = v - - return subs - - -def dereference_mapping_right(right): - # the right table is wrapped in a JoinTable the uniqueness of the underlying - # table which requires the predicates to be dereferenced to the wrapped - return {v: ops.Field(right, k) for k, v in right.values.items()} - - -def dereference_sides(left, right, deref_left, deref_right): - left = left.replace(deref_left, filter=ops.Value) - right = right.replace(deref_right, filter=ops.Value) - return left, right - - -def dereference_value(pred, deref_left, deref_right): - deref_both = {**deref_left, **deref_right} - if isinstance(pred, ops.Comparison) and pred.left.relations == pred.right.relations: - left, right = dereference_sides(pred.left, pred.right, deref_left, deref_right) - return pred.copy(left=left, right=right) - else: - return pred.replace(deref_both, filter=ops.Value) +def dereference_value(pred, deref_left, deref_right, deref_both): + return deref_both.dereference(pred) + # if isinstance(pred, ops.Comparison) and pred.left.relations == pred.right.relations: + # left = deref_left.dereference(pred.left) + # right = deref_right.dereference(pred.right) + # return pred.copy(left=left, right=right) + # else: + # return deref_both.dereference(pred) def prepare_predicates( - left: ops.JoinChain, + chain: ops.JoinChain, right: ops.Relation, predicates: Sequence[Any], comparison: type[ops.Comparison] = ops.Equals, @@ -187,8 +158,8 @@ def prepare_predicates( Parameters ---------- - left - The left table + chain + The join chain right The right table predicates @@ -197,21 +168,25 @@ def prepare_predicates( The comparison operation to construct if the input is a pair of expression-like objects """ - deref_left = dereference_mapping_left(left) - deref_right = dereference_mapping_right(right) + reverse = {ops.Field(chain, k): v for k, v in chain.values.items()} + deref_right = DerefMap.from_targets(right) + deref_left = DerefMap.from_targets(chain.tables, extra=reverse) + deref_both = DerefMap.from_targets([*chain.tables, right], extra=reverse) - left, right = left.to_expr(), right.to_expr() + left, right = chain.to_expr(), right.to_expr() for pred in util.promote_list(predicates): if pred is True or pred is False: yield ops.Literal(pred, dtype="bool") elif isinstance(pred, Value): for node in flatten_predicates(pred.op()): - yield dereference_value(node, deref_left, deref_right) + yield dereference_value(node, deref_left, deref_right, deref_both) + # yield deref_both.dereference(node) elif isinstance(pred, Deferred): # resolve deferred expressions on the left table pred = pred.resolve(left).op() for node in flatten_predicates(pred): - yield dereference_value(node, deref_left, deref_right) + yield dereference_value(node, deref_left, deref_right, deref_both) + # yield deref_both.dereference(node) else: if isinstance(pred, tuple): if len(pred) != 2: @@ -225,9 +200,9 @@ def prepare_predicates( (right_value,) = bind(right, rk) # dereference the left value to one of the relations in the join chain - left_value, right_value = dereference_sides( - left_value.op(), right_value.op(), deref_left, deref_right - ) + left_value = deref_left.dereference(left_value.op()) + right_value = deref_right.dereference(right_value.op()) + yield comparison(left_value, right_value) @@ -248,11 +223,8 @@ class Join(Table): def __init__(self, arg, collisions=(), equalities=()): assert isinstance(arg, ops.Node) if not isinstance(arg, ops.JoinChain): - # coerce the input node to a join chain operation by first wrapping - # the input relation in a JoinTable so that we can join the same - # table with itself multiple times and to enable optimization - # passes later on - arg = ops.JoinTable(arg, index=0) + # coerce the input node to a join chain operation + arg = ops.JoinReference(arg, identifier=0) arg = ops.JoinChain(arg, rest=(), values=arg.fields) super().__init__(arg) # the collisions and equalities are used to track the name collisions @@ -297,11 +269,13 @@ def join( elif how == "asof": raise IbisInputError("use table.asof_join(...) instead") - left = self.op() - right = ops.JoinTable(right, index=left.length) + chain = self.op() + right = right.op() + if not isinstance(right, ops.Reference): + right = ops.JoinReference(right, identifier=chain.length) # bind and dereference the predicates - preds = list(prepare_predicates(left, right, predicates)) + preds = list(prepare_predicates(chain, right, predicates)) if not preds and how != "cross": # if there are no predicates, default to every row matching unless # the join is a cross join, because a cross join already has this @@ -316,7 +290,7 @@ def join( how=how, predicates=preds, equalities=self._equalities, - left_fields=left.values, + left_fields=chain.values, right_fields=right.fields, left_template=lname, right_template=rname, @@ -324,10 +298,10 @@ def join( # construct a new join link and add it to the join chain link = ops.JoinLink(how, table=right, predicates=preds) - left = left.copy(rest=left.rest + (link,), values=values) + chain = chain.copy(rest=chain.rest + (link,), values=values) # return with a new JoinExpr wrapping the new join chain - return self.__class__(left, collisions=collisions, equalities=equalities) + return self.__class__(chain, collisions=collisions, equalities=equalities) @functools.wraps(Table.asof_join) def asof_join( @@ -382,19 +356,21 @@ def asof_join( values = {**self.op().values, **filtered.op().values} return result.select(values) - left = self.op() - right = ops.JoinTable(right, index=left.length) + chain = self.op() + right = right.op() + if not isinstance(right, ops.Reference): + right = ops.JoinReference(right, identifier=chain.length) # TODO(kszucs): add extra validation for `on` with clear error messages - (on,) = prepare_predicates(left, right, [on], comparison=ops.GreaterEqual) - preds = prepare_predicates(left, right, predicates, comparison=ops.Equals) + (on,) = prepare_predicates(chain, right, [on], comparison=ops.GreaterEqual) + preds = prepare_predicates(chain, right, predicates, comparison=ops.Equals) preds = [on, *preds] values, collisions, equalities = disambiguate_fields( how="asof", predicates=preds, equalities=self._equalities, - left_fields=left.values, + left_fields=chain.values, right_fields=right.fields, left_template=lname, right_template=rname, @@ -402,10 +378,10 @@ def asof_join( # construct a new join link and add it to the join chain link = ops.JoinLink("asof", table=right, predicates=preds) - left = left.copy(rest=left.rest + (link,), values=values) + chain = chain.copy(rest=chain.rest + (link,), values=values) # return with a new JoinExpr wrapping the new join chain - return self.__class__(left, collisions=collisions, equalities=equalities) + return self.__class__(chain, collisions=collisions, equalities=equalities) @functools.wraps(Table.cross_join) def cross_join( @@ -428,13 +404,15 @@ def select(self, *args, **kwargs): values = bind(self, (args, kwargs)) values = unwrap_aliases(values) + links = [link.table for link in chain.rest if link.how not in ("semi", "anti")] + derefmap = DerefMap.from_targets([chain.first, *links]) + # if there are values referencing fields from the join chain constructed # so far, we need to replace them the fields from one of the join links - subs = dereference_mapping_left(chain) values = { k: v.replace(peel_join_field, filter=ops.Value) for k, v in values.items() } - values = {k: v.replace(subs, filter=ops.Value) for k, v in values.items()} + values = {k: derefmap.dereference(v) for k, v in values.items()} node = chain.copy(values=values) return Table(node) diff --git a/ibis/expr/types/relations.py b/ibis/expr/types/relations.py index 6aff29612a31..5eef12d0eb39 100644 --- a/ibis/expr/types/relations.py +++ b/ibis/expr/types/relations.py @@ -5,12 +5,7 @@ import re from collections.abc import Iterable, Iterator, Mapping, Sequence from keyword import iskeyword -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Literal, -) +from typing import TYPE_CHECKING, Any, Callable, Literal import toolz from public import public @@ -22,6 +17,7 @@ import ibis.expr.schema as sch from ibis import util from ibis.common.deferred import Deferred, Resolver +from ibis.expr.rewrites import DerefMap from ibis.expr.types.core import Expr, _FixedTextJupyterMixin from ibis.expr.types.generic import Value, literal from ibis.expr.types.pretty import to_rich @@ -149,33 +145,6 @@ def unwrap_aliases(values: Iterator[ir.Value]) -> Mapping[str, ir.Value]: return result -def dereference_mapping(parents): - parents = util.promote_list(parents) - mapping = {} - - for parent in parents: - # do not defereference fields referencing the requested parents - for _, v in parent.fields.items(): - mapping[v] = v - - for parent in parents: - for k, v in parent.values.items(): - if isinstance(v, ops.Field): - # track down the field in the hierarchy until no modification - # is made so only follow ops.Field nodes not arbitrary values; - # also stop tracking if the field belongs to a parent which - # we want to dereference to, see the docstring of - # `dereference_values()` for more details - while isinstance(v, ops.Field) and v not in mapping: - mapping[v] = ops.Field(parent, k) - v = v.rel.values.get(v.name) - elif v not in mapping: - # do not dereference literal expressions - mapping[v] = ops.Field(parent, k) - - return mapping - - def dereference_values( parents: Iterable[ops.Parents], values: Mapping[str, ops.Value] ) -> Mapping[str, ops.Value]: @@ -214,8 +183,8 @@ def dereference_values( The same mapping as `values` but with all the dereferenceable fields replaced with the fields from the parents. """ - subs = dereference_mapping(parents) - return {k: v.replace(subs, filter=ops.Value) for k, v in values.items()} + dis = DerefMap.from_targets(parents) + return {k: dis.dereference(v) for k, v in values.items()} @public @@ -1185,23 +1154,21 @@ def aggregate( groups = bind(self, by) metrics = bind(self, (metrics, kwargs)) - having = bind(self, having) + having = tuple(bind(self, having)) groups = unwrap_aliases(groups) metrics = unwrap_aliases(metrics) - having = unwrap_aliases(having) groups = dereference_values(node, groups) metrics = dereference_values(node, metrics) - having = dereference_values(node, having) # the user doesn't need to specify the metrics used in the having clause # explicitly, we implicitly add them to the metrics list by looking for # any metrics depending on self which are not specified explicitly pattern = p.Reduction(relations=Contains(node)) & ~In(set(metrics.values())) original_metrics = metrics.copy() - for pred in having.values(): - for metric in pred.find_topmost(pattern): + for pred in having: + for metric in pred.op().find_topmost(pattern): if metric.name in metrics: metrics[util.get_name("metric")] = metric else: @@ -1212,7 +1179,7 @@ def aggregate( if having: # apply the having clause - agg = agg.filter(*having.values()) + agg = agg.filter(*having) # remove any metrics that were only used in the having clause if metrics != original_metrics: agg = agg.select(*groups.keys(), *original_metrics.keys()) diff --git a/ibis/tests/expr/test_analysis.py b/ibis/tests/expr/test_analysis.py index aeacf249edf0..527bbab84c8f 100644 --- a/ibis/tests/expr/test_analysis.py +++ b/ibis/tests/expr/test_analysis.py @@ -27,7 +27,7 @@ def test_rewrite_join_projection_without_other_ops(con): # Project out the desired fields view = j2[[filtered, table2["value1"], table3["value2"]]] - with join_tables(filtered, table2, table3) as (r1, r2, r3): + with join_tables(j2) as (r1, r2, r3): # Construct the thing we expect to obtain expected = ops.JoinChain( first=r1, @@ -170,7 +170,7 @@ def test_filter_self_join(): what = [left.region, metric] projected = joined.select(what) - with join_tables(left, right) as (r1, r2): + with join_tables(joined) as (r1, r2): join = ops.JoinChain( first=r1, rest=[ diff --git a/ibis/tests/expr/test_struct.py b/ibis/tests/expr/test_struct.py index 75707c650f2c..de297a4e4a58 100644 --- a/ibis/tests/expr/test_struct.py +++ b/ibis/tests/expr/test_struct.py @@ -72,16 +72,15 @@ def test_unpack_from_table(t): def test_lift_join(t, s): join = t.join(s, t.d == s.a.g) result = join.a_right.lift() - - with join_tables(t, s) as (r1, r2): - join = ops.JoinChain( - first=r1, + with join_tables(join) as (t, s): + expected = ops.JoinChain( + first=t, rest=[ - ops.JoinLink("inner", r2, [r1.d == r2.a.g]), + ops.JoinLink("inner", s, [t.d == s.a.g]), ], - values={"f": r2.a.f, "g": r2.a.g}, + values={"f": s.a.f, "g": s.a.g}, ) - assert result.op() == join + assert result.op() == expected def test_unpack_join_from_table(t, s): diff --git a/ibis/tests/expr/test_table.py b/ibis/tests/expr/test_table.py index 420682ae1810..c8b2ed0103ce 100644 --- a/ibis/tests/expr/test_table.py +++ b/ibis/tests/expr/test_table.py @@ -901,7 +901,7 @@ def test_join_no_predicate_list(con): pred = region.r_regionkey == nation.n_regionkey joined = region.inner_join(nation, pred) - with join_tables(region, nation) as (r1, r2): + with join_tables(joined) as (r1, r2): expected = ops.JoinChain( first=r1, rest=[ops.JoinLink("inner", r2, [r1.r_regionkey == r2.n_regionkey])], @@ -924,7 +924,7 @@ def test_join_deferred(con): res = region.join(nation, _.r_regionkey == nation.n_regionkey) - with join_tables(region, nation) as (r1, r2): + with join_tables(res) as (r1, r2): expected = ops.JoinChain( first=r1, rest=[ops.JoinLink("inner", r2, [r1.r_regionkey == r2.n_regionkey])], @@ -970,7 +970,7 @@ def test_asof_join_with_by(): right = ibis.table([("time", "int32"), ("key", "int32"), ("value2", "double")]) join_without_by = api.asof_join(left, right, "time") - with join_tables(left, right) as (r1, r2): + with join_tables(join_without_by) as (r1, r2): r2 = join_without_by.op().rest[0].table.to_expr() expected = ops.JoinChain( first=r1, @@ -987,7 +987,7 @@ def test_asof_join_with_by(): assert join_without_by.op() == expected join_with_predicates = api.asof_join(left, right, "time", predicates="key") - with join_tables(left, right) as (r1, r2): + with join_tables(join_with_predicates) as (r1, r2): expected = ops.JoinChain( first=r1, rest=[ @@ -1205,7 +1205,7 @@ def test_cross_join_multiple(table): c = table["f", "h"] joined = ibis.cross_join(a, b, c) - with join_tables(a, b, c) as (r1, r2, r3): + with join_tables(joined) as (r1, r2, r3): expected = ops.JoinChain( first=r1, rest=[ @@ -1292,7 +1292,7 @@ def test_join_key_alternatives(con, key_maker): key = key_maker(t1, t2) joined = t1.inner_join(t2, key) - with join_tables(t1, t2) as (r1, r2): + with join_tables(joined) as (r1, r2): expected = ops.JoinChain( first=r1, rest=[ @@ -1376,7 +1376,7 @@ def test_unravel_compound_equijoin(table): p3 = t1.key3 == t2.key3 joined = t1.inner_join(t2, [p1 & p2 & p3]) - with join_tables(t1, t2) as (r1, r2): + with join_tables(joined) as (r1, r2): expected = ops.JoinChain( first=r1, rest=[