Skip to content
forked from pydata/xarray

Commit

Permalink
Set squeeze=None for Dataset too
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Dec 2, 2023
1 parent c2e576e commit 4e9a063
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 25 deletions.
2 changes: 1 addition & 1 deletion xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -6709,7 +6709,7 @@ def groupby_bins(
labels: ArrayLike | Literal[False] | None = None,
precision: int = 3,
include_lowest: bool = False,
squeeze: bool = True,
squeeze: bool | None = None,
restore_coord_dims: bool = False,
) -> DataArrayGroupBy:
"""Returns a DataArrayGroupBy object for performing grouped operations.
Expand Down
4 changes: 2 additions & 2 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10052,7 +10052,7 @@ def interp_calendar(
def groupby(
self,
group: Hashable | DataArray | IndexVariable,
squeeze: bool = True,
squeeze: bool | None = None,
restore_coord_dims: bool = False,
) -> DatasetGroupBy:
"""Returns a DatasetGroupBy object for performing grouped operations.
Expand Down Expand Up @@ -10120,7 +10120,7 @@ def groupby_bins(
labels: ArrayLike | None = None,
precision: int = 3,
include_lowest: bool = False,
squeeze: bool = True,
squeeze: bool | None = None,
restore_coord_dims: bool = False,
) -> DatasetGroupBy:
"""Returns a DatasetGroupBy object for performing grouped operations.
Expand Down
8 changes: 4 additions & 4 deletions xarray/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def _maybe_squeeze_indices(
indices, squeeze: bool | None, grouper: ResolvedGrouper, warn: bool
):
if squeeze in [None, True] and grouper.can_squeeze:
if squeeze is None and warn:
if (squeeze is None and warn) or squeeze is True:
emit_user_level_warning(
"The `squeeze` kwarg to GroupBy is being removed."
"Pass .groupby(..., squeeze=False) to disable squeezing,"
Expand Down Expand Up @@ -727,7 +727,7 @@ def __init__(
self,
obj: T_Xarray,
groupers: tuple[ResolvedGrouper],
squeeze: bool = False,
squeeze: bool | None = False,
restore_coord_dims: bool = True,
) -> None:
"""Create a GroupBy object
Expand Down Expand Up @@ -859,7 +859,7 @@ def _iter_grouped(self) -> Iterator[T_Xarray]:
(grouper,) = self.groupers
for idx, indices in enumerate(self._group_indices):
indices = _maybe_squeeze_indices(
indices, self._squeeze, grouper, warn=idx > 0
indices, self._squeeze, grouper, warn=idx == 0
)
yield self._obj.isel({self._group_dim: indices})

Expand Down Expand Up @@ -1363,7 +1363,7 @@ def _iter_grouped_shortcut(self):
(grouper,) = self.groupers
for idx, indices in enumerate(self._group_indices):
indices = _maybe_squeeze_indices(
indices, self._squeeze, grouper, warn=idx > 0
indices, self._squeeze, grouper, warn=idx == 0
)
yield var[{self._group_dim: indices}]

Expand Down
41 changes: 27 additions & 14 deletions xarray/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,27 +59,34 @@ def test_consolidate_slices() -> None:
_consolidate_slices([slice(3), 4]) # type: ignore[list-item]


def test_groupby_dims_property(dataset) -> None:
assert dataset.groupby("x").dims == dataset.isel(x=1).dims
assert dataset.groupby("y").dims == dataset.isel(y=1).dims
def test_groupby_dims_property(dataset, recwarn) -> None:
# dims is sensitive to squeeze, always warn
with pytest.warns(UserWarning, match="The `squeeze` kwarg"):
assert dataset.groupby("x").dims == dataset.isel(x=1).dims
assert dataset.groupby("y").dims == dataset.isel(y=1).dims

# when squeeze=False, no warning should be raised
assert dataset.groupby("x", squeeze=False).dims == dataset.isel(x=slice(1, 2)).dims
assert dataset.groupby("y", squeeze=False).dims == dataset.isel(y=slice(1, 2)).dims
assert len(recwarn) == 0

stacked = dataset.stack({"xy": ("x", "y")})
assert stacked.groupby("xy", squeeze=False).dims == stacked.isel(xy=[0]).dims
assert len(recwarn) == 0


