From 0ab63319099efbbd8ce4c25ccacfa28dd4d6858c Mon Sep 17 00:00:00 2001 From: dcherian Date: Fri, 24 Jun 2022 15:13:12 -0600 Subject: [PATCH] Fixes #4276 Pass 0d dask arrays through for indexing. --- xarray/core/indexing.py | 2 +- xarray/core/variable.py | 10 ++++++---- xarray/tests/test_indexing.py | 3 +-- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index af07a7ad587..e0cc79593bd 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -366,7 +366,7 @@ def __init__(self, key): raise TypeError( f"invalid indexer array, does not have integer dtype: {k!r}" ) - if k.ndim != 1: + if k.ndim > 1: raise TypeError( f"invalid indexer array for {type(self).__name__}; must have " f"exactly 1 dimension: {k!r}" diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 2516aa40c5a..1ce3c25c949 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -601,11 +601,12 @@ def _broadcast_indexes(self, key): key = self._item_key_to_tuple(key) # key is a tuple # key is a tuple of full size key = indexing.expanded_indexer(key, self.ndim) - # Convert a scalar Variable to an integer + # Convert a scalar Variable to a 0d-array key = tuple( - k.data.item() if isinstance(k, Variable) and k.ndim == 0 else k for k in key + k.data if isinstance(k, Variable) and k.ndim == 0 else k for k in key ) - # Convert a 0d-array to an integer + # Convert a 0d numpy arrays to an integer + # dask 0d arrays are passed through key = tuple( k.item() if isinstance(k, np.ndarray) and k.ndim == 0 else k for k in key ) @@ -681,10 +682,11 @@ def _validate_indexers(self, key): ) def _broadcast_indexes_outer(self, key): + # drop dim if k is integer or if k is a 0d dask array dims = tuple( k.dims[0] if isinstance(k, Variable) else dim for k, dim in zip(key, self.dims) - if not isinstance(k, integer_types) + if (not isinstance(k, integer_types) and k.ndim > 0) ) new_key = [] diff --git a/xarray/tests/test_indexing.py b/xarray/tests/test_indexing.py index 6cccdb1fdf8..47b7a3e45f2 100644 --- a/xarray/tests/test_indexing.py +++ b/xarray/tests/test_indexing.py @@ -846,7 +846,6 @@ def test_indexing_dask_array(): assert_identical(actual, expected) -@pytest.mark.xfail @requires_dask def test_indexing_dask_array_scalar(): # GH4276 @@ -856,7 +855,7 @@ def test_indexing_dask_array_scalar(): da = DataArray(a, dims="x") x_selector = da.argmax(dim=...) actual = da.isel(x_selector) - expected = da.isel(x=1) + expected = da.isel(x=-1) assert_identical(actual, expected)