From 9e3d614ac2d379be97516d68445a3ce7644825b5 Mon Sep 17 00:00:00 2001 From: Marco Edward Gorelli Date: Mon, 20 May 2024 18:01:26 +0200 Subject: [PATCH] perf: use is_sorted in ewm_mean_by, deprecate check_sorted (#16335) --- crates/polars-ops/src/series/ops/ewm_by.rs | 64 +++++++++++++------ .../src/dsl/function_expr/ewm_by.rs | 12 ++-- .../polars-plan/src/dsl/function_expr/mod.rs | 11 +--- crates/polars-plan/src/dsl/mod.rs | 7 +- py-polars/polars/expr/expr.py | 13 +++- py-polars/src/expr/general.rs | 4 +- py-polars/src/lazyframe/visitor/expr_nodes.rs | 7 +- py-polars/tests/unit/functions/test_ewm_by.py | 4 +- .../tests/unit/operations/test_ewm_by.py | 18 +++--- 9 files changed, 79 insertions(+), 61 deletions(-) diff --git a/crates/polars-ops/src/series/ops/ewm_by.rs b/crates/polars-ops/src/series/ops/ewm_by.rs index 4b1d047269ef..1bc3630d6604 100644 --- a/crates/polars-ops/src/series/ops/ewm_by.rs +++ b/crates/polars-ops/src/series/ops/ewm_by.rs @@ -7,42 +7,68 @@ pub fn ewm_mean_by( s: &Series, times: &Series, half_life: i64, - assume_sorted: bool, + times_is_sorted: bool, ) -> PolarsResult { - match (s.dtype(), times.dtype()) { - (DataType::Float64, DataType::Int64) => Ok((if assume_sorted { - ewm_mean_by_impl_sorted(s.f64().unwrap(), times.i64().unwrap(), half_life) - } else { - ewm_mean_by_impl(s.f64().unwrap(), times.i64().unwrap(), half_life) - }) - .into_series()), - (DataType::Float32, DataType::Int64) => Ok((if assume_sorted { - ewm_mean_by_impl_sorted(s.f32().unwrap(), times.i64().unwrap(), half_life) + fn func( + values: &ChunkedArray, + times: &Int64Chunked, + half_life: i64, + times_is_sorted: bool, + ) -> PolarsResult + where + T: PolarsFloatType, + T::Native: Float + Zero + One, + ChunkedArray: IntoSeries, + { + if times_is_sorted { + Ok(ewm_mean_by_impl_sorted(values, times, half_life).into_series()) } else { - ewm_mean_by_impl(s.f32().unwrap(), times.i64().unwrap(), half_life) - }) - .into_series()), + Ok(ewm_mean_by_impl(values, times, half_life).into_series()) + } + } + + match (s.dtype(), times.dtype()) { + (DataType::Float64, DataType::Int64) => func( + s.f64().unwrap(), + times.i64().unwrap(), + half_life, + times_is_sorted, + ), + (DataType::Float32, DataType::Int64) => func( + s.f32().unwrap(), + times.i64().unwrap(), + half_life, + times_is_sorted, + ), #[cfg(feature = "dtype-datetime")] (_, DataType::Datetime(time_unit, _)) => { let half_life = adjust_half_life_to_time_unit(half_life, time_unit); - ewm_mean_by(s, ×.cast(&DataType::Int64)?, half_life, assume_sorted) + ewm_mean_by( + s, + ×.cast(&DataType::Int64)?, + half_life, + times_is_sorted, + ) }, #[cfg(feature = "dtype-date")] (_, DataType::Date) => ewm_mean_by( s, ×.cast(&DataType::Datetime(TimeUnit::Milliseconds, None))?, half_life, - assume_sorted, + times_is_sorted, + ), + (_, DataType::UInt64 | DataType::UInt32 | DataType::Int32) => ewm_mean_by( + s, + ×.cast(&DataType::Int64)?, + half_life, + times_is_sorted, ), - (_, DataType::UInt64 | DataType::UInt32 | DataType::Int32) => { - ewm_mean_by(s, ×.cast(&DataType::Int64)?, half_life, assume_sorted) - }, (DataType::UInt64 | DataType::UInt32 | DataType::Int64 | DataType::Int32, _) => { ewm_mean_by( &s.cast(&DataType::Float64)?, times, half_life, - assume_sorted, + times_is_sorted, ) }, _ => { diff --git a/crates/polars-plan/src/dsl/function_expr/ewm_by.rs b/crates/polars-plan/src/dsl/function_expr/ewm_by.rs index b47e4c25470e..c901dc22a25f 100644 --- a/crates/polars-plan/src/dsl/function_expr/ewm_by.rs +++ b/crates/polars-plan/src/dsl/function_expr/ewm_by.rs @@ -1,10 +1,8 @@ +use polars_ops::series::SeriesMethods; + use super::*; -pub(super) fn ewm_mean_by( - s: &[Series], - half_life: Duration, - check_sorted: bool, -) -> PolarsResult { +pub(super) fn ewm_mean_by(s: &[Series], half_life: Duration) -> PolarsResult { let time_zone = match s[1].dtype() { DataType::Datetime(_, Some(time_zone)) => Some(time_zone.as_str()), _ => None, @@ -15,6 +13,6 @@ pub(super) fn ewm_mean_by( let half_life = half_life.duration_ns(); let values = &s[0]; let times = &s[1]; - let assume_sorted = !check_sorted || times.is_sorted_flag() == IsSorted::Ascending; - polars_ops::prelude::ewm_mean_by(values, times, half_life, assume_sorted) + let times_is_sorted = times.is_sorted(Default::default())?; + polars_ops::prelude::ewm_mean_by(values, times, half_life, times_is_sorted) } diff --git a/crates/polars-plan/src/dsl/function_expr/mod.rs b/crates/polars-plan/src/dsl/function_expr/mod.rs index 21a512e3db5a..e8746609c66a 100644 --- a/crates/polars-plan/src/dsl/function_expr/mod.rs +++ b/crates/polars-plan/src/dsl/function_expr/mod.rs @@ -328,7 +328,6 @@ pub enum FunctionExpr { #[cfg(feature = "ewma_by")] EwmMeanBy { half_life: Duration, - check_sorted: bool, }, #[cfg(feature = "ewma")] EwmStd { @@ -542,10 +541,7 @@ impl Hash for FunctionExpr { #[cfg(feature = "ewma")] EwmMean { options } => options.hash(state), #[cfg(feature = "ewma_by")] - EwmMeanBy { - half_life, - check_sorted, - } => (half_life, check_sorted).hash(state), + EwmMeanBy { half_life } => (half_life).hash(state), #[cfg(feature = "ewma")] EwmStd { options } => options.hash(state), #[cfg(feature = "ewma")] @@ -1118,10 +1114,7 @@ impl From for SpecialEq> { #[cfg(feature = "ewma")] EwmMean { options } => map!(ewm::ewm_mean, options), #[cfg(feature = "ewma_by")] - EwmMeanBy { - half_life, - check_sorted, - } => map_as_slice!(ewm_by::ewm_mean_by, half_life, check_sorted), + EwmMeanBy { half_life } => map_as_slice!(ewm_by::ewm_mean_by, half_life), #[cfg(feature = "ewma")] EwmStd { options } => map!(ewm::ewm_std, options), #[cfg(feature = "ewma")] diff --git a/crates/polars-plan/src/dsl/mod.rs b/crates/polars-plan/src/dsl/mod.rs index 3357875e246a..633e14938dcb 100644 --- a/crates/polars-plan/src/dsl/mod.rs +++ b/crates/polars-plan/src/dsl/mod.rs @@ -1647,12 +1647,9 @@ impl Expr { #[cfg(feature = "ewma_by")] /// Calculate the exponentially-weighted moving average by a time column. - pub fn ewm_mean_by(self, times: Expr, half_life: Duration, check_sorted: bool) -> Self { + pub fn ewm_mean_by(self, times: Expr, half_life: Duration) -> Self { self.apply_many_private( - FunctionExpr::EwmMeanBy { - half_life, - check_sorted, - }, + FunctionExpr::EwmMeanBy { half_life }, &[times], false, false, diff --git a/py-polars/polars/expr/expr.py b/py-polars/polars/expr/expr.py index 7e74ce31d0fb..082a6e9e7e24 100644 --- a/py-polars/polars/expr/expr.py +++ b/py-polars/polars/expr/expr.py @@ -10537,7 +10537,7 @@ def ewm_mean_by( by: str | IntoExpr, *, half_life: str | timedelta, - check_sorted: bool = True, + check_sorted: bool | None = None, ) -> Self: r""" Calculate time-based exponentially weighted moving average. @@ -10587,6 +10587,10 @@ def ewm_mean_by( Check whether `by` column is sorted. Incorrectly setting this to `False` will lead to incorrect output. + .. deprecated:: 0.20.27 + Sortedness is now verified in a quick manner, you can safely remove + this argument. + Returns ------- Expr @@ -10625,7 +10629,12 @@ def ewm_mean_by( """ by = parse_as_expression(by) half_life = parse_as_duration_string(half_life) - return self._from_pyexpr(self._pyexpr.ewm_mean_by(by, half_life, check_sorted)) + if check_sorted is not None: + issue_deprecation_warning( + "`check_sorted` is now deprecated in `ewm_mean_by`, you can safely remove this argument.", + version="0.20.27", + ) + return self._from_pyexpr(self._pyexpr.ewm_mean_by(by, half_life)) @deprecate_nonkeyword_arguments(version="0.19.10") def ewm_std( diff --git a/py-polars/src/expr/general.rs b/py-polars/src/expr/general.rs index 21008b5b2128..6bd1a706c9cb 100644 --- a/py-polars/src/expr/general.rs +++ b/py-polars/src/expr/general.rs @@ -858,11 +858,11 @@ impl PyExpr { }; self.inner.clone().ewm_mean(options).into() } - fn ewm_mean_by(&self, times: PyExpr, half_life: &str, check_sorted: bool) -> Self { + fn ewm_mean_by(&self, times: PyExpr, half_life: &str) -> Self { let half_life = Duration::parse(half_life); self.inner .clone() - .ewm_mean_by(times.inner, half_life, check_sorted) + .ewm_mean_by(times.inner, half_life) .into() } diff --git a/py-polars/src/lazyframe/visitor/expr_nodes.rs b/py-polars/src/lazyframe/visitor/expr_nodes.rs index 98fce2b512af..ed9b3ed72b41 100644 --- a/py-polars/src/lazyframe/visitor/expr_nodes.rs +++ b/py-polars/src/lazyframe/visitor/expr_nodes.rs @@ -1020,10 +1020,9 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult { FunctionExpr::TopKBy { sort_options: _ } => { return Err(PyNotImplementedError::new_err("top_k_by")) }, - FunctionExpr::EwmMeanBy { - half_life: _, - check_sorted: _, - } => return Err(PyNotImplementedError::new_err("ewm_mean_by")), + FunctionExpr::EwmMeanBy { half_life: _ } => { + return Err(PyNotImplementedError::new_err("ewm_mean_by")) + }, }, options: py.None(), } diff --git a/py-polars/tests/unit/functions/test_ewm_by.py b/py-polars/tests/unit/functions/test_ewm_by.py index 6c7f1da4cf9f..fb56fa1f23cb 100644 --- a/py-polars/tests/unit/functions/test_ewm_by.py +++ b/py-polars/tests/unit/functions/test_ewm_by.py @@ -27,9 +27,7 @@ def test_ewm_by(data: st.DataObject, half_life: int) -> None: ) ) result = df.with_row_index().select( - pl.col("values").ewm_mean_by( - by="index", half_life=f"{half_life}i", check_sorted=False - ) + pl.col("values").ewm_mean_by(by="index", half_life=f"{half_life}i") ) expected = df.select( pl.col("values").ewm_mean(half_life=half_life, ignore_nulls=False, adjust=False) diff --git a/py-polars/tests/unit/operations/test_ewm_by.py b/py-polars/tests/unit/operations/test_ewm_by.py index 43884d7e0b82..ac0e6929df98 100644 --- a/py-polars/tests/unit/operations/test_ewm_by.py +++ b/py-polars/tests/unit/operations/test_ewm_by.py @@ -173,22 +173,20 @@ def test_ewma_by_empty() -> None: assert_frame_equal(result, expected) -def test_ewma_by_warn_if_unsorted() -> None: +def test_ewma_by_if_unsorted() -> None: df = pl.DataFrame({"values": [3.0, 2.0], "by": [3, 1]}) - - # Check that with `check_sorted=False`, the user can get incorrect results - # if they really want to. - result = df.select( - pl.col("values").ewm_mean_by("by", half_life="2i", check_sorted=False), - ) - expected = pl.DataFrame({"values": [3.0, 4.0]}) - assert_frame_equal(result, expected) - result = df.with_columns( pl.col("values").ewm_mean_by("by", half_life="2i"), ) expected = pl.DataFrame({"values": [2.5, 2.0], "by": [3, 1]}) assert_frame_equal(result, expected) + + with pytest.deprecated_call(match="you can safely remove this argument"): + result = df.with_columns( + pl.col("values").ewm_mean_by("by", half_life="2i", check_sorted=False), + ) + assert_frame_equal(result, expected) + result = df.sort("by").with_columns( pl.col("values").ewm_mean_by("by", half_life="2i"), )