Skip to content

Commit

Permalink
finish
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Dec 6, 2024
1 parent 9cf0157 commit 01d9951
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 11 deletions.
18 changes: 12 additions & 6 deletions xarray/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,12 +473,6 @@ def __init__(self, key: tuple[slice | np.ndarray[Any, np.dtype[np.generic]], ...
for k in key:
if isinstance(k, slice):
k = as_integer_slice(k)
# elif is_duck_dask_array(k):
# raise ValueError(
# "Vectorized indexing with Dask arrays is not supported. "
# "Please pass a numpy array by calling ``.compute``. "
# "See https://github.com/dask/dask/issues/8958."
# )
elif is_duck_array(k):
if not np.issubdtype(k.dtype, np.integer):
raise TypeError(
Expand Down Expand Up @@ -1509,6 +1503,7 @@ def _oindex_get(self, indexer: OuterIndexer):
return self.array[key]

def _vindex_get(self, indexer: VectorizedIndexer):
_assert_not_chunked_indexer(indexer.tuple)
array = NumpyVIndexAdapter(self.array)
return array[indexer.tuple]

Expand Down Expand Up @@ -1620,6 +1615,16 @@ def _apply_vectorized_indexer_dask_wrapper(indices, coord):
)


def _assert_not_chunked_indexer(idxr: tuple[Any, ...]) -> None:
if any(is_chunked_array(i) for i in idxr):
raise ValueError(
"Cannot index with a chunked array indexer. "
"Please chunk the array you are indexing first, "
"and drop any indexed dimension coordinate variables. "
"Alternatively, call `.compute()` on any chunked arrays in the indexer."
)


class DaskIndexingAdapter(ExplicitlyIndexedNDArrayMixin):
"""Wrap a dask array to support explicit indexing."""

Expand Down Expand Up @@ -1811,6 +1816,7 @@ def _vindex_get(
| np.datetime64
| np.timedelta64
):
_assert_not_chunked_indexer(indexer.tuple)
key = self._prepare_key(indexer.tuple)

if getattr(key, "ndim", 0) > 1: # Return np-array if multidimensional
Expand Down
28 changes: 23 additions & 5 deletions xarray/tests/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -974,7 +974,7 @@ def test_indexing_1d_object_array() -> None:


@requires_dask
def test_indexing_dask_array():
def test_indexing_dask_array() -> None:
import dask.array

da = DataArray(
Expand All @@ -988,7 +988,7 @@ def test_indexing_dask_array():


@requires_dask
def test_indexing_dask_array_scalar():
def test_indexing_dask_array_scalar() -> None:
# GH4276
import dask.array

Expand All @@ -1002,19 +1002,37 @@ def test_indexing_dask_array_scalar():


@requires_dask
def test_vectorized_indexing_dask_array():
def test_vectorized_indexing_dask_array() -> None:
# https://github.com/pydata/xarray/issues/2511#issuecomment-563330352
darr = DataArray(data=[0.2, 0.4, 0.6], coords={"z": range(3)}, dims=("z",))
indexer = DataArray(
data=np.random.randint(0, 3, 8).reshape(4, 2).astype(int),
coords={"y": range(4), "x": range(2)},
dims=("y", "x"),
)
darr[indexer.chunk({"y": 2})]
expected = darr[indexer]

# fails because we can't index pd.Index lazily (yet)
with pytest.raises(ValueError, match="Cannot index with"):
with raise_if_dask_computes():
darr.chunk()[indexer.chunk({"y": 2})]

# fails because we can't index pd.Index lazily (yet)
with pytest.raises(ValueError, match="Cannot index with"):
with raise_if_dask_computes():
actual = darr[indexer.chunk({"y": 2})]

with raise_if_dask_computes():
actual = darr.drop_vars("z").chunk()[indexer.chunk({"y": 2})]
assert_identical(actual, expected.drop_vars("z"))

with raise_if_dask_computes():
actual = darr.variable.chunk()[indexer.variable.chunk({"y": 2})]
assert_identical(actual, expected.variable)


@requires_dask
def test_advanced_indexing_dask_array():
def test_advanced_indexing_dask_array() -> None:
# GH4663
import dask.array as da

Expand Down

0 comments on commit 01d9951

Please sign in to comment.