Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Removed .isel for DatasetRolling.construct consistent rolling behavior #7578

Merged
merged 11 commits into from
Sep 20, 2023
9 changes: 9 additions & 0 deletions asv_bench/benchmarks/rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,11 @@ def peakmem_1drolling_reduce(self, func, use_bottleneck):
roll = self.ds.var3.rolling(t=100)
getattr(roll, func)()

@parameterized(["stride"], ([None, 5, 50]))
def peakmem_1drolling_construct(self, stride):
self.ds.var2.rolling(t=100).construct("w", stride=stride)
self.ds.var3.rolling(t=100).construct("w", stride=stride)


class DatasetRollingMemory(RollingMemory):
@parameterized(["func", "use_bottleneck"], (["sum", "max", "mean"], [True, False]))
Expand All @@ -128,3 +133,7 @@ def peakmem_1drolling_reduce(self, func, use_bottleneck):
with xr.set_options(use_bottleneck=use_bottleneck):
roll = self.ds.rolling(t=100)
getattr(roll, func)()

@parameterized(["stride"], ([None, 5, 50]))
def peakmem_1drolling_construct(self, stride):
self.ds.rolling(t=100).construct("w", stride=stride)
5 changes: 4 additions & 1 deletion doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,12 @@ Bug fixes
of :py:meth:`DataArray.__setitem__` lose dimension names.
(:issue:`7030`, :pull:`8067`) By `Darsh Ranjan <https://github.com/dranjan>`_.
- Return ``float64`` in presence of ``NaT`` in :py:class:`~core.accessor_dt.DatetimeAccessor` and
special case ``NaT`` handling in :py:meth:`~core.accessor_dt.DatetimeAccessor.isocalendar()`
special case ``NaT`` handling in :py:meth:`~core.accessor_dt.DatetimeAccessor.isocalendar`
(:issue:`7928`, :pull:`8084`).
By `Kai Mühlbauer <https://github.com/kmuehlbauer>`_.
- Fix :py:meth:`~core.rolling.DatasetRolling.construct` with stride on Datasets without indexes.
(:issue:`7021`, :pull:`7578`).
By `Amrest Chinkamol <https://github.com/p4perf4ce>`_ and `Michael Niklas <https://github.com/headtr1ck>`_.
- Calling plot with kwargs ``col``, ``row`` or ``hue`` no longer squeezes dimensions passed via these arguments
(:issue:`7552`, :pull:`8174`).
By `Wiktor Kraśnicki <https://github.com/wkrasnicki>`_.
Expand Down
9 changes: 6 additions & 3 deletions xarray/core/rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,11 +785,14 @@ def construct(
if not keep_attrs:
dataset[key].attrs = {}

# Need to stride coords as well. TODO: is there a better way?
coords = self.obj.isel(
mathause marked this conversation as resolved.
Show resolved Hide resolved
{d: slice(None, None, s) for d, s in zip(self.dim, strides)}
).coords

attrs = self.obj.attrs if keep_attrs else {}

return Dataset(dataset, coords=self.obj.coords, attrs=attrs).isel(
{d: slice(None, None, s) for d, s in zip(self.dim, strides)}
)
return Dataset(dataset, coords=coords, attrs=attrs)


class Coarsen(CoarsenArithmetic, Generic[T_Xarray]):
Expand Down
55 changes: 47 additions & 8 deletions xarray/tests/test_rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def test_rolling_pandas_compat(self, center, window, min_periods) -> None:

@pytest.mark.parametrize("center", (True, False))
@pytest.mark.parametrize("window", (1, 2, 3, 4))
def test_rolling_construct(self, center, window) -> None:
def test_rolling_construct(self, center: bool, window: int) -> None:
s = pd.Series(np.arange(10))
da = DataArray.from_series(s)

Expand Down Expand Up @@ -610,7 +610,7 @@ def test_rolling_pandas_compat(self, center, window, min_periods) -> None:

@pytest.mark.parametrize("center", (True, False))
@pytest.mark.parametrize("window", (1, 2, 3, 4))
def test_rolling_construct(self, center, window) -> None:
def test_rolling_construct(self, center: bool, window: int) -> None:
df = pd.DataFrame(
{
"x": np.random.randn(20),
Expand All @@ -627,19 +627,58 @@ def test_rolling_construct(self, center, window) -> None:
np.testing.assert_allclose(df_rolling["x"].values, ds_rolling_mean["x"].values)
np.testing.assert_allclose(df_rolling.index, ds_rolling_mean["index"])

# with stride
ds_rolling_mean = ds_rolling.construct("window", stride=2).mean("window")
np.testing.assert_allclose(
df_rolling["x"][::2].values, ds_rolling_mean["x"].values
)
np.testing.assert_allclose(df_rolling.index[::2], ds_rolling_mean["index"])
# with fill_value
ds_rolling_mean = ds_rolling.construct("window", stride=2, fill_value=0.0).mean(
"window"
)
assert (ds_rolling_mean.isnull().sum() == 0).to_array(dim="vars").all()
assert (ds_rolling_mean["x"] == 0.0).sum() >= 0

@pytest.mark.parametrize("center", (True, False))
@pytest.mark.parametrize("window", (1, 2, 3, 4))
def test_rolling_construct_stride(self, center: bool, window: int) -> None:
df = pd.DataFrame(
{
"x": np.random.randn(20),
"y": np.random.randn(20),
"time": np.linspace(0, 1, 20),
}
)
ds = Dataset.from_dataframe(df)
df_rolling_mean = df.rolling(window, center=center, min_periods=1).mean()

# With an index (dimension coordinate)
ds_rolling = ds.rolling(index=window, center=center)
ds_rolling_mean = ds_rolling.construct("w", stride=2).mean("w")
np.testing.assert_allclose(
df_rolling_mean["x"][::2].values, ds_rolling_mean["x"].values
)
np.testing.assert_allclose(df_rolling_mean.index[::2], ds_rolling_mean["index"])

# Without index (https://github.com/pydata/xarray/issues/7021)
ds2 = ds.drop_vars("index")
ds2_rolling = ds2.rolling(index=window, center=center)
ds2_rolling_mean = ds2_rolling.construct("w", stride=2).mean("w")
np.testing.assert_allclose(
df_rolling_mean["x"][::2].values, ds2_rolling_mean["x"].values
)

# Mixed coordinates, indexes and 2D coordinates
ds3 = xr.Dataset(
{"x": ("t", range(20)), "x2": ("y", range(5))},
{
"t": range(20),
"y": ("y", range(5)),
"t2": ("t", range(20)),
"y2": ("y", range(5)),
"yt": (["t", "y"], np.ones((20, 5))),
},
)
ds3_rolling = ds3.rolling(t=window, center=center)
ds3_rolling_mean = ds3_rolling.construct("w", stride=2).mean("w")
for coord in ds3.coords:
assert coord in ds3_rolling_mean.coords

@pytest.mark.slow
@pytest.mark.parametrize("ds", (1, 2), indirect=True)
@pytest.mark.parametrize("center", (True, False))
Expand Down