diff --git a/ibis/expr/tests/test_format.py b/ibis/expr/tests/test_format.py index d96a908afe3fd..d2b4e067b788a 100644 --- a/ibis/expr/tests/test_format.py +++ b/ibis/expr/tests/test_format.py @@ -62,7 +62,6 @@ def test_table_type_output(snapshot): expr = foo.dept_id == foo.view().dept_id result = fmt(expr) - assert "SelfReference[r0]" in result assert "UnboundTable: foo" in result snapshot.assert_match(result, "repr.txt") diff --git a/ibis/expr/tests/test_newrels.py b/ibis/expr/tests/test_newrels.py index 7b32e71fa5276..22141f14bd81e 100644 --- a/ibis/expr/tests/test_newrels.py +++ b/ibis/expr/tests/test_newrels.py @@ -960,12 +960,13 @@ def test_self_join_view(): t_view = t.view() expr = t.join(t_view, t.x == t_view.y).select("x", "y", "z", "z_right") + t_view_ = expr.op().rest[0].table.to_expr() expected = ops.JoinChain( first=t, rest=[ - ops.JoinLink("inner", t_view, [t.x == t_view.y]), + ops.JoinLink("inner", t_view_, [t.x == t_view_.y]), ], - values={"x": t.x, "y": t.y, "z": t.z, "z_right": t_view.z}, + values={"x": t.x, "y": t.y, "z": t.z, "z_right": t_view_.z}, ) assert expr.op() == expected @@ -975,11 +976,43 @@ def test_self_join_with_view_projection(): t2 = t1.view() expr = t1.inner_join(t2, ["x"])[[t1]] + t2_ = expr.op().rest[0].table.to_expr() expected = ops.JoinChain( first=t1, rest=[ - ops.JoinLink("inner", t2, [t1.x == t2.x]), + ops.JoinLink("inner", t2_, [t1.x == t2_.x]), ], values={"x": t1.x, "y": t1.y, "z": t1.z}, ) assert expr.op() == expected + + +def test_joining_same_table_twice(): + left = ibis.table(name="left", schema={"time1": int, "value": float, "a": str}) + right = ibis.table(name="right", schema={"time2": int, "value2": float, "b": str}) + + joined = left.inner_join(right, left.a == right.b).inner_join( + right, left.value == right.value2 + ) + + right_ = joined.op().rest[0].table.to_expr() + right__ = joined.op().rest[1].table.to_expr() + expected = ops.JoinChain( + first=left, + rest=[ + ops.JoinLink("inner", right_, [left.a == right_.b]), + ops.JoinLink("inner", right__, [left.value == right__.value2]), + ], + values={ + "time1": left.time1, + "value": left.value, + "a": left.a, + "time2": right_.time2, + "value2": right_.value2, + "b": right_.b, + "time2_right": right__.time2, + "value2_right": right__.value2, + "b_right": right__.b, + }, + ) + assert joined.op() == expected diff --git a/ibis/expr/types/joins.py b/ibis/expr/types/joins.py index 30b7e8471577b..41dc92540ec26 100644 --- a/ibis/expr/types/joins.py +++ b/ibis/expr/types/joins.py @@ -2,7 +2,6 @@ bind, dereference_values, unwrap_aliases, - dereference_mapping, ) from public import public import ibis.expr.operations as ops @@ -40,7 +39,7 @@ def disambiguate_fields(how, left_fields, right_fields, lname, rname): return fields -def prepare_predicates(left, right, predicates, deref_left, deref_right, deref_all): +def prepare_predicates(left, right, predicates, deref_left, deref_right, deref_both): """Bind predicates to the left and right tables.""" for pred in util.promote_list(predicates): @@ -48,11 +47,11 @@ def prepare_predicates(left, right, predicates, deref_left, deref_right, deref_a yield ops.Literal(pred, dtype="bool") elif isinstance(pred, ValueExpr): node = pred.op() - yield node.replace(deref_all, filter=ops.Value) + yield node.replace(deref_both, filter=ops.Value) elif isinstance(pred, Deferred): # resolve deferred expressions on the left table node = pred.resolve(left).op() - yield node.replace(deref_all, filter=ops.Value) + yield node.replace(deref_both, filter=ops.Value) else: if isinstance(pred, tuple): if len(pred) != 2: @@ -98,15 +97,18 @@ def join( left = self.op() right = right.op() if isinstance(right, ops.SelfReference): - tables = list(self._relations()) + deref_right = {} else: right = ops.SelfReference(right) - tables = list(self._relations()) + [right] + deref_right = {v: ops.Field(right, k) for k, v in right.fields.items()} # construct the various dereference mappings deref_left = {ops.Field(left, k): v for k, v in left.fields.items()} - deref_right = {v: ops.Field(right, k) for k, v in right.fields.items()} - deref_all = dereference_mapping(tables, extra=deref_left) + deref_both = deref_left.copy() + for table in self._relations(): + deref_both.update({v: ops.Field(table, k) for k, v in table.fields.items()}) + # finally add the right dereferences to take precedence + deref_both.update(deref_right) # bind and dereference the predicates preds = prepare_predicates( @@ -115,7 +117,7 @@ def join( predicates, deref_left=deref_left, deref_right=deref_right, - deref_all=deref_all, + deref_both=deref_both, ) preds = flatten_predicates(list(preds)) @@ -133,14 +135,19 @@ def join( def select(self, *args, **kwargs): """Select expressions.""" + chain = self.op() values = bind(self, (args, kwargs)) values = unwrap_aliases(values) # if there are values referencing fields from the join chain constructed # so far, we need to replace them the fields from one of the join links - tables = list(self._relations()) - extra = {ops.Field(self, k): v for k, v in self.op().fields.items()} - values = dereference_values(tables, values, extra=extra, include_top=True) + subs = {ops.Field(chain, k): v for k, v in chain.fields.items()} + tables = set(self._relations()) + for table in self._relations(): + for k, v in table.fields.items(): + if isinstance(v, ops.Field) and v.rel not in tables: + subs[v] = ops.Field(table, k) + values = {k: v.replace(subs, filter=ops.Value) for k, v in values.items()} return self.finish(values) diff --git a/ibis/expr/types/relations.py b/ibis/expr/types/relations.py index 5dccb9b639286..46330594e257c 100644 --- a/ibis/expr/types/relations.py +++ b/ibis/expr/types/relations.py @@ -152,27 +152,18 @@ def unwrap_aliases(values: Iterator[ir.Value]) -> Mapping[str, ir.Value]: return result -def dereference_mapping(parents, extra=None, include_top=False): - """Generate substitution mapping.""" - mapping = extra.copy() if extra is not None else {} +def dereference_values(parents, values): + """Trace and replace fields from earlier relations in the hierarchy.""" + mapping = {} for parent in util.promote_list(parents): - if include_top: - for k in parent.schema: - mapping[ops.Field(parent, k)] = ops.Field(parent, k) for k, v in parent.fields.items(): if isinstance(v, ops.Field): - while isinstance(v, ops.Field) and v not in mapping: + while isinstance(v, ops.Field): mapping[v] = ops.Field(parent, k) v = v.rel.fields.get(v.name) - elif v.relations and v not in mapping: + elif v.relations: # do not dereference literal expressions mapping[v] = ops.Field(parent, k) - return mapping - - -def dereference_values(parents, values, extra=None, include_top=False): - """Trace and replace fields from earlier relations in the hierarchy.""" - mapping = dereference_mapping(parents, extra=extra, include_top=include_top) return {k: v.replace(mapping, filter=ops.Value) for k, v in values.items()} @@ -923,10 +914,10 @@ def view(self) -> Table: Table Table expression """ - node = self.op() - if isinstance(node, ops.SelfReference): - node = node.parent - return ops.SelfReference(node).to_expr() + if isinstance(self.op(), ops.SelfReference): + return self + else: + return ops.SelfReference(self).to_expr() def difference(self, table: Table, *rest: Table, distinct: bool = True) -> Table: """Compute the set difference of multiple table expressions.