From 063e3aedd52127b4e58f228f33c49355a66133ee Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Wed, 7 Jun 2023 08:26:33 +0200 Subject: [PATCH] feat(rust, python): keep sorted flag after Expr::truncate (#9275) --- .../polars-plan/src/dsl/function_expr/datetime.rs | 6 ++++-- py-polars/tests/unit/namespaces/test_datetime.py | 4 +++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/polars/polars-lazy/polars-plan/src/dsl/function_expr/datetime.rs b/polars/polars-lazy/polars-plan/src/dsl/function_expr/datetime.rs index bf50c52a77fb..3fc096c7e4ec 100644 --- a/polars/polars-lazy/polars-plan/src/dsl/function_expr/datetime.rs +++ b/polars/polars-lazy/polars-plan/src/dsl/function_expr/datetime.rs @@ -181,7 +181,7 @@ pub(super) fn timestamp(s: &Series, tu: TimeUnit) -> PolarsResult { pub(super) fn truncate(s: &Series, every: &str, offset: &str) -> PolarsResult { let every = Duration::parse(every); let offset = Duration::parse(offset); - Ok(match s.dtype() { + let mut out = match s.dtype() { DataType::Datetime(_, tz) => match tz { #[cfg(feature = "timezones")] Some(tz) => s @@ -201,7 +201,9 @@ pub(super) fn truncate(s: &Series, every: &str, offset: &str) -> PolarsResult polars_bail!(opq = round, got = dt, expected = "date/datetime"), - }) + }; + out.set_sorted_flag(s.is_sorted_flag()); + Ok(out) } #[cfg(feature = "date_offset")] diff --git a/py-polars/tests/unit/namespaces/test_datetime.py b/py-polars/tests/unit/namespaces/test_datetime.py index eaeddfa4d99b..28aaeb8257c4 100644 --- a/py-polars/tests/unit/namespaces/test_datetime.py +++ b/py-polars/tests/unit/namespaces/test_datetime.py @@ -540,7 +540,7 @@ def test_negative_offset_by_err_msg_8464() -> None: pl.Series([datetime(2022, 3, 30)]).dt.offset_by("-1mo") -def test_offset_by_sorted_flag() -> None: +def test_offset_by_truncate_sorted_flag() -> None: s = pl.Series([datetime(2001, 1, 1), datetime(2001, 1, 2)]) s = s.set_sorted() @@ -548,6 +548,8 @@ def test_offset_by_sorted_flag() -> None: s1 = s.dt.offset_by("1d") assert s1.to_list() == [datetime(2001, 1, 2), datetime(2001, 1, 3)] assert s1.flags["SORTED_ASC"] + s2 = s1.dt.truncate("1mo") + assert s2.flags["SORTED_ASC"] @pytest.mark.parametrize(