From 593ec6ae7cbe793c97d9fcc0d3e2cf6830b0f4d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Fri, 22 Dec 2023 18:42:30 +0100 Subject: [PATCH] refactor(ir): add `JoinTable` operation unique to `JoinChain` instead of using the globally unique `SelfReference` This enables us to maintain join expression equality: `a.join(b).equals(a.join(b))` So far we have been using SelfReference to make join tables unique, but it was globally unique which broke the equality check above. Therefore we need to restrict the uniqueness to the scope of the join chain. The simplest solution for that is to simply enumerate the join tables in the join chain, hence now all join participants must be `ops.JoinTable(rel, index)` instances. `ops.SelfReference` is still required to distinguish between two identical tables at the API level, but it is now decoupled from the join internal representation. --- ibis/expr/decompile.py | 7 +- ibis/expr/format.py | 7 +- ibis/expr/operations/relations.py | 15 +- .../test_format/test_asof_join/repr.txt | 28 +-- .../repr.txt | 30 +-- .../repr.txt | 18 +- .../test_table_count_expr/join_repr.txt | 20 +- .../test_format/test_two_inner_joins/repr.txt | 34 ++- ibis/expr/tests/test_newrels.py | 213 ++++++++---------- ibis/expr/types/joins.py | 38 ++-- ibis/expr/types/relations.py | 11 +- .../test_memoize_database_table/repr.txt | 32 ++- ibis/tests/expr/test_analysis.py | 25 +- ibis/tests/expr/test_struct.py | 10 +- ibis/tests/expr/test_table.py | 46 ++-- 15 files changed, 238 insertions(+), 296 deletions(-) diff --git a/ibis/expr/decompile.py b/ibis/expr/decompile.py index af279447bd76..79b53dbe4ba2 100644 --- a/ibis/expr/decompile.py +++ b/ibis/expr/decompile.py @@ -184,6 +184,11 @@ def self_reference(op, parent, identifier): return parent +@translate.register(ops.JoinTable) +def join_table(op, parent, index): + return parent + + @translate.register(ops.JoinLink) def join_link(op, table, predicates, how): return f".{how}_join({table}, {_try_unwrap(predicates)})" @@ -327,7 +332,7 @@ def isin(op, value, options): class CodeContext: always_assign = (ops.ScalarParameter, ops.UnboundTable, ops.Aggregate) always_ignore = ( - ops.SelfReference, + ops.JoinTable, ops.Field, dt.Primitive, dt.Variadic, diff --git a/ibis/expr/format.py b/ibis/expr/format.py index 6ac9dfeb7b8a..4a904a99462b 100644 --- a/ibis/expr/format.py +++ b/ibis/expr/format.py @@ -162,7 +162,7 @@ def pretty(node): def mapper(op, _, **kwargs): result = fmt(op, **kwargs) - if isinstance(op, ops.Relation): + if isinstance(op, ops.Relation) and not isinstance(op, ops.JoinTable): tables[op] = result result = f"r{next(refcnt)}" return Rendered(result) @@ -337,6 +337,11 @@ def _self_reference(op, parent, **kwargs): return f"{op.__class__.__name__}[{parent}]" +@fmt.register(ops.JoinTable) +def _join_table(op, parent, index): + return parent + + @fmt.register(ops.Literal) def _literal(op, value, **kwargs): if op.dtype.is_interval(): diff --git a/ibis/expr/operations/relations.py b/ibis/expr/operations/relations.py index 9a2ecef54622..a4c37ce2912b 100644 --- a/ibis/expr/operations/relations.py +++ b/ibis/expr/operations/relations.py @@ -229,27 +229,38 @@ def name(self) -> str: ] +@public +class JoinTable(Simple): + index: int + + @public class JoinLink(Node): how: JoinKind - table: SelfReference + table: JoinTable predicates: VarTuple[Value[dt.Boolean]] @public class JoinChain(Relation): - first: SelfReference + first: JoinTable rest: VarTuple[JoinLink] values: FrozenDict[str, Unaliased[Value]] def __init__(self, first, rest, values): allowed_parents = {first} + assert first.index == 0 for join in rest: + assert join.table.index == len(allowed_parents) allowed_parents.add(join.table) _check_integrity(join.predicates, allowed_parents) _check_integrity(values.values(), allowed_parents) super().__init__(first=first, rest=rest, values=values) + @property + def length(self): + return len(self.rest) + 1 + @attribute def schema(self): return Schema({k: v.dtype.copy(nullable=True) for k, v in self.values.items()}) diff --git a/ibis/expr/tests/snapshots/test_format/test_asof_join/repr.txt b/ibis/expr/tests/snapshots/test_format/test_asof_join/repr.txt index e28f3c5bb0df..6c43f0adfc6b 100644 --- a/ibis/expr/tests/snapshots/test_format/test_asof_join/repr.txt +++ b/ibis/expr/tests/snapshots/test_format/test_asof_join/repr.txt @@ -6,21 +6,15 @@ r1 := UnboundTable: right time2 int32 value2 float64 -r2 := SelfReference[r0] - -r3 := SelfReference[r1] - -r4 := SelfReference[r1] - -JoinChain[r2] - JoinLink[asof, r3] - r2.time1 == r3.time2 - JoinLink[inner, r4] - r2.value == r4.value2 +JoinChain[r0] + JoinLink[asof, r1] + r0.time1 == r1.time2 + JoinLink[inner, r1] + r0.value == r1.value2 values: - time1: r2.time1 - value: r2.value - time2: r3.time2 - value2: r3.value2 - time2_right: r4.time2 - value2_right: r4.value2 \ No newline at end of file + time1: r0.time1 + value: r0.value + time2: r1.time2 + value2: r1.value2 + time2_right: r1.time2 + value2_right: r1.value2 \ No newline at end of file diff --git a/ibis/expr/tests/snapshots/test_format/test_format_multiple_join_with_projection/repr.txt b/ibis/expr/tests/snapshots/test_format/test_format_multiple_join_with_projection/repr.txt index 2e5cd4a00c70..8879597be115 100644 --- a/ibis/expr/tests/snapshots/test_format/test_format_multiple_join_with_projection/repr.txt +++ b/ibis/expr/tests/snapshots/test_format/test_format_multiple_join_with_projection/repr.txt @@ -12,24 +12,18 @@ r2 := UnboundTable: three bar_id string value2 float64 -r3 := SelfReference[r1] - -r4 := SelfReference[r2] - -r5 := Filter[r0] +r3 := Filter[r0] r0.f > 0 -r6 := SelfReference[r5] - -JoinChain[r6] - JoinLink[left, r3] - r6.foo_id == r3.foo_id - JoinLink[inner, r4] - r6.bar_id == r4.bar_id +JoinChain[r3] + JoinLink[left, r1] + r3.foo_id == r1.foo_id + JoinLink[inner, r2] + r3.bar_id == r2.bar_id values: - c: r6.c - f: r6.f - foo_id: r6.foo_id - bar_id: r6.bar_id - value1: r3.value1 - value2: r4.value2 \ No newline at end of file + c: r3.c + f: r3.f + foo_id: r3.foo_id + bar_id: r3.bar_id + value1: r1.value1 + value2: r2.value2 \ No newline at end of file diff --git a/ibis/expr/tests/snapshots/test_format/test_memoize_filtered_tables_in_join/repr.txt b/ibis/expr/tests/snapshots/test_format/test_memoize_filtered_tables_in_join/repr.txt index 758b88722b59..128ffd518dd6 100644 --- a/ibis/expr/tests/snapshots/test_format/test_memoize_filtered_tables_in_join/repr.txt +++ b/ibis/expr/tests/snapshots/test_format/test_memoize_filtered_tables_in_join/repr.txt @@ -17,15 +17,11 @@ r2 := Filter[r1] r3 := Filter[r1] r1.kind == 'bar' -r4 := SelfReference[r2] - -r5 := SelfReference[r3] - -JoinChain[r4] - JoinLink[inner, r5] - r4.region == r5.region +JoinChain[r2] + JoinLink[inner, r3] + r2.region == r3.region values: - region: r4.region - kind: r4.kind - total: r4.total - right_total: r5.total \ No newline at end of file + region: r2.region + kind: r2.kind + total: r2.total + right_total: r3.total \ No newline at end of file diff --git a/ibis/expr/tests/snapshots/test_format/test_table_count_expr/join_repr.txt b/ibis/expr/tests/snapshots/test_format/test_table_count_expr/join_repr.txt index 999c2664f114..6f7009dc8056 100644 --- a/ibis/expr/tests/snapshots/test_format/test_table_count_expr/join_repr.txt +++ b/ibis/expr/tests/snapshots/test_format/test_table_count_expr/join_repr.txt @@ -6,17 +6,13 @@ r1 := UnboundTable: t2 a int64 b float64 -r2 := SelfReference[r0] - -r3 := SelfReference[r1] - -r4 := JoinChain[r2] - JoinLink[inner, r3] - r2.a == r3.a +r2 := JoinChain[r0] + JoinLink[inner, r1] + r0.a == r1.a values: - a: r2.a - b: r2.b - a_right: r3.a - b_right: r3.b + a: r0.a + b: r0.b + a_right: r1.a + b_right: r1.b -CountStar(): CountStar(r4) \ No newline at end of file +CountStar(): CountStar(r2) \ No newline at end of file diff --git a/ibis/expr/tests/snapshots/test_format/test_two_inner_joins/repr.txt b/ibis/expr/tests/snapshots/test_format/test_two_inner_joins/repr.txt index 37d25bcc6b54..672faadf9ba2 100644 --- a/ibis/expr/tests/snapshots/test_format/test_two_inner_joins/repr.txt +++ b/ibis/expr/tests/snapshots/test_format/test_two_inner_joins/repr.txt @@ -8,24 +8,18 @@ r1 := UnboundTable: right value2 float64 b string -r2 := SelfReference[r0] - -r3 := SelfReference[r1] - -r4 := SelfReference[r1] - -JoinChain[r2] - JoinLink[inner, r3] - r2.a == r3.b - JoinLink[inner, r4] - r2.value == r4.value2 +JoinChain[r0] + JoinLink[inner, r1] + r0.a == r1.b + JoinLink[inner, r1] + r0.value == r1.value2 values: - time1: r2.time1 - value: r2.value - a: r2.a - time2: r3.time2 - value2: r3.value2 - b: r3.b - time2_right: r4.time2 - value2_right: r4.value2 - b_right: r4.b \ No newline at end of file + time1: r0.time1 + value: r0.value + a: r0.a + time2: r1.time2 + value2: r1.value2 + b: r1.b + time2_right: r1.time2 + value2_right: r1.value2 + b_right: r1.b \ No newline at end of file diff --git a/ibis/expr/tests/test_newrels.py b/ibis/expr/tests/test_newrels.py index d3821cecd73f..4735596a970d 100644 --- a/ibis/expr/tests/test_newrels.py +++ b/ibis/expr/tests/test_newrels.py @@ -1,7 +1,6 @@ from __future__ import annotations import contextlib -import itertools import pytest @@ -36,16 +35,8 @@ @contextlib.contextmanager -def self_references(*tables): - old_counter = ops.SelfReference._uid_counter - # set a new counter with 1000 to avoid colliding with manually created - # self-references using t.view() - new_counter = itertools.count(1000) - try: - ops.SelfReference._uid_counter = new_counter - yield tuple(ops.SelfReference(t).to_expr() for t in tables) - finally: - ops.SelfReference._uid_counter = old_counter +def join_tables(*tables): + yield tuple(ops.JoinTable(t, i).to_expr() for i, t in enumerate(tables)) def test_field(): @@ -486,9 +477,7 @@ def test_project_before_and_after_filter(): def test_join(): t1 = ibis.table(name="t1", schema={"a": "int64", "b": "string"}) t2 = ibis.table(name="t2", schema={"c": "int64", "d": "string"}) - - with self_references(): - joined = t1.join(t2, [t1.a == t2.c]) + joined = t1.join(t2, [t1.a == t2.c]) assert isinstance(joined, ir.JoinExpr) assert isinstance(joined.op(), JoinChain) @@ -499,7 +488,7 @@ def test_join(): assert isinstance(joined.op(), JoinChain) assert isinstance(joined.op().to_expr(), ir.JoinExpr) - with self_references(t1, t2) as (t1, t2): + with join_tables(t1, t2) as (t1, t2): assert result.op() == JoinChain( first=t1, rest=[ @@ -518,12 +507,12 @@ def test_join_unambiguous_select(): a = ibis.table(name="a", schema={"a_int": "int64", "a_str": "string"}) b = ibis.table(name="b", schema={"b_int": "int64", "b_str": "string"}) - with self_references(): - join = a.join(b, a.a_int == b.b_int) - expr1 = join["a_int", "b_int"] - expr2 = join.select("a_int", "b_int") - assert expr1.equals(expr2) - with self_references(a, b) as (r1, r2): + join = a.join(b, a.a_int == b.b_int) + expr1 = join["a_int", "b_int"] + expr2 = join.select("a_int", "b_int") + assert expr1.equals(expr2) + + with join_tables(a, b) as (r1, r2): assert expr1.op() == JoinChain( first=r1, rest=[JoinLink("inner", r2, [r1.a_int == r2.b_int])], @@ -539,10 +528,9 @@ def test_join_with_subsequent_projection(): t2 = ibis.table(name="t2", schema={"c": "int64", "d": "string"}) # a single computed value is pulled to a subsequent projection - with self_references(): - joined = t1.join(t2, [t1.a == t2.c]) - expr = joined.select(t1.a, t1.b, col=t2.c + 1) - with self_references(t1, t2) as (r1, r2): + joined = t1.join(t2, [t1.a == t2.c]) + expr = joined.select(t1.a, t1.b, col=t2.c + 1) + with join_tables(t1, t2) as (r1, r2): expected = JoinChain( first=r1, rest=[JoinLink("inner", r2, [r1.a == r2.c])], @@ -551,17 +539,16 @@ def test_join_with_subsequent_projection(): assert expr.op() == expected # multiple computed values - with self_references(): - joined = t1.join(t2, [t1.a == t2.c]) - expr = joined.select( - t1.a, - t1.b, - foo=t2.c + 1, - bar=t2.c + 2, - baz=t2.d.name("bar") + "3", - baz2=(t2.c + t1.a).name("foo"), - ) - with self_references(t1, t2) as (r1, r2): + joined = t1.join(t2, [t1.a == t2.c]) + expr = joined.select( + t1.a, + t1.b, + foo=t2.c + 1, + bar=t2.c + 2, + baz=t2.d.name("bar") + "3", + baz2=(t2.c + t1.a).name("foo"), + ) + with join_tables(t1, t2) as (r1, r2): expected = JoinChain( first=r1, rest=[JoinLink("inner", r2, [r1.a == r2.c])], @@ -583,15 +570,14 @@ def test_join_with_subsequent_projection_colliding_names(): name="t2", schema={"a": "int64", "b": "string", "c": "float", "d": "string"} ) - with self_references(): - joined = t1.join(t2, [t1.a == t2.a]) - expr = joined.select( - t1.a, - t1.b, - foo=t2.a + 1, - bar=t1.a + t2.a, - ) - with self_references(t1, t2) as (r1, r2): + joined = t1.join(t2, [t1.a == t2.a]) + expr = joined.select( + t1.a, + t1.b, + foo=t2.a + 1, + bar=t1.a + t2.a, + ) + with join_tables(t1, t2) as (r1, r2): expected = JoinChain( first=r1, rest=[JoinLink("inner", r2, [r1.a == r2.a])], @@ -609,12 +595,10 @@ def test_chained_join(): a = ibis.table(name="a", schema={"a": "int64", "b": "string"}) b = ibis.table(name="b", schema={"c": "int64", "d": "string"}) c = ibis.table(name="c", schema={"e": "int64", "f": "string"}) + joined = a.join(b, [a.a == b.c]).join(c, [a.a == c.e]) + result = joined._finish() - with self_references(): - joined = a.join(b, [a.a == b.c]).join(c, [a.a == c.e]) - result = joined._finish() - - with self_references(a, b, c) as (r1, r2, r3): + with join_tables(a, b, c) as (r1, r2, r3): assert result.op() == JoinChain( first=r1, rest=[ @@ -631,11 +615,10 @@ def test_chained_join(): }, ) - with self_references(): - joined = a.join(b, [a.a == b.c]).join(c, [b.c == c.e]) - result = joined.select(a.a, b.d, c.f) + joined = a.join(b, [a.a == b.c]).join(c, [b.c == c.e]) + result = joined.select(a.a, b.d, c.f) - with self_references(a, b, c) as (r1, r2, r3): + with join_tables(a, b, c) as (r1, r2, r3): assert result.op() == JoinChain( first=r1, rest=[ @@ -655,11 +638,11 @@ def test_chained_join_referencing_intermediate_table(): b = ibis.table(name="b", schema={"c": "int64", "d": "string"}) c = ibis.table(name="c", schema={"e": "int64", "f": "string"}) - with self_references(): - ab = a.join(b, [a.a == b.c]) - abc = ab.join(c, [ab.a == c.e]) - result = abc._finish() - with self_references(a, b, c) as (r1, r2, r3): + ab = a.join(b, [a.a == b.c]) + abc = ab.join(c, [ab.a == c.e]) + result = abc._finish() + + with join_tables(a, b, c) as (r1, r2, r3): assert result.op() == JoinChain( first=r1, rest=[ @@ -691,9 +674,8 @@ def test_join_predicate_dereferencing(): filtered = table[table["f"] > 0] # dereference table.foo_id to filtered.foo_id - with self_references(): - j1 = filtered.left_join(table2, table["foo_id"] == table2["foo_id"]) - with self_references(filtered, table2) as (r1, r2): + j1 = filtered.left_join(table2, table["foo_id"] == table2["foo_id"]) + with join_tables(filtered, table2) as (r1, r2): expected = ops.JoinChain( first=r1, rest=[ @@ -711,11 +693,10 @@ def test_join_predicate_dereferencing(): ) assert j1.op() == expected - with self_references(): - j1 = filtered.left_join(table2, table["foo_id"] == table2["foo_id"]) - j2 = j1.inner_join(table3, filtered["bar_id"] == table3["bar_id"]) - view = j2[[filtered, table2["value1"], table3["value2"]]] - with self_references(filtered, table2, table3) as (r1, r2, r3): + j1 = filtered.left_join(table2, table["foo_id"] == table2["foo_id"]) + j2 = j1.inner_join(table3, filtered["bar_id"] == table3["bar_id"]) + view = j2[[filtered, table2["value1"], table3["value2"]]] + with join_tables(filtered, table2, table3) as (r1, r2, r3): expected = ops.JoinChain( first=r1, rest=[ @@ -943,12 +924,10 @@ def test_self_join(): t0 = ibis.table(schema=ibis.schema(dict(key="int")), name="leaf") t1 = t0.filter(ibis.literal(True)) t2 = t1[["key"]] + t3 = t2.join(t2, ["key"]) + t4 = t3.join(t3, ["key"]) - with self_references(): - t3 = t2.join(t2, ["key"]) - t4 = t3.join(t3, ["key"]) - - with self_references(t2, t2, t3) as (r1, r2, r3): + with join_tables(t2, t2, t3) as (r1, r2, r3): expected = ops.JoinChain( first=r1, rest=[ @@ -976,12 +955,9 @@ def test_self_join(): def test_self_join_view(): t = ibis.memtable({"x": [1, 2], "y": [2, 1], "z": ["a", "b"]}) t_view = t.view() + expr = t.join(t_view, t.x == t_view.y).select("x", "y", "z", "z_right") - with self_references(): - expr = t.join(t_view, t.x == t_view.y).select("x", "y", "z", "z_right") - - with self_references(t) as (r1,): - r2 = t_view + with join_tables(t, t_view) as (r1, r2): expected = ops.JoinChain( first=r1, rest=[ @@ -995,11 +971,9 @@ def test_self_join_view(): def test_self_join_with_view_projection(): t1 = ibis.memtable({"x": [1, 2], "y": [2, 1], "z": ["a", "b"]}) t2 = t1.view() + expr = t1.inner_join(t2, ["x"])[[t1]] - with self_references(): - expr = t1.inner_join(t2, ["x"])[[t1]] - with self_references(t1) as (r1,): - r2 = t2 + with join_tables(t1, t2) as (r1, r2): expected = ops.JoinChain( first=r1, rest=[ @@ -1014,11 +988,10 @@ def test_joining_same_table_twice(): left = ibis.table(name="left", schema={"time1": int, "value": float, "a": str}) right = ibis.table(name="right", schema={"time2": int, "value2": float, "b": str}) - with self_references(): - joined = left.inner_join(right, left.a == right.b).inner_join( - right, left.value == right.value2 - ) - with self_references(left, right, right) as (r1, r2, r3): + joined = left.inner_join(right, left.a == right.b).inner_join( + right, left.value == right.value2 + ) + with join_tables(left, right, right) as (r1, r2, r3): expected = ops.JoinChain( first=r1, rest=[ @@ -1045,10 +1018,10 @@ def test_join_chain_gets_reused_and_continued_after_a_select(): b = ibis.table(name="b", schema={"c": "int64", "d": "string"}) c = ibis.table(name="c", schema={"e": "int64", "f": "string"}) - with self_references(): - ab = a.join(b, [a.a == b.c]) - abc = ab[a.b, b.d].join(c, [a.a == c.e]) - with self_references(a, b, c) as (r1, r2, r3): + ab = a.join(b, [a.a == b.c]) + abc = ab[a.b, b.d].join(c, [a.a == c.e]) + + with join_tables(a, b, c) as (r1, r2, r3): expected = ops.JoinChain( first=r1, rest=[ @@ -1068,13 +1041,10 @@ def test_join_chain_gets_reused_and_continued_after_a_select(): def test_self_join_extensive(): a = ibis.table(name="a", schema={"a": "int64", "b": "string"}) - with self_references(): - aa = a.join(a, [a.a == a.a]) - with self_references(): - aa1 = a.join(a, "a") - with self_references(): - aa2 = a.join(a, [("a", "a")]) - with self_references(a, a) as (r1, r2): + aa = a.join(a, [a.a == a.a]) + aa1 = a.join(a, "a") + aa2 = a.join(a, [("a", "a")]) + with join_tables(a, a) as (r1, r2): expected = ops.JoinChain( first=r1, rest=[ @@ -1091,18 +1061,11 @@ def test_self_join_extensive(): assert aa1.op() == expected assert aa2.op() == expected - with self_references(): - aaa = a.join(a, [a.a == a.a]).join(a, [a.a == a.a]) - with self_references(): - aa = a.join(a, [a.a == a.a]) - aaa1 = aa.join(a, [aa.a == a.a]) - with self_references(): - aa = a.join(a, [a.a == a.a]) - aaa2 = aa.join(a, "a") - with self_references(): - aa = a.join(a, [a.a == a.a]) - aaa3 = aa.join(a, [("a", "a")]) - with self_references(a, a, a) as (r1, r2, r3): + aaa = a.join(a, [a.a == a.a]).join(a, [a.a == a.a]) + aaa1 = aa.join(a, [aa.a == a.a]) + aaa2 = aa.join(a, "a") + aaa3 = aa.join(a, [("a", "a")]) + with join_tables(a, a, a) as (r1, r2, r3): expected = ops.JoinChain( first=r1, rest=[ @@ -1124,11 +1087,9 @@ def test_self_join_extensive(): def test_self_join_with_intermediate_selection(): a = ibis.table(name="a", schema={"a": "int64", "b": "string"}) - - with self_references(): - proj = a[["b", "a"]] - join = proj.join(a, [a.a == a.a]) - with self_references(proj, a) as (r1, r2): + proj = a[["b", "a"]] + join = proj.join(a, [a.a == a.a]) + with join_tables(proj, a) as (r1, r2): expected = ops.JoinChain( first=r1, rest=[ @@ -1143,10 +1104,9 @@ def test_self_join_with_intermediate_selection(): ) assert join.op() == expected - with self_references(): - aa = a.join(a, [a.a == a.a])["a", "b_right"] - aaa = aa.join(a, [aa.a == a.a]) - with self_references(a, a, a) as (r1, r2, r3): + aa = a.join(a, [a.a == a.a])["a", "b_right"] + aaa = aa.join(a, [aa.a == a.a]) + with join_tables(a, a, a) as (r1, r2, r3): expected = ops.JoinChain( first=r1, rest=[ @@ -1215,12 +1175,10 @@ def test_self_view_join_followed_by_aggregate_correctly_dereference_fields(): agged = t.aggregate([t.f.sum().name("total")], by=["g", "a", "b"]) view = agged.view() metrics = [(agged.total - view.total).max().name("metric")] + join = agged.inner_join(view, [agged.a == view.b]) + agg = join.aggregate(metrics, by=[agged.g]) - with self_references(): - join = agged.inner_join(view, [agged.a == view.b]) - agg = join.aggregate(metrics, by=[agged.g]) - with self_references(agged) as (r1,): - r2 = view + with join_tables(agged, view) as (r1, r2): expected_join = ops.JoinChain( first=r1, rest=[ @@ -1248,3 +1206,12 @@ def test_self_view_join_followed_by_aggregate_correctly_dereference_fields(): ).to_expr() assert join.equals(expected_join) assert agg.equals(expected_agg) + + +def test_join_expressions_are_equal(): + t1 = ibis.table(name="t1", schema={"a": "int64", "b": "int64"}) + t2 = ibis.table(name="t2", schema={"a": "int64", "b": "int64"}) + + join1 = t1.inner_join(t2, [t1.a == t2.a]) + join2 = t1.inner_join(t2, [t1.a == t2.a]) + assert join1.equals(join2) diff --git a/ibis/expr/types/joins.py b/ibis/expr/types/joins.py index 7568f05cb236..e66b97732cd0 100644 --- a/ibis/expr/types/joins.py +++ b/ibis/expr/types/joins.py @@ -46,34 +46,28 @@ def disambiguate_fields(how, left_fields, right_fields, lname, rname): return fields, collisions -def dereference_targets(chain): - yield chain.first - for join in chain.rest: - if join.how not in ("semi", "anti"): - yield join.table - - def dereference_mapping_left(chain): - rels = dereference_targets(chain) + # construct the list of join table we wish to dereference fields to + rels = [chain.first] + for link in chain.rest: + if link.how not in ("semi", "anti"): + rels.append(link.table) + + # create the dereference mapping suitable to disambiguate field references + # from earlier in the relation hierarchy to one of the join tables subs = dereference_mapping(rels) - # join chain fields => link table fields + + # also allow to dereference fields of the join chain itself for k, v in chain.values.items(): subs[ops.Field(chain, k)] = v + return subs def dereference_mapping_right(right): - if isinstance(right, ops.SelfReference): - # no support for dereferencing, the user must use the right table - # directly in the predicates - return {}, right - - # wrap the right table in a self reference to ensure its uniqueness in the - # join chain which requires dereferencing the predicates from - # right => SelfReference(right) - right = ops.SelfReference(right) - subs = {v: ops.Field(right, k) for k, v in right.values.items()} - return subs, right + # the right table is wrapped in a JoinTable the uniqueness of the underlying + # table which requires the predicates to be dereferenced to the wrapped + return {v: ops.Field(right, k) for k, v in right.values.items()} def dereference_sides(left, right, deref_left, deref_right): @@ -175,9 +169,9 @@ def join( how = "semi" left = self.op() - right = right.op() + right = ops.JoinTable(right, index=left.length) subs_left = dereference_mapping_left(left) - subs_right, right = dereference_mapping_right(right) + subs_right = dereference_mapping_right(right) # bind and dereference the predicates preds = prepare_predicates( diff --git a/ibis/expr/types/relations.py b/ibis/expr/types/relations.py index 63a0be061112..128e53cb4dff 100644 --- a/ibis/expr/types/relations.py +++ b/ibis/expr/types/relations.py @@ -2982,13 +2982,10 @@ def join( # `ir.Table(ops.JoinChain())` expression, which we can reuse here expr = left.to_expr() else: - # all participants of the join must be wrapped in SelfReferences so - # that we can join the same table with itself multiple times and to - # enable optimization passes later on - if not isinstance(left, ops.SelfReference): - left = ops.SelfReference(left) - # construct an empty join chain and wrap it with a JoinExpr, the - # projected fields are the fields of the starting table + # all participants of the join must be wrapped in JoinTable nodes + # so that we can join the same table with itself multiple times and + # to enable optimization passes later on + left = ops.JoinTable(left, index=0) expr = ops.JoinChain(left, rest=(), values=left.fields).to_expr() return expr.join(right, predicates, how=how, lname=lname, rname=rname) diff --git a/ibis/tests/expr/snapshots/test_format_sql_operations/test_memoize_database_table/repr.txt b/ibis/tests/expr/snapshots/test_format_sql_operations/test_memoize_database_table/repr.txt index b4761984f092..afa7b6662e2f 100644 --- a/ibis/tests/expr/snapshots/test_format_sql_operations/test_memoize_database_table/repr.txt +++ b/ibis/tests/expr/snapshots/test_format_sql_operations/test_memoize_database_table/repr.txt @@ -7,27 +7,23 @@ r1 := DatabaseTable: test1 f float64 g string -r2 := SelfReference[r0] - -r3 := Filter[r1] +r2 := Filter[r1] r1.f > 0 -r4 := SelfReference[r3] - -r5 := JoinChain[r2] - JoinLink[inner, r4] - r4.g == r2.key +r3 := JoinChain[r0] + JoinLink[inner, r2] + r2.g == r0.key values: - key: r2.key - value: r2.value - c: r4.c - f: r4.f - g: r4.g + key: r0.key + value: r0.value + c: r2.c + f: r2.f + g: r2.g -Aggregate[r5] +Aggregate[r3] groups: - g: r5.g - key: r5.key + g: r3.g + key: r3.key metrics: - foo: Mean(r5.f - r5.value) - bar: Sum(r5.f) \ No newline at end of file + foo: Mean(r3.f - r3.value) + bar: Sum(r3.f) \ No newline at end of file diff --git a/ibis/tests/expr/test_analysis.py b/ibis/tests/expr/test_analysis.py index 7db0ae06c67a..aeacf249edf0 100644 --- a/ibis/tests/expr/test_analysis.py +++ b/ibis/tests/expr/test_analysis.py @@ -6,7 +6,7 @@ import ibis.common.exceptions as com import ibis.expr.operations as ops from ibis.expr.rewrites import simplify -from ibis.expr.tests.test_newrels import self_references +from ibis.expr.tests.test_newrels import join_tables # Place to collect esoteric expression analysis bugs and tests @@ -22,12 +22,12 @@ def test_rewrite_join_projection_without_other_ops(con): pred1 = table["foo_id"] == table2["foo_id"] pred2 = filtered["bar_id"] == table3["bar_id"] - with self_references(): - j1 = filtered.left_join(table2, [pred1]) - j2 = j1.inner_join(table3, [pred2]) - # Project out the desired fields - view = j2[[filtered, table2["value1"], table3["value2"]]] - with self_references(filtered, table2, table3) as (r1, r2, r3): + j1 = filtered.left_join(table2, [pred1]) + j2 = j1.inner_join(table3, [pred2]) + # Project out the desired fields + view = j2[[filtered, table2["value1"], table3["value2"]]] + + with join_tables(filtered, table2, table3) as (r1, r2, r3): # Construct the thing we expect to obtain expected = ops.JoinChain( first=r1, @@ -165,13 +165,12 @@ def test_filter_self_join(): ) cond = left.region == right.region - with self_references(): - joined = left.join(right, cond) - metric = (left.total - right.total).name("diff") - what = [left.region, metric] - projected = joined.select(what) + joined = left.join(right, cond) + metric = (left.total - right.total).name("diff") + what = [left.region, metric] + projected = joined.select(what) - with self_references(left, right) as (r1, r2): + with join_tables(left, right) as (r1, r2): join = ops.JoinChain( first=r1, rest=[ diff --git a/ibis/tests/expr/test_struct.py b/ibis/tests/expr/test_struct.py index d7e5b05ddd5a..75707c650f2c 100644 --- a/ibis/tests/expr/test_struct.py +++ b/ibis/tests/expr/test_struct.py @@ -8,7 +8,7 @@ import ibis.expr.operations as ops import ibis.expr.types as ir from ibis import _ -from ibis.expr.tests.test_newrels import self_references +from ibis.expr.tests.test_newrels import join_tables from ibis.tests.util import assert_pickle_roundtrip @@ -70,10 +70,10 @@ def test_unpack_from_table(t): def test_lift_join(t, s): - with self_references(): - join = t.join(s, t.d == s.a.g) - result = join.a_right.lift() - with self_references(t, s) as (r1, r2): + join = t.join(s, t.d == s.a.g) + result = join.a_right.lift() + + with join_tables(t, s) as (r1, r2): join = ops.JoinChain( first=r1, rest=[ diff --git a/ibis/tests/expr/test_table.py b/ibis/tests/expr/test_table.py index 9f26fb41d205..b2ee00a2ec27 100644 --- a/ibis/tests/expr/test_table.py +++ b/ibis/tests/expr/test_table.py @@ -22,7 +22,7 @@ from ibis.common.exceptions import ExpressionError, IntegrityError, RelationError from ibis.expr import api from ibis.expr.rewrites import simplify -from ibis.expr.tests.test_newrels import self_references +from ibis.expr.tests.test_newrels import join_tables from ibis.expr.types import Column, Table from ibis.tests.util import assert_equal, assert_pickle_roundtrip @@ -847,10 +847,10 @@ def test_join_no_predicate_list(con): region = con.table("tpch_region") nation = con.table("tpch_nation") - with self_references(): - pred = region.r_regionkey == nation.n_regionkey - joined = region.inner_join(nation, pred) - with self_references(region, nation) as (r1, r2): + pred = region.r_regionkey == nation.n_regionkey + joined = region.inner_join(nation, pred) + + with join_tables(region, nation) as (r1, r2): expected = ops.JoinChain( first=r1, rest=[ops.JoinLink("inner", r2, [r1.r_regionkey == r2.n_regionkey])], @@ -871,9 +871,9 @@ def test_join_deferred(con): region = con.table("tpch_region") nation = con.table("tpch_nation") - with self_references(): - res = region.join(nation, _.r_regionkey == nation.n_regionkey) - with self_references(region, nation) as (r1, r2): + res = region.join(nation, _.r_regionkey == nation.n_regionkey) + + with join_tables(region, nation) as (r1, r2): expected = ops.JoinChain( first=r1, rest=[ops.JoinLink("inner", r2, [r1.r_regionkey == r2.n_regionkey])], @@ -918,9 +918,8 @@ def test_asof_join_with_by(): left = ibis.table([("time", "int32"), ("key", "int32"), ("value", "double")]) right = ibis.table([("time", "int32"), ("key", "int32"), ("value2", "double")]) - with self_references(): - join_without_by = api.asof_join(left, right, "time") - with self_references(left, right) as (r1, r2): + join_without_by = api.asof_join(left, right, "time") + with join_tables(left, right) as (r1, r2): r2 = join_without_by.op().rest[0].table.to_expr() expected = ops.JoinChain( first=r1, @@ -936,9 +935,8 @@ def test_asof_join_with_by(): ) assert join_without_by.op() == expected - with self_references(): - join_with_by = api.asof_join(left, right, "time", by="key") - with self_references(left, right, right) as (r1, r2, r3): + join_with_by = api.asof_join(left, right, "time", by="key") + with join_tables(left, right, right) as (r1, r2, r3): expected = ops.JoinChain( first=r1, rest=[ @@ -981,9 +979,8 @@ def test_asof_join_with_tolerance(ibis_interval, timedelta_interval): right = ibis.table([("time", "int32"), ("key", "int32"), ("value2", "double")]) for interval in [ibis_interval, timedelta_interval]: - with self_references(): - joined = api.asof_join(left, right, "time", tolerance=interval) - with self_references(left, right) as (r1, r2): + joined = api.asof_join(left, right, "time", tolerance=interval) + with join_tables(left, right) as (r1, r2): expected = ops.JoinChain( first=r1, rest=[ @@ -1161,9 +1158,8 @@ def test_cross_join_multiple(table): b = table["d", "e"] c = table["f", "h"] - with self_references(): - joined = ibis.cross_join(a, b, c) - with self_references(a, b, c) as (r1, r2, r3): + joined = ibis.cross_join(a, b, c) + with join_tables(a, b, c) as (r1, r2, r3): expected = ops.JoinChain( first=r1, rest=[ @@ -1251,9 +1247,8 @@ def test_join_key_alternatives(con, key_maker): t2 = con.table("star2") key = key_maker(t1, t2) - with self_references(): - joined = t1.inner_join(t2, key) - with self_references(t1, t2) as (r1, r2): + joined = t1.inner_join(t2, key) + with join_tables(t1, t2) as (r1, r2): expected = ops.JoinChain( first=r1, rest=[ @@ -1337,9 +1332,8 @@ def test_unravel_compound_equijoin(table): p2 = t1.key2 == t2.key2 p3 = t1.key3 == t2.key3 - with self_references(): - joined = t1.inner_join(t2, [p1 & p2 & p3]) - with self_references(t1, t2) as (r1, r2): + joined = t1.inner_join(t2, [p1 & p2 & p3]) + with join_tables(t1, t2) as (r1, r2): expected = ops.JoinChain( first=r1, rest=[