Skip to content

Commit

Permalink
codes is always a DataArray.
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Nov 2, 2022
1 parent 13f350e commit b64df5b
Showing 1 changed file with 15 additions and 15 deletions.
30 changes: 15 additions & 15 deletions xarray/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,16 +490,20 @@ def __init__(
unique_coord, group_indices, codes, full_index = _factorize_grouper(
group, grouper
)
self._codes = group.copy(data=codes)
elif bins is not None:
unique_coord, group_indices, codes, full_index, group = _factorize_bins(
group, bins, cut_kwargs
)
self._codes = group.copy(data=codes)
elif group.dims == (group.name,) and _unique_and_monotonic(group):
unique_coord, group_indices, codes = _factorize_dummy(group, squeeze)
full_index = None
self._codes = obj[group.name].copy(data=codes)
else:
unique_coord, group_indices, codes = _factorize_rest(group)
full_index = None
self._codes = group.copy(data=codes)

# specification for the groupby operation
self._obj: T_Xarray = obj
Expand All @@ -513,7 +517,7 @@ def __init__(
self._restore_coord_dims = restore_coord_dims
self._bins = bins
self._squeeze = squeeze
self._codes = codes
self._codes = self._maybe_unstack(self._codes)

# cached attributes
self._groups: dict[GroupKey, slice | int | list[int]] | None = None
Expand Down Expand Up @@ -616,6 +620,7 @@ def _binary_op(self, other, f, reflexive=False):

obj = self._original_obj
group = self._original_group
codes = self._codes
dims = group.dims

if isinstance(group, _DummyGroup):
Expand Down Expand Up @@ -650,16 +655,15 @@ def _binary_op(self, other, f, reflexive=False):
other[var].drop_vars(var).expand_dims({name: other.sizes[name]})
)

if (self._codes == -1).any():
# need to handle NaNs in group or
# elements that don't belong to any bins
# for nD group, we need to work with the stacked versions
mask = self._group.notnull()
obj = self._maybe_unstack(self._obj.where(mask, drop=True))
group = self._maybe_unstack(self._group.dropna(self._group_dim))
# need to handle NaNs in group or
# elements that don't belong to any bins
mask = self._codes == -1
if mask.any():
obj = self._original_obj.where(~mask, drop=True)
codes = self._codes.where(~mask, drop=True).astype(int)

other, _ = align(other, coord, join="outer")
expanded = other.sel({name: group})
expanded = other.isel({name: codes})

result = g(obj, expanded)

Expand Down Expand Up @@ -778,14 +782,10 @@ def _flox_reduce(
# as a kwarg for count, so this should be OK
kwargs["min_count"] = 1

# rename to handle binning where name has "_bins" added
group_name = self._group.name
codes = group.copy(data=self._codes.reshape(group.shape)).rename(group_name)

output_index = self._get_output_index()
result = xarray_reduce(
obj.drop_vars(non_numeric.keys()),
codes,
self._codes,
dim=parsed_dim,
# pass RangeIndex as a hint to flox that `by` is already factorized
expected_groups=(pd.RangeIndex(len(output_index)),),
Expand All @@ -796,7 +796,7 @@ def _flox_reduce(

# we did end up reducing over dimension(s) that are
# in the grouped variable
if set(codes.dims).issubset(set(parsed_dim)):
if set(self._codes.dims).issubset(set(parsed_dim)):
result[self._unique_coord.name] = output_index

# Ignore error when the groupby reduction is effectively
Expand Down

0 comments on commit b64df5b

Please sign in to comment.