Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GroupBy(multiple strings) #9414

Merged
merged 16 commits into from
Sep 4, 2024
28 changes: 16 additions & 12 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@
Dims,
ErrorOptions,
ErrorOptionsWithWarn,
GroupInput,
InterpOptions,
PadModeOptions,
PadReflectOptions,
Expand Down Expand Up @@ -6706,10 +6707,7 @@ def interp_calendar(
@_deprecate_positional_args("v2024.07.0")
def groupby(
self,
group: (
Hashable | DataArray | IndexVariable | Mapping[Any, Grouper] | None
) = None,
*,
group: GroupInput = None,
dcherian marked this conversation as resolved.
Show resolved Hide resolved
squeeze: Literal[False] = False,
restore_coord_dims: bool = False,
**groupers: Grouper,
Expand All @@ -6718,7 +6716,7 @@ def groupby(

Parameters
----------
group : Hashable or DataArray or IndexVariable or mapping of Hashable to Grouper
group : str or DataArray or IndexVariable or iterable of Hashable or mapping of Hashable to Grouper
dcherian marked this conversation as resolved.
Show resolved Hide resolved
Array whose unique values should be used to group this array. If a
Hashable, must be the name of a coordinate contained in this dataarray. If a dictionary,
must map an existing variable name to a :py:class:`Grouper` instance.
Expand Down Expand Up @@ -6788,29 +6786,35 @@ def groupby(
Dataset.resample
DataArray.resample
"""
from xarray.core.dataarray import DataArray
from xarray.core.groupby import (
DataArrayGroupBy,
ResolvedGrouper,
_validate_group_and_groupers,
_validate_groupby_squeeze,
)
from xarray.core.variable import Variable
from xarray.groupers import UniqueGrouper

_validate_groupby_squeeze(squeeze)
_validate_group_and_groupers(group, groupers)

if isinstance(group, Mapping):
groupers = either_dict_or_kwargs(group, groupers, "groupby") # type: ignore
group = None

rgroupers: tuple[ResolvedGrouper, ...]
if group is not None:
if groupers:
raise ValueError(
"Providing a combination of `group` and **groupers is not supported."
)
if isinstance(group, DataArray | Variable):
rgroupers = (ResolvedGrouper(UniqueGrouper(), group, self),)
else:
if not groupers:
raise ValueError("Either `group` or `**groupers` must be provided.")
if group is not None:
if TYPE_CHECKING:
assert isinstance(group, str | Iterable)
group_iter: Iterable[Hashable] = (
(group,) if isinstance(group, str) else group
)
groupers = {g: UniqueGrouper() for g in group_iter}

rgroupers = tuple(
ResolvedGrouper(grouper, group, self)
for group, grouper in groupers.items()
Expand Down
28 changes: 16 additions & 12 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@
DsCompatible,
ErrorOptions,
ErrorOptionsWithWarn,
GroupInput,
InterpOptions,
JoinOptions,
PadModeOptions,
Expand Down Expand Up @@ -10331,10 +10332,7 @@ def interp_calendar(
@_deprecate_positional_args("v2024.07.0")
def groupby(
self,
group: (
Hashable | DataArray | IndexVariable | Mapping[Any, Grouper] | None
) = None,
*,
group: GroupInput = None,
dcherian marked this conversation as resolved.
Show resolved Hide resolved
squeeze: Literal[False] = False,
restore_coord_dims: bool = False,
**groupers: Grouper,
Expand All @@ -10343,7 +10341,7 @@ def groupby(

Parameters
----------
group : Hashable or DataArray or IndexVariable or mapping of Hashable to Grouper
group : str or DataArray or IndexVariable or sequence of hashable or mapping of Hashable to Grouper
dcherian marked this conversation as resolved.
Show resolved Hide resolved
Array whose unique values should be used to group this array. If a
Hashable, must be the name of a coordinate contained in this dataarray. If a dictionary,
must map an existing variable name to a :py:class:`Grouper` instance.
Expand Down Expand Up @@ -10384,29 +10382,35 @@ def groupby(
Dataset.resample
DataArray.resample
"""
from xarray.core.dataarray import DataArray
from xarray.core.groupby import (
DatasetGroupBy,
ResolvedGrouper,
_validate_group_and_groupers,
_validate_groupby_squeeze,
)
from xarray.core.variable import Variable
from xarray.groupers import UniqueGrouper

_validate_groupby_squeeze(squeeze)
_validate_group_and_groupers(group, groupers)

if isinstance(group, Mapping):
groupers = either_dict_or_kwargs(group, groupers, "groupby") # type: ignore
group = None

rgroupers: tuple[ResolvedGrouper, ...]
if group is not None:
if groupers:
raise ValueError(
"Providing a combination of `group` and **groupers is not supported."
)
if isinstance(group, DataArray | Variable):
rgroupers = (ResolvedGrouper(UniqueGrouper(), group, self),)
else:
if not groupers:
raise ValueError("Either `group` or `**groupers` must be provided.")
if group is not None:
if TYPE_CHECKING:
assert isinstance(group, str | Iterable)
group_iter: Iterable[Hashable] = (
(group,) if isinstance(group, str) else group
)
groupers = {g: UniqueGrouper() for g in group_iter}

rgroupers = tuple(
ResolvedGrouper(grouper, group, self)
for group, grouper in groupers.items()
Expand Down
19 changes: 17 additions & 2 deletions xarray/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@

from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset
from xarray.core.types import GroupIndex, GroupIndices, GroupKey
from xarray.core.types import GroupIndex, GroupIndices, GroupInput, GroupKey
from xarray.core.utils import Frozen
from xarray.groupers import EncodedGroups, Grouper

Expand Down Expand Up @@ -319,6 +319,21 @@ def __len__(self) -> int:
return len(self.encoded.full_index)


def _validate_group_and_groupers(group: GroupInput, groupers: dict[str, Grouper]):
if group is not None and groupers:
raise ValueError(
"Providing a combination of `group` and **groupers is not supported."
)

if group is None and not groupers:
raise ValueError("Either `group` or `**groupers` must be provided.")

if isinstance(group, np.ndarray | pd.Index):
raise TypeError(
f"`group` must be a DataArray. Received {type(group).__name__!r} instead"
)


def _validate_groupby_squeeze(squeeze: Literal[False]) -> None:
# While we don't generally check the type of every arg, passing
# multiple dimensions as multiple arguments is common enough, and the
Expand All @@ -327,7 +342,7 @@ def _validate_groupby_squeeze(squeeze: Literal[False]) -> None:
# A future version could make squeeze kwarg only, but would face
# backward-compat issues.
if squeeze is not False:
raise TypeError(f"`squeeze` must be False, but {squeeze} was supplied.")
raise TypeError(f"`squeeze` must be False, but {squeeze!r} was supplied.")


def _resolve_group(
Expand Down
13 changes: 11 additions & 2 deletions xarray/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,17 @@
from xarray.core.dataset import Dataset
from xarray.core.indexes import Index, Indexes
from xarray.core.utils import Frozen
from xarray.core.variable import Variable
from xarray.groupers import TimeResampler
from xarray.core.variable import IndexVariable, Variable
from xarray.groupers import Grouper, TimeResampler

GroupInput: TypeAlias = (
str
| DataArray
| IndexVariable
| Sequence[Hashable]
| Mapping[Any, Grouper]
| None
)

try:
from dask.array import Array as DaskArray
Expand Down
25 changes: 25 additions & 0 deletions xarray/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -2635,6 +2635,31 @@ def test_weather_data_resample(use_flox):
assert expected.location.attrs == ds.location.attrs


@pytest.mark.parametrize("as_dataset", [True, False])
def test_multiple_groupers_string(as_dataset) -> None:
obj = DataArray(
np.array([1, 2, 3, 0, 2, np.nan]),
dims="d",
coords=dict(
labels1=("d", np.array(["a", "b", "c", "c", "b", "a"])),
labels2=("d", np.array(["x", "y", "z", "z", "y", "x"])),
),
name="foo",
)

if as_dataset:
obj = obj.to_dataset()

expected = obj.groupby(labels1=UniqueGrouper(), labels2=UniqueGrouper()).mean()
actual = obj.groupby(("labels1", "labels2")).mean()
assert_identical(expected, actual)

with pytest.raises(TypeError):
obj.groupby("labels1", "labels2")
with pytest.raises(ValueError):
obj.groupby("labels1", foo="bar")


@pytest.mark.parametrize("use_flox", [True, False])
def test_multiple_groupers(use_flox) -> None:
da = DataArray(
Expand Down
Loading