Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

groupby: remove some internal use of IndexVariable #9123

Merged
merged 5 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 41 additions & 22 deletions xarray/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@
from xarray.core.arithmetic import DataArrayGroupbyArithmetic, DatasetGroupbyArithmetic
from xarray.core.common import ImplementsArrayReduce, ImplementsDatasetReduce
from xarray.core.concat import concat
from xarray.core.coordinates import Coordinates
from xarray.core.formatting import format_array_flat
from xarray.core.indexes import (
PandasIndex,
create_default_index_implicit,
filter_indexes_from_coords,
)
Expand Down Expand Up @@ -246,7 +248,7 @@ def to_array(self) -> DataArray:
return self.to_dataarray()


T_Group = Union["T_DataArray", "IndexVariable", _DummyGroup]
T_Group = Union["T_DataArray", _DummyGroup]


def _ensure_1d(group: T_Group, obj: T_DataWithCoords) -> tuple[
Expand All @@ -256,7 +258,7 @@ def _ensure_1d(group: T_Group, obj: T_DataWithCoords) -> tuple[
list[Hashable],
]:
# 1D cases: do nothing
if isinstance(group, (IndexVariable, _DummyGroup)) or group.ndim == 1:
if isinstance(group, _DummyGroup) or group.ndim == 1:
return group, obj, None, []

from xarray.core.dataarray import DataArray
Expand All @@ -271,9 +273,7 @@ def _ensure_1d(group: T_Group, obj: T_DataWithCoords) -> tuple[
newobj = obj.stack({stacked_dim: orig_dims})
return newgroup, newobj, stacked_dim, inserted_dims

raise TypeError(
f"group must be DataArray, IndexVariable or _DummyGroup, got {type(group)!r}."
)
raise TypeError(f"group must be DataArray or _DummyGroup, got {type(group)!r}.")


@dataclass
Expand All @@ -299,7 +299,7 @@ class ResolvedGrouper(Generic[T_DataWithCoords]):
codes: DataArray = field(init=False)
full_index: pd.Index = field(init=False)
group_indices: T_GroupIndices = field(init=False)
unique_coord: IndexVariable | _DummyGroup = field(init=False)
unique_coord: Variable | _DummyGroup = field(init=False)

# _ensure_1d:
group1d: T_Group = field(init=False)
Expand All @@ -315,7 +315,7 @@ def __post_init__(self) -> None:
# might be used multiple times.
self.grouper = copy.deepcopy(self.grouper)

self.group: T_Group = _resolve_group(self.obj, self.group)
self.group = _resolve_group(self.obj, self.group)

(
self.group1d,
Expand All @@ -328,14 +328,18 @@ def __post_init__(self) -> None:

@property
def name(self) -> Hashable:
"""Name for the grouped coordinate after reduction."""
# the name has to come from unique_coord because we need `_bins` suffix for BinGrouper
return self.unique_coord.name
(name,) = self.unique_coord.dims
return name

@property
def size(self) -> int:
"""Number of groups."""
return len(self)

def __len__(self) -> int:
"""Number of groups."""
return len(self.full_index)

@property
Expand All @@ -358,8 +362,8 @@ def factorize(self) -> None:
]
if encoded.unique_coord is None:
unique_values = self.full_index[np.unique(encoded.codes)]
self.unique_coord = IndexVariable(
self.codes.name, unique_values, attrs=self.group.attrs
self.unique_coord = Variable(
dims=self.codes.name, data=unique_values, attrs=self.group.attrs
)
else:
self.unique_coord = encoded.unique_coord
Expand All @@ -378,7 +382,9 @@ def _validate_groupby_squeeze(squeeze: bool | None) -> None:
)


def _resolve_group(obj: T_DataWithCoords, group: T_Group | Hashable) -> T_Group:
def _resolve_group(
obj: T_DataWithCoords, group: T_Group | Hashable | IndexVariable
) -> T_Group:
from xarray.core.dataarray import DataArray

error_msg = (
Expand Down Expand Up @@ -620,6 +626,8 @@ def _iter_grouped(self, warn_squeeze=True) -> Iterator[T_Xarray]:
yield self._obj.isel({self._group_dim: indices})

def _infer_concat_args(self, applied_example):
from xarray.core.groupers import BinGrouper

(grouper,) = self.groupers
if self._group_dim in applied_example.dims:
coord = grouper.group1d
Expand All @@ -628,7 +636,10 @@ def _infer_concat_args(self, applied_example):
coord = grouper.unique_coord
positions = None
(dim,) = coord.dims
if isinstance(coord, _DummyGroup):
if isinstance(grouper.group, _DummyGroup) and not isinstance(
grouper.grouper, BinGrouper
):
# When binning we actually do set the index
coord = None
coord = getattr(coord, "variable", coord)
return coord, dim, positions
Expand All @@ -641,6 +652,7 @@ def _binary_op(self, other, f, reflexive=False):

(grouper,) = self.groupers
obj = self._original_obj
name = grouper.name
group = grouper.group
codes = self._codes
dims = group.dims
Expand All @@ -649,9 +661,11 @@ def _binary_op(self, other, f, reflexive=False):
group = coord = group.to_dataarray()
else:
coord = grouper.unique_coord
if not isinstance(coord, DataArray):
coord = DataArray(grouper.unique_coord)
name = grouper.name
if isinstance(coord, Variable):
assert coord.ndim == 1
(coord_dim,) = coord.dims
# TODO: explicitly create Index here
coord = DataArray(coord, coords={coord_dim: coord.data})

if not isinstance(other, (Dataset, DataArray)):
raise TypeError(
Expand Down Expand Up @@ -766,6 +780,7 @@ def _flox_reduce(

obj = self._original_obj
(grouper,) = self.groupers
name = grouper.name
isbin = isinstance(grouper.grouper, BinGrouper)

if keep_attrs is None:
Expand Down Expand Up @@ -797,14 +812,14 @@ def _flox_reduce(
# weird backcompat
# reducing along a unique indexed dimension with squeeze=True
# should raise an error
if (dim is None or dim == grouper.name) and grouper.name in obj.xindexes:
index = obj.indexes[grouper.name]
if (dim is None or dim == name) and name in obj.xindexes:
index = obj.indexes[name]
if index.is_unique and self._squeeze:
raise ValueError(f"cannot reduce over dimensions {grouper.name!r}")
raise ValueError(f"cannot reduce over dimensions {name!r}")

unindexed_dims: tuple[Hashable, ...] = tuple()
if isinstance(grouper.group, _DummyGroup) and not isbin:
unindexed_dims = (grouper.name,)
unindexed_dims = (name,)

parsed_dim: tuple[Hashable, ...]
if isinstance(dim, str):
Expand Down Expand Up @@ -848,15 +863,19 @@ def _flox_reduce(
# in the grouped variable
group_dims = grouper.group.dims
if set(group_dims).issubset(set(parsed_dim)):
result[grouper.name] = output_index
result = result.assign_coords(
Coordinates(
coords={name: (name, np.array(output_index))},
indexes={name: PandasIndex(output_index, dim=name)},
)
)
result = result.drop_vars(unindexed_dims)

# broadcast and restore non-numeric data variables (backcompat)
for name, var in non_numeric.items():
if all(d not in var.dims for d in parsed_dim):
result[name] = var.variable.set_dims(
(grouper.name,) + var.dims,
(result.sizes[grouper.name],) + var.shape,
(name,) + var.dims, (result.sizes[name],) + var.shape
)

if not isinstance(result, Dataset):
Expand Down
37 changes: 26 additions & 11 deletions xarray/core/groupers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from xarray.core.resample_cftime import CFTimeGrouper
from xarray.core.types import DatetimeLike, SideOptions, T_GroupIndices
from xarray.core.utils import emit_user_level_warning
from xarray.core.variable import IndexVariable
from xarray.core.variable import Variable

__all__ = [
"EncodedGroups",
Expand Down Expand Up @@ -55,7 +55,17 @@ class EncodedGroups:
codes: DataArray
full_index: pd.Index
group_indices: T_GroupIndices | None = field(default=None)
unique_coord: IndexVariable | _DummyGroup | None = field(default=None)
unique_coord: Variable | _DummyGroup | None = field(default=None)

def __post_init__(self):
assert isinstance(self.codes, DataArray)
if self.codes.name is None:
raise ValueError("Please set a name on the array you are grouping by.")
assert isinstance(self.full_index, pd.Index)
assert (
isinstance(self.unique_coord, (Variable, _DummyGroup))
or self.unique_coord is None
)


class Grouper(ABC):
Expand Down Expand Up @@ -134,10 +144,10 @@ def _factorize_unique(self) -> EncodedGroups:
"Failed to group data. Are you grouping by a variable that is all NaN?"
)
codes = self.group.copy(data=codes_)
unique_coord = IndexVariable(
self.group.name, unique_values, attrs=self.group.attrs
unique_coord = Variable(
dims=codes.name, data=unique_values, attrs=self.group.attrs
)
full_index = unique_coord
full_index = pd.Index(unique_values)

return EncodedGroups(
codes=codes, full_index=full_index, unique_coord=unique_coord
Expand All @@ -152,12 +162,13 @@ def _factorize_dummy(self) -> EncodedGroups:
size_range = np.arange(size)
if isinstance(self.group, _DummyGroup):
codes = self.group.to_dataarray().copy(data=size_range)
unique_coord = self.group
full_index = pd.RangeIndex(self.group.size)
else:
codes = self.group.copy(data=size_range)
unique_coord = self.group
full_index = IndexVariable(
self.group.name, unique_coord.values, self.group.attrs
)
unique_coord = self.group.variable.to_base_variable()
full_index = pd.Index(unique_coord.data)

return EncodedGroups(
codes=codes,
group_indices=group_indices,
Expand Down Expand Up @@ -201,7 +212,9 @@ def factorize(self, group) -> EncodedGroups:
codes = DataArray(
binned_codes, getattr(group, "coords", None), name=new_dim_name
)
unique_coord = IndexVariable(new_dim_name, pd.Index(unique_values), group.attrs)
unique_coord = Variable(
dims=new_dim_name, data=unique_values, attrs=group.attrs
)
return EncodedGroups(
codes=codes, full_index=full_index, unique_coord=unique_coord
)
Expand Down Expand Up @@ -318,7 +331,9 @@ def factorize(self, group) -> EncodedGroups:
]
group_indices += [slice(sbins[-1], None)]

unique_coord = IndexVariable(group.name, first_items.index, group.attrs)
unique_coord = Variable(
dims=group.name, data=first_items.index, attrs=group.attrs
)
codes = group.copy(data=codes_)

return EncodedGroups(
Expand Down
Loading