Skip to content

Commit

Permalink
refactor(ir): support join of joins while avoiding nesting
Browse files Browse the repository at this point in the history
  • Loading branch information
kszucs committed Dec 23, 2023
1 parent 836b89a commit 57d163e
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 0 deletions.
5 changes: 5 additions & 0 deletions ibis/expr/rewrites.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@
name = var("name")


@replace(p.Field(p.JoinChain))
def peel_join_field(_):
return _.rel.values[_.name]


@replace(ops.Analytic)
def project_wrap_analytic(_, rel):
# Wrap analytic functions in a window function
Expand Down
41 changes: 41 additions & 0 deletions ibis/expr/tests/test_newrels.py
Original file line number Diff line number Diff line change
Expand Up @@ -1215,3 +1215,44 @@ def test_join_expressions_are_equal():
join1 = t1.inner_join(t2, [t1.a == t2.a])
join2 = t1.inner_join(t2, [t1.a == t2.a])
assert join1.equals(join2)


def test_join_between_joins():
t1 = ibis.table(
[("key1", "string"), ("key2", "string"), ("value1", "double")],
"first",
)
t2 = ibis.table([("key1", "string"), ("value2", "double")], "second")
t3 = ibis.table(
[("key2", "string"), ("key3", "string"), ("value3", "double")],
"third",
)
t4 = ibis.table([("key3", "string"), ("value4", "double")], "fourth")

left = t1.inner_join(t2, [("key1", "key1")])[t1, t2.value2]
right = t3.inner_join(t4, [("key3", "key3")])[t3, t4.value4]

joined = left.inner_join(right, left.key2 == right.key2)

# At one point, the expression simplification was resulting in bad refs
# here (right.value3 referencing the table inside the right join)
exprs = [left, right.value3, right.value4]
expr = joined.select(exprs)

with join_tables(t1, t2, right) as (r1, r2, r3):
expected = ops.JoinChain(
first=r1,
rest=[
ops.JoinLink("inner", r2, [r1.key1 == r2.key1]),
ops.JoinLink("inner", r3, [r1.key2 == r3.key2]),
],
values={
"key1": r1.key1,
"key2": r1.key2,
"value1": r1.value1,
"value2": r2.value2,
"value3": r3.value3,
"value4": r3.value4,
},
)
assert expr.op() == expected
5 changes: 5 additions & 0 deletions ibis/expr/types/joins.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from ibis.expr.types.relations import dereference_mapping
import ibis

from ibis.expr.rewrites import peel_join_field


def disambiguate_fields(how, left_fields, right_fields, lname, rname):
collisions = set()
Expand Down Expand Up @@ -207,6 +209,9 @@ def select(self, *args, **kwargs):
# if there are values referencing fields from the join chain constructed
# so far, we need to replace them the fields from one of the join links
subs = dereference_mapping_left(chain)
values = {
k: v.replace(peel_join_field, filter=ops.Value) for k, v in values.items()
}
values = {k: v.replace(subs, filter=ops.Value) for k, v in values.items()}

node = chain.copy(values=values)
Expand Down

0 comments on commit 57d163e

Please sign in to comment.