Skip to content

Commit

Permalink
feat(duckdb): support asof joins including tolerance parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
kszucs committed Feb 2, 2024
1 parent 6c7c9ee commit ae04bf4
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 5 deletions.
4 changes: 2 additions & 2 deletions ibis/backends/base/sqlglot/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,14 +868,14 @@ def visit_JoinLink(self, op, *, how, table, predicates):
"anti": "left",
"cross": None,
"outer": "full",
"asof": "left",
"asof": "asof",
"any_left": "left",
"any_inner": None,
}
kinds = {
"any_left": "any",
"any_inner": "any",
"asof": "asof",
"asof": "left",
"inner": "inner",
"left": "outer",
"right": "outer",
Expand Down
147 changes: 147 additions & 0 deletions ibis/backends/tests/test_asof_join.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
from __future__ import annotations

import operator

import pandas as pd
import pandas.testing as tm
import pytest

import ibis


@pytest.fixture(scope="module")
def time_df1():
return pd.DataFrame(
{
"time": pd.to_datetime([1, 2, 3, 4], unit="s"),
"value": [1.1, 2.2, 3.3, 4.4],
"group": ["a", "a", "a", "a"],
}
)


@pytest.fixture(scope="module")
def time_df2():
return pd.DataFrame(
{
"time": pd.to_datetime([2, 4], unit="s"),
"other_value": [1.2, 2.0],
"group": ["a", "a"],
}
)


@pytest.fixture(scope="module")
def time_keyed_df1():
return pd.DataFrame(
{
"time": pd.Series(
pd.date_range(start="2017-01-02 01:02:03.234", periods=6)
),
"key": [1, 2, 3, 1, 2, 3],
"value": [1.2, 1.4, 2.0, 4.0, 8.0, 16.0],
}
)


@pytest.fixture(scope="module")
def time_keyed_df2():
return pd.DataFrame(
{
"time": pd.Series(
pd.date_range(start="2017-01-02 01:02:03.234", freq="3D", periods=3)
),
"key": [1, 2, 3],
"other_value": [1.1, 1.2, 2.2],
}
)


@pytest.fixture(scope="module")
def time_left(time_df1):
return ibis.memtable(time_df1)


@pytest.fixture(scope="module")
def time_right(time_df2):
return ibis.memtable(time_df2)


@pytest.fixture(scope="module")
def time_keyed_left(time_keyed_df1):
return ibis.memtable(time_keyed_df1)


@pytest.fixture(scope="module")
def time_keyed_right(time_keyed_df2):
return ibis.memtable(time_keyed_df2)


@pytest.mark.parametrize(
("direction", "op"),
[
("backward", operator.ge),
("forward", operator.le),
],
)
@pytest.mark.notimpl(["datafusion", "snowflake"])
def test_asof_join(con, time_left, time_right, time_df1, time_df2, direction, op):
on = op(time_left["time"], time_right["time"])
expr = time_left.asof_join(time_right, on=on, predicates="group")

result = con.execute(expr)
expected = pd.merge_asof(
time_df1, time_df2, on="time", by="group", direction=direction
)

result = result.sort_values(["group", "time"]).reset_index(drop=True)
expected = expected.sort_values(["group", "time"]).reset_index(drop=True)

tm.assert_frame_equal(result[expected.columns], expected)
with pytest.raises(AssertionError):
tm.assert_series_equal(result["time"], result["time_right"])


@pytest.mark.parametrize(
("direction", "op"),
[
("backward", operator.ge),
("forward", operator.le),
],
)
@pytest.mark.broken(
["clickhouse"], raises=AssertionError, reason="`time` is truncated to seconds"
)
@pytest.mark.notimpl(["datafusion", "snowflake"])
def test_keyed_asof_join_with_tolerance(
con,
time_keyed_left,
time_keyed_right,
time_keyed_df1,
time_keyed_df2,
direction,
op,
):
on = op(time_keyed_left["time"], time_keyed_right["time"])
expr = time_keyed_left.asof_join(
time_keyed_right, on=on, by="key", tolerance=ibis.interval(days=2)
)

result = con.execute(expr)
expected = pd.merge_asof(
time_keyed_df1,
time_keyed_df2,
on="time",
by="key",
tolerance=pd.Timedelta("2D"),
direction=direction,
)

result = result.sort_values(["key", "time"]).reset_index(drop=True)
expected = expected.sort_values(["key", "time"]).reset_index(drop=True)

tm.assert_frame_equal(result[expected.columns], expected)
with pytest.raises(AssertionError):
tm.assert_series_equal(result["time"], result["time_right"])
with pytest.raises(AssertionError):
tm.assert_series_equal(result["key"], result["key_right"])
4 changes: 2 additions & 2 deletions ibis/backends/tests/test_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,9 +289,9 @@ def test_join_with_pandas_non_null_typed_columns(batting, awards_players):
reason="polars doesn't support join predicates",
)
@pytest.mark.notimpl(
["dask", "pandas"],
["dask"],
raises=TypeError,
reason="dask and pandas don't support join predicates",
reason="dask doesn't support join predicates",
)
@pytest.mark.notimpl(
["exasol"],
Expand Down
2 changes: 1 addition & 1 deletion ibis/expr/tests/test_dereference.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def dereference_expect(expected):


def test_dereference_project():
p = t.projection([t.int_col, t.double_col])
p = t.select([t.int_col, t.double_col])

mapping = dereference_mapping([p.op()])
expected = dereference_expect(
Expand Down
12 changes: 12 additions & 0 deletions ibis/expr/types/joins.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,15 @@ def asof_join( # noqa: D102
):
predicates = util.promote_list(predicates) + util.promote_list(by)
if tolerance is not None:
# `tolerance` parameter is mimicking the pandas API, but we express
# it at the expression level by a sequence of operations:
# 1. perform the `asof` join with the `on` an `predicates` parameters
# where the `on` parameter is an inequality predicate
# 2. filter the asof join result using the `tolerance` parameter and
# the `on` parameter
# 3. perform a left join between the original left table and the
# filtered asof join result using the `on` parameter but this
# time as an equality predicate
if isinstance(on, str):
# self is always a JoinChain so reference one of the join tables
left_on = self.op().values[on].to_expr()
Expand All @@ -260,6 +269,9 @@ def asof_join( # noqa: D102
)
right_on = right_on.op().replace({right.op(): filtered.op()}).to_expr()

# without joining twice the table would not contain the rows from
# the left table that do not match any row from the right table
# given the tolerance filter
result = self.left_join(
filtered, predicates=[left_on == right_on] + predicates
)
Expand Down

0 comments on commit ae04bf4

Please sign in to comment.