diff --git a/polars/polars-time/src/windows/groupby.rs b/polars/polars-time/src/windows/groupby.rs index 5c51745d5302..518989b52ff2 100644 --- a/polars/polars-time/src/windows/groupby.rs +++ b/polars/polars-time/src/windows/groupby.rs @@ -369,7 +369,7 @@ pub(crate) fn groupby_values_iter_partial_lookbehind( } #[allow(clippy::too_many_arguments)] -pub(crate) fn groupby_values_iter_full_lookahead( +pub(crate) fn groupby_values_iter_partial_lookahead( period: Duration, offset: Duration, time: &[i64], @@ -405,6 +405,53 @@ pub(crate) fn groupby_values_iter_full_lookahead( Ok((i as IdxSize, len as IdxSize)) }) } +#[allow(clippy::too_many_arguments)] +pub(crate) fn groupby_values_iter_full_lookahead( + period: Duration, + offset: Duration, + time: &[i64], + closed_window: ClosedWindow, + tu: TimeUnit, + tz: Option, + start_offset: usize, + upper_bound: Option, +) -> impl Iterator> + TrustedLen + '_ { + let upper_bound = upper_bound.unwrap_or(time.len()); + debug_assert!(!offset.negative); + + let add = match tu { + TimeUnit::Nanoseconds => Duration::add_ns, + TimeUnit::Microseconds => Duration::add_us, + TimeUnit::Milliseconds => Duration::add_ms, + }; + + time[start_offset..upper_bound] + .iter() + .enumerate() + .map(move |(mut i, lower)| { + i += start_offset; + let lower = add(&offset, *lower, tz.as_ref())?; + let upper = add(&period, lower, tz.as_ref())?; + + let b = Bounds::new(lower, upper); + + // find starting point of window + for &t in &time[i..] { + if b.is_member(t, closed_window) { + break; + } + i += 1; + } + if i >= time.len() { + return Ok((i as IdxSize, 0)); + } + + let slice = unsafe { time.get_unchecked(i..) }; + let len = slice.partition_point(|v| b.is_member(*v, closed_window)); + + Ok((i as IdxSize, len as IdxSize)) + }) +} #[cfg(feature = "rolling_window")] pub(crate) fn groupby_values_iter<'a>( @@ -429,7 +476,11 @@ pub(crate) fn groupby_values_iter<'a>( groupby_values_iter_partial_lookbehind(period, offset, time, closed_window, tu, tz); Box::new(iter) } - } else { + } else if offset != Duration::parse("0ns") + || closed_window == ClosedWindow::Right + || closed_window == ClosedWindow::None + { + // only lookahead let iter = groupby_values_iter_full_lookahead( period, offset, @@ -441,6 +492,19 @@ pub(crate) fn groupby_values_iter<'a>( None, ); Box::new(iter) + } else { + // partial lookahead + let iter = groupby_values_iter_partial_lookahead( + period, + offset, + time, + closed_window, + tu, + tz, + 0, + None, + ); + Box::new(iter) } } @@ -517,7 +581,13 @@ pub fn groupby_values( iter.map(|result| result.map(|(offset, len)| [offset, len])) .collect::>() } - } else { + } else if offset != Duration::parse("0ns") + || closed_window == ClosedWindow::Right + || closed_window == ClosedWindow::None + { + // window is completely ahead of t and t itself is not a member + // --t----------- + // [---] let vals = POOL.install(|| { thread_offsets .par_iter() @@ -541,5 +611,33 @@ pub fn groupby_values( .collect::>>() })?; Ok(flatten(&vals, Some(time.len()))) + } else { + // Duration is 0 and window is closed on the left: + // it must be that the window starts at t and t is a member + // --t----------- + // [---] + let vals = POOL.install(|| { + thread_offsets + .par_iter() + .copied() + .map(|(base_offset, len)| { + let lower_bound = base_offset; + let upper_bound = base_offset + len; + let iter = groupby_values_iter_partial_lookahead( + period, + offset, + time, + closed_window, + tu, + tz, + lower_bound, + Some(upper_bound), + ); + iter.map(|result| result.map(|(offset, len)| [offset as IdxSize, len])) + .collect::>>() + }) + .collect::>>() + })?; + Ok(flatten(&vals, Some(time.len()))) } } diff --git a/polars/polars-time/src/windows/test.rs b/polars/polars-time/src/windows/test.rs index 9b4e184fea34..28d9613b9e8e 100644 --- a/polars/polars-time/src/windows/test.rs +++ b/polars/polars-time/src/windows/test.rs @@ -690,15 +690,15 @@ fn test_rolling_lookback() { ) .unwrap(); assert_eq!(dates.len(), groups.len()); - assert_eq!(groups[0], [0, 5]); - assert_eq!(groups[1], [1, 5]); - assert_eq!(groups[2], [2, 5]); - assert_eq!(groups[3], [3, 5]); - assert_eq!(groups[4], [4, 5]); - assert_eq!(groups[5], [5, 4]); - assert_eq!(groups[6], [6, 3]); - assert_eq!(groups[7], [7, 2]); - assert_eq!(groups[8], [8, 0]); + assert_eq!(groups[0], [1, 4]); // (00:00, 02:00] + assert_eq!(groups[1], [2, 4]); // (00:30, 02:30] + assert_eq!(groups[2], [3, 4]); // (01:00, 03:00] + assert_eq!(groups[3], [4, 4]); // (01:30, 03:30] + assert_eq!(groups[4], [5, 4]); // (02:00, 04:00] + assert_eq!(groups[5], [6, 3]); // (02:30, 04:30] + assert_eq!(groups[6], [7, 2]); // (03:00, 05:00] + assert_eq!(groups[7], [8, 1]); // (03:30, 05:30] + assert_eq!(groups[8], [9, 0]); // (04:00, 06:00] let period = Duration::parse("2h"); let tu = TimeUnit::Milliseconds; @@ -708,25 +708,6 @@ fn test_rolling_lookback() { ClosedWindow::Both, ClosedWindow::None, ] { - let offset = Duration::parse("0h"); - let g0 = groupby_values_iter_full_lookahead( - period, - offset, - &dates, - closed_window, - tu, - None, - 0, - None, - ) - .collect::>>() - .unwrap(); - let g1 = - groupby_values_iter_partial_lookbehind(period, offset, &dates, closed_window, tu, None) - .collect::>>() - .unwrap(); - assert_eq!(g0, g1); - let offset = Duration::parse("-2h"); let g0 = groupby_values_iter_full_lookbehind(period, offset, &dates, closed_window, tu, None, 0) diff --git a/py-polars/tests/unit/operations/test_groupby_rolling.py b/py-polars/tests/unit/operations/test_groupby_rolling.py index 57a4893e40da..13de2008754a 100644 --- a/py-polars/tests/unit/operations/test_groupby_rolling.py +++ b/py-polars/tests/unit/operations/test_groupby_rolling.py @@ -1,13 +1,16 @@ from __future__ import annotations from datetime import datetime -from typing import Any +from typing import TYPE_CHECKING, Any import pytest import polars as pl from polars.testing import assert_frame_equal, assert_series_equal +if TYPE_CHECKING: + from polars.type_aliases import ClosedInterval + def bad_agg_parameters() -> list[Any]: """Currently, IntoExpr and Iterable[IntoExpr] are supported.""" @@ -176,6 +179,56 @@ def test_groupby_rolling_negative_offset_crossing_dst(time_zone: str | None) -> assert_frame_equal(result, expected) +@pytest.mark.parametrize("time_zone", [None, "US/Central"]) +@pytest.mark.parametrize( + ("offset", "closed", "expected_values"), + [ + ("0d", "left", [[1, 4], [4, 9], [9, 155], [155]]), + ("0d", "right", [[4, 9], [9, 155], [155], []]), + ("0d", "both", [[1, 4, 9], [4, 9, 155], [9, 155], [155]]), + ("0d", "none", [[4], [9], [155], []]), + ("1d", "left", [[4, 9], [9, 155], [155], []]), + ("1d", "right", [[9, 155], [155], [], []]), + ("1d", "both", [[4, 9, 155], [9, 155], [155], []]), + ("1d", "none", [[9], [155], [], []]), + ], +) +def test_groupby_rolling_non_negative_offset_9077( + time_zone: str | None, + offset: str, + closed: ClosedInterval, + expected_values: list[list[int]], +) -> None: + df = pl.DataFrame( + { + "datetime": pl.date_range( + datetime(2021, 11, 6), + datetime(2021, 11, 9), + "1d", + time_zone=time_zone, + eager=True, + ), + "value": [1, 4, 9, 155], + } + ) + result = df.groupby_rolling( + index_column="datetime", period="2d", offset=offset, closed=closed + ).agg(pl.col("value")) + expected = pl.DataFrame( + { + "datetime": pl.date_range( + datetime(2021, 11, 6), + datetime(2021, 11, 9), + "1d", + time_zone=time_zone, + eager=True, + ), + "value": expected_values, + } + ) + assert_frame_equal(result, expected) + + def test_groupby_rolling_dynamic_sortedness_check() -> None: # when the by argument is passed, the sortedness flag # will be unset as the take shuffles data, so we must explicitly