Skip to content

Commit

Permalink
fix(ir): resolve the mind-bending self-join-dereferencing problem
Browse files Browse the repository at this point in the history
  • Loading branch information
kszucs committed Dec 12, 2023
1 parent 1c22862 commit ef708f7
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 34 deletions.
1 change: 0 additions & 1 deletion ibis/expr/tests/test_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ def test_table_type_output(snapshot):

expr = foo.dept_id == foo.view().dept_id
result = fmt(expr)
assert "SelfReference[r0]" in result
assert "UnboundTable: foo" in result
snapshot.assert_match(result, "repr.txt")

Expand Down
39 changes: 36 additions & 3 deletions ibis/expr/tests/test_newrels.py
Original file line number Diff line number Diff line change
Expand Up @@ -960,12 +960,13 @@ def test_self_join_view():
t_view = t.view()
expr = t.join(t_view, t.x == t_view.y).select("x", "y", "z", "z_right")

t_view_ = expr.op().rest[0].table.to_expr()
expected = ops.JoinChain(
first=t,
rest=[
ops.JoinLink("inner", t_view, [t.x == t_view.y]),
ops.JoinLink("inner", t_view_, [t.x == t_view_.y]),
],
values={"x": t.x, "y": t.y, "z": t.z, "z_right": t_view.z},
values={"x": t.x, "y": t.y, "z": t.z, "z_right": t_view_.z},
)
assert expr.op() == expected

Expand All @@ -975,11 +976,43 @@ def test_self_join_with_view_projection():
t2 = t1.view()
expr = t1.inner_join(t2, ["x"])[[t1]]

t2_ = expr.op().rest[0].table.to_expr()
expected = ops.JoinChain(
first=t1,
rest=[
ops.JoinLink("inner", t2, [t1.x == t2.x]),
ops.JoinLink("inner", t2_, [t1.x == t2_.x]),
],
values={"x": t1.x, "y": t1.y, "z": t1.z},
)
assert expr.op() == expected


def test_joining_same_table_twice():
left = ibis.table(name="left", schema={"time1": int, "value": float, "a": str})
right = ibis.table(name="right", schema={"time2": int, "value2": float, "b": str})

joined = left.inner_join(right, left.a == right.b).inner_join(
right, left.value == right.value2
)

right_ = joined.op().rest[0].table.to_expr()
right__ = joined.op().rest[1].table.to_expr()
expected = ops.JoinChain(
first=left,
rest=[
ops.JoinLink("inner", right_, [left.a == right_.b]),
ops.JoinLink("inner", right__, [left.value == right__.value2]),
],
values={
"time1": left.time1,
"value": left.value,
"a": left.a,
"time2": right_.time2,
"value2": right_.value2,
"b": right_.b,
"time2_right": right__.time2,
"value2_right": right__.value2,
"b_right": right__.b,
},
)
assert joined.op() == expected
31 changes: 19 additions & 12 deletions ibis/expr/types/joins.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
bind,
dereference_values,
unwrap_aliases,
dereference_mapping,
)
from public import public
import ibis.expr.operations as ops
Expand Down Expand Up @@ -40,19 +39,19 @@ def disambiguate_fields(how, left_fields, right_fields, lname, rname):
return fields


def prepare_predicates(left, right, predicates, deref_left, deref_right, deref_all):
def prepare_predicates(left, right, predicates, deref_left, deref_right, deref_both):
"""Bind predicates to the left and right tables."""

