From d3779ae9e34712afaaa22de49126bc5efdac2e12 Mon Sep 17 00:00:00 2001 From: Marco Edward Gorelli Date: Tue, 20 Jun 2023 21:32:30 +0100 Subject: [PATCH] fix(rust, python): groupby rolling with negative offset (#9428) --- polars/polars-time/src/windows/groupby.rs | 80 ++++++++--------- polars/polars-time/src/windows/test.rs | 7 +- .../polars/testing/parametric/primitives.py | 51 ++++++++++- .../polars/testing/parametric/strategies.py | 2 + .../tests/parametric/test_groupby_rolling.py | 76 ++++++++++++++++ .../tests/unit/operations/test_rolling.py | 90 +++++++++++++++++++ 6 files changed, 257 insertions(+), 49 deletions(-) create mode 100644 py-polars/tests/parametric/test_groupby_rolling.py diff --git a/polars/polars-time/src/windows/groupby.rs b/polars/polars-time/src/windows/groupby.rs index 44b7f5bbdc2d..3f393c54db9f 100644 --- a/polars/polars-time/src/windows/groupby.rs +++ b/polars/polars-time/src/windows/groupby.rs @@ -223,8 +223,8 @@ pub fn groupby_windows( (groups, lower_bound, upper_bound) } -// this assumes that the starting point is alwa -pub(crate) fn groupby_values_iter_full_lookbehind( +// this assumes that the given time point is the right endpoint of the window +pub(crate) fn groupby_values_iter_lookbehind( period: Duration, offset: Duration, time: &[i64], @@ -233,7 +233,7 @@ pub(crate) fn groupby_values_iter_full_lookbehind( tz: Option, start_offset: usize, ) -> impl Iterator> + TrustedLen + '_ { - debug_assert!(offset.duration_ns() >= period.duration_ns()); + debug_assert!(offset.duration_ns() == period.duration_ns()); debug_assert!(offset.negative); let add = match tu { TimeUnit::Nanoseconds => Duration::add_ns, @@ -465,8 +465,7 @@ pub(crate) fn groupby_values_iter<'a>( offset.negative = !period.negative; if offset.duration_ns() > 0 { // t is at the right endpoint of the window - let iter = - groupby_values_iter_full_lookbehind(period, offset, time, closed_window, tu, tz, 0); + let iter = groupby_values_iter_lookbehind(period, offset, time, closed_window, tu, tz, 0); Box::new(iter) } else if closed_window == ClosedWindow::Right || closed_window == ClosedWindow::None { // only lookahead @@ -514,49 +513,44 @@ pub fn groupby_values( // we have a (partial) lookbehind window if offset.negative { - if offset.duration_ns() >= period.duration_ns() { - // lookbehind - // window is within 2 periods length of t + // lookbehind + if offset.duration_ns() == period.duration_ns() { + // t is right at the end of the window // ------t--- // [------] - if offset.duration_ns() < period.duration_ns() * 2 { - POOL.install(|| { - let vals = thread_offsets - .par_iter() - .copied() - .map(|(base_offset, len)| { - let upper_bound = base_offset + len; - let iter = groupby_values_iter_full_lookbehind( - period, - offset, - &time[..upper_bound], - closed_window, - tu, - tz, - base_offset, - ); - iter.map(|result| result.map(|(offset, len)| [offset, len])) - .collect::>>() - }) - .collect::>>()?; - Ok(flatten_par(&vals)) - }) - } + POOL.install(|| { + let vals = thread_offsets + .par_iter() + .copied() + .map(|(base_offset, len)| { + let upper_bound = base_offset + len; + let iter = groupby_values_iter_lookbehind( + period, + offset, + &time[..upper_bound], + closed_window, + tu, + tz, + base_offset, + ); + iter.map(|result| result.map(|(offset, len)| [offset, len])) + .collect::>>() + }) + .collect::>>()?; + Ok(flatten_par(&vals)) + }) + } else if ((offset.duration_ns() >= period.duration_ns()) + && matches!(closed_window, ClosedWindow::Left | ClosedWindow::None)) + || ((offset.duration_ns() > period.duration_ns()) + && matches!(closed_window, ClosedWindow::Right | ClosedWindow::Both)) + { // window is completely behind t and t itself is not a member // ---------------t--- // [---] - else { - let iter = groupby_values_iter_window_behind_t( - period, - offset, - time, - closed_window, - tu, - tz, - ); - iter.map(|result| result.map(|(offset, len)| [offset, len])) - .collect::>() - } + let iter = + groupby_values_iter_window_behind_t(period, offset, time, closed_window, tu, tz); + iter.map(|result| result.map(|(offset, len)| [offset, len])) + .collect::>() } // partial lookbehind // this one is still single threaded diff --git a/polars/polars-time/src/windows/test.rs b/polars/polars-time/src/windows/test.rs index 28d9613b9e8e..adb820837230 100644 --- a/polars/polars-time/src/windows/test.rs +++ b/polars/polars-time/src/windows/test.rs @@ -709,10 +709,9 @@ fn test_rolling_lookback() { ClosedWindow::None, ] { let offset = Duration::parse("-2h"); - let g0 = - groupby_values_iter_full_lookbehind(period, offset, &dates, closed_window, tu, None, 0) - .collect::>>() - .unwrap(); + let g0 = groupby_values_iter_lookbehind(period, offset, &dates, closed_window, tu, None, 0) + .collect::>>() + .unwrap(); let g1 = groupby_values_iter_partial_lookbehind(period, offset, &dates, closed_window, tu, None) .collect::>>() diff --git a/py-polars/polars/testing/parametric/primitives.py b/py-polars/polars/testing/parametric/primitives.py index 05bad0ce9d3e..6df1fe38b866 100644 --- a/py-polars/polars/testing/parametric/primitives.py +++ b/py-polars/polars/testing/parametric/primitives.py @@ -5,7 +5,7 @@ from dataclasses import dataclass from math import isfinite from textwrap import dedent -from typing import TYPE_CHECKING, Any, Collection, Sequence +from typing import TYPE_CHECKING, Any, Collection, Sequence, overload from hypothesis.errors import InvalidArgument, NonInteractiveExampleWarning from hypothesis.strategies import ( @@ -41,11 +41,18 @@ ) if TYPE_CHECKING: + import sys + from hypothesis.strategies import DrawFn, SearchStrategy from polars import LazyFrame from polars.type_aliases import OneOrMoreDataTypes, PolarsDataType + if sys.version_info >= (3, 8): + from typing import Literal + else: + from typing_extensions import Literal + _time_units = list(DTYPE_TEMPORAL_UNITS) @@ -444,11 +451,51 @@ def draw_series(draw: DrawFn) -> Series: _failed_frame_init_msgs_: set[str] = set() +@overload +def dataframes( + cols: int | column | Sequence[column] | None = None, + *, + lazy: Literal[False] = ..., + min_cols: int | None = 0, + max_cols: int | None = MAX_COLS, + size: int | None = None, + min_size: int | None = 0, + max_size: int | None = MAX_DATA_SIZE, + chunked: bool | None = None, + include_cols: Sequence[column] | column | None = None, + null_probability: float | dict[str, float] = 0.0, + allow_infinities: bool = True, + allowed_dtypes: Collection[PolarsDataType] | PolarsDataType | None = None, + excluded_dtypes: Collection[PolarsDataType] | PolarsDataType | None = None, +) -> SearchStrategy[DataFrame]: + ... + + +@overload +def dataframes( + cols: int | column | Sequence[column] | None = None, + *, + lazy: Literal[True], + min_cols: int | None = 0, + max_cols: int | None = MAX_COLS, + size: int | None = None, + min_size: int | None = 0, + max_size: int | None = MAX_DATA_SIZE, + chunked: bool | None = None, + include_cols: Sequence[column] | column | None = None, + null_probability: float | dict[str, float] = 0.0, + allow_infinities: bool = True, + allowed_dtypes: Collection[PolarsDataType] | PolarsDataType | None = None, + excluded_dtypes: Collection[PolarsDataType] | PolarsDataType | None = None, +) -> SearchStrategy[LazyFrame]: + ... + + @defines_strategy() def dataframes( cols: int | column | Sequence[column] | None = None, - lazy: bool = False, *, + lazy: bool = False, min_cols: int | None = 0, max_cols: int | None = MAX_COLS, size: int | None = None, diff --git a/py-polars/polars/testing/parametric/strategies.py b/py-polars/polars/testing/parametric/strategies.py index a966b4792c22..39d6cb762af0 100644 --- a/py-polars/polars/testing/parametric/strategies.py +++ b/py-polars/polars/testing/parametric/strategies.py @@ -109,6 +109,8 @@ def between(draw: DrawFn, type_: type, min_: Any, max_: Any) -> Any: min_value=timedelta(microseconds=-(2**46)), max_value=timedelta(microseconds=(2**46) - 1), ) +strategy_closed = sampled_from(["left", "right", "both", "none"]) +strategy_time_unit = sampled_from(["ns", "us", "ms"]) @composite diff --git a/py-polars/tests/parametric/test_groupby_rolling.py b/py-polars/tests/parametric/test_groupby_rolling.py new file mode 100644 index 000000000000..9f98199af233 --- /dev/null +++ b/py-polars/tests/parametric/test_groupby_rolling.py @@ -0,0 +1,76 @@ +from __future__ import annotations + +from datetime import timedelta +from typing import TYPE_CHECKING + +import hypothesis.strategies as st +from hypothesis import given, reject + +import polars as pl +from polars.testing import assert_frame_equal +from polars.testing.parametric.primitives import column, dataframes +from polars.testing.parametric.strategies import strategy_closed, strategy_time_unit +from polars.utils.convert import _timedelta_to_pl_duration + +if TYPE_CHECKING: + from polars.type_aliases import ClosedInterval, TimeUnit + + +@given( + period=st.timedeltas(min_value=timedelta(microseconds=0)).map( + _timedelta_to_pl_duration + ), + offset=st.timedeltas().map(_timedelta_to_pl_duration), + closed=strategy_closed, + data=st.data(), + time_unit=strategy_time_unit, +) +def test_groupby_rolling( + period: str, + offset: str, + closed: ClosedInterval, + data: st.DataObject, + time_unit: TimeUnit, +) -> None: + dataframe = data.draw( + dataframes( + [ + column("ts", dtype=pl.Datetime(time_unit)), + column("value", dtype=pl.Int64), + ], + ) + ) + df = dataframe.sort("ts").unique("ts") + try: + result = df.groupby_rolling( + "ts", period=period, offset=offset, closed=closed + ).agg(pl.col("value")) + except pl.exceptions.PolarsPanicError as exc: + assert any( # noqa: PT017 + msg in str(exc) + for msg in ( + "attempt to multiply with overflow", + "attempt to add with overflow", + ) + ) + reject() + + expected_dict: dict[str, list[object]] = {"ts": [], "value": []} + for ts, _ in df.iter_rows(): + window = df.filter( + pl.col("ts").is_between( + pl.lit(ts, dtype=pl.Datetime(time_unit)).dt.offset_by(offset), + pl.lit(ts, dtype=pl.Datetime(time_unit)) + .dt.offset_by(offset) + .dt.offset_by(period), + closed=closed, + ) + ) + value = window["value"].to_list() + expected_dict["ts"].append(ts) + expected_dict["value"].append(value) + expected = pl.DataFrame(expected_dict).select( + pl.col("ts").cast(pl.Datetime(time_unit)), + pl.col("value").cast(pl.List(pl.Int64)), + ) + assert_frame_equal(result, expected) diff --git a/py-polars/tests/unit/operations/test_rolling.py b/py-polars/tests/unit/operations/test_rolling.py index 97f7cd18cac4..7c1f4232e4be 100644 --- a/py-polars/tests/unit/operations/test_rolling.py +++ b/py-polars/tests/unit/operations/test_rolling.py @@ -72,6 +72,96 @@ def test_rolling_kernels_and_groupby_rolling( assert_frame_equal(out1, out2) +@pytest.mark.parametrize( + ("offset", "closed", "expected_values"), + [ + pytest.param( + "-1d", + "left", + [[1], [1, 2], [2, 3], [3, 4]], + id="partial lookbehind, left", + ), + pytest.param( + "-1d", + "right", + [[1, 2], [2, 3], [3, 4], [4]], + id="partial lookbehind, right", + ), + pytest.param( + "-1d", + "both", + [[1, 2], [1, 2, 3], [2, 3, 4], [3, 4]], + id="partial lookbehind, both", + ), + pytest.param( + "-1d", + "none", + [[1], [2], [3], [4]], + id="partial lookbehind, none", + ), + pytest.param( + "-2d", + "left", + [[], [1], [1, 2], [2, 3]], + id="full lookbehind, left", + ), + pytest.param( + "-3d", + "left", + [[], [], [1], [1, 2]], + id="full lookbehind, offset > period, left", + ), + pytest.param( + "-3d", + "right", + [[], [1], [1, 2], [2, 3]], + id="full lookbehind, right", + ), + pytest.param( + "-3d", + "both", + [[], [1], [1, 2], [1, 2, 3]], + id="full lookbehind, both", + ), + pytest.param( + "-2d", + "none", + [[], [1], [2], [3]], + id="full lookbehind, none", + ), + pytest.param( + "-3d", + "none", + [[], [], [1], [2]], + id="full lookbehind, offset > period, none", + ), + ], +) +def test_rolling_negative_offset( + offset: str, closed: ClosedInterval, expected_values: list[list[int]] +) -> None: + df = pl.DataFrame( + { + "ts": pl.date_range( + datetime(2021, 1, 1), datetime(2021, 1, 4), "1d", eager=True + ), + "value": [1, 2, 3, 4], + } + ) + result = df.groupby_rolling("ts", period="2d", offset=offset, closed=closed).agg( + pl.col("value") + ) + expected = pl.DataFrame( + { + "ts": pl.date_range( + datetime(2021, 1, 1), datetime(2021, 1, 4), "1d", eager=True + ), + "value": expected_values, + } + ) + assert_frame_equal(result, expected) + + def test_rolling_skew() -> None: s = pl.Series([1, 2, 3, 3, 2, 10, 8]) assert s.rolling_skew(window_size=4, bias=True).to_list() == pytest.approx(