Skip to content

Commit

Permalink
refactor(ir): add JoinTable operation unique to JoinChain instead…
Browse files Browse the repository at this point in the history
… of using the globally unique `SelfReference`

This enables us to maintain join expression equality:
`a.join(b).equals(a.join(b))`

So far we have been using SelfReference to make join tables unique, but
it was globally unique which broke the equality check above. Therefore
we need to restrict the uniqueness to the scope of the join chain. The
simplest solution for that is to simply enumerate the join tables in
the join chain, hence now all join participants must be
`ops.JoinTable(rel, index)` instances.

`ops.SelfReference` is still required to distinguish between two
identical tables at the API level, but it is now decoupled from the
join internal representation.
  • Loading branch information
kszucs committed Dec 22, 2023
1 parent 011700d commit 593ec6a
Show file tree
Hide file tree
Showing 15 changed files with 238 additions and 296 deletions.
7 changes: 6 additions & 1 deletion ibis/expr/decompile.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,11 @@ def self_reference(op, parent, identifier):
return parent


@translate.register(ops.JoinTable)
def join_table(op, parent, index):
return parent


@translate.register(ops.JoinLink)
def join_link(op, table, predicates, how):
return f".{how}_join({table}, {_try_unwrap(predicates)})"
Expand Down Expand Up @@ -327,7 +332,7 @@ def isin(op, value, options):
class CodeContext:
always_assign = (ops.ScalarParameter, ops.UnboundTable, ops.Aggregate)
always_ignore = (
ops.SelfReference,
ops.JoinTable,
ops.Field,
dt.Primitive,
dt.Variadic,
Expand Down
7 changes: 6 additions & 1 deletion ibis/expr/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def pretty(node):

def mapper(op, _, **kwargs):
result = fmt(op, **kwargs)
if isinstance(op, ops.Relation):
if isinstance(op, ops.Relation) and not isinstance(op, ops.JoinTable):
tables[op] = result
result = f"r{next(refcnt)}"
return Rendered(result)
Expand Down Expand Up @@ -337,6 +337,11 @@ def _self_reference(op, parent, **kwargs):
return f"{op.__class__.__name__}[{parent}]"


@fmt.register(ops.JoinTable)
def _join_table(op, parent, index):
return parent


@fmt.register(ops.Literal)
def _literal(op, value, **kwargs):
if op.dtype.is_interval():
Expand Down
15 changes: 13 additions & 2 deletions ibis/expr/operations/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,27 +229,38 @@ def name(self) -> str:
]


@public
class JoinTable(Simple):
index: int


@public
class JoinLink(Node):
how: JoinKind
table: SelfReference
table: JoinTable
predicates: VarTuple[Value[dt.Boolean]]


@public
class JoinChain(Relation):
first: SelfReference
first: JoinTable
rest: VarTuple[JoinLink]
values: FrozenDict[str, Unaliased[Value]]

def __init__(self, first, rest, values):
allowed_parents = {first}
assert first.index == 0
for join in rest:
assert join.table.index == len(allowed_parents)
allowed_parents.add(join.table)
_check_integrity(join.predicates, allowed_parents)
_check_integrity(values.values(), allowed_parents)
super().__init__(first=first, rest=rest, values=values)

@property
def length(self):
return len(self.rest) + 1

@attribute
def schema(self):
return Schema({k: v.dtype.copy(nullable=True) for k, v in self.values.items()})
Expand Down
28 changes: 11 additions & 17 deletions ibis/expr/tests/snapshots/test_format/test_asof_join/repr.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,15 @@ r1 := UnboundTable: right
time2 int32
value2 float64

r2 := SelfReference[r0]

r3 := SelfReference[r1]

r4 := SelfReference[r1]

JoinChain[r2]
JoinLink[asof, r3]
r2.time1 == r3.time2
JoinLink[inner, r4]
r2.value == r4.value2
JoinChain[r0]
JoinLink[asof, r1]
r0.time1 == r1.time2
JoinLink[inner, r1]
r0.value == r1.value2
values:
time1: r2.time1
value: r2.value
time2: r3.time2
value2: r3.value2
time2_right: r4.time2
value2_right: r4.value2
time1: r0.time1
value: r0.value
time2: r1.time2
value2: r1.value2
time2_right: r1.time2
value2_right: r1.value2
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,18 @@ r2 := UnboundTable: three
bar_id string
value2 float64

r3 := SelfReference[r1]

r4 := SelfReference[r2]

r5 := Filter[r0]
r3 := Filter[r0]
r0.f > 0

r6 := SelfReference[r5]

JoinChain[r6]
JoinLink[left, r3]
r6.foo_id == r3.foo_id
JoinLink[inner, r4]
r6.bar_id == r4.bar_id
JoinChain[r3]
JoinLink[left, r1]
r3.foo_id == r1.foo_id
JoinLink[inner, r2]
r3.bar_id == r2.bar_id
values:
c: r6.c
f: r6.f
foo_id: r6.foo_id
bar_id: r6.bar_id
value1: r3.value1
value2: r4.value2
c: r3.c
f: r3.f
foo_id: r3.foo_id
bar_id: r3.bar_id
value1: r1.value1
value2: r2.value2
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,11 @@ r2 := Filter[r1]
r3 := Filter[r1]
r1.kind == 'bar'

r4 := SelfReference[r2]

r5 := SelfReference[r3]

JoinChain[r4]
JoinLink[inner, r5]
r4.region == r5.region
JoinChain[r2]
JoinLink[inner, r3]
r2.region == r3.region
values:
region: r4.region
kind: r4.kind
total: r4.total
right_total: r5.total
region: r2.region
kind: r2.kind
total: r2.total
right_total: r3.total
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,13 @@ r1 := UnboundTable: t2
a int64
b float64

r2 := SelfReference[r0]

r3 := SelfReference[r1]

r4 := JoinChain[r2]
JoinLink[inner, r3]
r2.a == r3.a
r2 := JoinChain[r0]
JoinLink[inner, r1]
r0.a == r1.a
values:
a: r2.a
b: r2.b
a_right: r3.a
b_right: r3.b
a: r0.a
b: r0.b
a_right: r1.a
b_right: r1.b

CountStar(): CountStar(r4)
CountStar(): CountStar(r2)
34 changes: 14 additions & 20 deletions ibis/expr/tests/snapshots/test_format/test_two_inner_joins/repr.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,18 @@ r1 := UnboundTable: right
value2 float64
b string

r2 := SelfReference[r0]

r3 := SelfReference[r1]

r4 := SelfReference[r1]

JoinChain[r2]
JoinLink[inner, r3]
r2.a == r3.b
JoinLink[inner, r4]
r2.value == r4.value2
JoinChain[r0]
JoinLink[inner, r1]
r0.a == r1.b
JoinLink[inner, r1]
r0.value == r1.value2
values:
time1: r2.time1
value: r2.value
a: r2.a
time2: r3.time2
value2: r3.value2
b: r3.b
time2_right: r4.time2
value2_right: r4.value2
b_right: r4.b
time1: r0.time1
value: r0.value
a: r0.a
time2: r1.time2
value2: r1.value2
b: r1.b
time2_right: r1.time2
value2_right: r1.value2
b_right: r1.b
Loading

0 comments on commit 593ec6a

Please sign in to comment.