Skip to content

Commit

Permalink
fir(ir): asof join tolerance parameter should post-filter and pos…
Browse files Browse the repository at this point in the history
…t-join instead of adding a predicate
  • Loading branch information
kszucs authored and cpcloud committed Feb 4, 2024
1 parent 06659b3 commit a064cfb
Show file tree
Hide file tree
Showing 6 changed files with 130 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ r1 := UnboundTable: right

JoinChain[r0]
JoinLink[asof, r1]
r0.time1 <= r1.time2
r0.time1 >= r1.time2
JoinLink[inner, r1]
r0.value == r1.value2
values:
Expand Down
23 changes: 22 additions & 1 deletion ibis/expr/tests/test_dereference.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,30 @@ def dereference_expect(expected):
return {k.op(): v.op() for k, v in expected.items()}


def test_dereference_project():
p = t.projection([t.int_col, t.double_col])

mapping = dereference_mapping([p.op()])
expected = dereference_expect(
{
p.int_col: p.int_col,
p.double_col: p.double_col,
t.int_col: p.int_col,
t.double_col: p.double_col,
}
)
assert mapping == expected


def test_dereference_mapping_self_reference():
v = t.view()

mapping = dereference_mapping([v.op()])
expected = dereference_expect({})
expected = dereference_expect(
{
v.int_col: v.int_col,
v.double_col: v.double_col,
v.string_col: v.string_col,
}
)
assert mapping == expected
21 changes: 21 additions & 0 deletions ibis/expr/tests/test_newrels.py
Original file line number Diff line number Diff line change
Expand Up @@ -1266,6 +1266,27 @@ def test_join_between_joins():
assert expr.op() == expected


def test_join_with_filtered_join_of_left():
t1 = ibis.table(name="t1", schema={"a": "int64", "b": "string"})
t2 = ibis.table(name="t2", schema={"a": "int64", "b": "string"})

joined = t1.left_join(t2, [t1.a == t2.a]).filter(t1.a < 5)
expr = t1.left_join(joined, [t1.a == joined.a]).select(t1)

with join_tables(t1, joined) as (r1, r2):
expected = ops.JoinChain(
first=r1,
rest=[
ops.JoinLink("left", r2, [r1.a == r2.a]),
],
values={
"a": r1.a,
"b": r1.b,
},
)
assert expr.op() == expected


def test_join_method_docstrings():
t1 = ibis.table(name="t1", schema={"a": "int64", "b": "string"})
t2 = ibis.table(name="t2", schema={"c": "int64", "d": "string"})
Expand Down
99 changes: 59 additions & 40 deletions ibis/expr/types/joins.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import ibis.expr.operations as ops

from ibis import util
from ibis.expr.types import Table, ValueExpr
from ibis.expr.types import Table, Value
from ibis.common.deferred import Deferred
from ibis.expr.analysis import flatten_predicates
from ibis.common.exceptions import ExpressionError, IntegrityError
Expand Down Expand Up @@ -88,7 +88,7 @@ def dereference_binop(pred, deref_left, deref_right):