def test_multi_index_groupby_map(dataset) -> None:
# regression test for GH873
ds = dataset.isel(z=1, drop=True)[["foo"]]
expected = 2 * ds
actual = (
ds.stack(space=["x", "y"])
.groupby("space")
.map(lambda x: 2 * x)
.unstack("space")
)
# The function in `map` may be sensitive to squeeze, always warn
with pytest.warns(UserWarning, match="The `squeeze` kwarg"):
actual = (
ds.stack(space=["x", "y"])
.groupby("space")
.map(lambda x: 2 * x)
.unstack("space")
)
assert_equal(expected, actual)


Expand Down Expand Up @@ -202,7 +209,9 @@ def func(arg1, arg2, arg3=0):

dataset = xr.Dataset({"foo": ("x", [1, 1, 1])}, {"x": [1, 2, 3]})
expected = xr.Dataset({"foo": ("x", [3, 3, 3])}, {"x": [1, 2, 3]})
actual = dataset.groupby("x").map(func, args=(1,), arg3=1)
# The function in `map` may be sensitive to squeeze, always warn
with pytest.warns(UserWarning, match="The `squeeze` kwarg"):
actual = dataset.groupby("x").map(func, args=(1,), arg3=1)
assert_identical(expected, actual)


Expand Down Expand Up @@ -887,7 +896,7 @@ def test_groupby_bins_cut_kwargs(use_flox: bool) -> None:

with xr.set_options(use_flox=use_flox):
actual = da.groupby_bins(
"x", bins=x_bins, include_lowest=True, right=False
"x", bins=x_bins, include_lowest=True, right=False, squeeze=False
).mean()
expected = xr.DataArray(
np.array([[1.0, 2.0], [5.0, 6.0], [9.0, 10.0]]),
Expand Down Expand Up @@ -1135,8 +1144,8 @@ def test_groupby_properties(self):
"by, use_da", [("x", False), ("y", False), ("y", True), ("abc", False)]
)
@pytest.mark.parametrize("shortcut", [True, False])
@pytest.mark.parametrize("squeeze", [True, False])
def test_groupby_map_identity(self, by, use_da, shortcut, squeeze) -> None:
@pytest.mark.parametrize("squeeze", [None])
def test_groupby_map_identity(self, by, use_da, shortcut, squeeze, recwarn) -> None:
expected = self.da
if use_da:
by = expected.coords[by]
Expand All @@ -1148,6 +1157,10 @@ def identity(x):
actual = grouped.map(identity, shortcut=shortcut)
assert_identical(expected, actual)

# abc is not a dim coordinate so no warnings expected!
if (by.name if use_da else by) != "abc":
assert len(recwarn) == (1 if squeeze in [None, True] else 0)

def test_groupby_sum(self):
array = self.da
grouped = array.groupby("abc")
Expand Down Expand Up @@ -1508,7 +1521,7 @@ def test_groupby_bins_ellipsis(self):
da = xr.DataArray(np.ones((2, 3, 4)))
bins = [-1, 0, 1, 2]
with xr.set_options(use_flox=False):
actual = da.groupby_bins("dim_0", bins).mean(...)
actual = da.groupby_bins("dim_0", bins, squeeze=False).mean(...)
with xr.set_options(use_flox=True):
expected = da.groupby_bins("dim_0", bins).mean(...)
assert_allclose(actual, expected)
Expand Down
11 changes: 7 additions & 4 deletions xarray/tests/test_units.py
Original file line number Diff line number Diff line change
Expand Up @@ -3933,9 +3933,12 @@ def test_grouped_operations(self, func, variant, dtype):
for key, value in func.kwargs.items()
}
expected = attach_units(
func(strip_units(data_array).groupby("y"), **stripped_kwargs), units
func(
strip_units(data_array).groupby("y", squeeze=False), **stripped_kwargs
),
units,
)
actual = func(data_array.groupby("y"))
actual = func(data_array.groupby("y", squeeze=False))

assert_units_equal(expected, actual)
assert_identical(expected, actual)
Expand Down Expand Up @@ -5440,9 +5443,9 @@ def test_grouped_operations(self, func, variant, dtype):
name: strip_units(value) for name, value in func.kwargs.items()
}
expected = attach_units(
func(strip_units(ds).groupby("y"), **stripped_kwargs), units
func(strip_units(ds).groupby("y", squeeze=False), **stripped_kwargs), units
)
actual = func(ds.groupby("y"))
actual = func(ds.groupby("y", squeeze=False))

assert_units_equal(expected, actual)
assert_equal(expected, actual)
Expand Down

0 comments on commit 4e9a063

Please sign in to comment.