Skip to content

Commit

Permalink
refactor(python): Warn for future change of closed default value in r…
Browse files Browse the repository at this point in the history
…olling functions (#9470)
  • Loading branch information
zundertj authored Jun 21, 2023
1 parent 1c7744d commit ca08fdf
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 2 deletions.
10 changes: 9 additions & 1 deletion py-polars/polars/expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
parse_as_list_of_expressions,
)
from polars.utils.convert import _timedelta_to_pl_duration
from polars.utils.decorators import deprecated_alias
from polars.utils.decorators import deprecated_alias, warn_closed_future_change
from polars.utils.meta import threadpool_size
from polars.utils.various import sphinx_accessor

Expand Down Expand Up @@ -4776,6 +4776,7 @@ def interpolate(self, method: InterpolationMethod = "linear") -> Self:
"""
return self._from_pyexpr(self._pyexpr.interpolate(method))

@warn_closed_future_change()
def rolling_min(
self,
window_size: int | timedelta | str,
Expand Down Expand Up @@ -4973,6 +4974,7 @@ def rolling_min(
)
)

@warn_closed_future_change()
def rolling_max(
self,
window_size: int | timedelta | str,
Expand Down Expand Up @@ -5195,6 +5197,7 @@ def rolling_max(
)
)

@warn_closed_future_change()
def rolling_mean(
self,
window_size: int | timedelta | str,
Expand Down Expand Up @@ -5417,6 +5420,7 @@ def rolling_mean(
)
)

@warn_closed_future_change()
def rolling_sum(
self,
window_size: int | timedelta | str,
Expand Down Expand Up @@ -5639,6 +5643,7 @@ def rolling_sum(
)
)

@warn_closed_future_change()
def rolling_std(
self,
window_size: int | timedelta | str,
Expand Down Expand Up @@ -5864,6 +5869,7 @@ def rolling_std(
)
)

@warn_closed_future_change()
def rolling_var(
self,
window_size: int | timedelta | str,
Expand Down Expand Up @@ -6095,6 +6101,7 @@ def rolling_var(
)
)

@warn_closed_future_change()
def rolling_median(
self,
window_size: int | timedelta | str,
Expand Down Expand Up @@ -6242,6 +6249,7 @@ def rolling_median(
)
)

@warn_closed_future_change()
def rolling_quantile(
self,
quantile: float,
Expand Down
32 changes: 32 additions & 0 deletions py-polars/polars/utils/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,38 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
return deco


def warn_closed_future_change() -> Callable[[Callable[P, T]], Callable[P, T]]:
"""
Warn that user should pass in 'closed' as default value will change.
Decorator for rolling function. Use as follows:
@warn_closed_future_change()
def myfunc():
...
"""

def deco(function: Callable[P, T]) -> Callable[P, T]:
@wraps(function)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
# we only warn if 'by' is passed in, otherwise 'closed' is not used
if (kwargs.get("by") is not None) and ("closed" not in kwargs):
warnings.warn(
message=(
"The default argument for closed, 'left', will be changed to 'right' in the future."
"Fix this warning by explicitly passing in a value for closed"
),
category=FutureWarning,
stacklevel=find_stacklevel(),
)

return function(*args, **kwargs)

return wrapper

return deco


def _rename_kwargs(
func_name: str,
kwargs: dict[str, object],
Expand Down
4 changes: 3 additions & 1 deletion py-polars/tests/unit/operations/test_rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,9 @@ def test_rolling_crossing_dst(
datetime(2021, 11, 5), datetime(2021, 11, 10), "1d", time_zone="UTC", eager=True
).dt.replace_time_zone(time_zone)
df = pl.DataFrame({"ts": ts, "value": [1, 2, 3, 4, 5, 6]})
result = df.with_columns(getattr(pl.col("value"), rolling_fn)("1d", by="ts"))
result = df.with_columns(
getattr(pl.col("value"), rolling_fn)("1d", by="ts", closed="left")
)
expected = pl.DataFrame({"ts": ts, "value": expected_values})
assert_frame_equal(result, expected)

Expand Down
22 changes: 22 additions & 0 deletions py-polars/tests/unit/test_lazy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import warnings
from datetime import date, datetime
from functools import reduce
from inspect import signature
Expand Down Expand Up @@ -725,6 +726,27 @@ def test_rolling(fruits_cars: pl.DataFrame) -> None:
assert cast(float, out_single_val_variance[0, "var"]) == 0.0


def test_rolling_closed_decorator() -> None:
# no warning if we do not use by
with warnings.catch_warnings():
warnings.simplefilter("error")
_ = pl.col("a").rolling_min(2)

# if we pass in a by, but no closed, we expect a warning
with pytest.warns(FutureWarning):
_ = pl.col("a").rolling_min(2, by="b")

# if we pass in a by and a closed, we expect no warning
with warnings.catch_warnings():
warnings.simplefilter("error")
_ = pl.col("a").rolling_min(2, by="b", closed="left")

# regardless of the value
with warnings.catch_warnings():
warnings.simplefilter("error")
_ = pl.col("a").rolling_min(2, by="b", closed="right")


def test_arr_namespace(fruits_cars: pl.DataFrame) -> None:
ldf = fruits_cars.lazy()
out = ldf.select(
Expand Down

0 comments on commit ca08fdf

Please sign in to comment.