def dereference_value(pred, deref_left, deref_right):
deref_both = {**deref_left, **deref_right}
if isinstance(pred, ops.Binary) and pred.left == pred.right:
if isinstance(pred, ops.Binary) and pred.left.relations == pred.right.relations:
return dereference_binop(pred, deref_left, deref_right)
else:
return pred.replace(deref_both, filter=ops.Value)
Expand All @@ -103,13 +103,14 @@ def prepare_predicates(
for pred in util.promote_list(predicates):
if pred is True or pred is False:
yield ops.Literal(pred, dtype="bool")
elif isinstance(pred, ValueExpr):
node = pred.op()
yield dereference_value(node, deref_left, deref_right)
elif isinstance(pred, Value):
for node in flatten_predicates(pred.op()):
yield dereference_value(node, deref_left, deref_right)
elif isinstance(pred, Deferred):
# resolve deferred expressions on the left table
node = pred.resolve(left).op()
yield dereference_value(node, deref_left, deref_right)
pred = pred.resolve(left).op()
for node in flatten_predicates(pred):
yield dereference_value(node, deref_left, deref_right)
else:
if isinstance(pred, tuple):
if len(pred) != 2:
Expand Down Expand Up @@ -193,14 +194,15 @@ def join( # noqa: D102
subs_right = dereference_mapping_right(right)

# bind and dereference the predicates
preds = prepare_predicates(
left,
right,
predicates,
deref_left=subs_left,
deref_right=subs_right,
preds = list(
prepare_predicates(
left,
right,
predicates,
deref_left=subs_left,
deref_right=subs_right,
)
)
preds = flatten_predicates(list(preds))
if not preds and how != "cross":
# if there are no predicates, default to every row matching unless
# the join is a cross join, because a cross join already has this
Expand Down Expand Up @@ -236,41 +238,58 @@ def asof_join( # noqa: D102
):
predicates = util.promote_list(predicates) + util.promote_list(by)
if tolerance is not None:
if not isinstance(on, str):
raise TypeError(
"tolerance can only be specified when predicates is a string"
)
# construct a predicate with two sides from the two tables
predicates.append(self[on] <= right[on] + tolerance)
if isinstance(on, str):
# self is always a JoinChain so reference one of the join tables
left_on = self.op().values[on].to_expr()
right_on = right[on]
on = left_on >= right_on
elif isinstance(on, Value):
node = on.op()
if not isinstance(node, ops.Binary):
raise InputTypeError("`on` must be a comparison expression")
left_on = node.left.to_expr()
right_on = node.right.to_expr()
else:
raise TypeError("`on` must be a string or a ValueExpr")

joined = self.asof_join(
right, on=on, predicates=predicates, lname=lname, rname=rname
)
filtered = joined.filter(
left_on <= right_on + tolerance, left_on >= right_on - tolerance
)
right_on = right_on.op().replace({right.op(): filtered.op()}).to_expr()

result = self.left_join(
filtered, predicates=[left_on == right_on] + predicates
)
values = {**self.op().values, **filtered.op().values}
return result.select(values)

left = self.op()
right = ops.JoinTable(right, index=left.length)
subs_left = dereference_mapping_left(left)
subs_right = dereference_mapping_right(right)

# TODO(kszucs): add extra validation for `on` with clear error messages
preds = list(
prepare_predicates(
left,
right,
[on],
deref_left=subs_left,
deref_right=subs_right,
comparison=ops.LessEqual,
)
(on,) = prepare_predicates(
left,
right,
[on],
deref_left=subs_left,
deref_right=subs_right,
comparison=ops.GreaterEqual,
)
preds += flatten_predicates(
list(
prepare_predicates(
left,
right,
predicates,
deref_left=subs_left,
deref_right=subs_right,
comparison=ops.Equals,
)
)
predicates = prepare_predicates(
left,
right,
predicates,
deref_left=subs_left,
deref_right=subs_right,
comparison=ops.Equals,
)
preds = [on, *predicates]

values, collisions = disambiguate_fields(
"asof", left.values, right.fields, lname, rname
)
Expand Down
9 changes: 8 additions & 1 deletion ibis/expr/types/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,14 @@ def unwrap_aliases(values: Iterator[ir.Value]) -> Mapping[str, ir.Value]:


def dereference_mapping(parents):
mapping = {}
parents = util.promote_list(parents)
mapping = {}

for parent in parents:
# do not defereference fields referencing the requested parents
for k, v in parent.fields.items():
mapping[v] = v

for parent in parents:
for k, v in parent.values.items():
if isinstance(v, ops.Field):
Expand All @@ -171,6 +177,7 @@ def dereference_mapping(parents):
elif v.relations and v not in mapping:
# do not dereference literal expressions
mapping[v] = ops.Field(parent, k)

return mapping


Expand Down
41 changes: 19 additions & 22 deletions ibis/tests/expr/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -923,7 +923,7 @@ def test_asof_join_with_by():
r2 = join_without_by.op().rest[0].table.to_expr()
expected = ops.JoinChain(
first=r1,
rest=[ops.JoinLink("asof", r2, [r1.time <= r2.time])],
rest=[ops.JoinLink("asof", r2, [r1.time >= r2.time])],
values={
"time": r1.time,
"key": r1.key,
Expand All @@ -940,7 +940,7 @@ def test_asof_join_with_by():
expected = ops.JoinChain(
first=r1,
rest=[
ops.JoinLink("asof", r2, [r1.time <= r2.time, r1.key == r2.key]),
ops.JoinLink("asof", r2, [r1.time >= r2.time, r1.key == r2.key]),
],
values={
"time": r1.time,
Expand Down Expand Up @@ -978,26 +978,23 @@ def test_asof_join_with_tolerance(ibis_interval, timedelta_interval):

for interval in [ibis_interval, timedelta_interval]:
joined = api.asof_join(left, right, "time", tolerance=interval)
with join_tables(left, right) as (r1, r2):
expected = ops.JoinChain(
first=r1,
rest=[
ops.JoinLink(
"asof",
r2,
[r1.time <= r2.time, r1.time <= (r2.time + interval)],
)
],
values={
"time": r1.time,
"key": r1.key,
"value": r1.value,
"time_right": r2.time,
"key_right": r2.key,
"value2": r2.value2,
},
)
assert joined.op() == expected

asof = left.asof_join(right, "time")
filt = asof.filter(
[
asof.time <= asof.time_right + interval,
asof.time >= asof.time_right - interval,
]
)
join = left.left_join(filt, [left.time == filt.time])
expected = join.select(
left,
time_right=filt.time_right,
key_right=filt.key_right,
value2=filt.value2,
)

assert joined.equals(expected)


def test_equijoin_schema_merge():
Expand Down

0 comments on commit a064cfb

Please sign in to comment.