From 57d163e11a30c7364e3daaa013cbf3c2d2331a09 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Sat, 23 Dec 2023 02:27:58 +0100 Subject: [PATCH] refactor(ir): support join of joins while avoiding nesting --- ibis/expr/rewrites.py | 5 ++++ ibis/expr/tests/test_newrels.py | 41 +++++++++++++++++++++++++++++++++ ibis/expr/types/joins.py | 5 ++++ 3 files changed, 51 insertions(+) diff --git a/ibis/expr/rewrites.py b/ibis/expr/rewrites.py index 74e2294ec3db..12a314379d9a 100644 --- a/ibis/expr/rewrites.py +++ b/ibis/expr/rewrites.py @@ -17,6 +17,11 @@ name = var("name") +@replace(p.Field(p.JoinChain)) +def peel_join_field(_): + return _.rel.values[_.name] + + @replace(ops.Analytic) def project_wrap_analytic(_, rel): # Wrap analytic functions in a window function diff --git a/ibis/expr/tests/test_newrels.py b/ibis/expr/tests/test_newrels.py index 4735596a970d..906284c1a32c 100644 --- a/ibis/expr/tests/test_newrels.py +++ b/ibis/expr/tests/test_newrels.py @@ -1215,3 +1215,44 @@ def test_join_expressions_are_equal(): join1 = t1.inner_join(t2, [t1.a == t2.a]) join2 = t1.inner_join(t2, [t1.a == t2.a]) assert join1.equals(join2) + + +def test_join_between_joins(): + t1 = ibis.table( + [("key1", "string"), ("key2", "string"), ("value1", "double")], + "first", + ) + t2 = ibis.table([("key1", "string"), ("value2", "double")], "second") + t3 = ibis.table( + [("key2", "string"), ("key3", "string"), ("value3", "double")], + "third", + ) + t4 = ibis.table([("key3", "string"), ("value4", "double")], "fourth") + + left = t1.inner_join(t2, [("key1", "key1")])[t1, t2.value2] + right = t3.inner_join(t4, [("key3", "key3")])[t3, t4.value4] + + joined = left.inner_join(right, left.key2 == right.key2) + + # At one point, the expression simplification was resulting in bad refs + # here (right.value3 referencing the table inside the right join) + exprs = [left, right.value3, right.value4] + expr = joined.select(exprs) + + with join_tables(t1, t2, right) as (r1, r2, r3): + expected = ops.JoinChain( + first=r1, + rest=[ + ops.JoinLink("inner", r2, [r1.key1 == r2.key1]), + ops.JoinLink("inner", r3, [r1.key2 == r3.key2]), + ], + values={ + "key1": r1.key1, + "key2": r1.key2, + "value1": r1.value1, + "value2": r2.value2, + "value3": r3.value3, + "value4": r3.value4, + }, + ) + assert expr.op() == expected diff --git a/ibis/expr/types/joins.py b/ibis/expr/types/joins.py index e66b97732cd0..4afaca288980 100644 --- a/ibis/expr/types/joins.py +++ b/ibis/expr/types/joins.py @@ -17,6 +17,8 @@ from ibis.expr.types.relations import dereference_mapping import ibis +from ibis.expr.rewrites import peel_join_field + def disambiguate_fields(how, left_fields, right_fields, lname, rname): collisions = set() @@ -207,6 +209,9 @@ def select(self, *args, **kwargs): # if there are values referencing fields from the join chain constructed # so far, we need to replace them the fields from one of the join links subs = dereference_mapping_left(chain) + values = { + k: v.replace(peel_join_field, filter=ops.Value) for k, v in values.items() + } values = {k: v.replace(subs, filter=ops.Value) for k, v in values.items()} node = chain.copy(values=values)