From 1c22862929ec5da65156f18bc037307cf843e8f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Tue, 12 Dec 2023 14:46:51 +0100 Subject: [PATCH] e --- ibis/expr/tests/test_newrels.py | 15 +++++++++++++++ ibis/expr/types/joins.py | 2 +- ibis/expr/types/relations.py | 20 ++++++++++++++------ 3 files changed, 30 insertions(+), 7 deletions(-) diff --git a/ibis/expr/tests/test_newrels.py b/ibis/expr/tests/test_newrels.py index 2f427796bdfb5..7b32e71fa5276 100644 --- a/ibis/expr/tests/test_newrels.py +++ b/ibis/expr/tests/test_newrels.py @@ -968,3 +968,18 @@ def test_self_join_view(): values={"x": t.x, "y": t.y, "z": t.z, "z_right": t_view.z}, ) assert expr.op() == expected + + +def test_self_join_with_view_projection(): + t1 = ibis.memtable({"x": [1, 2], "y": [2, 1], "z": ["a", "b"]}) + t2 = t1.view() + expr = t1.inner_join(t2, ["x"])[[t1]] + + expected = ops.JoinChain( + first=t1, + rest=[ + ops.JoinLink("inner", t2, [t1.x == t2.x]), + ], + values={"x": t1.x, "y": t1.y, "z": t1.z}, + ) + assert expr.op() == expected diff --git a/ibis/expr/types/joins.py b/ibis/expr/types/joins.py index 53171ac3eb216..30b7e8471577b 100644 --- a/ibis/expr/types/joins.py +++ b/ibis/expr/types/joins.py @@ -140,7 +140,7 @@ def select(self, *args, **kwargs): # so far, we need to replace them the fields from one of the join links tables = list(self._relations()) extra = {ops.Field(self, k): v for k, v in self.op().fields.items()} - values = dereference_values(tables, values, extra=extra) + values = dereference_values(tables, values, extra=extra, include_top=True) return self.finish(values) diff --git a/ibis/expr/types/relations.py b/ibis/expr/types/relations.py index 5827fd9efebb2..5dccb9b639286 100644 --- a/ibis/expr/types/relations.py +++ b/ibis/expr/types/relations.py @@ -152,24 +152,27 @@ def unwrap_aliases(values: Iterator[ir.Value]) -> Mapping[str, ir.Value]: return result -def dereference_mapping(parents, extra=None): +def dereference_mapping(parents, extra=None, include_top=False): """Generate substitution mapping.""" mapping = extra.copy() if extra is not None else {} for parent in util.promote_list(parents): + if include_top: + for k in parent.schema: + mapping[ops.Field(parent, k)] = ops.Field(parent, k) for k, v in parent.fields.items(): if isinstance(v, ops.Field): - while isinstance(v, ops.Field): + while isinstance(v, ops.Field) and v not in mapping: mapping[v] = ops.Field(parent, k) v = v.rel.fields.get(v.name) - elif v.relations: + elif v.relations and v not in mapping: # do not dereference literal expressions mapping[v] = ops.Field(parent, k) return mapping -def dereference_values(parents, values, extra=None): +def dereference_values(parents, values, extra=None, include_top=False): """Trace and replace fields from earlier relations in the hierarchy.""" - mapping = dereference_mapping(parents, extra=extra) + mapping = dereference_mapping(parents, extra=extra, include_top=include_top) return {k: v.replace(mapping, filter=ops.Value) for k, v in values.items()} @@ -2960,8 +2963,13 @@ def join( """ from ibis.expr.types.joins import JoinExpr + # the first participant of the join shouldn't be a self reference + left = left.op() + if isinstance(left, ops.SelfReference): + left = left.parent + # construct an empty join chain and wrap it with a JoinExpr - values = {k: ops.Field(left, k) for k in left.schema().names} + values = {k: ops.Field(left, k) for k in left.schema} node = ops.JoinChain(left, rest=(), values=values) # add the first join link to the join chain and return the result if how == "left_semi":