Skip to content

Commit

Permalink
Group by multiple strings
Browse files Browse the repository at this point in the history
Closes #9396
  • Loading branch information
dcherian committed Aug 30, 2024
1 parent d33e4ad commit abfb63c
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 28 deletions.
35 changes: 23 additions & 12 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@
Dims,
ErrorOptions,
ErrorOptionsWithWarn,
GroupInput,
InterpOptions,
PadModeOptions,
PadReflectOptions,
Expand Down Expand Up @@ -6706,10 +6707,7 @@ def interp_calendar(
@_deprecate_positional_args("v2024.07.0")
def groupby(
self,
group: (
Hashable | DataArray | IndexVariable | Mapping[Any, Grouper] | None
) = None,
*,
group: GroupInput = None,
squeeze: Literal[False] = False,
restore_coord_dims: bool = False,
**groupers: Grouper,
Expand All @@ -6718,7 +6716,7 @@ def groupby(
Parameters
----------
group : Hashable or DataArray or IndexVariable or mapping of Hashable to Grouper
group : str or DataArray or IndexVariable or iterable of Hashable or mapping of Hashable to Grouper
Array whose unique values should be used to group this array. If a
Hashable, must be the name of a coordinate contained in this dataarray. If a dictionary,
must map an existing variable name to a :py:class:`Grouper` instance.
Expand Down Expand Up @@ -6788,29 +6786,42 @@ def groupby(
Dataset.resample
DataArray.resample
"""
from xarray.core.dataarray import DataArray
from xarray.core.groupby import (
DataArrayGroupBy,
ResolvedGrouper,
_validate_group_and_groupers,
_validate_groupby_squeeze,
)
from xarray.core.variable import Variable
from xarray.groupers import UniqueGrouper

_validate_groupby_squeeze(squeeze)
_validate_group_and_groupers(group, groupers)
if group is not None and groupers:
raise ValueError(
"Providing a combination of `group` and **groupers is not supported."
)

if group is None and not groupers:
raise ValueError("Either `group` or `**groupers` must be provided.")

if isinstance(group, Mapping):
groupers = either_dict_or_kwargs(group, groupers, "groupby") # type: ignore
group = None

rgroupers: tuple[ResolvedGrouper, ...]
if group is not None:
if groupers:
raise ValueError(
"Providing a combination of `group` and **groupers is not supported."
)
if isinstance(group, DataArray | Variable):
rgroupers = (ResolvedGrouper(UniqueGrouper(), group, self),)
else:
if not groupers:
raise ValueError("Either `group` or `**groupers` must be provided.")
if group is not None:
if TYPE_CHECKING:
assert isinstance(group, str | Iterable)
group_iter: Iterable[Hashable] = (
(group,) if isinstance(group, str) else group
)
groupers = {g: UniqueGrouper() for g in group_iter}

rgroupers = tuple(
ResolvedGrouper(grouper, group, self)
for group, grouper in groupers.items()
Expand Down
28 changes: 16 additions & 12 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@
DsCompatible,
ErrorOptions,
ErrorOptionsWithWarn,
GroupInput,
InterpOptions,
JoinOptions,
PadModeOptions,
Expand Down Expand Up @@ -10331,10 +10332,7 @@ def interp_calendar(
@_deprecate_positional_args("v2024.07.0")
def groupby(
self,
group: (
Hashable | DataArray | IndexVariable | Mapping[Any, Grouper] | None
) = None,
*,
group: GroupInput = None,
squeeze: Literal[False] = False,
restore_coord_dims: bool = False,
**groupers: Grouper,
Expand All @@ -10343,7 +10341,7 @@ def groupby(
Parameters
----------
group : Hashable or DataArray or IndexVariable or mapping of Hashable to Grouper
group : str or DataArray or IndexVariable or sequence of hashable or mapping of Hashable to Grouper
Array whose unique values should be used to group this array. If a
Hashable, must be the name of a coordinate contained in this dataarray. If a dictionary,
must map an existing variable name to a :py:class:`Grouper` instance.
Expand Down Expand Up @@ -10384,29 +10382,35 @@ def groupby(
Dataset.resample
DataArray.resample
"""
from xarray.core.dataarray import DataArray
from xarray.core.groupby import (
DatasetGroupBy,
ResolvedGrouper,
_validate_group_and_groupers,
_validate_groupby_squeeze,
)
from xarray.core.variable import Variable
from xarray.groupers import UniqueGrouper

_validate_groupby_squeeze(squeeze)
_validate_group_and_groupers(group, groupers)

if isinstance(group, Mapping):
groupers = either_dict_or_kwargs(group, groupers, "groupby") # type: ignore
group = None

rgroupers: tuple[ResolvedGrouper, ...]
if group is not None:
if groupers:
raise ValueError(
"Providing a combination of `group` and **groupers is not supported."
)
if isinstance(group, DataArray | Variable):
rgroupers = (ResolvedGrouper(UniqueGrouper(), group, self),)
else:
if not groupers:
raise ValueError("Either `group` or `**groupers` must be provided.")
if group is not None:
if TYPE_CHECKING:
assert isinstance(group, str | Iterable)
group_iter: Iterable[Hashable] = (
(group,) if isinstance(group, str) else group
)
groupers = {g: UniqueGrouper() for g in group_iter}

rgroupers = tuple(
ResolvedGrouper(grouper, group, self)
for group, grouper in groupers.items()
Expand Down
19 changes: 17 additions & 2 deletions xarray/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@

from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset
from xarray.core.types import GroupIndex, GroupIndices, GroupKey
from xarray.core.types import GroupIndex, GroupIndices, GroupInput, GroupKey
from xarray.core.utils import Frozen
from xarray.groupers import EncodedGroups, Grouper

Expand Down Expand Up @@ -319,6 +319,21 @@ def __len__(self) -> int:
return len(self.encoded.full_index)


def _validate_group_and_groupers(group: GroupInput, groupers: dict[str, Grouper]):
if group is not None and groupers:
raise ValueError(
"Providing a combination of `group` and **groupers is not supported."
)

if group is None and not groupers:
raise ValueError("Either `group` or `**groupers` must be provided.")

if isinstance(group, np.ndarray | pd.Index):
raise TypeError(
f"`group` must be a DataArray. Received {type(group).__name__!r} instead"
)


def _validate_groupby_squeeze(squeeze: Literal[False]) -> None:
# While we don't generally check the type of every arg, passing
# multiple dimensions as multiple arguments is common enough, and the
Expand All @@ -327,7 +342,7 @@ def _validate_groupby_squeeze(squeeze: Literal[False]) -> None:
# A future version could make squeeze kwarg only, but would face
# backward-compat issues.
if squeeze is not False:
raise TypeError(f"`squeeze` must be False, but {squeeze} was supplied.")
raise TypeError(f"`squeeze` must be False, but {squeeze!r} was supplied.")


def _resolve_group(
Expand Down
13 changes: 11 additions & 2 deletions xarray/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,17 @@
from xarray.core.dataset import Dataset
from xarray.core.indexes import Index, Indexes
from xarray.core.utils import Frozen
from xarray.core.variable import Variable
from xarray.groupers import TimeResampler
from xarray.core.variable import IndexVariable, Variable
from xarray.groupers import Grouper, TimeResampler

GroupInput: TypeAlias = (
str
| DataArray
| IndexVariable
| Sequence[Hashable]
| Mapping[Any, Grouper]
| None
)

try:
from dask.array import Array as DaskArray
Expand Down
25 changes: 25 additions & 0 deletions xarray/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -2635,6 +2635,31 @@ def test_weather_data_resample(use_flox):
assert expected.location.attrs == ds.location.attrs


@pytest.mark.parametrize("as_dataset", [True, False])
def test_multiple_groupers_string(as_dataset) -> None:
obj = DataArray(
np.array([1, 2, 3, 0, 2, np.nan]),
dims="d",
coords=dict(
labels1=("d", np.array(["a", "b", "c", "c", "b", "a"])),
labels2=("d", np.array(["x", "y", "z", "z", "y", "x"])),
),
name="foo",
)

if as_dataset:
obj = obj.to_dataset()

expected = obj.groupby(labels1=UniqueGrouper(), labels2=UniqueGrouper()).mean()
actual = obj.groupby(("labels1", "labels2")).mean()
assert_identical(expected, actual)

with pytest.raises(TypeError):
obj.groupby("labels1", "labels2")
with pytest.raises(ValueError):
obj.groupby("labels1", foo="bar")


@pytest.mark.parametrize("use_flox", [True, False])
def test_multiple_groupers(use_flox) -> None:
da = DataArray(
Expand Down

0 comments on commit abfb63c

Please sign in to comment.