Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow indexing unindexed dimensions using dask arrays #5873

Merged
merged 27 commits into from
Mar 15, 2023
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
bc4271c
Attempt to fix indexing for Dask
bzah Oct 15, 2021
73696e9
Works now.
dcherian Dec 14, 2021
fad4348
avoid importorskip
dcherian Dec 14, 2021
7dadbf2
More tests and fixes
dcherian Dec 14, 2021
b7c382b
Merge branch 'main' into fix/dask_indexing
bzah Feb 15, 2022
46a4b16
Merge branch 'main' into fix/dask_indexing
dcherian Mar 18, 2022
ec4d6ee
Raise nicer error when indexing with boolean dask array
dcherian Mar 18, 2022
944dbac
Annotate tests
dcherian Mar 18, 2022
fb5b01e
Merge branch 'main' into fix/dask_indexing
bzah Mar 24, 2022
335b5da
edit query tests
dcherian Apr 12, 2022
a11be00
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 12, 2022
d5e7646
Merge branch 'main' into fix/dask_indexing
dcherian Jun 24, 2022
9cde88d
Fixes #4276
dcherian Jun 24, 2022
8df0c2a
Add xfail notes.
dcherian Jun 24, 2022
9f5e31b
backcompat: vendor np.broadcast_shapes
dcherian Jun 24, 2022
3306329
Small improvement
dcherian Jun 24, 2022
32b73c3
fix: Handle scalars properly.
dcherian Jun 24, 2022
d6170ce
fix bad test
dcherian Jun 25, 2022
aa1df48
Check computes with setitem
dcherian Jun 25, 2022
c93b297
Merge branch 'main' into fix/dask_indexing
dcherian Jun 25, 2022
97fa188
Merge remote-tracking branch 'upstream/main' into fix/dask_indexing
dcherian Feb 28, 2023
3f008c8
Merge branch 'main' into fix/dask_indexing
dcherian Mar 3, 2023
ff42585
Better error
dcherian Feb 28, 2023
d15c7fe
Cleanup
dcherian Mar 6, 2023
220edc8
Raise nice error with VectorizedIndexer and dask.
dcherian Mar 6, 2023
8445120
Add whats-new
dcherian Mar 6, 2023
75a6299
Merge branch 'main' into fix/dask_indexing
dcherian Mar 9, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@
from .missing import get_clean_interp_index
from .npcompat import QUANTILE_METHODS, ArrayLike
from .options import OPTIONS, _get_keep_attrs
from .pycompat import is_duck_dask_array, sparse_array_type
from .pycompat import is_duck_array, is_duck_dask_array, sparse_array_type
from .types import T_Dataset
from .utils import (
Default,
Expand Down Expand Up @@ -2257,7 +2257,8 @@ def _validate_indexers(
elif isinstance(v, Sequence) and len(v) == 0:
yield k, np.empty((0,), dtype="int64")
else:
v = np.asarray(v)
if not is_duck_array(v):
v = np.asarray(v)

if v.dtype.kind in "US":
index = self._indexes[k].to_pandas_index()
Expand Down
17 changes: 9 additions & 8 deletions xarray/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from .options import OPTIONS
from .pycompat import dask_version, integer_types, is_duck_dask_array, sparse_array_type
from .types import T_Xarray
from .utils import either_dict_or_kwargs, get_valid_numpy_dtype
from .utils import either_dict_or_kwargs, get_valid_numpy_dtype, is_duck_array

if TYPE_CHECKING:
from .indexes import Index
Expand Down Expand Up @@ -361,17 +361,17 @@ def __init__(self, key):
k = int(k)
elif isinstance(k, slice):
k = as_integer_slice(k)
elif isinstance(k, np.ndarray):
elif is_duck_array(k):
if not np.issubdtype(k.dtype, np.integer):
raise TypeError(
f"invalid indexer array, does not have integer dtype: {k!r}"
)
if k.ndim != 1:
if k.ndim > 1:
Copy link
Contributor

@dcherian dcherian Jun 24, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

k might be a 0d dask array (see below and #4276)

raise TypeError(
f"invalid indexer array for {type(self).__name__}; must have "
f"exactly 1 dimension: {k!r}"
)
k = np.asarray(k, dtype=np.int64)
k = k.astype(np.int64)
else:
raise TypeError(
f"unexpected indexer type for {type(self).__name__}: {k!r}"
Expand Down Expand Up @@ -402,7 +402,7 @@ def __init__(self, key):
for k in key:
if isinstance(k, slice):
k = as_integer_slice(k)
elif isinstance(k, np.ndarray):
elif is_duck_array(k):
if not np.issubdtype(k.dtype, np.integer):
raise TypeError(
f"invalid indexer array, does not have integer dtype: {k!r}"
Expand All @@ -415,7 +415,7 @@ def __init__(self, key):
"invalid indexer key: ndarray arguments "
f"have different numbers of dimensions: {ndims}"
)
k = np.asarray(k, dtype=np.int64)
k = k.astype(np.int64)
else:
raise TypeError(
f"unexpected indexer type for {type(self).__name__}: {k!r}"
Expand Down Expand Up @@ -1301,8 +1301,9 @@ def __getitem__(self, key):
rewritten_indexer = False
new_indexer = []
for idim, k in enumerate(key.tuple):
if isinstance(k, Iterable) and duck_array_ops.array_equiv(
k, np.arange(self.array.shape[idim])
if isinstance(k, Iterable) and (
not is_duck_dask_array(k)
and duck_array_ops.array_equiv(k, np.arange(self.array.shape[idim]))
):
new_indexer.append(slice(None))
rewritten_indexer = True
Expand Down
63 changes: 63 additions & 0 deletions xarray/core/npcompat.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,3 +245,66 @@ def sliding_window_view(
"midpoint",
"nearest",
]


if Version(np.__version__) < Version("1.20"):

def _broadcast_shape(*args):
"""Returns the shape of the arrays that would result from broadcasting the
supplied arrays against each other.
"""
# use the old-iterator because np.nditer does not handle size 0 arrays
# consistently
b = np.broadcast(*args[:32])
# unfortunately, it cannot handle 32 or more arguments directly
for pos in range(32, len(args), 31):
# ironically, np.broadcast does not properly handle np.broadcast
# objects (it treats them as scalars)
# use broadcasting to avoid allocating the full array
b = np.broadcast_to(0, b.shape)
b = np.broadcast(b, *args[pos : (pos + 31)])
return b.shape

def broadcast_shapes(*args):
"""
Broadcast the input shapes into a single shape.

:ref:`Learn more about broadcasting here <basics.broadcasting>`.

.. versionadded:: 1.20.0

Parameters
----------
`*args` : tuples of ints, or ints
The shapes to be broadcast against each other.

Returns
-------
tuple
Broadcasted shape.

Raises
------
ValueError
If the shapes are not compatible and cannot be broadcast according
to NumPy's broadcasting rules.

See Also
--------
broadcast
broadcast_arrays
broadcast_to

Examples
--------
>>> np.broadcast_shapes((1, 2), (3, 1), (3, 2))
(3, 2)

>>> np.broadcast_shapes((6, 7), (5, 6, 1), (7,), (5, 1, 7))
(5, 6, 7)
"""
arrays = [np.empty(x, dtype=[]) for x in args]
return _broadcast_shape(*arrays)

else:
from numpy import broadcast_shapes # noqa
8 changes: 7 additions & 1 deletion xarray/core/nputils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
import pandas as pd
from numpy.core.multiarray import normalize_axis_index # type: ignore[attr-defined]

from .npcompat import broadcast_shapes
from .options import OPTIONS
from .pycompat import is_duck_array

try:
import bottleneck as bn
Expand Down Expand Up @@ -109,7 +111,11 @@ def _advanced_indexer_subspaces(key):
return (), ()

non_slices = [k for k in key if not isinstance(k, slice)]
ndim = len(np.broadcast(*non_slices).shape)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is just an optimization. Happy to add it in a different PR.

ndim = len(
broadcast_shapes(
*[item.shape if is_duck_array(item) else (0,) for item in non_slices]
)
)
mixed_positions = advanced_index_positions[0] + np.arange(ndim)
vindex_positions = np.arange(ndim)
return mixed_positions, vindex_positions
Expand Down
6 changes: 5 additions & 1 deletion xarray/core/pycompat.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
from packaging.version import Version

from .utils import is_duck_array
from .utils import is_duck_array, is_scalar

integer_types = (int, np.integer)

Expand Down Expand Up @@ -67,3 +67,7 @@ def is_dask_collection(x):

def is_duck_dask_array(x):
return is_duck_array(x) and is_dask_collection(x)


def is_0d_dask_array(x):
return is_duck_dask_array(x) and is_scalar(x)
24 changes: 18 additions & 6 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
cupy_array_type,
dask_array_type,
integer_types,
is_0d_dask_array,
is_duck_dask_array,
sparse_array_type,
)
Expand Down Expand Up @@ -601,11 +602,12 @@ def _broadcast_indexes(self, key):
key = self._item_key_to_tuple(key) # key is a tuple
# key is a tuple of full size
key = indexing.expanded_indexer(key, self.ndim)
# Convert a scalar Variable to an integer
# Convert a scalar Variable to a 0d-array
key = tuple(
k.data.item() if isinstance(k, Variable) and k.ndim == 0 else k for k in key
Copy link
Contributor

@dcherian dcherian Jun 24, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.item oes not work when k is a dask 0d array, so pass that along.

This is the major change here. Everything else is avoid calling np.asarray unless we really have to.

k.data if isinstance(k, Variable) and k.ndim == 0 else k for k in key
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Necessary for #4276

)
# Convert a 0d-array to an integer
# Convert a 0d numpy arrays to an integer
# dask 0d arrays are passed through
key = tuple(
k.item() if isinstance(k, np.ndarray) and k.ndim == 0 else k for k in key
)
Expand Down Expand Up @@ -646,7 +648,8 @@ def _validate_indexers(self, key):
for dim, k in zip(self.dims, key):
if not isinstance(k, BASIC_INDEXING_TYPES):
if not isinstance(k, Variable):
k = np.asarray(k)
if not is_duck_array(k):
k = np.asarray(k)
if k.ndim > 1:
raise IndexError(
"Unlabeled multi-dimensional array cannot be "
Expand All @@ -663,6 +666,13 @@ def _validate_indexers(self, key):
"{}-dimensional boolean indexing is "
"not supported. ".format(k.ndim)
)
if is_duck_dask_array(k.data):
raise KeyError(
"Indexing with a boolean dask array is not allowed. "
"This will result in a dask array of unknown shape. "
"Such arrays are unsupported by Xarray."
"Please compute the indexer first using .compute()"
)
if getattr(k, "dims", (dim,)) != (dim,):
raise IndexError(
"Boolean indexer should be unlabeled or on the "
Expand All @@ -673,18 +683,20 @@ def _validate_indexers(self, key):
)

def _broadcast_indexes_outer(self, key):
# drop dim if k is integer or if k is a 0d dask array
dims = tuple(
k.dims[0] if isinstance(k, Variable) else dim
for k, dim in zip(key, self.dims)
if not isinstance(k, integer_types)
if (not isinstance(k, integer_types) and not is_0d_dask_array(k))
)

new_key = []
for k in key:
if isinstance(k, Variable):
k = k.data
if not isinstance(k, BASIC_INDEXING_TYPES):
k = np.asarray(k)
if not is_duck_array(k):
k = np.asarray(k)
if k.size == 0:
# Slice by empty list; numpy could not infer the dtype
k = k.astype(int)
Expand Down
4 changes: 2 additions & 2 deletions xarray/tests/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,8 @@ def test_indexing(self):
(da.array([99, 99, 3, 99]), [0, -1, 1]),
(da.array([99, 99, 99, 4]), np.arange(3)),
(da.array([1, 99, 99, 99]), [False, True, True, True]),
(da.array([1, 99, 99, 99]), np.arange(4) > 0),
(da.array([99, 99, 99, 99]), Variable(("x"), da.array([1, 2, 3, 4])) > 0),
(da.array([1, 99, 99, 99]), np.array([False, True, True, True])),
(da.array([99, 99, 99, 99]), Variable(("x"), np.array([True] * 4))),
Copy link
Contributor

@dcherian dcherian Jun 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The last test was silently computing earlier.

],
)
def test_setitem_dask_array(self, expected_data, index):
Expand Down
34 changes: 19 additions & 15 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -4063,45 +4063,49 @@ def test_query(
d = np.random.choice(["foo", "bar", "baz"], size=30, replace=True).astype(
object
)
if backend == "numpy":
aa = DataArray(data=a, dims=["x"], name="a")
bb = DataArray(data=b, dims=["x"], name="b")
cc = DataArray(data=c, dims=["y"], name="c")
dd = DataArray(data=d, dims=["z"], name="d")
aa = DataArray(data=a, dims=["x"], name="a", coords={"a2": ("x", a)})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix the tests to actually be lazy.

bb = DataArray(data=b, dims=["x"], name="b", coords={"b2": ("x", b)})
cc = DataArray(data=c, dims=["y"], name="c", coords={"c2": ("y", c)})
dd = DataArray(data=d, dims=["z"], name="d", coords={"d2": ("z", d)})

elif backend == "dask":
if backend == "dask":
import dask.array as da

aa = DataArray(data=da.from_array(a, chunks=3), dims=["x"], name="a")
bb = DataArray(data=da.from_array(b, chunks=3), dims=["x"], name="b")
cc = DataArray(data=da.from_array(c, chunks=7), dims=["y"], name="c")
dd = DataArray(data=da.from_array(d, chunks=12), dims=["z"], name="d")
aa = aa.copy(data=da.from_array(a, chunks=3))
bb = bb.copy(data=da.from_array(b, chunks=3))
cc = cc.copy(data=da.from_array(c, chunks=7))
dd = dd.copy(data=da.from_array(d, chunks=12))

# query single dim, single variable
actual = aa.query(x="a > 5", engine=engine, parser=parser)
with raise_if_dask_computes():
actual = aa.query(x="a2 > 5", engine=engine, parser=parser)
expect = aa.isel(x=(a > 5))
assert_identical(expect, actual)

# query single dim, single variable, via dict
actual = aa.query(dict(x="a > 5"), engine=engine, parser=parser)
with raise_if_dask_computes():
actual = aa.query(dict(x="a2 > 5"), engine=engine, parser=parser)
expect = aa.isel(dict(x=(a > 5)))
assert_identical(expect, actual)

# query single dim, single variable
actual = bb.query(x="b > 50", engine=engine, parser=parser)
with raise_if_dask_computes():
actual = bb.query(x="b2 > 50", engine=engine, parser=parser)
expect = bb.isel(x=(b > 50))
assert_identical(expect, actual)

# query single dim, single variable
actual = cc.query(y="c < .5", engine=engine, parser=parser)
with raise_if_dask_computes():
actual = cc.query(y="c2 < .5", engine=engine, parser=parser)
expect = cc.isel(y=(c < 0.5))
assert_identical(expect, actual)

# query single dim, single string variable
if parser == "pandas":
# N.B., this query currently only works with the pandas parser
# xref https://github.com/pandas-dev/pandas/issues/40436
actual = dd.query(z='d == "bar"', engine=engine, parser=parser)
with raise_if_dask_computes():
actual = dd.query(z='d2 == "bar"', engine=engine, parser=parser)
expect = dd.isel(z=(d == "bar"))
assert_identical(expect, actual)

Expand Down
Loading