Skip to content

Commit

Permalink
Automatically chunk other in GroupBy binary ops. (pydata#7684)
Browse files Browse the repository at this point in the history
* Automatically chunk `other` in GroupBy binary ops.

Closes pydata#7683

* Update xarray/core/groupby.py

* Add test

* Update xarray/core/groupby.py
  • Loading branch information
dcherian authored Jul 27, 2023
1 parent db12b0d commit 52f5cf1
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 0 deletions.
15 changes: 15 additions & 0 deletions xarray/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -888,6 +888,21 @@ def _binary_op(self, other, f, reflexive=False):
group = group.where(~mask, drop=True)
codes = codes.where(~mask, drop=True).astype(int)

# if other is dask-backed, that's a hint that the
# "expanded" dataset is too big to hold in memory.
# this can be the case when `other` was read from disk
# and contains our lazy indexing classes
# We need to check for dask-backed Datasets
# so utils.is_duck_dask_array does not work for this check
if obj.chunks and not other.chunks:
# TODO: What about datasets with some dask vars, and others not?
# This handles dims other than `name``
chunks = {k: v for k, v in obj.chunksizes.items() if k in other.dims}
# a chunk size of 1 seems reasonable since we expect individual elements of
# other to be repeated multiple times across the reduced dimension(s)
chunks[name] = 1
other = other.chunk(chunks)

# codes are defined for coord, so we align `other` with `coord`
# before indexing
other, _ = align(other, coord, join="right", copy=False)
Expand Down
15 changes: 15 additions & 0 deletions xarray/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from xarray import DataArray, Dataset, Variable
from xarray.core.groupby import _consolidate_slices
from xarray.tests import (
InaccessibleArray,
assert_allclose,
assert_array_equal,
assert_equal,
Expand Down Expand Up @@ -2392,3 +2393,17 @@ def test_min_count_error(use_flox: bool) -> None:
with xr.set_options(use_flox=use_flox):
with pytest.raises(TypeError):
da.groupby("labels").mean(min_count=1)


@requires_dask
def test_groupby_math_auto_chunk():
da = xr.DataArray(
[[1, 2, 3], [1, 2, 3], [1, 2, 3]],
dims=("y", "x"),
coords={"label": ("x", [2, 2, 1])},
)
sub = xr.DataArray(
InaccessibleArray(np.array([1, 2])), dims="label", coords={"label": [1, 2]}
)
actual = da.chunk(x=1, y=2).groupby("label") - sub
assert actual.chunksizes == {"x": (1, 1, 1), "y": (2, 1)}

0 comments on commit 52f5cf1

Please sign in to comment.