for pred in util.promote_list(predicates):
if pred is True or pred is False:
yield ops.Literal(pred, dtype="bool")
elif isinstance(pred, ValueExpr):
node = pred.op()
yield node.replace(deref_all, filter=ops.Value)
yield node.replace(deref_both, filter=ops.Value)
elif isinstance(pred, Deferred):
# resolve deferred expressions on the left table
node = pred.resolve(left).op()
yield node.replace(deref_all, filter=ops.Value)
yield node.replace(deref_both, filter=ops.Value)
else:
if isinstance(pred, tuple):
if len(pred) != 2:
Expand Down Expand Up @@ -98,15 +97,18 @@ def join(
left = self.op()
right = right.op()
if isinstance(right, ops.SelfReference):
tables = list(self._relations())
deref_right = {}
else:
right = ops.SelfReference(right)
tables = list(self._relations()) + [right]
deref_right = {v: ops.Field(right, k) for k, v in right.fields.items()}

# construct the various dereference mappings
deref_left = {ops.Field(left, k): v for k, v in left.fields.items()}
deref_right = {v: ops.Field(right, k) for k, v in right.fields.items()}
deref_all = dereference_mapping(tables, extra=deref_left)
deref_both = deref_left.copy()
for table in self._relations():
deref_both.update({v: ops.Field(table, k) for k, v in table.fields.items()})
# finally add the right dereferences to take precedence
deref_both.update(deref_right)

# bind and dereference the predicates
preds = prepare_predicates(
Expand All @@ -115,7 +117,7 @@ def join(
predicates,
deref_left=deref_left,
deref_right=deref_right,
deref_all=deref_all,
deref_both=deref_both,
)
preds = flatten_predicates(list(preds))

Expand All @@ -133,14 +135,19 @@ def join(

def select(self, *args, **kwargs):
"""Select expressions."""
chain = self.op()
values = bind(self, (args, kwargs))
values = unwrap_aliases(values)

# if there are values referencing fields from the join chain constructed
# so far, we need to replace them the fields from one of the join links
tables = list(self._relations())
extra = {ops.Field(self, k): v for k, v in self.op().fields.items()}
values = dereference_values(tables, values, extra=extra, include_top=True)
subs = {ops.Field(chain, k): v for k, v in chain.fields.items()}
tables = set(self._relations())
for table in self._relations():
for k, v in table.fields.items():
if isinstance(v, ops.Field) and v.rel not in tables:
subs[v] = ops.Field(table, k)
values = {k: v.replace(subs, filter=ops.Value) for k, v in values.items()}

return self.finish(values)

Expand Down
27 changes: 9 additions & 18 deletions ibis/expr/types/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,27 +152,18 @@ def unwrap_aliases(values: Iterator[ir.Value]) -> Mapping[str, ir.Value]:
return result


def dereference_mapping(parents, extra=None, include_top=False):
"""Generate substitution mapping."""
mapping = extra.copy() if extra is not None else {}
def dereference_values(parents, values):
"""Trace and replace fields from earlier relations in the hierarchy."""
mapping = {}
for parent in util.promote_list(parents):
if include_top:
for k in parent.schema:
mapping[ops.Field(parent, k)] = ops.Field(parent, k)
for k, v in parent.fields.items():
if isinstance(v, ops.Field):
while isinstance(v, ops.Field) and v not in mapping:
while isinstance(v, ops.Field):
mapping[v] = ops.Field(parent, k)
v = v.rel.fields.get(v.name)
elif v.relations and v not in mapping:
elif v.relations:
# do not dereference literal expressions
mapping[v] = ops.Field(parent, k)
return mapping


def dereference_values(parents, values, extra=None, include_top=False):
"""Trace and replace fields from earlier relations in the hierarchy."""
mapping = dereference_mapping(parents, extra=extra, include_top=include_top)
return {k: v.replace(mapping, filter=ops.Value) for k, v in values.items()}


Expand Down Expand Up @@ -923,10 +914,10 @@ def view(self) -> Table:
Table
Table expression
"""
node = self.op()
if isinstance(node, ops.SelfReference):
node = node.parent
return ops.SelfReference(node).to_expr()
if isinstance(self.op(), ops.SelfReference):
return self
else:
return ops.SelfReference(self).to_expr()

def difference(self, table: Table, *rest: Table, distinct: bool = True) -> Table:
"""Compute the set difference of multiple table expressions.
Expand Down

0 comments on commit ef708f7

Please sign in to comment.