Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ir): more flexible dereferencing for join right-hand side #8916

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ INNER JOIN (
GROUP BY
1
) AS `t10`
ON `t8`.`d` = `t10`.`d`
ON `t10`.`d` = `t10`.`d`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't look correct to me. Should be t8 as it was before.

) AS `t11`
WHERE
`t11`.`row_count` < (
Expand Down
40 changes: 40 additions & 0 deletions ibis/expr/tests/test_newrels.py
Original file line number Diff line number Diff line change
Expand Up @@ -864,6 +864,46 @@ def test_join_predicate_dereferencing_using_tuple_syntax():
assert j2.op() == expected


def test_join_rhs_dereferencing():
t1 = ibis.table(name="t1", schema={"a": "int64", "b": "string"})
t2 = ibis.table(name="t2", schema={"c": "int64", "d": "string"})

t3 = t2.mutate(e=t2.c + 1)
joined = t1.join(t3, [t1.a == t2.c])
with join_tables(t1, t3) as (r1, r2):
expected = JoinChain(
first=r1,
rest=[
JoinLink("inner", r2, [r1.a == r2.c]),
],
values={
"a": r1.a,
"b": r1.b,
"c": r2.c,
"d": r2.d,
"e": r2.e,
},
)
assert joined.op() == expected

joined = t1.join(t3, [t1.a == (t2.c + 1)])
with join_tables(t1, t3) as (r1, r2):
expected = JoinChain(
first=r1,
rest=[
JoinLink("inner", r2, [r1.a == (r2.c + 1)]),
],
values={
"a": r1.a,
"b": r1.b,
"c": r2.c,
"d": r2.d,
"e": r2.e,
},
)
assert joined.op() == expected


def test_aggregate():
agg = t.aggregate(by=[t.bool_col], metrics=[t.int_col.sum()])
expected = Aggregate(
Expand Down
59 changes: 27 additions & 32 deletions ibis/expr/types/joins.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,28 +112,23 @@ def disambiguate_fields(
return fields, collisions, equalities


def dereference_mapping_left(chain):
# construct the list of join table we wish to dereference fields to
rels = [chain.first]
for link in chain.rest:
if link.how not in ("semi", "anti"):
rels.append(link.table)

# create the dereference mapping suitable to disambiguate field references
# from earlier in the relation hierarchy to one of the join tables
subs = dereference_mapping(rels)

# also allow to dereference fields of the join chain itself
for k, v in chain.values.items():
subs[ops.Field(chain, k)] = v

return subs

def dereference_mapping_join(parents):
targets = []
reverse = {}
for parent in util.promote_list(parents):
if isinstance(parent, ops.JoinChain):
targets.append(parent.first)
for link in parent.rest:
if link.how not in ("semi", "anti"):
targets.append(link.table)
for k, v in parent.values.items():
reverse[ops.Field(parent, k)] = v
else:
targets.append(parent)

def dereference_mapping_right(right):
# the right table is wrapped in a JoinTable the uniqueness of the underlying
# table which requires the predicates to be dereferenced to the wrapped
return {v: ops.Field(right, k) for k, v in right.values.items()}
mapping = dereference_mapping(targets)
mapping.update(reverse)
return mapping


def dereference_sides(left, right, deref_left, deref_right):
Expand All @@ -142,8 +137,7 @@ def dereference_sides(left, right, deref_left, deref_right):
return left, right


def dereference_value(pred, deref_left, deref_right):
deref_both = {**deref_left, **deref_right}
def dereference_value(pred, deref_left, deref_right, deref_both):
if isinstance(pred, ops.Comparison) and pred.left.relations == pred.right.relations:
left, right = dereference_sides(pred.left, pred.right, deref_left, deref_right)
return pred.copy(left=left, right=right)
Expand All @@ -152,7 +146,7 @@ def dereference_value(pred, deref_left, deref_right):


def prepare_predicates(
left: ops.JoinChain,
chain: ops.JoinChain,
right: ops.Relation,
predicates: Sequence[Any],
comparison: type[ops.Comparison] = ops.Equals,
Expand Down Expand Up @@ -187,8 +181,8 @@ def prepare_predicates(

Parameters
----------
left
The left table
chain
The join chain
right
The right table
predicates
Expand All @@ -197,21 +191,22 @@ def prepare_predicates(
The comparison operation to construct if the input is a pair of
expression-like objects
"""
deref_left = dereference_mapping_left(left)
deref_right = dereference_mapping_right(right)
deref_left = dereference_mapping_join([chain])
deref_right = dereference_mapping_join([right])
deref_both = dereference_mapping_join([right, chain])

left, right = left.to_expr(), right.to_expr()
left, right = chain.to_expr(), right.to_expr()
for pred in util.promote_list(predicates):
if pred is True or pred is False:
yield ops.Literal(pred, dtype="bool")
elif isinstance(pred, Value):
for node in flatten_predicates(pred.op()):
yield dereference_value(node, deref_left, deref_right)
yield dereference_value(node, deref_left, deref_right, deref_both)
elif isinstance(pred, Deferred):
# resolve deferred expressions on the left table
pred = pred.resolve(left).op()
for node in flatten_predicates(pred):
yield dereference_value(node, deref_left, deref_right)
yield dereference_value(node, deref_left, deref_right, deref_both)
else:
if isinstance(pred, tuple):
if len(pred) != 2:
Expand Down Expand Up @@ -430,7 +425,7 @@ def select(self, *args, **kwargs):

# if there are values referencing fields from the join chain constructed
# so far, we need to replace them the fields from one of the join links
subs = dereference_mapping_left(chain)
subs = dereference_mapping_join([chain])
values = {
k: v.replace(peel_join_field, filter=ops.Value) for k, v in values.items()
}
Expand Down
Loading