Skip to content

Commit

Permalink
FIX: reverse index output of bottleneck move_argmax/move_argmin funct…
Browse files Browse the repository at this point in the history
…ions (#8552)

* FIX: reverse index output of bottleneck move_argmax/move_argmin functions, add move_argmax/move_argmin to bottleneck tests

* add whats-new.rst entry
  • Loading branch information
kmuehlbauer authored Dec 21, 2023
1 parent b444438 commit c35d6b6
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 2 deletions.
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ Deprecations
Bug fixes
~~~~~~~~~

- Reverse index output of bottleneck's rolling move_argmax/move_argmin functions (:issue:`8541`, :pull:`8552`).
By `Kai Mühlbauer <https://github.com/kmuehlbauer>`_.


Documentation
~~~~~~~~~~~~~
Expand Down
5 changes: 5 additions & 0 deletions xarray/core/rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,11 @@ def _bottleneck_reduce(self, func, keep_attrs, **kwargs):
values = func(
padded.data, window=self.window[0], min_count=min_count, axis=axis
)
# index 0 is at the rightmost edge of the window
# need to reverse index here
# see GH #8541
if func in [bottleneck.move_argmin, bottleneck.move_argmax]:
values = self.window[0] - 1 - values

if self.center[0]:
values = values[valid]
Expand Down
12 changes: 10 additions & 2 deletions xarray/tests/test_rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,9 @@ def test_rolling_properties(self, da) -> None:
):
da.rolling(foo=2)

@pytest.mark.parametrize("name", ("sum", "mean", "std", "min", "max", "median"))
@pytest.mark.parametrize(
"name", ("sum", "mean", "std", "min", "max", "median", "argmin", "argmax")
)
@pytest.mark.parametrize("center", (True, False, None))
@pytest.mark.parametrize("min_periods", (1, None))
@pytest.mark.parametrize("backend", ["numpy"], indirect=True)
Expand All @@ -133,9 +135,15 @@ def test_rolling_wrapped_bottleneck(

func_name = f"move_{name}"
actual = getattr(rolling_obj, name)()
window = 7
expected = getattr(bn, func_name)(
da.values, window=7, axis=1, min_count=min_periods
da.values, window=window, axis=1, min_count=min_periods
)
# index 0 is at the rightmost edge of the window
# need to reverse index here
# see GH #8541
if func_name in ["move_argmin", "move_argmax"]:
expected = window - 1 - expected

# Using assert_allclose because we get tiny (1e-17) differences in numbagg.
np.testing.assert_allclose(actual.values, expected)
Expand Down

0 comments on commit c35d6b6

Please sign in to comment.