From 4e9a06360a0985dd5e7c7fdccaf3605b6e23d49a Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Fri, 1 Dec 2023 21:10:14 -0700 Subject: [PATCH] Set squeeze=None for Dataset too --- xarray/core/dataarray.py | 2 +- xarray/core/dataset.py | 4 ++-- xarray/core/groupby.py | 8 +++---- xarray/tests/test_groupby.py | 41 ++++++++++++++++++++++++------------ xarray/tests/test_units.py | 11 ++++++---- 5 files changed, 41 insertions(+), 25 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index dfbc1317919..660ef70d5ee 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -6709,7 +6709,7 @@ def groupby_bins( labels: ArrayLike | Literal[False] | None = None, precision: int = 3, include_lowest: bool = False, - squeeze: bool = True, + squeeze: bool | None = None, restore_coord_dims: bool = False, ) -> DataArrayGroupBy: """Returns a DataArrayGroupBy object for performing grouped operations. diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index d010bfbade0..ddb0d31aa47 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -10052,7 +10052,7 @@ def interp_calendar( def groupby( self, group: Hashable | DataArray | IndexVariable, - squeeze: bool = True, + squeeze: bool | None = None, restore_coord_dims: bool = False, ) -> DatasetGroupBy: """Returns a DatasetGroupBy object for performing grouped operations. @@ -10120,7 +10120,7 @@ def groupby_bins( labels: ArrayLike | None = None, precision: int = 3, include_lowest: bool = False, - squeeze: bool = True, + squeeze: bool | None = None, restore_coord_dims: bool = False, ) -> DatasetGroupBy: """Returns a DatasetGroupBy object for performing grouped operations. diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index c79033cecd6..fd676b97462 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -77,7 +77,7 @@ def _maybe_squeeze_indices( indices, squeeze: bool | None, grouper: ResolvedGrouper, warn: bool ): if squeeze in [None, True] and grouper.can_squeeze: - if squeeze is None and warn: + if (squeeze is None and warn) or squeeze is True: emit_user_level_warning( "The `squeeze` kwarg to GroupBy is being removed." "Pass .groupby(..., squeeze=False) to disable squeezing," @@ -727,7 +727,7 @@ def __init__( self, obj: T_Xarray, groupers: tuple[ResolvedGrouper], - squeeze: bool = False, + squeeze: bool | None = False, restore_coord_dims: bool = True, ) -> None: """Create a GroupBy object @@ -859,7 +859,7 @@ def _iter_grouped(self) -> Iterator[T_Xarray]: (grouper,) = self.groupers for idx, indices in enumerate(self._group_indices): indices = _maybe_squeeze_indices( - indices, self._squeeze, grouper, warn=idx > 0 + indices, self._squeeze, grouper, warn=idx == 0 ) yield self._obj.isel({self._group_dim: indices}) @@ -1363,7 +1363,7 @@ def _iter_grouped_shortcut(self): (grouper,) = self.groupers for idx, indices in enumerate(self._group_indices): indices = _maybe_squeeze_indices( - indices, self._squeeze, grouper, warn=idx > 0 + indices, self._squeeze, grouper, warn=idx == 0 ) yield var[{self._group_dim: indices}] diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index 07e2233c4fd..d1ae18674b1 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -59,27 +59,34 @@ def test_consolidate_slices() -> None: _consolidate_slices([slice(3), 4]) # type: ignore[list-item] -def test_groupby_dims_property(dataset) -> None: - assert dataset.groupby("x").dims == dataset.isel(x=1).dims - assert dataset.groupby("y").dims == dataset.isel(y=1).dims +def test_groupby_dims_property(dataset, recwarn) -> None: + # dims is sensitive to squeeze, always warn + with pytest.warns(UserWarning, match="The `squeeze` kwarg"): + assert dataset.groupby("x").dims == dataset.isel(x=1).dims + assert dataset.groupby("y").dims == dataset.isel(y=1).dims + # when squeeze=False, no warning should be raised assert dataset.groupby("x", squeeze=False).dims == dataset.isel(x=slice(1, 2)).dims assert dataset.groupby("y", squeeze=False).dims == dataset.isel(y=slice(1, 2)).dims + assert len(recwarn) == 0 stacked = dataset.stack({"xy": ("x", "y")}) assert stacked.groupby("xy", squeeze=False).dims == stacked.isel(xy=[0]).dims + assert len(recwarn) == 0 def test_multi_index_groupby_map(dataset) -> None: # regression test for GH873 ds = dataset.isel(z=1, drop=True)[["foo"]] expected = 2 * ds - actual = ( - ds.stack(space=["x", "y"]) - .groupby("space") - .map(lambda x: 2 * x) - .unstack("space") - ) + # The function in `map` may be sensitive to squeeze, always warn + with pytest.warns(UserWarning, match="The `squeeze` kwarg"): + actual = ( + ds.stack(space=["x", "y"]) + .groupby("space") + .map(lambda x: 2 * x) + .unstack("space") + ) assert_equal(expected, actual) @@ -202,7 +209,9 @@ def func(arg1, arg2, arg3=0): dataset = xr.Dataset({"foo": ("x", [1, 1, 1])}, {"x": [1, 2, 3]}) expected = xr.Dataset({"foo": ("x", [3, 3, 3])}, {"x": [1, 2, 3]}) - actual = dataset.groupby("x").map(func, args=(1,), arg3=1) + # The function in `map` may be sensitive to squeeze, always warn + with pytest.warns(UserWarning, match="The `squeeze` kwarg"): + actual = dataset.groupby("x").map(func, args=(1,), arg3=1) assert_identical(expected, actual) @@ -887,7 +896,7 @@ def test_groupby_bins_cut_kwargs(use_flox: bool) -> None: with xr.set_options(use_flox=use_flox): actual = da.groupby_bins( - "x", bins=x_bins, include_lowest=True, right=False + "x", bins=x_bins, include_lowest=True, right=False, squeeze=False ).mean() expected = xr.DataArray( np.array([[1.0, 2.0], [5.0, 6.0], [9.0, 10.0]]), @@ -1135,8 +1144,8 @@ def test_groupby_properties(self): "by, use_da", [("x", False), ("y", False), ("y", True), ("abc", False)] ) @pytest.mark.parametrize("shortcut", [True, False]) - @pytest.mark.parametrize("squeeze", [True, False]) - def test_groupby_map_identity(self, by, use_da, shortcut, squeeze) -> None: + @pytest.mark.parametrize("squeeze", [None]) + def test_groupby_map_identity(self, by, use_da, shortcut, squeeze, recwarn) -> None: expected = self.da if use_da: by = expected.coords[by] @@ -1148,6 +1157,10 @@ def identity(x): actual = grouped.map(identity, shortcut=shortcut) assert_identical(expected, actual) + # abc is not a dim coordinate so no warnings expected! + if (by.name if use_da else by) != "abc": + assert len(recwarn) == (1 if squeeze in [None, True] else 0) + def test_groupby_sum(self): array = self.da grouped = array.groupby("abc") @@ -1508,7 +1521,7 @@ def test_groupby_bins_ellipsis(self): da = xr.DataArray(np.ones((2, 3, 4))) bins = [-1, 0, 1, 2] with xr.set_options(use_flox=False): - actual = da.groupby_bins("dim_0", bins).mean(...) + actual = da.groupby_bins("dim_0", bins, squeeze=False).mean(...) with xr.set_options(use_flox=True): expected = da.groupby_bins("dim_0", bins).mean(...) assert_allclose(actual, expected) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index af86c18668f..21915a9a17c 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -3933,9 +3933,12 @@ def test_grouped_operations(self, func, variant, dtype): for key, value in func.kwargs.items() } expected = attach_units( - func(strip_units(data_array).groupby("y"), **stripped_kwargs), units + func( + strip_units(data_array).groupby("y", squeeze=False), **stripped_kwargs + ), + units, ) - actual = func(data_array.groupby("y")) + actual = func(data_array.groupby("y", squeeze=False)) assert_units_equal(expected, actual) assert_identical(expected, actual) @@ -5440,9 +5443,9 @@ def test_grouped_operations(self, func, variant, dtype): name: strip_units(value) for name, value in func.kwargs.items() } expected = attach_units( - func(strip_units(ds).groupby("y"), **stripped_kwargs), units + func(strip_units(ds).groupby("y", squeeze=False), **stripped_kwargs), units ) - actual = func(ds.groupby("y")) + actual = func(ds.groupby("y", squeeze=False)) assert_units_equal(expected, actual) assert_equal(expected, actual)