From 71f5e107e0a8e27821d440ff308e69b4bf99b54f Mon Sep 17 00:00:00 2001 From: dcherian Date: Sat, 25 Feb 2023 22:29:41 -0700 Subject: [PATCH 01/36] Introduce Grouper objects. --- xarray/core/common.py | 13 +- xarray/core/computation.py | 7 +- xarray/core/dataarray.py | 32 +- xarray/core/dataset.py | 19 +- xarray/core/groupby.py | 623 +++++++++++++++++++---------------- xarray/core/resample.py | 12 +- xarray/tests/test_groupby.py | 4 +- 7 files changed, 400 insertions(+), 310 deletions(-) diff --git a/xarray/core/common.py b/xarray/core/common.py index af935ae15d2..1c6118e8d4c 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -996,11 +996,16 @@ def _resample( if base is not None and offset is not None: raise ValueError("base and offset cannot be present at the same time") + index = self._indexes[dim_name].to_pandas_index() if base is not None: - index = self._indexes[dim_name].to_pandas_index() offset = _convert_base_to_offset(base, freq, index) + group = DataArray( + dim_coord, coords=dim_coord.coords, dims=dim_coord.dims, name=RESAMPLE_DIM + ) + grouper = TimeResampleGrouper( + group=group, freq=freq, closed=closed, label=label, @@ -1009,14 +1014,10 @@ def _resample( loffset=loffset, ) - group = DataArray( - dim_coord, coords=dim_coord.coords, dims=dim_coord.dims, name=RESAMPLE_DIM - ) return resample_cls( self, - group=group, - dim=dim_name, grouper=grouper, + dim=dim_name, resample_dim=RESAMPLE_DIM, restore_coord_dims=restore_coord_dims, ) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 9af7fcd89a4..356f1029192 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -515,15 +515,16 @@ def apply_groupby_func(func, *args): groupbys = [arg for arg in args if isinstance(arg, GroupBy)] assert groupbys, "must have at least one groupby to iterate over" first_groupby = groupbys[0] - if any(not first_groupby._group.equals(gb._group) for gb in groupbys[1:]): + (grouper,) = first_groupby.groupers + if any(not grouper.group.equals(gb.groupers[0].group) for gb in groupbys[1:]): raise ValueError( "apply_ufunc can only perform operations over " "multiple GroupBy objects at once if they are all " "grouped the same way" ) - grouped_dim = first_groupby._group.name - unique_values = first_groupby._unique_coord.values + grouped_dim = grouper.name + unique_values = grouper.unique_coord.values iterators = [] for arg in args: diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 1f04f506397..ed1ae078710 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -6256,7 +6256,7 @@ def groupby( core.groupby.DataArrayGroupBy pandas.DataFrame.groupby """ - from xarray.core.groupby import DataArrayGroupBy + from xarray.core.groupby import DataArrayGroupBy, UniqueGrouper # While we don't generally check the type of every arg, passing # multiple dimensions as multiple arguments is common enough, and the @@ -6269,8 +6269,9 @@ def groupby( f"`squeeze` must be True or False, but {squeeze} was supplied" ) + grouper = UniqueGrouper(group) return DataArrayGroupBy( - self, group, squeeze=squeeze, restore_coord_dims=restore_coord_dims + self, grouper, squeeze=squeeze, restore_coord_dims=restore_coord_dims ) def groupby_bins( @@ -6341,14 +6342,22 @@ def groupby_bins( ---------- .. [1] http://pandas.pydata.org/pandas-docs/stable/generated/pandas.cut.html """ - from xarray.core.groupby import DataArrayGroupBy + from xarray.core.groupby import BinGrouper, DataArrayGroupBy - return DataArrayGroupBy( - self, - group, - squeeze=squeeze, + # While we don't generally check the type of every arg, passing + # multiple dimensions as multiple arguments is common enough, and the + # consequences hidden enough (strings evaluate as true) to warrant + # checking here. + # A future version could make squeeze kwarg only, but would face + # backward-compat issues. + if not isinstance(squeeze, bool): + raise TypeError( + f"`squeeze` must be True or False, but {squeeze} was supplied" + ) + + grouper = BinGrouper( + group=group, bins=bins, - restore_coord_dims=restore_coord_dims, cut_kwargs={ "right": right, "labels": labels, @@ -6357,6 +6366,13 @@ def groupby_bins( }, ) + return DataArrayGroupBy( + self, + grouper, + squeeze=squeeze, + restore_coord_dims=restore_coord_dims, + ) + def weighted(self, weights: DataArray) -> DataArrayWeighted: """ Weighted DataArray operations. diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 0bd335f3f0a..005801573ff 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -8943,7 +8943,7 @@ def groupby( Dataset.resample DataArray.resample """ - from xarray.core.groupby import DatasetGroupBy + from xarray.core.groupby import DatasetGroupBy, UniqueGrouper # While we don't generally check the type of every arg, passing # multiple dimensions as multiple arguments is common enough, and the @@ -8956,8 +8956,10 @@ def groupby( f"`squeeze` must be True or False, but {squeeze} was supplied" ) + grouper = UniqueGrouper(group) + return DatasetGroupBy( - self, group, squeeze=squeeze, restore_coord_dims=restore_coord_dims + self, grouper, squeeze=squeeze, restore_coord_dims=restore_coord_dims ) def groupby_bins( @@ -9028,14 +9030,11 @@ def groupby_bins( ---------- .. [1] http://pandas.pydata.org/pandas-docs/stable/generated/pandas.cut.html """ - from xarray.core.groupby import DatasetGroupBy + from xarray.core.groupby import BinGrouper, DatasetGroupBy - return DatasetGroupBy( - self, - group, - squeeze=squeeze, + grouper = BinGrouper( + group=group, bins=bins, - restore_coord_dims=restore_coord_dims, cut_kwargs={ "right": right, "labels": labels, @@ -9044,6 +9043,10 @@ def groupby_bins( }, ) + return DatasetGroupBy( + self, grouper, squeeze=squeeze, restore_coord_dims=restore_coord_dims + ) + def weighted(self, weights: DataArray) -> DatasetWeighted: """ Weighted Dataset operations. diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 7de975c9c0a..738b7d20712 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -2,7 +2,7 @@ import datetime import warnings -from collections.abc import Hashable, Iterator, Mapping, Sequence +from collections.abc import Hashable, Iterator, Sequence from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, TypeVar, Union, cast import numpy as np @@ -185,12 +185,13 @@ class _DummyGroup: Should not be user visible. """ - __slots__ = ("name", "coords", "size") + __slots__ = ("name", "coords", "size", "dataarray") def __init__(self, obj: T_Xarray, name: Hashable, coords) -> None: self.name = name self.coords = coords self.size = obj.sizes[name] + self.dataarray = obj[name] @property def dims(self) -> tuple[Hashable]: @@ -208,10 +209,17 @@ def values(self) -> range: def data(self) -> range: return range(self.size) + def __array__(self) -> np.ndarray: + return np.arange(self.size) + @property def shape(self) -> tuple[int]: return (self.size,) + @property + def attrs(self) -> dict: + return {} + def __getitem__(self, key): if isinstance(key, tuple): key = key[0] @@ -250,13 +258,6 @@ def _ensure_1d( ) -def _unique_and_monotonic(group: T_Group) -> bool: - if isinstance(group, _DummyGroup): - return True - index = safe_cast_to_index(group) - return index.is_unique and index.is_monotonic_increasing - - def _apply_loffset( loffset: str | pd.DateOffset | datetime.timedelta | pd.Timedelta, result: pd.Series | pd.DataFrame, @@ -294,91 +295,261 @@ def _apply_loffset( result.index = result.index + loffset -def _get_index_and_items(index, grouper): - first_items, codes = grouper.first_items(index) - full_index = first_items.index - if first_items.isnull().any(): - first_items = first_items.dropna() - return full_index, first_items, codes - - -def _factorize_grouper( - group, grouper -) -> tuple[ - DataArray | IndexVariable | _DummyGroup, - T_GroupIndices, - np.ndarray, - pd.Index, -]: - index = safe_cast_to_index(group) - if not index.is_monotonic_increasing: - # TODO: sort instead of raising an error - raise ValueError("index must be monotonic for resampling") - full_index, first_items, codes = _get_index_and_items(index, grouper) - sbins = first_items.values.astype(np.int64) - group_indices: T_GroupIndices = [ - slice(i, j) for i, j in zip(sbins[:-1], sbins[1:]) - ] + [slice(sbins[-1], None)] - unique_coord = IndexVariable(group.name, first_items.index) - return unique_coord, group_indices, codes, full_index - - -def _factorize_bins( - group, bins, cut_kwargs: Mapping | None -) -> tuple[IndexVariable, T_GroupIndices, np.ndarray, pd.IntervalIndex, DataArray]: - from xarray.core.dataarray import DataArray +class Grouper: + def __init__(self, group): + self.group = group + self.codes = None + self.labels = None + self.group_indices = None + self.unique_coord = None + self.full_index = None + self._group_as_index = None - if cut_kwargs is None: - cut_kwargs = {} - - if duck_array_ops.isnull(bins).all(): - raise ValueError("All bin edges are NaN.") - binned, bins = pd.cut(group.values, bins, **cut_kwargs, retbins=True) - codes = binned.codes - if (codes == -1).all(): - raise ValueError(f"None of the data falls within bins with edges {bins!r}") - full_index = binned.categories - unique_values = binned.unique().dropna() - group_indices = [g for g in _codes_to_groups(codes, len(full_index)) if g] - - if len(group_indices) == 0: - raise ValueError(f"None of the data falls within bins with edges {bins!r}") - - new_dim_name = str(group.name) + "_bins" - group_ = DataArray(binned, getattr(group, "coords", None), name=new_dim_name) - unique_coord = IndexVariable(new_dim_name, unique_values) - return unique_coord, group_indices, codes, full_index, group_ - - -def _factorize_rest( - group, -) -> tuple[IndexVariable, T_GroupIndices, np.ndarray]: - # look through group to find the unique values - group_as_index = safe_cast_to_index(group) - sort = not isinstance(group_as_index, pd.MultiIndex) - unique_values, group_indices, codes = unique_value_groups(group_as_index, sort=sort) - if len(group_indices) == 0: - raise ValueError( - "Failed to group data. Are you grouping by a variable that is all NaN?" + @property + def name(self): + return self.group1d.name + + @property + def size(self): + return len(self) + + def __len__(self): + return len(self.full_index) + + @property + def dims(self): + return self.group1d.dims + + def factorize(self, squeeze): + raise NotImplementedError + + @property + def is_unique_and_monotonic(self) -> bool: + if isinstance(self.group, _DummyGroup): + return True + index = self.group_as_index + return index.is_unique and index.is_monotonic_increasing + + @property + def group_as_index(self) -> pd.Index: + if self._group_as_index is None: + self._group_as_index = safe_cast_to_index(self.group1d) + return self._group_as_index + + def _resolve_group(self, obj): + from xarray.core.dataarray import DataArray + + group = self.group + if not isinstance(group, (DataArray, IndexVariable)): + if not hashable(group): + raise TypeError( + "`group` must be an xarray.DataArray or the " + "name of an xarray variable or dimension. " + f"Received {group!r} instead." + ) + group = obj[group] + if len(group) == 0: + raise ValueError(f"{group.name} must not be empty") + + if group.name not in obj.coords and group.name in obj.dims: + # DummyGroups should not appear on groupby results + group = _DummyGroup(obj, group.name, group.coords) + + if getattr(group, "name", None) is None: + group.name = "group" + + self.group = group + + self.group1d, stacked_obj, self.stacked_dim, self.inserted_dims = _ensure_1d( + group, obj ) - unique_coord = IndexVariable(group.name, unique_values) - return unique_coord, group_indices, codes - - -def _factorize_dummy( - group, squeeze: bool -) -> tuple[IndexVariable, T_GroupIndices, np.ndarray]: - # no need to factorize - group_indices: T_GroupIndices - if not squeeze: - # use slices to do views instead of fancy indexing - # equivalent to: group_indices = group_indices.reshape(-1, 1) - group_indices = [slice(i, i + 1) for i in range(group.size)] - else: - group_indices = np.arange(group.size) - codes = np.arange(group.size) - unique_coord = group - return unique_coord, group_indices, codes + + (group_dim,) = self.group1d.dims + expected_size = stacked_obj.sizes[group_dim] + if group.size != expected_size: + raise ValueError( + "the group variable's length does not " + "match the length of this variable along its " + "dimension" + ) + + return self, stacked_obj + + +class UniqueGrouper(Grouper): + def factorize(self, squeeze) -> None: + is_dimension = self.group.dims == (self.group.name,) + if is_dimension and self.is_unique_and_monotonic: + self._factorize_dummy(squeeze) + else: + self._factorize_unique() + + def _factorize_unique(self) -> None: + # look through group to find the unique values + sort = not isinstance(self.group_as_index, pd.MultiIndex) + unique_values, group_indices, codes = unique_value_groups( + self.group_as_index, sort=sort + ) + if len(group_indices) == 0: + raise ValueError( + "Failed to group data. Are you grouping by a variable that is all NaN?" + ) + self.unique_coord = IndexVariable( + self.group.name, unique_values, attrs=self.group.attrs + ) + self.codes = self.group1d.copy(data=codes) + self.group_indices = group_indices + self.full_index = self.unique_coord + + def _factorize_dummy(self, squeeze) -> None: + size = self.group.size + # no need to factorize + if not squeeze: + # use slices to do views instead of fancy indexing + # equivalent to: group_indices = group_indices.reshape(-1, 1) + self.group_indices = [slice(i, i + 1) for i in range(size)] + else: + self.group_indices = np.arange(size) + codes = np.arange(size) + if isinstance(self.group, _DummyGroup): + self.codes = self.group.dataarray.copy(data=codes) + else: + self.codes = self.group.copy(data=codes) + self.unique_coord = self.group + self.full_index = IndexVariable(self.name, self.group.values, self.group.attrs) + + +class BinGrouper(Grouper): + def __init__(self, group, bins, cut_kwargs: Mapping | None): + if duck_array_ops.isnull(bins).all(): + raise ValueError("All bin edges are NaN.") + + if cut_kwargs is None: + cut_kwargs = {} + + self.group = group + self.bins = bins + self.cut_kwargs = cut_kwargs + + def factorize(self, squeeze) -> None: + from xarray.core.dataarray import DataArray + + data = self.group1d.values + binned, bins = pd.cut(data, self.bins, **self.cut_kwargs, retbins=True) + codes = binned.codes + if (codes == -1).all(): + raise ValueError(f"None of the data falls within bins with edges {bins!r}") + + full_index = binned.categories + + unique_values = binned.unique().dropna() + group_indices = [g for g in _codes_to_groups(codes, len(full_index)) if g] + + if len(group_indices) == 0: + raise ValueError(f"None of the data falls within bins with edges {bins!r}") + + new_dim_name = str(self.group.name) + "_bins" + self.group1d = DataArray( + binned, getattr(self.group1d, "coords", None), name=new_dim_name + ) + self.unique_coord = IndexVariable( + self.group1d.name, unique_values, self.group.attrs + ) + self.codes = self.group1d.copy(data=codes) + # TODO: support IntervalIndex in IndexVariable + self.full_index = full_index + self.group_indices = group_indices + + +class TimeResampleGrouper(Grouper): + def __init__( + self, + group, + freq: str, + closed: SideOptions | None, + label: SideOptions | None, + origin: str | DatetimeLike, + offset: pd.Timedelta | datetime.timedelta | str | None, + loffset: datetime.timedelta | str | None, + ): + from xarray import CFTimeIndex + from xarray.core.resample_cftime import CFTimeGrouper + + self.group = group + self.freq = freq + self.closed = closed + self.label = label + self.origin = origin + self.offset = offset + self.loffset = loffset + self._group_as_index = safe_cast_to_index(group) + group_as_index = self._group_as_index + + if not group_as_index.is_monotonic_increasing: + # TODO: sort instead of raising an error + raise ValueError("index must be monotonic for resampling") + + if isinstance(group_as_index, CFTimeIndex): + self.grouper = CFTimeGrouper( + freq=self.freq, + closed=self.closed, + label=self.label, + origin=self.origin, + offset=self.offset, + loffset=self.loffset, + ) + else: + self.grouper = pd.Grouper( + freq=self.freq, + closed=self.closed, + label=self.label, + origin=self.origin, + offset=self.offset, + ) + + def _get_index_and_items(self): + first_items, codes = self.first_items() + full_index = first_items.index + if first_items.isnull().any(): + first_items = first_items.dropna() + + full_index = full_index.rename("__resample_dim__") + return full_index, first_items, codes + + def first_items(self): + from xarray import CFTimeIndex + + if isinstance(self.group_as_index, CFTimeIndex): + return self.grouper.first_items(self.group_as_index) + else: + s = pd.Series(np.arange(self.group_as_index.size), self.group_as_index) + grouped = s.groupby(self.grouper) + first_items = grouped.first() + counts = grouped.count() + # This way we generate codes for the final output index: full_index. + # So for _flox_reduce we avoid one reindex and copy by avoiding + # _maybe_restore_empty_groups + codes = np.repeat(np.arange(len(first_items)), counts) + if self.loffset is not None: + _apply_loffset(self.loffset, first_items) + return first_items, codes + + def factorize( + self, squeeze + ) -> tuple[ + DataArray | IndexVariable | _DummyGroup, + list[slice] | list[list[int]] | np.ndarray, + np.ndarray, + ]: + self.full_index, first_items, codes = self._get_index_and_items() + sbins = first_items.values.astype(np.int64) + self.group_indices = [slice(i, j) for i, j in zip(sbins[:-1], sbins[1:])] + [ + slice(sbins[-1], None) + ] + self.unique_coord = IndexVariable( + self.group.name, first_items.index, self.group.attrs + ) + self.codes = self.group.copy(data=codes) class GroupBy(Generic[T_Xarray]): @@ -405,6 +576,7 @@ class GroupBy(Generic[T_Xarray]): "_group_dim", "_group_indices", "_groups", + "groupers", "_obj", "_restore_coord_dims", "_stacked_dim", @@ -423,12 +595,10 @@ class GroupBy(Generic[T_Xarray]): def __init__( self, obj: T_Xarray, - group: Hashable | DataArray | IndexVariable, + grouper: Grouper, + *, squeeze: bool = False, - grouper: pd.Grouper | None = None, - bins: ArrayLike | None = None, restore_coord_dims: bool = True, - cut_kwargs: Mapping[Any, Any] | None = None, ) -> None: """Create a GroupBy object @@ -436,96 +606,31 @@ def __init__( ---------- obj : Dataset or DataArray Object to group. - group : Hashable, DataArray or Index - Array with the group values or name of the variable. - squeeze : bool, default: False - If "group" is a coordinate of object, `squeeze` controls whether - the subarrays have a dimension of length 1 along that coordinate or - if the dimension is squeezed out. - grouper : pandas.Grouper, optional - Used for grouping values along the `group` array. - bins : array-like, optional - If `bins` is specified, the groups will be discretized into the - specified bins by `pandas.cut`. + grouper : Grouper + Grouper object restore_coord_dims : bool, default: True If True, also restore the dimension order of multi-dimensional coordinates. - cut_kwargs : dict-like, optional - Extra keyword arguments to pass to `pandas.cut` - """ - from xarray.core.dataarray import DataArray - - if grouper is not None and bins is not None: - raise TypeError("can't specify both `grouper` and `bins`") - - if not isinstance(group, (DataArray, IndexVariable)): - if not hashable(group): - raise TypeError( - "`group` must be an xarray.DataArray or the " - "name of an xarray variable or dimension. " - f"Received {group!r} instead." - ) - group = obj[group] - if len(group) == 0: - raise ValueError(f"{group.name} must not be empty") - - if group.name not in obj.coords and group.name in obj.dims: - # DummyGroups should not appear on groupby results - group = _DummyGroup(obj, group.name, group.coords) - - if getattr(group, "name", None) is None: - group.name = "group" - self._original_obj: T_Xarray = obj - self._original_group = group - self._bins = bins - group, obj, stacked_dim, inserted_dims = _ensure_1d(group, obj) - (group_dim,) = group.dims + grouper, obj = grouper._resolve_group(obj) - expected_size = obj.sizes[group_dim] - if group.size != expected_size: - raise ValueError( - "the group variable's length does not " - "match the length of this variable along its " - "dimension" - ) + self._original_group = grouper.group + self.groupers = (grouper,) - self._codes: DataArray - if grouper is not None: - unique_coord, group_indices, codes, full_index = _factorize_grouper( - group, grouper - ) - self._codes = group.copy(data=codes) - elif bins is not None: - unique_coord, group_indices, codes, full_index, group = _factorize_bins( - group, bins, cut_kwargs - ) - self._codes = group.copy(data=codes) - elif group.dims == (group.name,) and _unique_and_monotonic(group): - unique_coord, group_indices, codes = _factorize_dummy(group, squeeze) - full_index = None - self._codes = obj[group.name].copy(data=codes) - else: - unique_coord, group_indices, codes = _factorize_rest(group) - full_index = None - self._codes = group.copy(data=codes) + grouper.factorize(squeeze) # specification for the groupby operation self._obj: T_Xarray = obj - self._group = group - self._group_dim = group_dim - self._group_indices = group_indices - self._unique_coord = unique_coord - self._stacked_dim = stacked_dim - self._inserted_dims = inserted_dims - self._full_index = full_index self._restore_coord_dims = restore_coord_dims - self._bins = bins self._squeeze = squeeze - self._codes = self._maybe_unstack(self._codes) + # These should generalize to multiple groupers + self._group_indices = grouper.group_indices + self._codes = self._maybe_unstack(grouper.codes) + + (self._group_dim,) = grouper.group1d.dims # cached attributes self._groups: dict[GroupKey, slice | int | list[int]] | None = None self._dims: tuple[Hashable, ...] | Frozen[Hashable, int] | None = None @@ -578,7 +683,8 @@ def groups(self) -> dict[GroupKey, slice | int | list[int]]: """ # provided to mimic pandas.groupby if self._groups is None: - self._groups = dict(zip(self._unique_coord.values, self._group_indices)) + (grouper,) = self.groupers + self._groups = dict(zip(grouper.unique_coord.values, self._group_indices)) return self._groups def __getitem__(self, key: GroupKey) -> T_Xarray: @@ -588,17 +694,20 @@ def __getitem__(self, key: GroupKey) -> T_Xarray: return self._obj.isel({self._group_dim: self.groups[key]}) def __len__(self) -> int: - return self._unique_coord.size + (grouper,) = self.groupers + return grouper.size def __iter__(self) -> Iterator[tuple[GroupKey, T_Xarray]]: - return zip(self._unique_coord.values, self._iter_grouped()) + (grouper,) = self.groupers + return zip(grouper.unique_coord.data, self._iter_grouped()) def __repr__(self) -> str: + (grouper,) = self.groupers return "{}, grouped over {!r}\n{!r} groups with labels {}.".format( self.__class__.__name__, - self._unique_coord.name, - self._unique_coord.size, - ", ".join(format_array_flat(self._unique_coord, 30).split()), + grouper.name, + grouper.full_index.size, + ", ".join(format_array_flat(grouper.full_index, 30).split()), ) def _iter_grouped(self) -> Iterator[T_Xarray]: @@ -607,11 +716,12 @@ def _iter_grouped(self) -> Iterator[T_Xarray]: yield self._obj.isel({self._group_dim: indices}) def _infer_concat_args(self, applied_example): + (grouper,) = self.groupers if self._group_dim in applied_example.dims: - coord = self._group + coord = grouper.group1d positions = self._group_indices else: - coord = self._unique_coord + coord = grouper.unique_coord positions = None (dim,) = coord.dims if isinstance(coord, _DummyGroup): @@ -625,19 +735,19 @@ def _binary_op(self, other, f, reflexive=False): g = f if not reflexive else lambda x, y: f(y, x) + (grouper,) = self.groupers obj = self._original_obj - group = self._original_group + group = grouper.group codes = self._codes dims = group.dims if isinstance(group, _DummyGroup): - group = obj[group.name] - coord = group + group = coord = group.dataarray else: - coord = self._unique_coord + coord = grouper.unique_coord if not isinstance(coord, DataArray): - coord = DataArray(self._unique_coord) - name = self._group.name + coord = DataArray(grouper.unique_coord) + name = grouper.name if not isinstance(other, (Dataset, DataArray)): raise TypeError( @@ -668,7 +778,7 @@ def _binary_op(self, other, f, reflexive=False): obj = obj.where(~mask, drop=True) codes = codes.where(~mask, drop=True).astype(int) - other, _ = align(other, coord, join="outer") + other, _ = align(other, coord, join="outer", copy=False) expanded = other.isel({name: codes}) result = g(obj, expanded) @@ -688,20 +798,27 @@ def _binary_op(self, other, f, reflexive=False): return result def _maybe_restore_empty_groups(self, combined): - """Our index contained empty groups (e.g., from a resampling). If we + """Our index contained empty groups (e.g., from a resampling or binning). If we reduced on that dimension, we want to restore the full index. """ - if self._full_index is not None and self._group.name in combined.dims: - indexers = {self._group.name: self._full_index} + (grouper,) = self.groupers + if ( + isinstance(grouper, (BinGrouper, TimeResampleGrouper)) + and grouper.name in combined.dims + ): + indexers = {grouper.name: grouper.full_index} combined = combined.reindex(**indexers) return combined def _maybe_unstack(self, obj): """This gets called if we are applying on an array with a multidimensional group.""" - if self._stacked_dim is not None and self._stacked_dim in obj.dims: - obj = obj.unstack(self._stacked_dim) - for dim in self._inserted_dims: + (grouper,) = self.groupers + stacked_dim = grouper.stacked_dim + inserted_dims = grouper.inserted_dims + if stacked_dim is not None and stacked_dim in obj.dims: + obj = obj.unstack(stacked_dim) + for dim in inserted_dims: if dim in obj.coords: del obj.coords[dim] obj._indexes = filter_indexes_from_coords(obj._indexes, set(obj.coords)) @@ -719,7 +836,8 @@ def _flox_reduce( from xarray.core.dataset import Dataset obj = self._original_obj - group = self._original_group + (grouper,) = self.groupers + isbin = isinstance(grouper, BinGrouper) if keep_attrs is None: keep_attrs = _get_keep_attrs(default=True) @@ -742,22 +860,20 @@ def _flox_reduce( # weird backcompat # reducing along a unique indexed dimension with squeeze=True # should raise an error - if ( - dim is None or dim == self._group.name - ) and self._group.name in obj.xindexes: - index = obj.indexes[self._group.name] + if (dim is None or dim == grouper.name) and grouper.name in obj.xindexes: + index = obj.indexes[grouper.name] if index.is_unique and self._squeeze: - raise ValueError(f"cannot reduce over dimensions {self._group.name!r}") + raise ValueError(f"cannot reduce over dimensions {grouper.name!r}") unindexed_dims: tuple[Hashable, ...] = tuple() - if isinstance(group, _DummyGroup) and self._bins is None: - unindexed_dims = (group.name,) + if isinstance(grouper.group, _DummyGroup) and not isbin: + unindexed_dims = (grouper.name,) parsed_dim: tuple[Hashable, ...] if isinstance(dim, str): parsed_dim = (dim,) elif dim is None: - parsed_dim = group.dims + parsed_dim = grouper.group.dims elif dim is ...: parsed_dim = tuple(obj.dims) else: @@ -765,12 +881,12 @@ def _flox_reduce( # Do this so we raise the same error message whether flox is present or not. # Better to control it here than in flox. - if any(d not in group.dims and d not in obj.dims for d in parsed_dim): + if any(d not in grouper.group.dims and d not in obj.dims for d in parsed_dim): raise ValueError(f"cannot reduce over dimensions {dim}.") if kwargs["func"] not in ["all", "any", "count"]: kwargs.setdefault("fill_value", np.nan) - if self._bins is not None and kwargs["func"] == "count": + if isbin and kwargs["func"] == "count": # This is an annoying hack. Xarray returns np.nan # when there are no observations in a bin, instead of 0. # We can fake that here by forcing min_count=1. @@ -779,7 +895,7 @@ def _flox_reduce( kwargs.setdefault("fill_value", np.nan) kwargs.setdefault("min_count", 1) - output_index = self._get_output_index() + output_index = grouper.full_index result = xarray_reduce( obj.drop_vars(non_numeric.keys()), self._codes, @@ -793,35 +909,29 @@ def _flox_reduce( # we did end up reducing over dimension(s) that are # in the grouped variable - if set(self._codes.dims).issubset(set(parsed_dim)): - result[self._unique_coord.name] = output_index + group_dims = grouper.group.dims + if set(group_dims).issubset(set(parsed_dim)): + result[grouper.name] = output_index result = result.drop_vars(unindexed_dims) # broadcast and restore non-numeric data variables (backcompat) for name, var in non_numeric.items(): if all(d not in var.dims for d in parsed_dim): result[name] = var.variable.set_dims( - (group.name,) + var.dims, (result.sizes[group.name],) + var.shape + (grouper.name,) + var.dims, + (result.sizes[grouper.name],) + var.shape, ) - if self._bins is not None: + if isbin: # Fix dimension order when binning a dimension coordinate # Needed as long as we do a separate code path for pint; # For some reason Datasets and DataArrays behave differently! - if isinstance(self._obj, Dataset) and self._group_dim in self._obj.dims: - result = result.transpose(self._group.name, ...) + (group_dim,) = grouper.dims + if isinstance(self._obj, Dataset) and group_dim in self._obj.dims: + result = result.transpose(grouper.name, ...) return result - def _get_output_index(self) -> pd.Index: - """Return pandas.Index object for the output array.""" - if self._full_index is not None: - # binning and resample - return self._full_index.rename(self._unique_coord.name) - if isinstance(self._unique_coord, _DummyGroup): - return IndexVariable(self._group.name, self._unique_coord.values) - return self._unique_coord - def fillna(self, value: Any) -> T_Xarray: """Fill missing values in this object by group. @@ -975,7 +1085,8 @@ def quantile( The American Statistician, 50(4), pp. 361-365, 1996 """ if dim is None: - dim = (self._group_dim,) + (grouper,) = self.groupers + dim = grouper.group1d.dims return self.map( self._obj.__class__.quantile, @@ -1078,13 +1189,18 @@ def _concat_shortcut(self, applied, dim, positions=None): # TODO: benbovy - explicit indexes: this fast implementation doesn't # create an explicit index for the stacked dim coordinate stacked = Variable.concat(applied, dim, shortcut=True) - reordered = _maybe_reorder(stacked, dim, positions, N=self._group.size) + + (grouper,) = self.groupers + reordered = _maybe_reorder(stacked, dim, positions, N=grouper.group.size) return self._obj._replace_maybe_drop_dims(reordered) def _restore_dim_order(self, stacked: DataArray) -> DataArray: + (grouper,) = self.groupers + group = grouper.group1d + def lookup_order(dimension): - if dimension == self._group.name: - (dimension,) = self._group.dims + if dimension == group.name: + (dimension,) = group.dims if dimension in self._obj.dims: axis = self._obj.get_axis_num(dimension) else: @@ -1169,7 +1285,8 @@ def _combine(self, applied, shortcut=False): combined = self._concat_shortcut(applied, dim, positions) else: combined = concat(applied, dim) - combined = _maybe_reorder(combined, dim, positions, N=self._group.size) + (grouper,) = self.groupers + combined = _maybe_reorder(combined, dim, positions, N=grouper.group.size) if isinstance(combined, type(self._obj)): # only restore dimension order for arrays @@ -1325,7 +1442,8 @@ def _combine(self, applied): applied_example, applied = peek_at(applied) coord, dim, positions = self._infer_concat_args(applied_example) combined = concat(applied, dim) - combined = _maybe_reorder(combined, dim, positions, N=self._group.size) + (grouper,) = self.groupers + combined = _maybe_reorder(combined, dim, positions, N=grouper.group.size) # assign coord when the applied function does not return that coord if coord is not None and dim not in applied_example.dims: index, index_vars = create_default_index_implicit(coord) @@ -1412,56 +1530,3 @@ class DatasetGroupBy( # type: ignore[misc] ImplementsDatasetReduce, ): __slots__ = () - - -class TimeResampleGrouper: - def __init__( - self, - freq: str, - closed: SideOptions | None, - label: SideOptions | None, - origin: str | DatetimeLike, - offset: pd.Timedelta | datetime.timedelta | str | None, - loffset: datetime.timedelta | str | None, - ): - self.freq = freq - self.closed = closed - self.label = label - self.origin = origin - self.offset = offset - self.loffset = loffset - - def first_items(self, index): - from xarray import CFTimeIndex - from xarray.core.resample_cftime import CFTimeGrouper - - if isinstance(index, CFTimeIndex): - grouper = CFTimeGrouper( - freq=self.freq, - closed=self.closed, - label=self.label, - origin=self.origin, - offset=self.offset, - loffset=self.loffset, - ) - return grouper.first_items(index) - else: - s = pd.Series(np.arange(index.size), index, copy=False) - grouper = pd.Grouper( - freq=self.freq, - closed=self.closed, - label=self.label, - origin=self.origin, - offset=self.offset, - ) - - grouped = s.groupby(grouper) - first_items = grouped.first() - counts = grouped.count() - # This way we generate codes for the final output index: full_index. - # So for _flox_reduce we avoid one reindex and copy by avoiding - # _maybe_restore_empty_groups - codes = np.repeat(np.arange(len(first_items)), counts) - if self.loffset is not None: - _apply_loffset(self.loffset, first_items) - return first_items, codes diff --git a/xarray/core/resample.py b/xarray/core/resample.py index ad9b8379322..d78676b188e 100644 --- a/xarray/core/resample.py +++ b/xarray/core/resample.py @@ -84,8 +84,9 @@ def pad(self, tolerance: float | Iterable[float] | None = None) -> T_Xarray: padded : DataArray or Dataset """ obj = self._drop_coords() + (grouper,) = self.groupers return obj.reindex( - {self._dim: self._full_index}, method="pad", tolerance=tolerance + {self._dim: grouper.full_index}, method="pad", tolerance=tolerance ) ffill = pad @@ -108,8 +109,9 @@ def backfill(self, tolerance: float | Iterable[float] | None = None) -> T_Xarray backfilled : DataArray or Dataset """ obj = self._drop_coords() + (grouper,) = self.groupers return obj.reindex( - {self._dim: self._full_index}, method="backfill", tolerance=tolerance + {self._dim: grouper.full_index}, method="backfill", tolerance=tolerance ) bfill = backfill @@ -133,8 +135,9 @@ def nearest(self, tolerance: float | Iterable[float] | None = None) -> T_Xarray: upsampled : DataArray or Dataset """ obj = self._drop_coords() + (grouper,) = self.groupers return obj.reindex( - {self._dim: self._full_index}, method="nearest", tolerance=tolerance + {self._dim: grouper.full_index}, method="nearest", tolerance=tolerance ) def interpolate(self, kind: InterpOptions = "linear") -> T_Xarray: @@ -170,8 +173,9 @@ def interpolate(self, kind: InterpOptions = "linear") -> T_Xarray: def _interpolate(self, kind="linear") -> T_Xarray: """Apply scipy.interpolate.interp1d along resampling dimension.""" obj = self._drop_coords() + (grouper,) = self.groupers return obj.interp( - coords={self._dim: self._full_index}, + coords={self._dim: grouper.full_index}, assume_sorted=True, method=kind, kwargs={"bounds_error": False}, diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index ccbead9dbc4..73a5b6494a3 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -537,7 +537,7 @@ def test_groupby_drops_nans() -> None: .reset_index("id", drop=True) .assign(id=stacked.id.values) .dropna("id") - .transpose(*actual2.dims) + .transpose(*actual2.variable.dims) ) assert_identical(actual2, expected2) @@ -1684,7 +1684,7 @@ def test_upsample(self): # Nearest rs = array.resample(time="3H") actual = rs.nearest() - new_times = rs._full_index + new_times = rs.groupers[0].full_index expected = DataArray(array.reindex(time=new_times, method="nearest")) assert_identical(expected, actual) From b9500ceff63b839702619237894f6be6ba435282 Mon Sep 17 00:00:00 2001 From: dcherian Date: Thu, 9 Mar 2023 21:16:34 -0700 Subject: [PATCH 02/36] Remove a copy after stacking for a groupby. Upstream bug https://github.com/pydata/pandas/issues/12813 is fixed --- xarray/core/groupby.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 738b7d20712..37bfb65e7ce 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -247,9 +247,7 @@ def _ensure_1d( stacked_dim = "stacked_" + "_".join(map(str, orig_dims)) # these dimensions get created by the stack operation inserted_dims = [dim for dim in group.dims if dim not in group.coords] - # the copy is necessary here, otherwise read only array raises error - # in pandas: https://github.com/pydata/pandas/issues/12813 - newgroup = group.stack({stacked_dim: orig_dims}).copy() + newgroup = group.stack({stacked_dim: orig_dims}) newobj = obj.stack({stacked_dim: orig_dims}) return cast(T_Group, newgroup), newobj, stacked_dim, inserted_dims From 44f13258188e7ec161d2119e5b0a73baacc3e5fa Mon Sep 17 00:00:00 2001 From: dcherian Date: Thu, 9 Mar 2023 21:37:13 -0700 Subject: [PATCH 03/36] Fix typing --- xarray/core/groupby.py | 22 ++++++++-------------- 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 37bfb65e7ce..52ceec0b1ed 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -294,7 +294,7 @@ def _apply_loffset( class Grouper: - def __init__(self, group): + def __init__(self, group: T_Group): self.group = group self.codes = None self.labels = None @@ -304,21 +304,21 @@ def __init__(self, group): self._group_as_index = None @property - def name(self): + def name(self) -> Hashable: return self.group1d.name @property - def size(self): + def size(self) -> int: return len(self) - def __len__(self): + def __len__(self) -> int: return len(self.full_index) @property def dims(self): return self.group1d.dims - def factorize(self, squeeze): + def factorize(self, squeeze: bool) -> None: raise NotImplementedError @property @@ -334,7 +334,7 @@ def group_as_index(self) -> pd.Index: self._group_as_index = safe_cast_to_index(self.group1d) return self._group_as_index - def _resolve_group(self, obj): + def _resolve_group(self, obj) -> None: from xarray.core.dataarray import DataArray group = self.group @@ -429,7 +429,7 @@ def __init__(self, group, bins, cut_kwargs: Mapping | None): self.bins = bins self.cut_kwargs = cut_kwargs - def factorize(self, squeeze) -> None: + def factorize(self, squeeze: bool) -> None: from xarray.core.dataarray import DataArray data = self.group1d.values @@ -532,13 +532,7 @@ def first_items(self): _apply_loffset(self.loffset, first_items) return first_items, codes - def factorize( - self, squeeze - ) -> tuple[ - DataArray | IndexVariable | _DummyGroup, - list[slice] | list[list[int]] | np.ndarray, - np.ndarray, - ]: + def factorize(self, squeeze: bool) -> None: self.full_index, first_items, codes = self._get_index_and_items() sbins = first_items.values.astype(np.int64) self.group_indices = [slice(i, j) for i, j in zip(sbins[:-1], sbins[1:])] + [ From 1168ab7907d35b361cd6c322394adf495c00aa3b Mon Sep 17 00:00:00 2001 From: dcherian Date: Thu, 16 Mar 2023 10:22:22 -0600 Subject: [PATCH 04/36] [WIP] typing --- xarray/core/groupby.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 52ceec0b1ed..f038911ddbe 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -295,12 +295,12 @@ def _apply_loffset( class Grouper: def __init__(self, group: T_Group): - self.group = group - self.codes = None + self.group : T_Group | None = group + self.codes : np.ndarry | None = None self.labels = None - self.group_indices = None + self.group_indices : list[list[int, ...]] | None= None self.unique_coord = None - self.full_index = None + self.full_index : pd.Index | None = None self._group_as_index = None @property @@ -334,9 +334,10 @@ def group_as_index(self) -> pd.Index: self._group_as_index = safe_cast_to_index(self.group1d) return self._group_as_index - def _resolve_group(self, obj) -> None: + def _resolve_group(self, obj: T_DataArray | T_Dataset) -> None: from xarray.core.dataarray import DataArray + group: T_Group group = self.group if not isinstance(group, (DataArray, IndexVariable)): if not hashable(group): @@ -345,11 +346,11 @@ def _resolve_group(self, obj) -> None: "name of an xarray variable or dimension. " f"Received {group!r} instead." ) - group = obj[group] - if len(group) == 0: - raise ValueError(f"{group.name} must not be empty") + group_da : T_DataArray = obj[group] + if len(group_da) == 0: + raise ValueError(f"{group_da.name} must not be empty") - if group.name not in obj.coords and group.name in obj.dims: + if group_da.name not in obj.coords and group_da.name in obj.dims: # DummyGroups should not appear on groupby results group = _DummyGroup(obj, group.name, group.coords) From c905b745adf81301094c3c83258afdfd457f3831 Mon Sep 17 00:00:00 2001 From: dcherian Date: Fri, 17 Mar 2023 21:26:43 -0600 Subject: [PATCH 05/36] Cleanup --- xarray/core/groupby.py | 60 ++++++++++++++++++++++++------------------ 1 file changed, 34 insertions(+), 26 deletions(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index f038911ddbe..0425da2ffc7 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -44,6 +44,7 @@ from xarray.core.utils import Frozen GroupKey = Any + GroupIndex = int | slice | list[int] T_GroupIndicesListInt = list[list[int]] T_GroupIndices = Union[T_GroupIndicesListInt, list[slice], np.ndarray] @@ -129,11 +130,11 @@ def _dummy_copy(xarray_obj): return res -def _is_one_or_none(obj): +def _is_one_or_none(obj) -> bool: return obj == 1 or obj is None -def _consolidate_slices(slices): +def _consolidate_slices(slices: list[slice]) -> list[slice]: """Consolidate adjacent slices in a list of slices.""" result = [] last_slice = slice(None) @@ -191,7 +192,6 @@ def __init__(self, obj: T_Xarray, name: Hashable, coords) -> None: self.name = name self.coords = coords self.size = obj.sizes[name] - self.dataarray = obj[name] @property def dims(self) -> tuple[Hashable]: @@ -228,6 +228,13 @@ def __getitem__(self, key): def copy(self, deep: bool = True, data: Any = None): raise NotImplementedError + def as_dataarray(self) -> DataArray: + from xarray.core.dataarray import DataArray + + return DataArray( + data=self.data, dims=(self.name,), coords=self.coords, name=self.name + ) + T_Group = TypeVar("T_Group", bound=Union["DataArray", "IndexVariable", _DummyGroup]) @@ -294,14 +301,16 @@ def _apply_loffset( class Grouper: - def __init__(self, group: T_Group): - self.group : T_Group | None = group - self.codes : np.ndarry | None = None + def __init__(self, group: T_Group | Hashable): + self.group: T_Group | Hashable = group + self.labels = None - self.group_indices : list[list[int, ...]] | None= None - self.unique_coord = None - self.full_index : pd.Index | None = None - self._group_as_index = None + self._group_as_index: pd.Index | None = None + + self.codes: DataArray + self.group_indices: list[int] | list[slice] | list[list[int]] + self.unique_coord: IndexVariable | _DummyGroup + self.full_index: pd.Index @property def name(self) -> Hashable: @@ -334,10 +343,9 @@ def group_as_index(self) -> pd.Index: self._group_as_index = safe_cast_to_index(self.group1d) return self._group_as_index - def _resolve_group(self, obj: T_DataArray | T_Dataset) -> None: + def _resolve_group(self, obj: T_Xarray): from xarray.core.dataarray import DataArray - group: T_Group group = self.group if not isinstance(group, (DataArray, IndexVariable)): if not hashable(group): @@ -346,15 +354,14 @@ def _resolve_group(self, obj: T_DataArray | T_Dataset) -> None: "name of an xarray variable or dimension. " f"Received {group!r} instead." ) - group_da : T_DataArray = obj[group] - if len(group_da) == 0: - raise ValueError(f"{group_da.name} must not be empty") - - if group_da.name not in obj.coords and group_da.name in obj.dims: + group = obj[group] + if len(group) == 0: + raise ValueError(f"{group.name} must not be empty") + if group.name not in obj._indexes and group.name in obj.dims: # DummyGroups should not appear on groupby results group = _DummyGroup(obj, group.name, group.coords) - if getattr(group, "name", None) is None: + elif getattr(group, "name", None) is None: group.name = "group" self.group = group @@ -408,10 +415,10 @@ def _factorize_dummy(self, squeeze) -> None: # equivalent to: group_indices = group_indices.reshape(-1, 1) self.group_indices = [slice(i, i + 1) for i in range(size)] else: - self.group_indices = np.arange(size) + self.group_indices = list(range(size)) codes = np.arange(size) if isinstance(self.group, _DummyGroup): - self.codes = self.group.dataarray.copy(data=codes) + self.codes = self.group.as_dataarray().copy(data=codes) else: self.codes = self.group.copy(data=codes) self.unique_coord = self.group @@ -489,7 +496,7 @@ def __init__( raise ValueError("index must be monotonic for resampling") if isinstance(group_as_index, CFTimeIndex): - self.grouper = CFTimeGrouper( + grouper = CFTimeGrouper( freq=self.freq, closed=self.closed, label=self.label, @@ -498,15 +505,16 @@ def __init__( loffset=self.loffset, ) else: - self.grouper = pd.Grouper( + grouper = pd.Grouper( freq=self.freq, closed=self.closed, label=self.label, origin=self.origin, offset=self.offset, ) + self.grouper: CFTimeGrouper | pd.Grouper = grouper - def _get_index_and_items(self): + def _get_index_and_items(self) -> tuple[pd.Index, pd.Series, np.ndarray]: first_items, codes = self.first_items() full_index = first_items.index if first_items.isnull().any(): @@ -515,7 +523,7 @@ def _get_index_and_items(self): full_index = full_index.rename("__resample_dim__") return full_index, first_items, codes - def first_items(self): + def first_items(self) -> tuple[pd.Series, np.ndarray]: from xarray import CFTimeIndex if isinstance(self.group_as_index, CFTimeIndex): @@ -670,7 +678,7 @@ def reduce( raise NotImplementedError() @property - def groups(self) -> dict[GroupKey, slice | int | list[int]]: + def groups(self) -> dict[GroupKey, GroupIndex]: """ Mapping from group labels to indices. The indices can be used to index the underlying object. """ @@ -735,7 +743,7 @@ def _binary_op(self, other, f, reflexive=False): dims = group.dims if isinstance(group, _DummyGroup): - group = coord = group.dataarray + group = coord = group.as_dataarray() else: coord = grouper.unique_coord if not isinstance(coord, DataArray): From 22ad7fa7607cb83832935533a55df1f73c65811d Mon Sep 17 00:00:00 2001 From: dcherian Date: Sat, 18 Mar 2023 20:40:22 -0600 Subject: [PATCH 06/36] [WIP] --- xarray/core/common.py | 77 ++++++++++++++++++++++++++++++++++-- xarray/core/dataarray.py | 52 ++++++------------------ xarray/core/dataset.py | 43 +++++++------------- xarray/core/groupby.py | 85 ++++++++++++++++++++++++---------------- 4 files changed, 153 insertions(+), 104 deletions(-) diff --git a/xarray/core/common.py b/xarray/core/common.py index 1c6118e8d4c..cb7df60cffa 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -814,6 +814,74 @@ def rolling_exp( return rolling_exp.RollingExp(self, window, window_type) + def _groupby(self, groupby_cls, group, squeeze: bool, restore_coord_dims): + from xarray.core.groupby import UniqueGrouper, _validate_group + + # While we don't generally check the type of every arg, passing + # multiple dimensions as multiple arguments is common enough, and the + # consequences hidden enough (strings evaluate as true) to warrant + # checking here. + # A future version could make squeeze kwarg only, but would face + # backward-compat issues. + if not isinstance(squeeze, bool): + raise TypeError( + f"`squeeze` must be True or False, but {squeeze} was supplied" + ) + + newobj, name = _validate_group(self, group) + + grouper = UniqueGrouper() + return groupby_cls( + newobj, + {name: grouper}, + squeeze=squeeze, + restore_coord_dims=restore_coord_dims, + ) + + def _groupby_bins( + self, + groupby_cls, + group: Hashable | DataArray | IndexVariable, + bins: ArrayLike, + right: bool = True, + labels: ArrayLike | None = None, + precision: int = 3, + include_lowest: bool = False, + squeeze: bool = True, + restore_coord_dims: bool = False, + ): + from xarray.core.groupby import BinGrouper, _validate_group + + # While we don't generally check the type of every arg, passing + # multiple dimensions as multiple arguments is common enough, and the + # consequences hidden enough (strings evaluate as true) to warrant + # checking here. + # A future version could make squeeze kwarg only, but would face + # backward-compat issues. + if not isinstance(squeeze, bool): + raise TypeError( + f"`squeeze` must be True or False, but {squeeze} was supplied" + ) + + newobj, name = _validate_group(self, group) + + grouper = BinGrouper( + bins=bins, + cut_kwargs={ + "right": right, + "labels": labels, + "precision": precision, + "include_lowest": include_lowest, + }, + ) + + return groupby_cls( + newobj, + {name: grouper}, + squeeze=squeeze, + restore_coord_dims=restore_coord_dims, + ) + def _resample( self, resample_cls: type[T_Resample], @@ -1000,12 +1068,13 @@ def _resample( if base is not None: offset = _convert_base_to_offset(base, freq, index) + name = RESAMPLE_DIM group = DataArray( - dim_coord, coords=dim_coord.coords, dims=dim_coord.dims, name=RESAMPLE_DIM + dim_coord, coords=dim_coord.coords, dims=dim_coord.dims, name=name ) + newobj = self.copy().assign_coords({name: group}) grouper = TimeResampleGrouper( - group=group, freq=freq, closed=closed, label=label, @@ -1015,8 +1084,8 @@ def _resample( ) return resample_cls( - self, - grouper=grouper, + newobj, + {name: grouper}, dim=dim_name, resample_dim=RESAMPLE_DIM, restore_coord_dims=restore_coord_dims, diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index ed1ae078710..b6bf27f0000 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -6256,22 +6256,13 @@ def groupby( core.groupby.DataArrayGroupBy pandas.DataFrame.groupby """ - from xarray.core.groupby import DataArrayGroupBy, UniqueGrouper - - # While we don't generally check the type of every arg, passing - # multiple dimensions as multiple arguments is common enough, and the - # consequences hidden enough (strings evaluate as true) to warrant - # checking here. - # A future version could make squeeze kwarg only, but would face - # backward-compat issues. - if not isinstance(squeeze, bool): - raise TypeError( - f"`squeeze` must be True or False, but {squeeze} was supplied" - ) + from xarray.core.groupby import DataArrayGroupBy - grouper = UniqueGrouper(group) - return DataArrayGroupBy( - self, grouper, squeeze=squeeze, restore_coord_dims=restore_coord_dims + return self._groupby( + groupby_cls=DataArrayGroupBy, + group=group, + squeeze=squeeze, + restore_coord_dims=restore_coord_dims, ) def groupby_bins( @@ -6342,33 +6333,16 @@ def groupby_bins( ---------- .. [1] http://pandas.pydata.org/pandas-docs/stable/generated/pandas.cut.html """ - from xarray.core.groupby import BinGrouper, DataArrayGroupBy - - # While we don't generally check the type of every arg, passing - # multiple dimensions as multiple arguments is common enough, and the - # consequences hidden enough (strings evaluate as true) to warrant - # checking here. - # A future version could make squeeze kwarg only, but would face - # backward-compat issues. - if not isinstance(squeeze, bool): - raise TypeError( - f"`squeeze` must be True or False, but {squeeze} was supplied" - ) + from xarray.core.groupby import DataArrayGroupBy - grouper = BinGrouper( + return self._groupby_bins( + groupby_cls=DataArrayGroupBy, group=group, bins=bins, - cut_kwargs={ - "right": right, - "labels": labels, - "precision": precision, - "include_lowest": include_lowest, - }, - ) - - return DataArrayGroupBy( - self, - grouper, + right=right, + labels=labels, + precision=precision, + include_lowest=include_lowest, squeeze=squeeze, restore_coord_dims=restore_coord_dims, ) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 005801573ff..ec1d857f563 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -8943,23 +8943,13 @@ def groupby( Dataset.resample DataArray.resample """ - from xarray.core.groupby import DatasetGroupBy, UniqueGrouper - - # While we don't generally check the type of every arg, passing - # multiple dimensions as multiple arguments is common enough, and the - # consequences hidden enough (strings evaluate as true) to warrant - # checking here. - # A future version could make squeeze kwarg only, but would face - # backward-compat issues. - if not isinstance(squeeze, bool): - raise TypeError( - f"`squeeze` must be True or False, but {squeeze} was supplied" - ) + from xarray.core.groupby import DatasetGroupBy - grouper = UniqueGrouper(group) - - return DatasetGroupBy( - self, grouper, squeeze=squeeze, restore_coord_dims=restore_coord_dims + return self._groupby( + groupby_cls=DatasetGroupBy, + group=group, + squeeze=squeeze, + restore_coord_dims=restore_coord_dims, ) def groupby_bins( @@ -9030,21 +9020,18 @@ def groupby_bins( ---------- .. [1] http://pandas.pydata.org/pandas-docs/stable/generated/pandas.cut.html """ - from xarray.core.groupby import BinGrouper, DatasetGroupBy + from xarray.core.groupby import DatasetGroupBy - grouper = BinGrouper( + return self._groupby_bins( + groupby_cls=DatasetGroupBy, group=group, bins=bins, - cut_kwargs={ - "right": right, - "labels": labels, - "precision": precision, - "include_lowest": include_lowest, - }, - ) - - return DatasetGroupBy( - self, grouper, squeeze=squeeze, restore_coord_dims=restore_coord_dims + right=right, + labels=labels, + precision=precision, + include_lowest=include_lowest, + squeeze=squeeze, + restore_coord_dims=restore_coord_dims, ) def weighted(self, weights: DataArray) -> DatasetWeighted: diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 0425da2ffc7..c907ebf9e0f 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -301,9 +301,7 @@ def _apply_loffset( class Grouper: - def __init__(self, group: T_Group | Hashable): - self.group: T_Group | Hashable = group - + def __init__(self): self.labels = None self._group_as_index: pd.Index | None = None @@ -343,26 +341,13 @@ def group_as_index(self) -> pd.Index: self._group_as_index = safe_cast_to_index(self.group1d) return self._group_as_index - def _resolve_group(self, obj: T_Xarray): - from xarray.core.dataarray import DataArray - - group = self.group - if not isinstance(group, (DataArray, IndexVariable)): - if not hashable(group): - raise TypeError( - "`group` must be an xarray.DataArray or the " - "name of an xarray variable or dimension. " - f"Received {group!r} instead." - ) - group = obj[group] - if len(group) == 0: - raise ValueError(f"{group.name} must not be empty") - if group.name not in obj._indexes and group.name in obj.dims: - # DummyGroups should not appear on groupby results - group = _DummyGroup(obj, group.name, group.coords) - - elif getattr(group, "name", None) is None: - group.name = "group" + def _resolve_group(self, obj: T_Xarray, group_name: Hashable): + group = obj[group_name] + if len(group) == 0: + raise ValueError(f"{group.name} must not be empty") + if group.name not in obj._indexes and group.name in obj.dims: + # DummyGroups should not appear on groupby results + group = _DummyGroup(obj, group.name, group.coords) self.group = group @@ -381,6 +366,14 @@ def _resolve_group(self, obj: T_Xarray): return self, stacked_obj + def copy(self, deep=False): + import copy + + if deep: + return copy.deepcopy(self) + else: + return copy.copy(self) + class UniqueGrouper(Grouper): def factorize(self, squeeze) -> None: @@ -433,7 +426,6 @@ def __init__(self, group, bins, cut_kwargs: Mapping | None): if cut_kwargs is None: cut_kwargs = {} - self.group = group self.bins = bins self.cut_kwargs = cut_kwargs @@ -459,7 +451,7 @@ def factorize(self, squeeze: bool) -> None: binned, getattr(self.group1d, "coords", None), name=new_dim_name ) self.unique_coord = IndexVariable( - self.group1d.name, unique_values, self.group.attrs + new_dim_name, unique_values, self.group.attrs ) self.codes = self.group1d.copy(data=codes) # TODO: support IntervalIndex in IndexVariable @@ -470,7 +462,6 @@ def factorize(self, squeeze: bool) -> None: class TimeResampleGrouper(Grouper): def __init__( self, - group, freq: str, closed: SideOptions | None, label: SideOptions | None, @@ -478,16 +469,19 @@ def __init__( offset: pd.Timedelta | datetime.timedelta | str | None, loffset: datetime.timedelta | str | None, ): - from xarray import CFTimeIndex - from xarray.core.resample_cftime import CFTimeGrouper - - self.group = group self.freq = freq self.closed = closed self.label = label self.origin = origin self.offset = offset self.loffset = loffset + + def _resolve_group(self, obj, group_name): + from xarray import CFTimeIndex + from xarray.core.resample_cftime import CFTimeGrouper + + group = obj[group_name] + self.group = group self._group_as_index = safe_cast_to_index(group) group_as_index = self._group_as_index @@ -514,6 +508,12 @@ def __init__( ) self.grouper: CFTimeGrouper | pd.Grouper = grouper + self.group1d, stacked_obj, self.stacked_dim, self.inserted_dims = _ensure_1d( + group, obj + ) + + return self, stacked_obj + def _get_index_and_items(self) -> tuple[pd.Index, pd.Series, np.ndarray]: first_items, codes = self.first_items() full_index = first_items.index @@ -553,6 +553,25 @@ def factorize(self, squeeze: bool) -> None: self.codes = self.group.copy(data=codes) +def _validate_group(obj, group): + from xarray.core.dataarray import DataArray + + if isinstance(group, (DataArray, IndexVariable)): + name = group.name or "group" + newobj = obj.copy().assign_coords({name: group}) + else: + if not hashable(group): + raise TypeError( + "`group` must be an xarray.DataArray or the " + "name of an xarray variable or dimension. " + f"Received {group!r} instead." + ) + name = group + newobj = obj + + return newobj, name + + class GroupBy(Generic[T_Xarray]): """A object that implements the split-apply-combine pattern. @@ -596,8 +615,7 @@ class GroupBy(Generic[T_Xarray]): def __init__( self, obj: T_Xarray, - grouper: Grouper, - *, + groupers: Dict[Hashable, Grouper], squeeze: bool = False, restore_coord_dims: bool = True, ) -> None: @@ -615,7 +633,8 @@ def __init__( """ self._original_obj: T_Xarray = obj - grouper, obj = grouper._resolve_group(obj) + for group_name, grouper_ in groupers.items(): + grouper, obj = grouper_.copy()._resolve_group(obj, group_name) self._original_group = grouper.group self.groupers = (grouper,) From 22ac6deb78c21a7fcaae30f60f37c187eaf192ad Mon Sep 17 00:00:00 2001 From: dcherian Date: Sun, 19 Mar 2023 22:00:54 -0600 Subject: [PATCH 07/36] group as Variable? --- xarray/core/groupby.py | 50 +++++++++++++++++++++++++----------------- 1 file changed, 30 insertions(+), 20 deletions(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index c907ebf9e0f..2744275ddcd 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -309,10 +309,7 @@ def __init__(self): self.group_indices: list[int] | list[slice] | list[list[int]] self.unique_coord: IndexVariable | _DummyGroup self.full_index: pd.Index - - @property - def name(self) -> Hashable: - return self.group1d.name + self.name: Hashable @property def size(self) -> int: @@ -342,17 +339,23 @@ def group_as_index(self) -> pd.Index: return self._group_as_index def _resolve_group(self, obj: T_Xarray, group_name: Hashable): - group = obj[group_name] - if len(group) == 0: - raise ValueError(f"{group.name} must not be empty") - if group.name not in obj._indexes and group.name in obj.dims: + # handles virtual variables like time.month properly + group_da = obj[group_name] + name = group_da.name + self.name = name + + if len(group_da) == 0: + raise ValueError(f"{name} must not be empty") + if name not in obj._indexes and name in obj.dims: # DummyGroups should not appear on groupby results - group = _DummyGroup(obj, group.name, group.coords) + group = _DummyGroup(obj, name, group_da.coords) + else: + group = group_da.variable self.group = group self.group1d, stacked_obj, self.stacked_dim, self.inserted_dims = _ensure_1d( - group, obj + group_da, obj ) (group_dim,) = self.group1d.dims @@ -377,13 +380,14 @@ def copy(self, deep=False): class UniqueGrouper(Grouper): def factorize(self, squeeze) -> None: - is_dimension = self.group.dims == (self.group.name,) + is_dimension = self.group.dims == (self.name,) if is_dimension and self.is_unique_and_monotonic: self._factorize_dummy(squeeze) else: self._factorize_unique() def _factorize_unique(self) -> None: + from .dataarray import DataArray # look through group to find the unique values sort = not isinstance(self.group_as_index, pd.MultiIndex) unique_values, group_indices, codes = unique_value_groups( @@ -394,13 +398,14 @@ def _factorize_unique(self) -> None: "Failed to group data. Are you grouping by a variable that is all NaN?" ) self.unique_coord = IndexVariable( - self.group.name, unique_values, attrs=self.group.attrs + self.name, unique_values, attrs=self.group.attrs ) - self.codes = self.group1d.copy(data=codes) + self.codes = DataArray(self.group1d, name=self.name).copy(data=codes) self.group_indices = group_indices self.full_index = self.unique_coord def _factorize_dummy(self, squeeze) -> None: + from .dataarray import DataArray size = self.group.size # no need to factorize if not squeeze: @@ -413,13 +418,15 @@ def _factorize_dummy(self, squeeze) -> None: if isinstance(self.group, _DummyGroup): self.codes = self.group.as_dataarray().copy(data=codes) else: - self.codes = self.group.copy(data=codes) + self.codes = DataArray(self.group).copy(data=codes) + + self.codes.name = self.name self.unique_coord = self.group self.full_index = IndexVariable(self.name, self.group.values, self.group.attrs) class BinGrouper(Grouper): - def __init__(self, group, bins, cut_kwargs: Mapping | None): + def __init__(self, bins, cut_kwargs: Mapping | None): if duck_array_ops.isnull(bins).all(): raise ValueError("All bin edges are NaN.") @@ -446,14 +453,15 @@ def factorize(self, squeeze: bool) -> None: if len(group_indices) == 0: raise ValueError(f"None of the data falls within bins with edges {bins!r}") - new_dim_name = str(self.group.name) + "_bins" + new_dim_name = str(self.name) + "_bins" self.group1d = DataArray( binned, getattr(self.group1d, "coords", None), name=new_dim_name ) self.unique_coord = IndexVariable( new_dim_name, unique_values, self.group.attrs ) - self.codes = self.group1d.copy(data=codes) + from .dataarray import DataArray + self.codes = DataArray(self.group1d, name=self.name).copy(data=codes) # TODO: support IntervalIndex in IndexVariable self.full_index = full_index self.group_indices = group_indices @@ -542,15 +550,16 @@ def first_items(self) -> tuple[pd.Series, np.ndarray]: return first_items, codes def factorize(self, squeeze: bool) -> None: + from .dataarray import DataArray self.full_index, first_items, codes = self._get_index_and_items() sbins = first_items.values.astype(np.int64) self.group_indices = [slice(i, j) for i, j in zip(sbins[:-1], sbins[1:])] + [ slice(sbins[-1], None) ] self.unique_coord = IndexVariable( - self.group.name, first_items.index, self.group.attrs + self.name, first_items.index, self.group.attrs ) - self.codes = self.group.copy(data=codes) + self.codes = DataArray(self.group, name=self.name).copy(data=codes) def _validate_group(obj, group): @@ -915,6 +924,7 @@ def _flox_reduce( kwargs.setdefault("fill_value", np.nan) kwargs.setdefault("min_count", 1) + from .dataarray import DataArray output_index = grouper.full_index result = xarray_reduce( obj.drop_vars(non_numeric.keys()), @@ -1219,7 +1229,7 @@ def _restore_dim_order(self, stacked: DataArray) -> DataArray: group = grouper.group1d def lookup_order(dimension): - if dimension == group.name: + if dimension == grouper.name: (dimension,) = group.dims if dimension in self._obj.dims: axis = self._obj.get_axis_num(dimension) From 912e5c5c8b1a9c10f7fba5a62e545105dcbaa081 Mon Sep 17 00:00:00 2001 From: dcherian Date: Sun, 19 Mar 2023 22:01:06 -0600 Subject: [PATCH 08/36] Revert "group as Variable?" This reverts commit 2a36e21a031b9e061b932682758551956f3f06d2. --- xarray/core/groupby.py | 50 +++++++++++++++++------------------------- 1 file changed, 20 insertions(+), 30 deletions(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 2744275ddcd..c907ebf9e0f 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -309,7 +309,10 @@ def __init__(self): self.group_indices: list[int] | list[slice] | list[list[int]] self.unique_coord: IndexVariable | _DummyGroup self.full_index: pd.Index - self.name: Hashable + + @property + def name(self) -> Hashable: + return self.group1d.name @property def size(self) -> int: @@ -339,23 +342,17 @@ def group_as_index(self) -> pd.Index: return self._group_as_index def _resolve_group(self, obj: T_Xarray, group_name: Hashable): - # handles virtual variables like time.month properly - group_da = obj[group_name] - name = group_da.name - self.name = name - - if len(group_da) == 0: - raise ValueError(f"{name} must not be empty") - if name not in obj._indexes and name in obj.dims: + group = obj[group_name] + if len(group) == 0: + raise ValueError(f"{group.name} must not be empty") + if group.name not in obj._indexes and group.name in obj.dims: # DummyGroups should not appear on groupby results - group = _DummyGroup(obj, name, group_da.coords) - else: - group = group_da.variable + group = _DummyGroup(obj, group.name, group.coords) self.group = group self.group1d, stacked_obj, self.stacked_dim, self.inserted_dims = _ensure_1d( - group_da, obj + group, obj ) (group_dim,) = self.group1d.dims @@ -380,14 +377,13 @@ def copy(self, deep=False): class UniqueGrouper(Grouper): def factorize(self, squeeze) -> None: - is_dimension = self.group.dims == (self.name,) + is_dimension = self.group.dims == (self.group.name,) if is_dimension and self.is_unique_and_monotonic: self._factorize_dummy(squeeze) else: self._factorize_unique() def _factorize_unique(self) -> None: - from .dataarray import DataArray # look through group to find the unique values sort = not isinstance(self.group_as_index, pd.MultiIndex) unique_values, group_indices, codes = unique_value_groups( @@ -398,14 +394,13 @@ def _factorize_unique(self) -> None: "Failed to group data. Are you grouping by a variable that is all NaN?" ) self.unique_coord = IndexVariable( - self.name, unique_values, attrs=self.group.attrs + self.group.name, unique_values, attrs=self.group.attrs ) - self.codes = DataArray(self.group1d, name=self.name).copy(data=codes) + self.codes = self.group1d.copy(data=codes) self.group_indices = group_indices self.full_index = self.unique_coord def _factorize_dummy(self, squeeze) -> None: - from .dataarray import DataArray size = self.group.size # no need to factorize if not squeeze: @@ -418,15 +413,13 @@ def _factorize_dummy(self, squeeze) -> None: if isinstance(self.group, _DummyGroup): self.codes = self.group.as_dataarray().copy(data=codes) else: - self.codes = DataArray(self.group).copy(data=codes) - - self.codes.name = self.name + self.codes = self.group.copy(data=codes) self.unique_coord = self.group self.full_index = IndexVariable(self.name, self.group.values, self.group.attrs) class BinGrouper(Grouper): - def __init__(self, bins, cut_kwargs: Mapping | None): + def __init__(self, group, bins, cut_kwargs: Mapping | None): if duck_array_ops.isnull(bins).all(): raise ValueError("All bin edges are NaN.") @@ -453,15 +446,14 @@ def factorize(self, squeeze: bool) -> None: if len(group_indices) == 0: raise ValueError(f"None of the data falls within bins with edges {bins!r}") - new_dim_name = str(self.name) + "_bins" + new_dim_name = str(self.group.name) + "_bins" self.group1d = DataArray( binned, getattr(self.group1d, "coords", None), name=new_dim_name ) self.unique_coord = IndexVariable( new_dim_name, unique_values, self.group.attrs ) - from .dataarray import DataArray - self.codes = DataArray(self.group1d, name=self.name).copy(data=codes) + self.codes = self.group1d.copy(data=codes) # TODO: support IntervalIndex in IndexVariable self.full_index = full_index self.group_indices = group_indices @@ -550,16 +542,15 @@ def first_items(self) -> tuple[pd.Series, np.ndarray]: return first_items, codes def factorize(self, squeeze: bool) -> None: - from .dataarray import DataArray self.full_index, first_items, codes = self._get_index_and_items() sbins = first_items.values.astype(np.int64) self.group_indices = [slice(i, j) for i, j in zip(sbins[:-1], sbins[1:])] + [ slice(sbins[-1], None) ] self.unique_coord = IndexVariable( - self.name, first_items.index, self.group.attrs + self.group.name, first_items.index, self.group.attrs ) - self.codes = DataArray(self.group, name=self.name).copy(data=codes) + self.codes = self.group.copy(data=codes) def _validate_group(obj, group): @@ -924,7 +915,6 @@ def _flox_reduce( kwargs.setdefault("fill_value", np.nan) kwargs.setdefault("min_count", 1) - from .dataarray import DataArray output_index = grouper.full_index result = xarray_reduce( obj.drop_vars(non_numeric.keys()), @@ -1229,7 +1219,7 @@ def _restore_dim_order(self, stacked: DataArray) -> DataArray: group = grouper.group1d def lookup_order(dimension): - if dimension == grouper.name: + if dimension == group.name: (dimension,) = group.dims if dimension in self._obj.dims: axis = self._obj.get_axis_num(dimension) From 60abafe5594ebf1dbebd96cc33cea215d95c66dd Mon Sep 17 00:00:00 2001 From: dcherian Date: Wed, 29 Mar 2023 21:17:13 -0600 Subject: [PATCH 09/36] Small cleanup --- xarray/core/groupby.py | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index c907ebf9e0f..36e99dda21c 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -2,8 +2,17 @@ import datetime import warnings -from collections.abc import Hashable, Iterator, Sequence -from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, TypeVar, Union, cast +from collections.abc import Hashable, Iterator, Mapping, Sequence +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Generic, + Literal, + TypeVar, + Union, + cast, +) import numpy as np import pandas as pd @@ -419,7 +428,7 @@ def _factorize_dummy(self, squeeze) -> None: class BinGrouper(Grouper): - def __init__(self, group, bins, cut_kwargs: Mapping | None): + def __init__(self, bins, cut_kwargs: Mapping | None): if duck_array_ops.isnull(bins).all(): raise ValueError("All bin edges are NaN.") @@ -450,9 +459,7 @@ def factorize(self, squeeze: bool) -> None: self.group1d = DataArray( binned, getattr(self.group1d, "coords", None), name=new_dim_name ) - self.unique_coord = IndexVariable( - new_dim_name, unique_values, self.group.attrs - ) + self.unique_coord = IndexVariable(new_dim_name, unique_values, self.group.attrs) self.codes = self.group1d.copy(data=codes) # TODO: support IntervalIndex in IndexVariable self.full_index = full_index @@ -558,7 +565,11 @@ def _validate_group(obj, group): if isinstance(group, (DataArray, IndexVariable)): name = group.name or "group" - newobj = obj.copy().assign_coords({name: group}) + newobj = obj.copy() + if group.name in newobj: + newobj[group.name] = group + else: + newobj = newobj.assign_coords({name: group}) else: if not hashable(group): raise TypeError( @@ -615,7 +626,7 @@ class GroupBy(Generic[T_Xarray]): def __init__( self, obj: T_Xarray, - groupers: Dict[Hashable, Grouper], + groupers: dict[Hashable, Grouper], squeeze: bool = False, restore_coord_dims: bool = True, ) -> None: From c6bfdaaa0b61268f56d7069eec44c1c644fb56c8 Mon Sep 17 00:00:00 2001 From: dcherian Date: Wed, 29 Mar 2023 21:49:47 -0600 Subject: [PATCH 10/36] De-duplicate alignment check --- xarray/core/groupby.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 36e99dda21c..1f0347f2e1b 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -364,15 +364,6 @@ def _resolve_group(self, obj: T_Xarray, group_name: Hashable): group, obj ) - (group_dim,) = self.group1d.dims - expected_size = stacked_obj.sizes[group_dim] - if group.size != expected_size: - raise ValueError( - "the group variable's length does not " - "match the length of this variable along its " - "dimension" - ) - return self, stacked_obj def copy(self, deep=False): @@ -569,6 +560,15 @@ def _validate_group(obj, group): if group.name in newobj: newobj[group.name] = group else: + try: + align(newobj, group, join="exact", copy=False) + except ValueError: + raise ValueError( + "the group variable's length does not " + "match the length of this variable along its " + "dimensions" + ) + newobj = newobj.assign_coords({name: group}) else: if not hashable(group): From a2290aab191535b3ef86469d562c280ac111b51d Mon Sep 17 00:00:00 2001 From: dcherian Date: Wed, 29 Mar 2023 21:50:19 -0600 Subject: [PATCH 11/36] Fix resampling --- xarray/core/groupby.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 1f0347f2e1b..5a7fe9ffe50 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -476,9 +476,12 @@ def __init__( def _resolve_group(self, obj, group_name): from xarray import CFTimeIndex + from xarray.core.resample import RESAMPLE_DIM from xarray.core.resample_cftime import CFTimeGrouper - group = obj[group_name] + group = obj[group_name].reset_coords(drop=True) + obj = obj.drop_vars(RESAMPLE_DIM) + self.group = group self._group_as_index = safe_cast_to_index(group) group_as_index = self._group_as_index From e8630455890283de3e5999c7da27aee98db28e96 Mon Sep 17 00:00:00 2001 From: dcherian Date: Wed, 29 Mar 2023 21:53:53 -0600 Subject: [PATCH 12/36] Bugfix --- xarray/core/groupby.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 5a7fe9ffe50..abd7c2e1f8d 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -556,11 +556,14 @@ def factorize(self, squeeze: bool) -> None: def _validate_group(obj, group): from xarray.core.dataarray import DataArray + from xarray.core.dataset import Dataset if isinstance(group, (DataArray, IndexVariable)): name = group.name or "group" newobj = obj.copy() - if group.name in newobj: + if group.name in newobj.coords or ( + isinstance(newobj, Dataset) and group.name in newobj.data_vars + ): newobj[group.name] = group else: try: From 0d0b2cd268b4e833afed3f1d1439336a1f75b7b0 Mon Sep 17 00:00:00 2001 From: dcherian Date: Wed, 29 Mar 2023 22:08:17 -0600 Subject: [PATCH 13/36] Partial reverts commit 22ad7fa7607cb83832935533a55df1f73c65811d. --- xarray/core/common.py | 68 ---------------------------------------- xarray/core/dataarray.py | 44 +++++++++++++++++++------- xarray/core/dataset.py | 45 +++++++++++++++++++------- xarray/core/groupby.py | 11 +++++++ 4 files changed, 76 insertions(+), 92 deletions(-) diff --git a/xarray/core/common.py b/xarray/core/common.py index cb7df60cffa..aa9af73edc4 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -814,74 +814,6 @@ def rolling_exp( return rolling_exp.RollingExp(self, window, window_type) - def _groupby(self, groupby_cls, group, squeeze: bool, restore_coord_dims): - from xarray.core.groupby import UniqueGrouper, _validate_group - - # While we don't generally check the type of every arg, passing - # multiple dimensions as multiple arguments is common enough, and the - # consequences hidden enough (strings evaluate as true) to warrant - # checking here. - # A future version could make squeeze kwarg only, but would face - # backward-compat issues. - if not isinstance(squeeze, bool): - raise TypeError( - f"`squeeze` must be True or False, but {squeeze} was supplied" - ) - - newobj, name = _validate_group(self, group) - - grouper = UniqueGrouper() - return groupby_cls( - newobj, - {name: grouper}, - squeeze=squeeze, - restore_coord_dims=restore_coord_dims, - ) - - def _groupby_bins( - self, - groupby_cls, - group: Hashable | DataArray | IndexVariable, - bins: ArrayLike, - right: bool = True, - labels: ArrayLike | None = None, - precision: int = 3, - include_lowest: bool = False, - squeeze: bool = True, - restore_coord_dims: bool = False, - ): - from xarray.core.groupby import BinGrouper, _validate_group - - # While we don't generally check the type of every arg, passing - # multiple dimensions as multiple arguments is common enough, and the - # consequences hidden enough (strings evaluate as true) to warrant - # checking here. - # A future version could make squeeze kwarg only, but would face - # backward-compat issues. - if not isinstance(squeeze, bool): - raise TypeError( - f"`squeeze` must be True or False, but {squeeze} was supplied" - ) - - newobj, name = _validate_group(self, group) - - grouper = BinGrouper( - bins=bins, - cut_kwargs={ - "right": right, - "labels": labels, - "precision": precision, - "include_lowest": include_lowest, - }, - ) - - return groupby_cls( - newobj, - {name: grouper}, - squeeze=squeeze, - restore_coord_dims=restore_coord_dims, - ) - def _resample( self, resample_cls: type[T_Resample], diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index b6bf27f0000..31c4875622f 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -6256,11 +6256,19 @@ def groupby( core.groupby.DataArrayGroupBy pandas.DataFrame.groupby """ - from xarray.core.groupby import DataArrayGroupBy + from xarray.core.groupby import ( + DataArrayGroupBy, + UniqueGrouper, + _validate_group, + _validate_groupby_squeeze, + ) - return self._groupby( - groupby_cls=DataArrayGroupBy, - group=group, + _validate_groupby_squeeze(squeeze) + newobj, name = _validate_group(self, group) + grouper = UniqueGrouper() + return DataArrayGroupBy( + newobj, + {name: grouper}, squeeze=squeeze, restore_coord_dims=restore_coord_dims, ) @@ -6333,16 +6341,28 @@ def groupby_bins( ---------- .. [1] http://pandas.pydata.org/pandas-docs/stable/generated/pandas.cut.html """ - from xarray.core.groupby import DataArrayGroupBy + from xarray.core.groupby import ( + BinGrouper, + DataArrayGroupBy, + _validate_group, + _validate_groupby_squeeze, + ) - return self._groupby_bins( - groupby_cls=DataArrayGroupBy, - group=group, + _validate_groupby_squeeze(squeeze) + newobj, name = _validate_group(self, group) + grouper = BinGrouper( bins=bins, - right=right, - labels=labels, - precision=precision, - include_lowest=include_lowest, + cut_kwargs={ + "right": right, + "labels": labels, + "precision": precision, + "include_lowest": include_lowest, + }, + ) + + return DataArrayGroupBy( + newobj, + {name: grouper}, squeeze=squeeze, restore_coord_dims=restore_coord_dims, ) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index ec1d857f563..87b57f3e9ea 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -8943,11 +8943,20 @@ def groupby( Dataset.resample DataArray.resample """ - from xarray.core.groupby import DatasetGroupBy + from xarray.core.groupby import ( + DatasetGroupBy, + UniqueGrouper, + _validate_group, + _validate_groupby_squeeze, + ) - return self._groupby( - groupby_cls=DatasetGroupBy, - group=group, + _validate_groupby_squeeze(squeeze) + newobj, name = _validate_group(self, group) + grouper = UniqueGrouper() + + return DatasetGroupBy( + self, + {name: grouper}, squeeze=squeeze, restore_coord_dims=restore_coord_dims, ) @@ -9020,16 +9029,28 @@ def groupby_bins( ---------- .. [1] http://pandas.pydata.org/pandas-docs/stable/generated/pandas.cut.html """ - from xarray.core.groupby import DatasetGroupBy + from xarray.core.groupby import ( + BinGrouper, + DatasetGroupBy, + _validate_group, + _validate_groupby_squeeze, + ) - return self._groupby_bins( - groupby_cls=DatasetGroupBy, - group=group, + _validate_groupby_squeeze(squeeze) + newobj, name = _validate_group(self, group) + grouper = BinGrouper( bins=bins, - right=right, - labels=labels, - precision=precision, - include_lowest=include_lowest, + cut_kwargs={ + "right": right, + "labels": labels, + "precision": precision, + "include_lowest": include_lowest, + }, + ) + + return DatasetGroupBy( + self, + {name: grouper}, squeeze=squeeze, restore_coord_dims=restore_coord_dims, ) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index abd7c2e1f8d..085e2c7a3ae 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -554,6 +554,17 @@ def factorize(self, squeeze: bool) -> None: self.codes = self.group.copy(data=codes) +def _validate_groupby_squeeze(squeeze): + # While we don't generally check the type of every arg, passing + # multiple dimensions as multiple arguments is common enough, and the + # consequences hidden enough (strings evaluate as true) to warrant + # checking here. + # A future version could make squeeze kwarg only, but would face + # backward-compat issues. + if not isinstance(squeeze, bool): + raise TypeError(f"`squeeze` must be True or False, but {squeeze} was supplied") + + def _validate_group(obj, group): from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset From c5daa47fe7b0f0573e11ed3b6e3aad72ab65f4e2 Mon Sep 17 00:00:00 2001 From: dcherian Date: Thu, 30 Mar 2023 20:50:46 -0600 Subject: [PATCH 14/36] fix tests --- xarray/core/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 87b57f3e9ea..46b6ee853d9 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -8955,7 +8955,7 @@ def groupby( grouper = UniqueGrouper() return DatasetGroupBy( - self, + newobj, {name: grouper}, squeeze=squeeze, restore_coord_dims=restore_coord_dims, From dda40f5df2d7dc0f304315878bea49be6eae5927 Mon Sep 17 00:00:00 2001 From: dcherian Date: Thu, 30 Mar 2023 20:58:20 -0600 Subject: [PATCH 15/36] small cleanup --- xarray/core/groupby.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 085e2c7a3ae..b1877ce7647 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -352,8 +352,6 @@ def group_as_index(self) -> pd.Index: def _resolve_group(self, obj: T_Xarray, group_name: Hashable): group = obj[group_name] - if len(group) == 0: - raise ValueError(f"{group.name} must not be empty") if group.name not in obj._indexes and group.name in obj.dims: # DummyGroups should not appear on groupby results group = _DummyGroup(obj, group.name, group.coords) @@ -570,6 +568,9 @@ def _validate_group(obj, group): from xarray.core.dataset import Dataset if isinstance(group, (DataArray, IndexVariable)): + if len(group) == 0: + raise ValueError(f"{group.name} must not be empty") + name = group.name or "group" newobj = obj.copy() if group.name in newobj.coords or ( From eb4304346dfdfce35ceb3dc3e94e1a4ebadfcef6 Mon Sep 17 00:00:00 2001 From: dcherian Date: Thu, 30 Mar 2023 21:25:37 -0600 Subject: [PATCH 16/36] more cleanup --- xarray/core/groupby.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index b1877ce7647..2b29da3ec54 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -351,7 +351,7 @@ def group_as_index(self) -> pd.Index: return self._group_as_index def _resolve_group(self, obj: T_Xarray, group_name: Hashable): - group = obj[group_name] + group = obj[group_name].reset_coords(drop=True) if group.name not in obj._indexes and group.name in obj.dims: # DummyGroups should not appear on groupby results group = _DummyGroup(obj, group.name, group.coords) @@ -568,15 +568,12 @@ def _validate_group(obj, group): from xarray.core.dataset import Dataset if isinstance(group, (DataArray, IndexVariable)): - if len(group) == 0: - raise ValueError(f"{group.name} must not be empty") - - name = group.name or "group" + group_name = group.name or "group" newobj = obj.copy() if group.name in newobj.coords or ( isinstance(newobj, Dataset) and group.name in newobj.data_vars ): - newobj[group.name] = group + newobj[group_name] = group else: try: align(newobj, group, join="exact", copy=False) @@ -587,7 +584,7 @@ def _validate_group(obj, group): "dimensions" ) - newobj = newobj.assign_coords({name: group}) + newobj = newobj.assign_coords({group_name: group}) else: if not hashable(group): raise TypeError( @@ -595,10 +592,13 @@ def _validate_group(obj, group): "name of an xarray variable or dimension. " f"Received {group!r} instead." ) - name = group + group_name = group newobj = obj - return newobj, name + if len(newobj[group_name]) == 0: + raise ValueError(f"{group_name} must not be empty") + + return newobj, group_name class GroupBy(Generic[T_Xarray]): @@ -825,6 +825,7 @@ def _binary_op(self, other, f, reflexive=False): mask = codes == -1 if mask.any(): obj = obj.where(~mask, drop=True) + group = group.where(~mask, drop=True) codes = codes.where(~mask, drop=True).astype(int) other, _ = align(other, coord, join="outer", copy=False) From 834731388d2fa5a5d39bcbc04afcffa6aaa3b4a5 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Fri, 31 Mar 2023 07:00:37 +0200 Subject: [PATCH 17/36] Apply suggestions from code review --- xarray/core/groupby.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 2b29da3ec54..172b20609f0 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -3,6 +3,7 @@ import datetime import warnings from collections.abc import Hashable, Iterator, Mapping, Sequence +from abc import ABC, abstractmethod from typing import ( TYPE_CHECKING, Any, @@ -53,7 +54,7 @@ from xarray.core.utils import Frozen GroupKey = Any - GroupIndex = int | slice | list[int] + GroupIndex = Union[int, slice, list[int]] T_GroupIndicesListInt = list[list[int]] T_GroupIndices = Union[T_GroupIndicesListInt, list[slice], np.ndarray] @@ -309,7 +310,7 @@ def _apply_loffset( result.index = result.index + loffset -class Grouper: +class Grouper(ABC): def __init__(self): self.labels = None self._group_as_index: pd.Index | None = None @@ -334,6 +335,7 @@ def __len__(self) -> int: def dims(self): return self.group1d.dims + @abstractmethod def factorize(self, squeeze: bool) -> None: raise NotImplementedError From a89ec14ecb0d14edfa13e78ab946e06de82602be Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 31 Mar 2023 05:01:16 +0000 Subject: [PATCH 18/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/core/groupby.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 172b20609f0..5e6259001af 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -2,8 +2,8 @@ import datetime import warnings -from collections.abc import Hashable, Iterator, Mapping, Sequence from abc import ABC, abstractmethod +from collections.abc import Hashable, Iterator, Mapping, Sequence from typing import ( TYPE_CHECKING, Any, From fe4e0a7d84216da588dd09aea56882b45f931890 Mon Sep 17 00:00:00 2001 From: dcherian Date: Thu, 30 Mar 2023 21:52:11 -0600 Subject: [PATCH 19/36] Add ResolvedGrouper class --- xarray/core/groupby.py | 188 +++++++++++++++++++++-------------------- 1 file changed, 97 insertions(+), 91 deletions(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 5e6259001af..44ab9478ffc 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -310,8 +310,8 @@ def _apply_loffset( result.index = result.index + loffset -class Grouper(ABC): - def __init__(self): +class ResolvedGrouper(ABC): + def __init__(self, grouper: Grouper, group, obj): self.labels = None self._group_as_index: pd.Index | None = None @@ -320,6 +320,15 @@ def __init__(self): self.unique_coord: IndexVariable | _DummyGroup self.full_index: pd.Index + self.grouper = grouper + self.group = group + ( + self.group1d, + self.stacked_obj, + self.stacked_dim, + self.inserted_dims, + ) = _ensure_1d(group, obj) + @property def name(self) -> Hashable: return self.group1d.name @@ -352,30 +361,8 @@ def group_as_index(self) -> pd.Index: self._group_as_index = safe_cast_to_index(self.group1d) return self._group_as_index - def _resolve_group(self, obj: T_Xarray, group_name: Hashable): - group = obj[group_name].reset_coords(drop=True) - if group.name not in obj._indexes and group.name in obj.dims: - # DummyGroups should not appear on groupby results - group = _DummyGroup(obj, group.name, group.coords) - - self.group = group - - self.group1d, stacked_obj, self.stacked_dim, self.inserted_dims = _ensure_1d( - group, obj - ) - - return self, stacked_obj - - def copy(self, deep=False): - import copy - - if deep: - return copy.deepcopy(self) - else: - return copy.copy(self) - -class UniqueGrouper(Grouper): +class ResolvedUniqueGrouper(ResolvedGrouper): def factorize(self, squeeze) -> None: is_dimension = self.group.dims == (self.group.name,) if is_dimension and self.is_unique_and_monotonic: @@ -418,22 +405,14 @@ def _factorize_dummy(self, squeeze) -> None: self.full_index = IndexVariable(self.name, self.group.values, self.group.attrs) -class BinGrouper(Grouper): - def __init__(self, bins, cut_kwargs: Mapping | None): - if duck_array_ops.isnull(bins).all(): - raise ValueError("All bin edges are NaN.") - - if cut_kwargs is None: - cut_kwargs = {} - - self.bins = bins - self.cut_kwargs = cut_kwargs - +class ResolvedBinGrouper(ResolvedGrouper): def factorize(self, squeeze: bool) -> None: from xarray.core.dataarray import DataArray data = self.group1d.values - binned, bins = pd.cut(data, self.bins, **self.cut_kwargs, retbins=True) + binned, bins = pd.cut( + data, self.grouper.bins, **self.grouper.cut_kwargs, retbins=True + ) codes = binned.codes if (codes == -1).all(): raise ValueError(f"None of the data falls within bins with edges {bins!r}") @@ -457,32 +436,13 @@ def factorize(self, squeeze: bool) -> None: self.group_indices = group_indices -class TimeResampleGrouper(Grouper): - def __init__( - self, - freq: str, - closed: SideOptions | None, - label: SideOptions | None, - origin: str | DatetimeLike, - offset: pd.Timedelta | datetime.timedelta | str | None, - loffset: datetime.timedelta | str | None, - ): - self.freq = freq - self.closed = closed - self.label = label - self.origin = origin - self.offset = offset - self.loffset = loffset - - def _resolve_group(self, obj, group_name): +class ResolvedTimeResampleGrouper(ResolvedGrouper): + def __init__(self, grouper, group, obj): from xarray import CFTimeIndex - from xarray.core.resample import RESAMPLE_DIM from xarray.core.resample_cftime import CFTimeGrouper - group = obj[group_name].reset_coords(drop=True) - obj = obj.drop_vars(RESAMPLE_DIM) + super().__init__(grouper, group, obj) - self.group = group self._group_as_index = safe_cast_to_index(group) group_as_index = self._group_as_index @@ -491,29 +451,23 @@ def _resolve_group(self, obj, group_name): raise ValueError("index must be monotonic for resampling") if isinstance(group_as_index, CFTimeIndex): - grouper = CFTimeGrouper( - freq=self.freq, - closed=self.closed, - label=self.label, - origin=self.origin, - offset=self.offset, - loffset=self.loffset, + index_grouper = CFTimeGrouper( + freq=grouper.freq, + closed=grouper.closed, + label=grouper.label, + origin=grouper.origin, + offset=grouper.offset, + loffset=grouper.loffset, ) else: - grouper = pd.Grouper( - freq=self.freq, - closed=self.closed, - label=self.label, - origin=self.origin, - offset=self.offset, + index_grouper = pd.Grouper( + freq=grouper.freq, + closed=grouper.closed, + label=grouper.label, + origin=grouper.origin, + offset=grouper.offset, ) - self.grouper: CFTimeGrouper | pd.Grouper = grouper - - self.group1d, stacked_obj, self.stacked_dim, self.inserted_dims = _ensure_1d( - group, obj - ) - - return self, stacked_obj + self.index_grouper: CFTimeGrouper | pd.Grouper = index_grouper def _get_index_and_items(self) -> tuple[pd.Index, pd.Series, np.ndarray]: first_items, codes = self.first_items() @@ -528,18 +482,18 @@ def first_items(self) -> tuple[pd.Series, np.ndarray]: from xarray import CFTimeIndex if isinstance(self.group_as_index, CFTimeIndex): - return self.grouper.first_items(self.group_as_index) + return self.index_grouper.first_items(self.group_as_index) else: s = pd.Series(np.arange(self.group_as_index.size), self.group_as_index) - grouped = s.groupby(self.grouper) + grouped = s.groupby(self.index_grouper) first_items = grouped.first() counts = grouped.count() # This way we generate codes for the final output index: full_index. # So for _flox_reduce we avoid one reindex and copy by avoiding # _maybe_restore_empty_groups codes = np.repeat(np.arange(len(first_items)), counts) - if self.loffset is not None: - _apply_loffset(self.loffset, first_items) + if self.grouper.loffset is not None: + _apply_loffset(self.grouper.loffset, first_items) return first_items, codes def factorize(self, squeeze: bool) -> None: @@ -554,6 +508,55 @@ def factorize(self, squeeze: bool) -> None: self.codes = self.group.copy(data=codes) +class Grouper(ABC): + pass + +class UniqueGrouper(Grouper): + _resolved_cls = ResolvedUniqueGrouper + + +class BinGrouper(Grouper): + _resolved_cls = ResolvedBinGrouper + + def __init__(self, bins, cut_kwargs: Mapping | None): + if duck_array_ops.isnull(bins).all(): + raise ValueError("All bin edges are NaN.") + + if cut_kwargs is None: + cut_kwargs = {} + + self.bins = bins + self.cut_kwargs = cut_kwargs + + +class TimeResampleGrouper(Grouper): + _resolved_cls = ResolvedTimeResampleGrouper + + def __init__( + self, + freq: str, + closed: SideOptions | None, + label: SideOptions | None, + origin: str | DatetimeLike, + offset: pd.Timedelta | datetime.timedelta | str | None, + loffset: datetime.timedelta | str | None, + ): + self.freq = freq + self.closed = closed + self.label = label + self.origin = origin + self.offset = offset + self.loffset = loffset + + def _resolve_group(self, obj, group_name) -> ResolvedGrouper: + from xarray.core.resample import RESAMPLE_DIM + + group = obj[group_name].reset_coords(drop=True) + # TODO: This is an ugly in-place modification + del obj[RESAMPLE_DIM] + return ResolvedTimeResampleGrouper(self, group, obj) + + def _validate_groupby_squeeze(squeeze): # While we don't generally check the type of every arg, passing # multiple dimensions as multiple arguments is common enough, and the @@ -662,18 +665,21 @@ def __init__( If True, also restore the dimension order of multi-dimensional coordinates. """ + self.groupers: tuple[ResolvedGrouper] = tuple( + grouper_._resolve_group(obj, group_name) + for group_name, grouper_ in groupers.items() + ) + self._original_obj: T_Xarray = obj - for group_name, grouper_ in groupers.items(): - grouper, obj = grouper_.copy()._resolve_group(obj, group_name) + for grouper_ in self.groupers: + grouper_.factorize(squeeze) + (grouper,) = self.groupers self._original_group = grouper.group - self.groupers = (grouper,) - - grouper.factorize(squeeze) # specification for the groupby operation - self._obj: T_Xarray = obj + self._obj: T_Xarray = grouper.stacked_obj self._restore_coord_dims = restore_coord_dims self._squeeze = squeeze @@ -855,7 +861,7 @@ def _maybe_restore_empty_groups(self, combined): """ (grouper,) = self.groupers if ( - isinstance(grouper, (BinGrouper, TimeResampleGrouper)) + isinstance(grouper, (ResolvedBinGrouper, ResolvedTimeResampleGrouper)) and grouper.name in combined.dims ): indexers = {grouper.name: grouper.full_index} @@ -889,7 +895,7 @@ def _flox_reduce( obj = self._original_obj (grouper,) = self.groupers - isbin = isinstance(grouper, BinGrouper) + isbin = isinstance(grouper, ResolvedBinGrouper) if keep_attrs is None: keep_attrs = _get_keep_attrs(default=True) From 0ffc0ad7b4919f662ab4fdffabcd09d81162eb99 Mon Sep 17 00:00:00 2001 From: dcherian Date: Sat, 1 Apr 2023 22:11:40 -0600 Subject: [PATCH 20/36] GroupBy only handles ResolvedGrouper objects. Much cleaner! --- xarray/core/common.py | 9 +++-- xarray/core/dataarray.py | 17 +++++----- xarray/core/dataset.py | 15 ++++----- xarray/core/groupby.py | 71 ++++++++++++++++------------------------ 4 files changed, 47 insertions(+), 65 deletions(-) diff --git a/xarray/core/common.py b/xarray/core/common.py index aa9af73edc4..a6d8a440706 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -949,7 +949,7 @@ def _resample( # TODO support non-string indexer after removing the old API. from xarray.core.dataarray import DataArray - from xarray.core.groupby import TimeResampleGrouper + from xarray.core.groupby import ResolvedTimeResampleGrouper, TimeResampleGrouper from xarray.core.resample import RESAMPLE_DIM if keep_attrs is not None: @@ -1004,8 +1004,6 @@ def _resample( group = DataArray( dim_coord, coords=dim_coord.coords, dims=dim_coord.dims, name=name ) - newobj = self.copy().assign_coords({name: group}) - grouper = TimeResampleGrouper( freq=freq, closed=closed, @@ -1014,10 +1012,11 @@ def _resample( offset=offset, loffset=loffset, ) + rgrouper = ResolvedTimeResampleGrouper(grouper, group, self) return resample_cls( - newobj, - {name: grouper}, + self, + (rgrouper,), dim=dim_name, resample_dim=RESAMPLE_DIM, restore_coord_dims=restore_coord_dims, diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 31c4875622f..f016a298374 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -6258,17 +6258,16 @@ def groupby( """ from xarray.core.groupby import ( DataArrayGroupBy, + ResolvedUniqueGrouper, UniqueGrouper, - _validate_group, _validate_groupby_squeeze, ) _validate_groupby_squeeze(squeeze) - newobj, name = _validate_group(self, group) - grouper = UniqueGrouper() + rgrouper = ResolvedUniqueGrouper(UniqueGrouper(), group, self) return DataArrayGroupBy( - newobj, - {name: grouper}, + self, + (rgrouper,), squeeze=squeeze, restore_coord_dims=restore_coord_dims, ) @@ -6344,12 +6343,11 @@ def groupby_bins( from xarray.core.groupby import ( BinGrouper, DataArrayGroupBy, - _validate_group, + ResolvedBinGrouper, _validate_groupby_squeeze, ) _validate_groupby_squeeze(squeeze) - newobj, name = _validate_group(self, group) grouper = BinGrouper( bins=bins, cut_kwargs={ @@ -6359,10 +6357,11 @@ def groupby_bins( "include_lowest": include_lowest, }, ) + rgrouper = ResolvedBinGrouper(grouper, group, self) return DataArrayGroupBy( - newobj, - {name: grouper}, + self, + (rgrouper,), squeeze=squeeze, restore_coord_dims=restore_coord_dims, ) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 46b6ee853d9..3cfb5a4f21f 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -8945,18 +8945,17 @@ def groupby( """ from xarray.core.groupby import ( DatasetGroupBy, + ResolvedUniqueGrouper, UniqueGrouper, - _validate_group, _validate_groupby_squeeze, ) _validate_groupby_squeeze(squeeze) - newobj, name = _validate_group(self, group) - grouper = UniqueGrouper() + rgrouper = ResolvedUniqueGrouper(UniqueGrouper(), group, self) return DatasetGroupBy( - newobj, - {name: grouper}, + self, + (rgrouper,), squeeze=squeeze, restore_coord_dims=restore_coord_dims, ) @@ -9032,12 +9031,11 @@ def groupby_bins( from xarray.core.groupby import ( BinGrouper, DatasetGroupBy, - _validate_group, + ResolvedBinGrouper, _validate_groupby_squeeze, ) _validate_groupby_squeeze(squeeze) - newobj, name = _validate_group(self, group) grouper = BinGrouper( bins=bins, cut_kwargs={ @@ -9047,10 +9045,11 @@ def groupby_bins( "include_lowest": include_lowest, }, ) + rgrouper = ResolvedBinGrouper(grouper, group, self) return DatasetGroupBy( self, - {name: grouper}, + (rgrouper,), squeeze=squeeze, restore_coord_dims=restore_coord_dims, ) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 44ab9478ffc..8a5f219eff9 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -321,13 +321,14 @@ def __init__(self, grouper: Grouper, group, obj): self.full_index: pd.Index self.grouper = grouper - self.group = group + self.group = _resolve_group(obj, group) + ( self.group1d, self.stacked_obj, self.stacked_dim, self.inserted_dims, - ) = _ensure_1d(group, obj) + ) = _ensure_1d(self.group, obj) @property def name(self) -> Hashable: @@ -511,13 +512,12 @@ def factorize(self, squeeze: bool) -> None: class Grouper(ABC): pass + class UniqueGrouper(Grouper): - _resolved_cls = ResolvedUniqueGrouper + pass class BinGrouper(Grouper): - _resolved_cls = ResolvedBinGrouper - def __init__(self, bins, cut_kwargs: Mapping | None): if duck_array_ops.isnull(bins).all(): raise ValueError("All bin edges are NaN.") @@ -530,8 +530,6 @@ def __init__(self, bins, cut_kwargs: Mapping | None): class TimeResampleGrouper(Grouper): - _resolved_cls = ResolvedTimeResampleGrouper - def __init__( self, freq: str, @@ -548,14 +546,6 @@ def __init__( self.offset = offset self.loffset = loffset - def _resolve_group(self, obj, group_name) -> ResolvedGrouper: - from xarray.core.resample import RESAMPLE_DIM - - group = obj[group_name].reset_coords(drop=True) - # TODO: This is an ugly in-place modification - del obj[RESAMPLE_DIM] - return ResolvedTimeResampleGrouper(self, group, obj) - def _validate_groupby_squeeze(squeeze): # While we don't generally check the type of every arg, passing @@ -568,28 +558,22 @@ def _validate_groupby_squeeze(squeeze): raise TypeError(f"`squeeze` must be True or False, but {squeeze} was supplied") -def _validate_group(obj, group): +def _resolve_group(obj, group: T_Group | Hashable) -> T_Group: from xarray.core.dataarray import DataArray - from xarray.core.dataset import Dataset if isinstance(group, (DataArray, IndexVariable)): - group_name = group.name or "group" - newobj = obj.copy() - if group.name in newobj.coords or ( - isinstance(newobj, Dataset) and group.name in newobj.data_vars - ): - newobj[group_name] = group - else: - try: - align(newobj, group, join="exact", copy=False) - except ValueError: - raise ValueError( - "the group variable's length does not " - "match the length of this variable along its " - "dimensions" - ) + try: + align(obj, group, join="exact", copy=False) + except ValueError: + raise ValueError( + "the group variable's length does not " + "match the length of this variable along its " + "dimensions" + ) + + newgroup = group.copy() + newgroup.name = group.name or "group" - newobj = newobj.assign_coords({group_name: group}) else: if not hashable(group): raise TypeError( @@ -597,13 +581,17 @@ def _validate_group(obj, group): "name of an xarray variable or dimension. " f"Received {group!r} instead." ) - group_name = group - newobj = obj + group = obj[group] + if group.name not in obj._indexes and group.name in obj.dims: + # DummyGroups should not appear on groupby results + newgroup = _DummyGroup(obj, group.name, group.coords) + else: + newgroup = group - if len(newobj[group_name]) == 0: - raise ValueError(f"{group_name} must not be empty") + if newgroup.size == 0: + raise ValueError(f"{newgroup.name} must not be empty") - return newobj, group_name + return newgroup class GroupBy(Generic[T_Xarray]): @@ -649,7 +637,7 @@ class GroupBy(Generic[T_Xarray]): def __init__( self, obj: T_Xarray, - groupers: dict[Hashable, Grouper], + groupers: tuple[ResolvedGrouper], squeeze: bool = False, restore_coord_dims: bool = True, ) -> None: @@ -665,10 +653,7 @@ def __init__( If True, also restore the dimension order of multi-dimensional coordinates. """ - self.groupers: tuple[ResolvedGrouper] = tuple( - grouper_._resolve_group(obj, group_name) - for group_name, grouper_ in groupers.items() - ) + self.groupers = groupers self._original_obj: T_Xarray = obj From 3e9479dff5fbfc789195327ded83952e50747d85 Mon Sep 17 00:00:00 2001 From: dcherian Date: Sat, 1 Apr 2023 22:16:52 -0600 Subject: [PATCH 21/36] review feedback --- xarray/core/groupby.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 8a5f219eff9..5dd73ccdfcc 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -238,7 +238,7 @@ def __getitem__(self, key): def copy(self, deep: bool = True, data: Any = None): raise NotImplementedError - def as_dataarray(self) -> DataArray: + def to_dataarray(self) -> DataArray: from xarray.core.dataarray import DataArray return DataArray( @@ -399,7 +399,7 @@ def _factorize_dummy(self, squeeze) -> None: self.group_indices = list(range(size)) codes = np.arange(size) if isinstance(self.group, _DummyGroup): - self.codes = self.group.as_dataarray().copy(data=codes) + self.codes = self.group.to_dataarray().copy(data=codes) else: self.codes = self.group.copy(data=codes) self.unique_coord = self.group @@ -784,7 +784,7 @@ def _binary_op(self, other, f, reflexive=False): dims = group.dims if isinstance(group, _DummyGroup): - group = coord = group.as_dataarray() + group = coord = group.to_dataarray() else: coord = grouper.unique_coord if not isinstance(coord, DataArray): From 81319e24df98cd5bf2da38a048283db29ccf7c4e Mon Sep 17 00:00:00 2001 From: dcherian Date: Sat, 1 Apr 2023 22:18:54 -0600 Subject: [PATCH 22/36] minimize diff --- xarray/core/common.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/xarray/core/common.py b/xarray/core/common.py index a6d8a440706..f6abcba1ff0 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -996,14 +996,10 @@ def _resample( if base is not None and offset is not None: raise ValueError("base and offset cannot be present at the same time") - index = self._indexes[dim_name].to_pandas_index() if base is not None: + index = self._indexes[dim_name].to_pandas_index() offset = _convert_base_to_offset(base, freq, index) - name = RESAMPLE_DIM - group = DataArray( - dim_coord, coords=dim_coord.coords, dims=dim_coord.dims, name=name - ) grouper = TimeResampleGrouper( freq=freq, closed=closed, @@ -1012,6 +1008,11 @@ def _resample( offset=offset, loffset=loffset, ) + + group = DataArray( + dim_coord, coords=dim_coord.coords, dims=dim_coord.dims, name=RESAMPLE_DIM + ) + rgrouper = ResolvedTimeResampleGrouper(grouper, group, self) return resample_cls( From f271d1bfe25c35fbb0c3ecaa3a69f8cef65a0356 Mon Sep 17 00:00:00 2001 From: dcherian Date: Sun, 2 Apr 2023 08:20:08 -0600 Subject: [PATCH 23/36] dataclass --- xarray/core/groupby.py | 25 ++++++++++--------------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 5dd73ccdfcc..73ce9c3c53f 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -4,6 +4,7 @@ import warnings from abc import ABC, abstractmethod from collections.abc import Hashable, Iterator, Mapping, Sequence +from dataclasses import dataclass from typing import ( TYPE_CHECKING, Any, @@ -513,10 +514,12 @@ class Grouper(ABC): pass +@dataclass class UniqueGrouper(Grouper): pass +@dataclass(init=False) class BinGrouper(Grouper): def __init__(self, bins, cut_kwargs: Mapping | None): if duck_array_ops.isnull(bins).all(): @@ -529,22 +532,14 @@ def __init__(self, bins, cut_kwargs: Mapping | None): self.cut_kwargs = cut_kwargs +@dataclass class TimeResampleGrouper(Grouper): - def __init__( - self, - freq: str, - closed: SideOptions | None, - label: SideOptions | None, - origin: str | DatetimeLike, - offset: pd.Timedelta | datetime.timedelta | str | None, - loffset: datetime.timedelta | str | None, - ): - self.freq = freq - self.closed = closed - self.label = label - self.origin = origin - self.offset = offset - self.loffset = loffset + freq: str + closed: SideOptions | None + label: SideOptions | None + origin: str | DatetimeLike | None + offset: pd.Timedelta | datetime.timedelta | str | None + loffset: datetime.timedelta | str | None def _validate_groupby_squeeze(squeeze): From e07ae3171ed693737df827b768bd0830f66458c6 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Sun, 2 Apr 2023 12:10:20 -0600 Subject: [PATCH 24/36] moar dataclass Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> --- xarray/core/groupby.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 73ce9c3c53f..5c55ec7b600 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -4,7 +4,7 @@ import warnings from abc import ABC, abstractmethod from collections.abc import Hashable, Iterator, Mapping, Sequence -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import ( TYPE_CHECKING, Any, @@ -519,17 +519,14 @@ class UniqueGrouper(Grouper): pass -@dataclass(init=False) +@dataclass class BinGrouper(Grouper): - def __init__(self, bins, cut_kwargs: Mapping | None): - if duck_array_ops.isnull(bins).all(): - raise ValueError("All bin edges are NaN.") + bins: Any # TODO: What is the typing? + cut_kwargs: Mapping = field(default_factory=dict) - if cut_kwargs is None: - cut_kwargs = {} - - self.bins = bins - self.cut_kwargs = cut_kwargs + def __post_init__(self): + if duck_array_ops.isnull(self.bins).all(): + raise ValueError("All bin edges are NaN.") @dataclass From d5574186d95ccdedbfd3815404a43629ae5ea0f7 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Mon, 3 Apr 2023 21:54:10 +0200 Subject: [PATCH 25/36] Add typing --- xarray/core/groupby.py | 115 +++++++++++++++++++++++++++-------------- 1 file changed, 75 insertions(+), 40 deletions(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 5c55ec7b600..d121a9bcd06 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -11,6 +11,7 @@ Callable, Generic, Literal, + overload, TypeVar, Union, cast, @@ -36,7 +37,7 @@ ) from xarray.core.options import _get_keep_attrs from xarray.core.pycompat import integer_types -from xarray.core.types import Dims, QuantileMethods, T_Xarray +from xarray.core.types import Dims, QuantileMethods, T_DataArray, T_Xarray from xarray.core.utils import ( either_dict_or_kwargs, hashable, @@ -56,9 +57,7 @@ GroupKey = Any GroupIndex = Union[int, slice, list[int]] - - T_GroupIndicesListInt = list[list[int]] - T_GroupIndices = Union[T_GroupIndicesListInt, list[slice], np.ndarray] + T_GroupIndices = list[GroupIndex] def check_reduce_dims(reduce_dims, dimensions): @@ -99,8 +98,8 @@ def unique_value_groups( return values, groups, inverse -def _codes_to_groups(inverse: np.ndarray, N: int) -> T_GroupIndicesListInt: - groups: T_GroupIndicesListInt = [[] for _ in range(N)] +def _codes_to_groups(inverse: np.ndarray, N: int) -> T_GroupIndices: + groups: T_GroupIndices = [[] for _ in range(N)] for n, g in enumerate(inverse): if g >= 0: groups[g].append(n) @@ -147,7 +146,7 @@ def _is_one_or_none(obj) -> bool: def _consolidate_slices(slices: list[slice]) -> list[slice]: """Consolidate adjacent slices in a list of slices.""" - result = [] + result: list[slice] = [] last_slice = slice(None) for slice_ in slices: if not isinstance(slice_, slice): @@ -191,7 +190,7 @@ def _inverse_permutation_indices(positions, N: int | None = None) -> np.ndarray return newpositions[newpositions != -1] -class _DummyGroup: +class _DummyGroup(Generic[T_Xarray]): """Class for keeping track of grouped dimensions without coordinates. Should not be user visible. @@ -247,18 +246,19 @@ def to_dataarray(self) -> DataArray: ) -T_Group = TypeVar("T_Group", bound=Union["DataArray", "IndexVariable", _DummyGroup]) +# T_Group = TypeVar("T_Group", bound=Union["DataArray", "IndexVariable", _DummyGroup]) +T_Group = Union["T_DataArray", "IndexVariable", _DummyGroup] def _ensure_1d( group: T_Group, obj: T_Xarray -) -> tuple[T_Group, T_Xarray, Hashable | None, list[Hashable]]: +) -> tuple[T_Group, T_Xarray, Hashable | None, list[Hashable],]: # 1D cases: do nothing - from xarray.core.dataarray import DataArray - if isinstance(group, (IndexVariable, _DummyGroup)) or group.ndim == 1: return group, obj, None, [] + from xarray.core.dataarray import DataArray + if isinstance(group, DataArray): # try to stack the dims of the group into a single dim orig_dims = group.dims @@ -267,7 +267,7 @@ def _ensure_1d( inserted_dims = [dim for dim in group.dims if dim not in group.coords] newgroup = group.stack({stacked_dim: orig_dims}) newobj = obj.stack({stacked_dim: orig_dims}) - return cast(T_Group, newgroup), newobj, stacked_dim, inserted_dims + return newgroup, newobj, stacked_dim, inserted_dims raise TypeError( f"group must be DataArray, IndexVariable or _DummyGroup, got {type(group)!r}." @@ -311,25 +311,36 @@ def _apply_loffset( result.index = result.index + loffset -class ResolvedGrouper(ABC): - def __init__(self, grouper: Grouper, group, obj): - self.labels = None - self._group_as_index: pd.Index | None = None +@dataclass +class ResolvedGrouper(ABC, Generic[T_Xarray]): + grouper: Grouper + group: T_Group + obj: T_Xarray + + _group_as_index: pd.Index | None = field(default=None, init=False) + + # Not used here:? + labels: Any | None = field(default=None, init=False) # TODO: Typing? + codes: DataArray = field(init=False) + group_indices: T_GroupIndices = field(init=False) + unique_coord: IndexVariable | _DummyGroup = field(init=False) + full_index: pd.Index = field(init=False) - self.codes: DataArray - self.group_indices: list[int] | list[slice] | list[list[int]] - self.unique_coord: IndexVariable | _DummyGroup - self.full_index: pd.Index + # _ensure_1d: + group1d: T_Group = field(init=False) + stacked_obj: T_Xarray = field(init=False) + stacked_dim: Hashable | None = field(init=False) + inserted_dims: list[Hashable] = field(init=False) - self.grouper = grouper - self.group = _resolve_group(obj, group) + def __post_init__(self) -> None: + self.group: T_Group = _resolve_group(self.obj, self.group) ( self.group1d, self.stacked_obj, self.stacked_dim, self.inserted_dims, - ) = _ensure_1d(self.group, obj) + ) = _ensure_1d(group=self.group, obj=self.obj) @property def name(self) -> Hashable: @@ -340,7 +351,7 @@ def size(self) -> int: return len(self) def __len__(self) -> int: - return len(self.full_index) + return len(self.full_index) # TODO: full_index not def, abstractmethod? @property def dims(self): @@ -364,7 +375,10 @@ def group_as_index(self) -> pd.Index: return self._group_as_index +@dataclass class ResolvedUniqueGrouper(ResolvedGrouper): + grouper: UniqueGrouper + def factorize(self, squeeze) -> None: is_dimension = self.group.dims == (self.group.name,) if is_dimension and self.is_unique_and_monotonic: @@ -407,7 +421,10 @@ def _factorize_dummy(self, squeeze) -> None: self.full_index = IndexVariable(self.name, self.group.values, self.group.attrs) +@dataclass class ResolvedBinGrouper(ResolvedGrouper): + grouper: BinGrouper + def factorize(self, squeeze: bool) -> None: from xarray.core.dataarray import DataArray @@ -438,21 +455,26 @@ def factorize(self, squeeze: bool) -> None: self.group_indices = group_indices +@dataclass class ResolvedTimeResampleGrouper(ResolvedGrouper): - def __init__(self, grouper, group, obj): - from xarray import CFTimeIndex - from xarray.core.resample_cftime import CFTimeGrouper + grouper: TimeResampleGrouper + + def __post_init__(self) -> None: + super().__post_init__() - super().__init__(grouper, group, obj) + from xarray import CFTimeIndex - self._group_as_index = safe_cast_to_index(group) - group_as_index = self._group_as_index + group_as_index = safe_cast_to_index(self.group) + self._group_as_index = group_as_index if not group_as_index.is_monotonic_increasing: # TODO: sort instead of raising an error raise ValueError("index must be monotonic for resampling") + grouper = self.grouper if isinstance(group_as_index, CFTimeIndex): + from xarray.core.resample_cftime import CFTimeGrouper + index_grouper = CFTimeGrouper( freq=grouper.freq, closed=grouper.closed, @@ -501,9 +523,9 @@ def first_items(self) -> tuple[pd.Series, np.ndarray]: def factorize(self, squeeze: bool) -> None: self.full_index, first_items, codes = self._get_index_and_items() sbins = first_items.values.astype(np.int64) - self.group_indices = [slice(i, j) for i, j in zip(sbins[:-1], sbins[1:])] + [ - slice(sbins[-1], None) - ] + self.group_indices = [slice(i, j) for i, j in zip(sbins[:-1], sbins[1:])] + self.group_indices += [slice(sbins[-1], None)] + self.unique_coord = IndexVariable( self.group.name, first_items.index, self.group.attrs ) @@ -550,7 +572,7 @@ def _validate_groupby_squeeze(squeeze): raise TypeError(f"`squeeze` must be True or False, but {squeeze} was supplied") -def _resolve_group(obj, group: T_Group | Hashable) -> T_Group: +def _resolve_group(obj: T_Xarray, group: T_Group | Hashable) -> T_Group: from xarray.core.dataarray import DataArray if isinstance(group, (DataArray, IndexVariable)): @@ -625,6 +647,19 @@ class GroupBy(Generic[T_Xarray]): "_codes", ) _obj: T_Xarray + groupers: tuple[ResolvedGrouper] + _squeeze: bool + _restore_coord_dims: bool + + _original_obj: T_Xarray + _original_group: T_Group + _group_indices: T_GroupIndices + _codes: DataArray + _group_dim: Hashable + + _groups: dict[GroupKey, GroupIndex] | None + _dims: tuple[Hashable, ...] | Frozen[Hashable, int] | None + _sizes: Frozen[Hashable, int] | None def __init__( self, @@ -647,7 +682,7 @@ def __init__( """ self.groupers = groupers - self._original_obj: T_Xarray = obj + self._original_obj = obj for grouper_ in self.groupers: grouper_.factorize(squeeze) @@ -656,7 +691,7 @@ def __init__( self._original_group = grouper.group # specification for the groupby operation - self._obj: T_Xarray = grouper.stacked_obj + self._obj = grouper.stacked_obj self._restore_coord_dims = restore_coord_dims self._squeeze = squeeze @@ -666,9 +701,9 @@ def __init__( (self._group_dim,) = grouper.group1d.dims # cached attributes - self._groups: dict[GroupKey, slice | int | list[int]] | None = None - self._dims: tuple[Hashable, ...] | Frozen[Hashable, int] | None = None - self._sizes: Frozen[Hashable, int] | None = None + self._groups = None + self._dims = None + self._sizes = None @property def sizes(self) -> Frozen[Hashable, int]: From 2188a17b88eef5039df43a0316c750f007f93195 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 3 Apr 2023 19:54:56 +0000 Subject: [PATCH 26/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/core/groupby.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index d121a9bcd06..a42141c0efd 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -11,10 +11,7 @@ Callable, Generic, Literal, - overload, - TypeVar, Union, - cast, ) import numpy as np From fe0e4213b481d8a88145d265eae80833aa0305a3 Mon Sep 17 00:00:00 2001 From: dcherian Date: Tue, 18 Apr 2023 16:22:45 -0600 Subject: [PATCH 27/36] Ignore type checking error. --- xarray/tests/test_groupby.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index ca9e0f40cc3..75783756382 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -51,8 +51,9 @@ def test_consolidate_slices() -> None: slices = [slice(2, 3), slice(5, 6)] assert _consolidate_slices(slices) == slices + # ignore type because we're checking for an error anyway with pytest.raises(ValueError): - _consolidate_slices([slice(3), 4]) + _consolidate_slices([slice(3), 4]) # type: ignore[list-item] def test_groupby_dims_property(dataset) -> None: From 0cc1ba387df672d2de981f6653012e7c5b84d201 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Wed, 26 Apr 2023 07:02:17 +0200 Subject: [PATCH 28/36] Update groupby.py --- xarray/core/groupby.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index eafeeef54af..8694d65c2d9 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -243,7 +243,6 @@ def to_dataarray(self) -> DataArray: ) -# T_Group = TypeVar("T_Group", bound=Union["DataArray", "IndexVariable", _DummyGroup]) T_Group = Union["T_DataArray", "IndexVariable", _DummyGroup] @@ -316,7 +315,7 @@ class ResolvedGrouper(ABC, Generic[T_Xarray]): _group_as_index: pd.Index | None = field(default=None, init=False) - # Not used here:? + # Defined by factorize: labels: Any | None = field(default=None, init=False) # TODO: Typing? codes: DataArray = field(init=False) group_indices: T_GroupIndices = field(init=False) @@ -543,7 +542,7 @@ class BinGrouper(Grouper): bins: Any # TODO: What is the typing? cut_kwargs: Mapping = field(default_factory=dict) - def __post_init__(self): + def __post_init__(self) -> None: if duck_array_ops.isnull(self.bins).all(): raise ValueError("All bin edges are NaN.") @@ -558,7 +557,7 @@ class TimeResampleGrouper(Grouper): loffset: datetime.timedelta | str | None -def _validate_groupby_squeeze(squeeze): +def _validate_groupby_squeeze(squeeze: bool) -> None: # While we don't generally check the type of every arg, passing # multiple dimensions as multiple arguments is common enough, and the # consequences hidden enough (strings evaluate as true) to warrant From 2e10d3fdf1b4df76326f97c01317e90869de1c90 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Thu, 27 Apr 2023 22:28:13 +0200 Subject: [PATCH 29/36] Move factorize to _factorize --- xarray/core/groupby.py | 92 ++++++++++++++++++++++++++---------------- 1 file changed, 57 insertions(+), 35 deletions(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 8694d65c2d9..cd2e83c4823 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -51,10 +51,14 @@ from xarray.core.dataset import Dataset from xarray.core.types import DatetimeLike, SideOptions from xarray.core.utils import Frozen + from xarray.core.resample_cftime import CFTimeGrouper GroupKey = Any GroupIndex = Union[int, slice, list[int]] T_GroupIndices = list[GroupIndex] + T_FactorizeOut = tuple[ + DataArray, T_GroupIndices, IndexVariable | "_DummyGroup", pd.Index + ] def check_reduce_dims(reduce_dims, dimensions): @@ -316,7 +320,6 @@ class ResolvedGrouper(ABC, Generic[T_Xarray]): _group_as_index: pd.Index | None = field(default=None, init=False) # Defined by factorize: - labels: Any | None = field(default=None, init=False) # TODO: Typing? codes: DataArray = field(init=False) group_indices: T_GroupIndices = field(init=False) unique_coord: IndexVariable | _DummyGroup = field(init=False) @@ -354,9 +357,17 @@ def dims(self): return self.group1d.dims @abstractmethod - def factorize(self, squeeze: bool) -> None: + def _factorize(self, squeeze: bool) -> T_FactorizeOut: raise NotImplementedError + def factorize(self, squeeze: bool) -> None: + ( + self.codes, + self.group_indices, + self.unique_coord, + self.full_index, + ) = self._factorize(squeeze) + @property def is_unique_and_monotonic(self) -> bool: if isinstance(self.group, _DummyGroup): @@ -375,67 +386,73 @@ def group_as_index(self) -> pd.Index: class ResolvedUniqueGrouper(ResolvedGrouper): grouper: UniqueGrouper - def factorize(self, squeeze) -> None: + def _factorize(self, squeeze) -> T_FactorizeOut: is_dimension = self.group.dims == (self.group.name,) if is_dimension and self.is_unique_and_monotonic: - self._factorize_dummy(squeeze) + return self._factorize_dummy(squeeze) else: - self._factorize_unique() + return self._factorize_unique() - def _factorize_unique(self) -> None: + def _factorize_unique(self) -> T_FactorizeOut: # look through group to find the unique values sort = not isinstance(self.group_as_index, pd.MultiIndex) - unique_values, group_indices, codes = unique_value_groups( + unique_values, group_indices, codes_ = unique_value_groups( self.group_as_index, sort=sort ) if len(group_indices) == 0: raise ValueError( "Failed to group data. Are you grouping by a variable that is all NaN?" ) - self.unique_coord = IndexVariable( + codes = self.group1d.copy(data=codes_) + group_indices = group_indices + unique_coord = IndexVariable( self.group.name, unique_values, attrs=self.group.attrs ) - self.codes = self.group1d.copy(data=codes) - self.group_indices = group_indices - self.full_index = self.unique_coord + full_index = unique_coord + + return codes, group_indices, unique_coord, full_index - def _factorize_dummy(self, squeeze) -> None: + def _factorize_dummy(self, squeeze) -> T_FactorizeOut: size = self.group.size # no need to factorize if not squeeze: # use slices to do views instead of fancy indexing # equivalent to: group_indices = group_indices.reshape(-1, 1) - self.group_indices = [slice(i, i + 1) for i in range(size)] + group_indices: T_GroupIndices = [slice(i, i + 1) for i in range(size)] else: - self.group_indices = list(range(size)) - codes = np.arange(size) + group_indices = list(range(size)) + size_range = np.arange(size) if isinstance(self.group, _DummyGroup): - self.codes = self.group.to_dataarray().copy(data=codes) + codes = self.group.to_dataarray().copy(data=size_range) else: - self.codes = self.group.copy(data=codes) - self.unique_coord = self.group - self.full_index = IndexVariable(self.name, self.group.values, self.group.attrs) + codes = self.group.copy(data=size_range) + unique_coord = self.group + full_index = IndexVariable(self.name, unique_coord.values, self.group.attrs) + + return codes, group_indices, unique_coord, full_index @dataclass class ResolvedBinGrouper(ResolvedGrouper): grouper: BinGrouper - def factorize(self, squeeze: bool) -> None: + def _factorize(self, squeeze: bool) -> T_FactorizeOut: from xarray.core.dataarray import DataArray data = self.group1d.values binned, bins = pd.cut( data, self.grouper.bins, **self.grouper.cut_kwargs, retbins=True ) - codes = binned.codes - if (codes == -1).all(): + binned_codes = binned.codes + if (binned_codes == -1).all(): raise ValueError(f"None of the data falls within bins with edges {bins!r}") full_index = binned.categories - uniques = np.sort(pd.unique(codes)) + uniques = np.sort(pd.unique(binned_codes)) unique_values = full_index[uniques[uniques != -1]] - group_indices = [g for g in _codes_to_groups(codes, len(full_index)) if g] + group_indices = [ + g for g in _codes_to_groups(binned_codes, len(full_index)) if g + ] if len(group_indices) == 0: raise ValueError(f"None of the data falls within bins with edges {bins!r}") @@ -444,16 +461,17 @@ def factorize(self, squeeze: bool) -> None: self.group1d = DataArray( binned, getattr(self.group1d, "coords", None), name=new_dim_name ) - self.unique_coord = IndexVariable(new_dim_name, unique_values, self.group.attrs) - self.codes = self.group1d.copy(data=codes) + unique_coord = IndexVariable(new_dim_name, unique_values, self.group.attrs) + codes = self.group1d.copy(data=binned_codes) # TODO: support IntervalIndex in IndexVariable - self.full_index = full_index - self.group_indices = group_indices + + return codes, group_indices, unique_coord, full_index @dataclass class ResolvedTimeResampleGrouper(ResolvedGrouper): grouper: TimeResampleGrouper + index_grouper: CFTimeGrouper | pd.Grouper = field(init=False) def __post_init__(self) -> None: super().__post_init__() @@ -487,7 +505,7 @@ def __post_init__(self) -> None: origin=grouper.origin, offset=grouper.offset, ) - self.index_grouper: CFTimeGrouper | pd.Grouper = index_grouper + self.index_grouper = index_grouper def _get_index_and_items(self) -> tuple[pd.Index, pd.Series, np.ndarray]: first_items, codes = self.first_items() @@ -516,16 +534,20 @@ def first_items(self) -> tuple[pd.Series, np.ndarray]: _apply_loffset(self.grouper.loffset, first_items) return first_items, codes - def factorize(self, squeeze: bool) -> None: - self.full_index, first_items, codes = self._get_index_and_items() + def _factorize(self, squeeze: bool) -> T_FactorizeOut: + full_index, first_items, codes_ = self._get_index_and_items() sbins = first_items.values.astype(np.int64) - self.group_indices = [slice(i, j) for i, j in zip(sbins[:-1], sbins[1:])] - self.group_indices += [slice(sbins[-1], None)] + group_indices: T_GroupIndices = [ + slice(i, j) for i, j in zip(sbins[:-1], sbins[1:]) + ] + group_indices += [slice(sbins[-1], None)] - self.unique_coord = IndexVariable( + unique_coord = IndexVariable( self.group.name, first_items.index, self.group.attrs ) - self.codes = self.group.copy(data=codes) + codes = self.group.copy(data=codes_) + + return codes, group_indices, unique_coord, full_index class Grouper(ABC): From d06bdeb2deac34183ee16094fdd3e616e62c6b97 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 27 Apr 2023 20:28:57 +0000 Subject: [PATCH 30/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/core/groupby.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index cd2e83c4823..9b9c23acd67 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -49,9 +49,9 @@ from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset + from xarray.core.resample_cftime import CFTimeGrouper from xarray.core.types import DatetimeLike, SideOptions from xarray.core.utils import Frozen - from xarray.core.resample_cftime import CFTimeGrouper GroupKey = Any GroupIndex = Union[int, slice, list[int]] From 867629f1de8543b295e464f6dde44c687c66da6b Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Thu, 27 Apr 2023 22:39:05 +0200 Subject: [PATCH 31/36] Update groupby.py --- xarray/core/groupby.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 9b9c23acd67..121a28380e6 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -57,7 +57,7 @@ GroupIndex = Union[int, slice, list[int]] T_GroupIndices = list[GroupIndex] T_FactorizeOut = tuple[ - DataArray, T_GroupIndices, IndexVariable | "_DummyGroup", pd.Index + DataArray, T_GroupIndices, Union[IndexVariable, "_DummyGroup"], pd.Index ] From 89ab5084f2bed4435797b47c8129d3e3094daf8e Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Fri, 28 Apr 2023 08:19:41 -0600 Subject: [PATCH 32/36] Update xarray/core/groupby.py --- xarray/core/groupby.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 121a28380e6..982530b3038 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -361,6 +361,9 @@ def _factorize(self, squeeze: bool) -> T_FactorizeOut: raise NotImplementedError def factorize(self, squeeze: bool) -> None: + # This design makes it clear to mypy that + # codes, group_indices, unique_coord, and full_index + # are set by the factorize method on the derived class. ( self.codes, self.group_indices, From 8d7e6b83513d550cc08178e8a60c3e5a46769fe2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 28 Apr 2023 14:22:06 +0000 Subject: [PATCH 33/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/core/groupby.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 982530b3038..9479835e81c 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -361,9 +361,9 @@ def _factorize(self, squeeze: bool) -> T_FactorizeOut: raise NotImplementedError def factorize(self, squeeze: bool) -> None: - # This design makes it clear to mypy that - # codes, group_indices, unique_coord, and full_index - # are set by the factorize method on the derived class. + # This design makes it clear to mypy that + # codes, group_indices, unique_coord, and full_index + # are set by the factorize method on the derived class. ( self.codes, self.group_indices, From dde88660ffe44751309bd38644d31551c3da53c6 Mon Sep 17 00:00:00 2001 From: dcherian Date: Mon, 24 Apr 2023 22:24:24 -0600 Subject: [PATCH 34/36] Calculate group_indices only when necessary --- xarray/core/groupby.py | 52 +++++++++++++++++++++++------------------- 1 file changed, 28 insertions(+), 24 deletions(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 5f987c13664..7c72bcd40ff 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -57,7 +57,7 @@ GroupIndex = Union[int, slice, list[int]] T_GroupIndices = list[GroupIndex] T_FactorizeOut = tuple[ - DataArray, T_GroupIndices, Union[IndexVariable, "_DummyGroup"], pd.Index + DataArray, T_GroupIndices | None, Union[IndexVariable, "_DummyGroup"], pd.Index ] @@ -74,7 +74,7 @@ def check_reduce_dims(reduce_dims, dimensions): def unique_value_groups( ar, sort: bool = True -) -> tuple[np.ndarray | pd.Index, T_GroupIndices, np.ndarray]: +) -> tuple[np.ndarray | pd.Index, np.ndarray]: """Group an array by its unique values. Parameters @@ -95,8 +95,7 @@ def unique_value_groups( inverse, values = pd.factorize(ar, sort=sort) if isinstance(values, pd.MultiIndex): values.names = ar.names - groups = _codes_to_groups(inverse, len(values)) - return values, groups, inverse + return values, inverse def _codes_to_groups(inverse: np.ndarray, N: int) -> T_GroupIndices: @@ -318,10 +317,10 @@ class ResolvedGrouper(ABC, Generic[T_Xarray]): obj: T_Xarray _group_as_index: pd.Index | None = field(default=None, init=False) + _group_indices: T_GroupIndices | None = field(default=None, init=False) # Defined by factorize: codes: DataArray = field(init=False) - group_indices: T_GroupIndices = field(init=False) unique_coord: IndexVariable | _DummyGroup = field(init=False) full_index: pd.Index = field(init=False) @@ -366,11 +365,20 @@ def factorize(self, squeeze: bool) -> None: # are set by the factorize method on the derived class. ( self.codes, - self.group_indices, + self._group_indices, self.unique_coord, self.full_index, ) = self._factorize(squeeze) + @property + def group_indices(self) -> T_GroupIndices: + if self._group_indices is None: + self._group_indices = [ + g for g in _codes_to_groups(self.codes.data, len(self.full_index)) if g + ] + + return self._group_indices + @property def is_unique_and_monotonic(self) -> bool: if isinstance(self.group, _DummyGroup): @@ -399,21 +407,18 @@ def _factorize(self, squeeze) -> T_FactorizeOut: def _factorize_unique(self) -> T_FactorizeOut: # look through group to find the unique values sort = not isinstance(self.group_as_index, pd.MultiIndex) - unique_values, group_indices, codes_ = unique_value_groups( - self.group_as_index, sort=sort - ) - if len(group_indices) == 0: + unique_values, codes_ = unique_value_groups(self.group_as_index, sort=sort) + if (codes_ == -1).all(): raise ValueError( "Failed to group data. Are you grouping by a variable that is all NaN?" ) codes = self.group1d.copy(data=codes_) - group_indices = group_indices unique_coord = IndexVariable( self.group.name, unique_values, attrs=self.group.attrs ) full_index = unique_coord - return codes, group_indices, unique_coord, full_index + return codes, None, unique_coord, full_index def _factorize_dummy(self, squeeze) -> T_FactorizeOut: size = self.group.size @@ -453,13 +458,6 @@ def _factorize(self, squeeze: bool) -> T_FactorizeOut: full_index = binned.categories uniques = np.sort(pd.unique(binned_codes)) unique_values = full_index[uniques[uniques != -1]] - group_indices = [ - g for g in _codes_to_groups(binned_codes, len(full_index)) if g - ] - - if len(group_indices) == 0: - raise ValueError(f"None of the data falls within bins with edges {bins!r}") - new_dim_name = str(self.group.name) + "_bins" self.group1d = DataArray( binned, getattr(self.group1d, "coords", None), name=new_dim_name @@ -467,8 +465,7 @@ def _factorize(self, squeeze: bool) -> T_FactorizeOut: unique_coord = IndexVariable(new_dim_name, unique_values, self.group.attrs) codes = self.group1d.copy(data=binned_codes) # TODO: support IntervalIndex in IndexVariable - - return codes, group_indices, unique_coord, full_index + return codes, None, unique_coord, full_index @dataclass @@ -665,7 +662,7 @@ class GroupBy(Generic[T_Xarray]): "_inserted_dims", "_group", "_group_dim", - "_group_indices", + "__group_indices", "_groups", "groupers", "_obj", @@ -688,7 +685,7 @@ class GroupBy(Generic[T_Xarray]): _original_obj: T_Xarray _original_group: T_Group - _group_indices: T_GroupIndices + __group_indices: T_GroupIndices _codes: DataArray _group_dim: Hashable @@ -731,15 +728,22 @@ def __init__( self._squeeze = squeeze # These should generalize to multiple groupers - self._group_indices = grouper.group_indices self._codes = self._maybe_unstack(grouper.codes) + self.__group_indices = None (self._group_dim,) = grouper.group1d.dims # cached attributes self._groups = None self._dims = None self._sizes = None + @property + def _group_indices(self): + if self.__group_indices is None: + (grouper,) = self.groupers + self.__group_indices = grouper.group_indices + return self.__group_indices + @property def sizes(self) -> Frozen[Hashable, int]: """Ordered mapping from dimension names to lengths. From b719976ae70feb903c4a931a721869d5f10090ff Mon Sep 17 00:00:00 2001 From: dcherian Date: Tue, 2 May 2023 11:11:11 -0600 Subject: [PATCH 35/36] Revert "Calculate group_indices only when necessary" This reverts commit 917c77efb05bacffcf901e61eabb9defc9a429d7. --- xarray/core/groupby.py | 52 +++++++++++++++++++----------------------- 1 file changed, 24 insertions(+), 28 deletions(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 7c72bcd40ff..5f987c13664 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -57,7 +57,7 @@ GroupIndex = Union[int, slice, list[int]] T_GroupIndices = list[GroupIndex] T_FactorizeOut = tuple[ - DataArray, T_GroupIndices | None, Union[IndexVariable, "_DummyGroup"], pd.Index + DataArray, T_GroupIndices, Union[IndexVariable, "_DummyGroup"], pd.Index ] @@ -74,7 +74,7 @@ def check_reduce_dims(reduce_dims, dimensions): def unique_value_groups( ar, sort: bool = True -) -> tuple[np.ndarray | pd.Index, np.ndarray]: +) -> tuple[np.ndarray | pd.Index, T_GroupIndices, np.ndarray]: """Group an array by its unique values. Parameters @@ -95,7 +95,8 @@ def unique_value_groups( inverse, values = pd.factorize(ar, sort=sort) if isinstance(values, pd.MultiIndex): values.names = ar.names - return values, inverse + groups = _codes_to_groups(inverse, len(values)) + return values, groups, inverse def _codes_to_groups(inverse: np.ndarray, N: int) -> T_GroupIndices: @@ -317,10 +318,10 @@ class ResolvedGrouper(ABC, Generic[T_Xarray]): obj: T_Xarray _group_as_index: pd.Index | None = field(default=None, init=False) - _group_indices: T_GroupIndices | None = field(default=None, init=False) # Defined by factorize: codes: DataArray = field(init=False) + group_indices: T_GroupIndices = field(init=False) unique_coord: IndexVariable | _DummyGroup = field(init=False) full_index: pd.Index = field(init=False) @@ -365,20 +366,11 @@ def factorize(self, squeeze: bool) -> None: # are set by the factorize method on the derived class. ( self.codes, - self._group_indices, + self.group_indices, self.unique_coord, self.full_index, ) = self._factorize(squeeze) - @property - def group_indices(self) -> T_GroupIndices: - if self._group_indices is None: - self._group_indices = [ - g for g in _codes_to_groups(self.codes.data, len(self.full_index)) if g - ] - - return self._group_indices - @property def is_unique_and_monotonic(self) -> bool: if isinstance(self.group, _DummyGroup): @@ -407,18 +399,21 @@ def _factorize(self, squeeze) -> T_FactorizeOut: def _factorize_unique(self) -> T_FactorizeOut: # look through group to find the unique values sort = not isinstance(self.group_as_index, pd.MultiIndex) - unique_values, codes_ = unique_value_groups(self.group_as_index, sort=sort) - if (codes_ == -1).all(): + unique_values, group_indices, codes_ = unique_value_groups( + self.group_as_index, sort=sort + ) + if len(group_indices) == 0: raise ValueError( "Failed to group data. Are you grouping by a variable that is all NaN?" ) codes = self.group1d.copy(data=codes_) + group_indices = group_indices unique_coord = IndexVariable( self.group.name, unique_values, attrs=self.group.attrs ) full_index = unique_coord - return codes, None, unique_coord, full_index + return codes, group_indices, unique_coord, full_index def _factorize_dummy(self, squeeze) -> T_FactorizeOut: size = self.group.size @@ -458,6 +453,13 @@ def _factorize(self, squeeze: bool) -> T_FactorizeOut: full_index = binned.categories uniques = np.sort(pd.unique(binned_codes)) unique_values = full_index[uniques[uniques != -1]] + group_indices = [ + g for g in _codes_to_groups(binned_codes, len(full_index)) if g + ] + + if len(group_indices) == 0: + raise ValueError(f"None of the data falls within bins with edges {bins!r}") + new_dim_name = str(self.group.name) + "_bins" self.group1d = DataArray( binned, getattr(self.group1d, "coords", None), name=new_dim_name @@ -465,7 +467,8 @@ def _factorize(self, squeeze: bool) -> T_FactorizeOut: unique_coord = IndexVariable(new_dim_name, unique_values, self.group.attrs) codes = self.group1d.copy(data=binned_codes) # TODO: support IntervalIndex in IndexVariable - return codes, None, unique_coord, full_index + + return codes, group_indices, unique_coord, full_index @dataclass @@ -662,7 +665,7 @@ class GroupBy(Generic[T_Xarray]): "_inserted_dims", "_group", "_group_dim", - "__group_indices", + "_group_indices", "_groups", "groupers", "_obj", @@ -685,7 +688,7 @@ class GroupBy(Generic[T_Xarray]): _original_obj: T_Xarray _original_group: T_Group - __group_indices: T_GroupIndices + _group_indices: T_GroupIndices _codes: DataArray _group_dim: Hashable @@ -728,22 +731,15 @@ def __init__( self._squeeze = squeeze # These should generalize to multiple groupers + self._group_indices = grouper.group_indices self._codes = self._maybe_unstack(grouper.codes) - self.__group_indices = None (self._group_dim,) = grouper.group1d.dims # cached attributes self._groups = None self._dims = None self._sizes = None - @property - def _group_indices(self): - if self.__group_indices is None: - (grouper,) = self.groupers - self.__group_indices = grouper.group_indices - return self.__group_indices - @property def sizes(self) -> Frozen[Hashable, int]: """Ordered mapping from dimension names to lengths. From 265f1dd61e98cd8186bdf92e81b174efb45ff5f5 Mon Sep 17 00:00:00 2001 From: dcherian Date: Tue, 2 May 2023 11:40:59 -0600 Subject: [PATCH 36/36] Fix regression from deep copy --- xarray/core/groupby.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 5f987c13664..55fe103d41e 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -609,7 +609,7 @@ def _resolve_group(obj: T_Xarray, group: T_Group | Hashable) -> T_Group: except ValueError: raise ValueError(error_msg) - newgroup = group.copy() + newgroup = group.copy(deep=False) newgroup.name = group.name or "group" elif isinstance(group, IndexVariable):