Skip to content

Commit

Permalink
Allow indexing unindexed dimensions using dask arrays (#5873)
Browse files Browse the repository at this point in the history
* Attempt to fix indexing for Dask

This is a naive attempt to make `isel` work with Dask

Known limitation: it triggers the computation.

* Works now.

* avoid importorskip

* More tests and fixes

* Raise nicer error when indexing with boolean dask array

* Annotate tests

* edit query tests

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fixes #4276

Pass 0d dask arrays through for indexing.

* Add xfail notes.

* backcompat: vendor np.broadcast_shapes

* Small improvement

* fix: Handle scalars properly.

* fix bad test

* Check computes with setitem

* Better error

* Cleanup

* Raise nice error with VectorizedIndexer and dask.

* Add whats-new

---------

Co-authored-by: dcherian <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Deepak Cherian <[email protected]>
  • Loading branch information
4 people authored Mar 15, 2023
1 parent 5043223 commit 83e159e
Show file tree
Hide file tree
Showing 10 changed files with 211 additions and 60 deletions.
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ New Features

- Fix :py:meth:`xr.cov` and :py:meth:`xr.corr` now support complex valued arrays (:issue:`7340`, :pull:`7392`).
By `Michael Niklas <https://github.com/headtr1ck>`_.
- Allow indexing along unindexed dimensions with dask arrays
(:issue:`2511`, :issue:`4276`, :issue:`4663`, :pull:`5873`).
By `Abel Aoun <https://github.com/bzah>`_ and `Deepak Cherian <https://github.com/dcherian>`_.
- Support dask arrays in ``first`` and ``last`` reductions.
By `Deepak Cherian <https://github.com/dcherian>`_.

Expand Down
5 changes: 3 additions & 2 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
)
from xarray.core.missing import get_clean_interp_index
from xarray.core.options import OPTIONS, _get_keep_attrs
from xarray.core.pycompat import array_type, is_duck_dask_array
from xarray.core.pycompat import array_type, is_duck_array, is_duck_dask_array
from xarray.core.types import QuantileMethods, T_Dataset
from xarray.core.utils import (
Default,
Expand Down Expand Up @@ -2292,7 +2292,8 @@ def _validate_indexers(
elif isinstance(v, Sequence) and len(v) == 0:
yield k, np.empty((0,), dtype="int64")
else:
v = np.asarray(v)
if not is_duck_array(v):
v = np.asarray(v)

if v.dtype.kind in "US":
index = self._indexes[k].to_pandas_index()
Expand Down
32 changes: 22 additions & 10 deletions xarray/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@
from xarray.core import duck_array_ops
from xarray.core.nputils import NumpyVIndexAdapter
from xarray.core.options import OPTIONS
from xarray.core.pycompat import array_type, integer_types, is_duck_dask_array
from xarray.core.pycompat import (
array_type,
integer_types,
is_duck_array,
is_duck_dask_array,
)
from xarray.core.types import T_Xarray
from xarray.core.utils import (
NDArrayMixin,
Expand Down Expand Up @@ -368,17 +373,17 @@ def __init__(self, key):
k = int(k)
elif isinstance(k, slice):
k = as_integer_slice(k)
elif isinstance(k, np.ndarray):
elif is_duck_array(k):
if not np.issubdtype(k.dtype, np.integer):
raise TypeError(
f"invalid indexer array, does not have integer dtype: {k!r}"
)
if k.ndim != 1:
if k.ndim > 1:
raise TypeError(
f"invalid indexer array for {type(self).__name__}; must have "
f"exactly 1 dimension: {k!r}"
f"invalid indexer array for {type(self).__name__}; must be scalar "
f"or have 1 dimension: {k!r}"
)
k = np.asarray(k, dtype=np.int64)
k = k.astype(np.int64)
else:
raise TypeError(
f"unexpected indexer type for {type(self).__name__}: {k!r}"
Expand Down Expand Up @@ -409,7 +414,13 @@ def __init__(self, key):
for k in key:
if isinstance(k, slice):
k = as_integer_slice(k)
elif isinstance(k, np.ndarray):
elif is_duck_dask_array(k):
raise ValueError(
"Vectorized indexing with Dask arrays is not supported. "
"Please pass a numpy array by calling ``.compute``. "
"See https://github.com/dask/dask/issues/8958."
)
elif is_duck_array(k):
if not np.issubdtype(k.dtype, np.integer):
raise TypeError(
f"invalid indexer array, does not have integer dtype: {k!r}"
Expand All @@ -422,7 +433,7 @@ def __init__(self, key):
"invalid indexer key: ndarray arguments "
f"have different numbers of dimensions: {ndims}"
)
k = np.asarray(k, dtype=np.int64)
k = k.astype(np.int64)
else:
raise TypeError(
f"unexpected indexer type for {type(self).__name__}: {k!r}"
Expand Down Expand Up @@ -1351,8 +1362,9 @@ def __getitem__(self, key):
rewritten_indexer = False
new_indexer = []
for idim, k in enumerate(key.tuple):
if isinstance(k, Iterable) and duck_array_ops.array_equiv(
k, np.arange(self.array.shape[idim])
if isinstance(k, Iterable) and (
not is_duck_dask_array(k)
and duck_array_ops.array_equiv(k, np.arange(self.array.shape[idim]))
):
new_indexer.append(slice(None))
rewritten_indexer = True
Expand Down
6 changes: 5 additions & 1 deletion xarray/core/nputils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from numpy.core.multiarray import normalize_axis_index # type: ignore[attr-defined]

from xarray.core.options import OPTIONS
from xarray.core.pycompat import is_duck_array

try:
import bottleneck as bn
Expand Down Expand Up @@ -121,7 +122,10 @@ def _advanced_indexer_subspaces(key):
return (), ()

non_slices = [k for k in key if not isinstance(k, slice)]
ndim = len(np.broadcast(*non_slices).shape)
broadcasted_shape = np.broadcast_shapes(
*[item.shape if is_duck_array(item) else (0,) for item in non_slices]
)
ndim = len(broadcasted_shape)
mixed_positions = advanced_index_positions[0] + np.arange(ndim)
vindex_positions = np.arange(ndim)
return mixed_positions, vindex_positions
Expand Down
6 changes: 5 additions & 1 deletion xarray/core/pycompat.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np
from packaging.version import Version

from xarray.core.utils import is_duck_array, module_available
from xarray.core.utils import is_duck_array, is_scalar, module_available

integer_types = (int, np.integer)

Expand Down Expand Up @@ -79,3 +79,7 @@ def is_dask_collection(x):

def is_duck_dask_array(x):
return is_duck_array(x) and is_dask_collection(x)


def is_0d_dask_array(x):
return is_duck_dask_array(x) and is_scalar(x)
30 changes: 23 additions & 7 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,12 @@
as_indexable,
)
from xarray.core.options import OPTIONS, _get_keep_attrs
from xarray.core.pycompat import array_type, integer_types, is_duck_dask_array
from xarray.core.pycompat import (
array_type,
integer_types,
is_0d_dask_array,
is_duck_dask_array,
)
from xarray.core.utils import (
Frozen,
NdimSizeLenMixin,
Expand Down Expand Up @@ -687,11 +692,12 @@ def _broadcast_indexes(self, key):
key = self._item_key_to_tuple(key) # key is a tuple
# key is a tuple of full size
key = indexing.expanded_indexer(key, self.ndim)
# Convert a scalar Variable to an integer
# Convert a scalar Variable to a 0d-array
key = tuple(
k.data.item() if isinstance(k, Variable) and k.ndim == 0 else k for k in key
k.data if isinstance(k, Variable) and k.ndim == 0 else k for k in key
)
# Convert a 0d-array to an integer
# Convert a 0d numpy arrays to an integer
# dask 0d arrays are passed through
key = tuple(
k.item() if isinstance(k, np.ndarray) and k.ndim == 0 else k for k in key
)
Expand Down Expand Up @@ -732,7 +738,8 @@ def _validate_indexers(self, key):
for dim, k in zip(self.dims, key):
if not isinstance(k, BASIC_INDEXING_TYPES):
if not isinstance(k, Variable):
k = np.asarray(k)
if not is_duck_array(k):
k = np.asarray(k)
if k.ndim > 1:
raise IndexError(
"Unlabeled multi-dimensional array cannot be "
Expand All @@ -749,6 +756,13 @@ def _validate_indexers(self, key):
"{}-dimensional boolean indexing is "
"not supported. ".format(k.ndim)
)
if is_duck_dask_array(k.data):
raise KeyError(
"Indexing with a boolean dask array is not allowed. "
"This will result in a dask array of unknown shape. "
"Such arrays are unsupported by Xarray."
"Please compute the indexer first using .compute()"
)
if getattr(k, "dims", (dim,)) != (dim,):
raise IndexError(
"Boolean indexer should be unlabeled or on the "
Expand All @@ -759,18 +773,20 @@ def _validate_indexers(self, key):
)

def _broadcast_indexes_outer(self, key):
# drop dim if k is integer or if k is a 0d dask array
dims = tuple(
k.dims[0] if isinstance(k, Variable) else dim
for k, dim in zip(key, self.dims)
if not isinstance(k, integer_types)
if (not isinstance(k, integer_types) and not is_0d_dask_array(k))
)

new_key = []
for k in key:
if isinstance(k, Variable):
k = k.data
if not isinstance(k, BASIC_INDEXING_TYPES):
k = np.asarray(k)
if not is_duck_array(k):
k = np.asarray(k)
if k.size == 0:
# Slice by empty list; numpy could not infer the dtype
k = k.astype(int)
Expand Down
7 changes: 4 additions & 3 deletions xarray/tests/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,14 +123,15 @@ def test_indexing(self):
(da.array([99, 99, 3, 99]), [0, -1, 1]),
(da.array([99, 99, 99, 4]), np.arange(3)),
(da.array([1, 99, 99, 99]), [False, True, True, True]),
(da.array([1, 99, 99, 99]), np.arange(4) > 0),
(da.array([99, 99, 99, 99]), Variable(("x"), da.array([1, 2, 3, 4])) > 0),
(da.array([1, 99, 99, 99]), np.array([False, True, True, True])),
(da.array([99, 99, 99, 99]), Variable(("x"), np.array([True] * 4))),
],
)
def test_setitem_dask_array(self, expected_data, index):
arr = Variable(("x"), da.array([1, 2, 3, 4]))
expected = Variable(("x"), expected_data)
arr[index] = 99
with raise_if_dask_computes():
arr[index] = 99
assert_identical(arr, expected)

def test_squeeze(self):
Expand Down
34 changes: 19 additions & 15 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -4216,45 +4216,49 @@ def test_query(
d = np.random.choice(["foo", "bar", "baz"], size=30, replace=True).astype(
object
)
if backend == "numpy":
aa = DataArray(data=a, dims=["x"], name="a")
bb = DataArray(data=b, dims=["x"], name="b")
cc = DataArray(data=c, dims=["y"], name="c")
dd = DataArray(data=d, dims=["z"], name="d")
aa = DataArray(data=a, dims=["x"], name="a", coords={"a2": ("x", a)})
bb = DataArray(data=b, dims=["x"], name="b", coords={"b2": ("x", b)})
cc = DataArray(data=c, dims=["y"], name="c", coords={"c2": ("y", c)})
dd = DataArray(data=d, dims=["z"], name="d", coords={"d2": ("z", d)})

elif backend == "dask":
if backend == "dask":
import dask.array as da

aa = DataArray(data=da.from_array(a, chunks=3), dims=["x"], name="a")
bb = DataArray(data=da.from_array(b, chunks=3), dims=["x"], name="b")
cc = DataArray(data=da.from_array(c, chunks=7), dims=["y"], name="c")
dd = DataArray(data=da.from_array(d, chunks=12), dims=["z"], name="d")
aa = aa.copy(data=da.from_array(a, chunks=3))
bb = bb.copy(data=da.from_array(b, chunks=3))
cc = cc.copy(data=da.from_array(c, chunks=7))
dd = dd.copy(data=da.from_array(d, chunks=12))

# query single dim, single variable
actual = aa.query(x="a > 5", engine=engine, parser=parser)
with raise_if_dask_computes():
actual = aa.query(x="a2 > 5", engine=engine, parser=parser)
expect = aa.isel(x=(a > 5))
assert_identical(expect, actual)

# query single dim, single variable, via dict
actual = aa.query(dict(x="a > 5"), engine=engine, parser=parser)
with raise_if_dask_computes():
actual = aa.query(dict(x="a2 > 5"), engine=engine, parser=parser)
expect = aa.isel(dict(x=(a > 5)))
assert_identical(expect, actual)

# query single dim, single variable
actual = bb.query(x="b > 50", engine=engine, parser=parser)
with raise_if_dask_computes():
actual = bb.query(x="b2 > 50", engine=engine, parser=parser)
expect = bb.isel(x=(b > 50))
assert_identical(expect, actual)

# query single dim, single variable
actual = cc.query(y="c < .5", engine=engine, parser=parser)
with raise_if_dask_computes():
actual = cc.query(y="c2 < .5", engine=engine, parser=parser)
expect = cc.isel(y=(c < 0.5))
assert_identical(expect, actual)

# query single dim, single string variable
if parser == "pandas":
# N.B., this query currently only works with the pandas parser
# xref https://github.com/pandas-dev/pandas/issues/40436
actual = dd.query(z='d == "bar"', engine=engine, parser=parser)
with raise_if_dask_computes():
actual = dd.query(z='d2 == "bar"', engine=engine, parser=parser)
expect = dd.isel(z=(d == "bar"))
assert_identical(expect, actual)

Expand Down
Loading

0 comments on commit 83e159e

Please sign in to comment.