From 34618e453c8a8e50ef95aa6bece61b7d2209f9f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Wed, 3 Jan 2024 15:50:18 +0100 Subject: [PATCH] feat(duckdb): support `asof` joins including `tolerance` parameter --- ibis/backends/base/sqlglot/compiler.py | 4 +- ibis/backends/tests/test_asof_join.py | 147 +++++++++++++++++++++++++ ibis/backends/tests/test_join.py | 4 +- ibis/expr/tests/test_dereference.py | 2 +- ibis/expr/types/joins.py | 12 ++ 5 files changed, 164 insertions(+), 5 deletions(-) create mode 100644 ibis/backends/tests/test_asof_join.py diff --git a/ibis/backends/base/sqlglot/compiler.py b/ibis/backends/base/sqlglot/compiler.py index b8a9a704cf68..16d94775fe3b 100644 --- a/ibis/backends/base/sqlglot/compiler.py +++ b/ibis/backends/base/sqlglot/compiler.py @@ -868,14 +868,14 @@ def visit_JoinLink(self, op, *, how, table, predicates): "anti": "left", "cross": None, "outer": "full", - "asof": "left", + "asof": "asof", "any_left": "left", "any_inner": None, } kinds = { "any_left": "any", "any_inner": "any", - "asof": "asof", + "asof": "left", "inner": "inner", "left": "outer", "right": "outer", diff --git a/ibis/backends/tests/test_asof_join.py b/ibis/backends/tests/test_asof_join.py new file mode 100644 index 000000000000..ad7678665678 --- /dev/null +++ b/ibis/backends/tests/test_asof_join.py @@ -0,0 +1,147 @@ +from __future__ import annotations + +import operator + +import pandas as pd +import pandas.testing as tm +import pytest + +import ibis + + +@pytest.fixture(scope="module") +def time_df1(): + return pd.DataFrame( + { + "time": pd.to_datetime([1, 2, 3, 4], unit="s"), + "value": [1.1, 2.2, 3.3, 4.4], + "group": ["a", "a", "a", "a"], + } + ) + + +@pytest.fixture(scope="module") +def time_df2(): + return pd.DataFrame( + { + "time": pd.to_datetime([2, 4], unit="s"), + "other_value": [1.2, 2.0], + "group": ["a", "a"], + } + ) + + +@pytest.fixture(scope="module") +def time_keyed_df1(): + return pd.DataFrame( + { + "time": pd.Series( + pd.date_range(start="2017-01-02 01:02:03.234", periods=6) + ), + "key": [1, 2, 3, 1, 2, 3], + "value": [1.2, 1.4, 2.0, 4.0, 8.0, 16.0], + } + ) + + +@pytest.fixture(scope="module") +def time_keyed_df2(): + return pd.DataFrame( + { + "time": pd.Series( + pd.date_range(start="2017-01-02 01:02:03.234", freq="3D", periods=3) + ), + "key": [1, 2, 3], + "other_value": [1.1, 1.2, 2.2], + } + ) + + +@pytest.fixture(scope="module") +def time_left(time_df1): + return ibis.memtable(time_df1) + + +@pytest.fixture(scope="module") +def time_right(time_df2): + return ibis.memtable(time_df2) + + +@pytest.fixture(scope="module") +def time_keyed_left(time_keyed_df1): + return ibis.memtable(time_keyed_df1) + + +@pytest.fixture(scope="module") +def time_keyed_right(time_keyed_df2): + return ibis.memtable(time_keyed_df2) + + +@pytest.mark.parametrize( + ("direction", "op"), + [ + ("backward", operator.ge), + ("forward", operator.le), + ], +) +@pytest.mark.notimpl(["datafusion", "snowflake"]) +def test_asof_join(con, time_left, time_right, time_df1, time_df2, direction, op): + on = op(time_left["time"], time_right["time"]) + expr = time_left.asof_join(time_right, on=on, predicates="group") + + result = con.execute(expr) + expected = pd.merge_asof( + time_df1, time_df2, on="time", by="group", direction=direction + ) + + result = result.sort_values(["group", "time"]).reset_index(drop=True) + expected = expected.sort_values(["group", "time"]).reset_index(drop=True) + + tm.assert_frame_equal(result[expected.columns], expected) + with pytest.raises(AssertionError): + tm.assert_series_equal(result["time"], result["time_right"]) + + +@pytest.mark.parametrize( + ("direction", "op"), + [ + ("backward", operator.ge), + ("forward", operator.le), + ], +) +@pytest.mark.broken( + ["clickhouse"], raises=AssertionError, reason="`time` is truncated to seconds" +) +@pytest.mark.notimpl(["datafusion", "snowflake"]) +def test_keyed_asof_join_with_tolerance( + con, + time_keyed_left, + time_keyed_right, + time_keyed_df1, + time_keyed_df2, + direction, + op, +): + on = op(time_keyed_left["time"], time_keyed_right["time"]) + expr = time_keyed_left.asof_join( + time_keyed_right, on=on, by="key", tolerance=ibis.interval(days=2) + ) + + result = con.execute(expr) + expected = pd.merge_asof( + time_keyed_df1, + time_keyed_df2, + on="time", + by="key", + tolerance=pd.Timedelta("2D"), + direction=direction, + ) + + result = result.sort_values(["key", "time"]).reset_index(drop=True) + expected = expected.sort_values(["key", "time"]).reset_index(drop=True) + + tm.assert_frame_equal(result[expected.columns], expected) + with pytest.raises(AssertionError): + tm.assert_series_equal(result["time"], result["time_right"]) + with pytest.raises(AssertionError): + tm.assert_series_equal(result["key"], result["key_right"]) diff --git a/ibis/backends/tests/test_join.py b/ibis/backends/tests/test_join.py index 09d4dade0991..3343ff2e39bc 100644 --- a/ibis/backends/tests/test_join.py +++ b/ibis/backends/tests/test_join.py @@ -299,9 +299,9 @@ def test_join_with_pandas_non_null_typed_columns(batting, awards_players): reason="polars doesn't support join predicates", ) @pytest.mark.notimpl( - ["dask", "pandas"], + ["dask"], raises=TypeError, - reason="dask and pandas don't support join predicates", + reason="dask doesn't support join predicates", ) @pytest.mark.notimpl( ["exasol"], diff --git a/ibis/expr/tests/test_dereference.py b/ibis/expr/tests/test_dereference.py index af569fab4a04..e19234f92084 100644 --- a/ibis/expr/tests/test_dereference.py +++ b/ibis/expr/tests/test_dereference.py @@ -18,7 +18,7 @@ def dereference_expect(expected): def test_dereference_project(): - p = t.projection([t.int_col, t.double_col]) + p = t.select([t.int_col, t.double_col]) mapping = dereference_mapping([p.op()]) expected = dereference_expect( diff --git a/ibis/expr/types/joins.py b/ibis/expr/types/joins.py index 861c3e8095c2..957d16f4253f 100644 --- a/ibis/expr/types/joins.py +++ b/ibis/expr/types/joins.py @@ -238,6 +238,15 @@ def asof_join( # noqa: D102 ): predicates = util.promote_list(predicates) + util.promote_list(by) if tolerance is not None: + # `tolerance` parameter is mimicking the pandas API, but we express + # it at the expression level by a sequence of operations: + # 1. perform the `asof` join with the `on` an `predicates` parameters + # where the `on` parameter is an inequality predicate + # 2. filter the asof join result using the `tolerance` parameter and + # the `on` parameter + # 3. perform a left join between the original left table and the + # filtered asof join result using the `on` parameter but this + # time as an equality predicate if isinstance(on, str): # self is always a JoinChain so reference one of the join tables left_on = self.op().values[on].to_expr() @@ -260,6 +269,9 @@ def asof_join( # noqa: D102 ) right_on = right_on.op().replace({right.op(): filtered.op()}).to_expr() + # without joining twice the table would not contain the rows from + # the left table that do not match any row from the right table + # given the tolerance filter result = self.left_join( filtered, predicates=[left_on == right_on] + predicates )