Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fancy indexing a Dataset with dask DataArray triggers multiple computes #4663

Closed
eric-czech opened this issue Dec 8, 2020 · 8 comments · Fixed by #5873
Closed

Fancy indexing a Dataset with dask DataArray triggers multiple computes #4663

eric-czech opened this issue Dec 8, 2020 · 8 comments · Fixed by #5873

Comments

@eric-czech
Copy link

It appears that boolean arrays (or any slicing array presumably) are evaluated many more times than necessary when applied to multiple variables in a Dataset. Is this intentional? Here is an example that demonstrates this:

# Use a custom array type to know when data is being evaluated
class Array():
    
    def __init__(self, x):
        self.shape = (x.shape[0],)
        self.ndim = x.ndim
        self.dtype = 'bool'
        self.x = x
        
    def __getitem__(self, idx):
        if idx[0].stop > 0:
            print('Evaluating')
        return (self.x > .5).__getitem__(idx)

# Control case -- this shows that the print statement is only reached once
da.from_array(Array(np.random.rand(100))).compute();
# Evaluating

# This usage somehow results in two evaluations of this one array?
ds = xr.Dataset(dict(
    a=('x', da.from_array(Array(np.random.rand(100))))
))
ds.sel(x=ds.a)
# Evaluating
# Evaluating
# <xarray.Dataset>
# Dimensions:  (x: 51)
# Dimensions without coordinates: x
# Data variables:
#     a        (x) bool dask.array<chunksize=(51,), meta=np.ndarray>

# The array is evaluated an extra time for each new variable
ds = xr.Dataset(dict(
    a=('x', da.from_array(Array(np.random.rand(100)))),
    b=(('x', 'y'), da.random.random((100, 10))),
    c=(('x', 'y'), da.random.random((100, 10))),
    d=(('x', 'y'), da.random.random((100, 10))),
))
ds.sel(x=ds.a)
# Evaluating
# Evaluating
# Evaluating
# Evaluating
# Evaluating
# <xarray.Dataset>
# Dimensions:  (x: 48, y: 10)
# Dimensions without coordinates: x, y
# Data variables:
#     a        (x) bool dask.array<chunksize=(48,), meta=np.ndarray>
#     b        (x, y) float64 dask.array<chunksize=(48, 10), meta=np.ndarray>
#     c        (x, y) float64 dask.array<chunksize=(48, 10), meta=np.ndarray>
#     d        (x, y) float64 dask.array<chunksize=(48, 10), meta=np.ndarray>

Given that slicing is already not lazy, why does the same predicate array need to be computed more than once?

@tomwhite originally pointed this out in https://github.com/pystatgen/sgkit/issues/299.

@dcherian
Copy link
Contributor

dcherian commented Dec 8, 2020

Thanks for the great example!

This looks like a duplicate of #2801. If you agree, can we move the conversation there?

I like using our raise_if_dask_computes context since it points out where the compute is happening

import dask.array as da
import numpy as np

from xarray.tests import raise_if_dask_computes

# Use a custom array type to know when data is being evaluated
class Array():
    
    def __init__(self, x):
        self.shape = (x.shape[0],)
        self.ndim = x.ndim
        self.dtype = 'bool'
        self.x = x
        
    def __getitem__(self, idx):
        if idx[0].stop > 0:
            print('Evaluating')
        return (self.x > .5).__getitem__(idx)
    

with raise_if_dask_computes(max_computes=1):
    ds = xr.Dataset(dict(
        a=('x', da.from_array(Array(np.random.rand(100)))),
        b=(('x', 'y'), da.random.random((100, 10))),
        c=(('x', 'y'), da.random.random((100, 10))),
        d=(('x', 'y'), da.random.random((100, 10))),
    ))
    ds.sel(x=ds.a)
--------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-76-8efd3a1c3fe5> in <module>
     26         d=(('x', 'y'), da.random.random((100, 10))),
     27     ))
---> 28     ds.sel(x=ds.a)

/project/mrgoodbar/dcherian/python/xarray/xarray/core/dataset.py in sel(self, indexers, method, tolerance, drop, **indexers_kwargs)
   2211             self, indexers=indexers, method=method, tolerance=tolerance
   2212         )
-> 2213         result = self.isel(indexers=pos_indexers, drop=drop)
   2214         return result._overwrite_indexes(new_indexes)
   2215 

/project/mrgoodbar/dcherian/python/xarray/xarray/core/dataset.py in isel(self, indexers, drop, missing_dims, **indexers_kwargs)
   2058         indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "isel")
   2059         if any(is_fancy_indexer(idx) for idx in indexers.values()):
-> 2060             return self._isel_fancy(indexers, drop=drop, missing_dims=missing_dims)
   2061 
   2062         # Much faster algorithm for when all indexers are ints, slices, one-dimensional

/project/mrgoodbar/dcherian/python/xarray/xarray/core/dataset.py in _isel_fancy(self, indexers, drop, missing_dims)
   2122                     indexes[name] = new_index
   2123             elif var_indexers:
-> 2124                 new_var = var.isel(indexers=var_indexers)
   2125             else:
   2126                 new_var = var.copy(deep=False)

/project/mrgoodbar/dcherian/python/xarray/xarray/core/variable.py in isel(self, indexers, missing_dims, **indexers_kwargs)
   1118 
   1119         key = tuple(indexers.get(dim, slice(None)) for dim in self.dims)
-> 1120         return self[key]
   1121 
   1122     def squeeze(self, dim=None):

/project/mrgoodbar/dcherian/python/xarray/xarray/core/variable.py in __getitem__(self, key)
    766         array `x.values` directly.
    767         """
--> 768         dims, indexer, new_order = self._broadcast_indexes(key)
    769         data = as_indexable(self._data)[indexer]
    770         if new_order:

/project/mrgoodbar/dcherian/python/xarray/xarray/core/variable.py in _broadcast_indexes(self, key)
    625                 dims.append(d)
    626         if len(set(dims)) == len(dims):
--> 627             return self._broadcast_indexes_outer(key)
    628 
    629         return self._broadcast_indexes_vectorized(key)

/project/mrgoodbar/dcherian/python/xarray/xarray/core/variable.py in _broadcast_indexes_outer(self, key)
    680                 k = k.data
    681             if not isinstance(k, BASIC_INDEXING_TYPES):
--> 682                 k = np.asarray(k)
    683                 if k.size == 0:
    684                     # Slice by empty list; numpy could not infer the dtype

~/miniconda3/envs/dcpy_old_dask/lib/python3.7/site-packages/numpy/core/_asarray.py in asarray(a, dtype, order)
     83 
     84     """
---> 85     return array(a, dtype, copy=False, order=order)
     86 
     87 

~/miniconda3/envs/dcpy_old_dask/lib/python3.7/site-packages/dask/array/core.py in __array__(self, dtype, **kwargs)
   1374 
   1375     def __array__(self, dtype=None, **kwargs):
-> 1376         x = self.compute()
   1377         if dtype and x.dtype != dtype:
   1378             x = x.astype(dtype)

~/miniconda3/envs/dcpy_old_dask/lib/python3.7/site-packages/dask/base.py in compute(self, **kwargs)
    165         dask.base.compute
    166         """
--> 167         (result,) = compute(self, traverse=False, **kwargs)
    168         return result
    169 

~/miniconda3/envs/dcpy_old_dask/lib/python3.7/site-packages/dask/base.py in compute(*args, **kwargs)
    450         postcomputes.append(x.__dask_postcompute__())
    451 
--> 452     results = schedule(dsk, keys, **kwargs)
    453     return repack([f(r, *a) for r, (f, a) in zip(results, postcomputes)])
    454 

/project/mrgoodbar/dcherian/python/xarray/xarray/tests/__init__.py in __call__(self, dsk, keys, **kwargs)
    112             raise RuntimeError(
    113                 "Too many computes. Total: %d > max: %d."
--> 114                 % (self.total_computes, self.max_computes)
    115             )
    116         return dask.get(dsk, keys, **kwargs)

RuntimeError: Too many computes. Total: 2 > max: 1.

So here it looks like we don't support indexing by dask arrays, so as we loop through the dataset the indexer array gets computed each time. As you say, this should be avoided.

@dcherian
Copy link
Contributor

dcherian commented Dec 8, 2020

I commented too soon.

ds.sel(x=ds.a.data)

only computes once (!) so there's something else going on possibly

@dcherian
Copy link
Contributor

dcherian commented Dec 8, 2020

I think the solution is to handle this case (dask-backed DataArray) in _isel_fancy in dataset.py.

@dcherian dcherian changed the title Filtering data triggers multiple evaluations of slice Fancy indexing a Dataset with dask DataArray triggers multiple computes Dec 8, 2020
@eric-czech
Copy link
Author

eric-czech commented Dec 8, 2020

I like using our raise_if_dask_computes context since it points out where the compute is happening

Oo nice, great to know about that.

This looks like a duplicate of #2801. If you agree, can we move the conversation there?

Defining a general strategy for handling unknown chunk sizes seems like a good umbrella for it. I would certainly mention the multiple executions though, that seems somewhat orthogonal.

Have there been prior discussions about the fact that dask doesn't support consecutive slicing operations well (i.e. applying filters one after the other)? I am wondering what the thinking is on how far off that is in dask vs simply trying to support the current behavior well. I.e. maybe forcing evaluation of indexer arrays is the practical solution for the foreseeable future if xarray didn't do so more than once.

@dcherian
Copy link
Contributor

dcherian commented Mar 18, 2021

From @alimanfoo in #5054


I have a dataset comprising several variables. All variables are dask arrays (e.g., backed by zarr). I would like to use one of these variables, which is a 1d boolean array, to index the other variables along a large single dimension. The boolean indexing array is about ~40 million items long, with ~20 million true values.

If I do this all via dask (i.e., not using xarray) then I can index one dask array with another dask array via fancy indexing. The indexing array is not loaded into memory or computed. If I need to know the shape and chunks of the resulting arrays I can call compute_chunk_sizes(), but still very little memory is required.

If I do this via xarray.Dataset.isel() then a substantial amount of memory (several GB) is allocated during isel() and retained. This is problematic as in a real-world use case there are many arrays to be indexed and memory runs out on standard systems.

There is a follow-on issue which is if I then want to run a computation over one of the indexed arrays, if the indexing was done via xarray then that leads to a further blow-up of multiple GB of memory usage, if using dask distributed cluster.

I think the underlying issue here is that the indexing array is loaded into memory, and then gets copied multiple times when the dask graph is constructed. If using a distributed scheduler, further copies get made during scheduling of any subsequent computation.

I made a notebook which illustrates the increased memory usage during Dataset.isel() here:
colab.research.google.com/drive/1bn7Sj0An7TehwltWizU8j_l2OvPeoJyo?usp=sharing

This is possibly the same underlying issue (and use case) as raised by @eric-czech in #4663, so feel free to close this if you think it's a duplicate.

@dcherian
Copy link
Contributor

I would start by trying to fix

import dask.array as da
import numpy as np
from xarray.tests import raise_if_dask_computes

with raise_if_dask_computes(max_computes=0):
    ds = xr.Dataset(
        dict(
            a=("x", da.from_array(np.random.randint(0, 100, 100))),
            b=(("x", "y"), da.random.random((100, 10))),
        )
    )
    ds.b.sel(x=ds.a.data)

specifically this np.asarray call

~/work/python/xarray/xarray/core/dataset.py in _validate_indexers(self, indexers, missing_dims)
   1960                 yield k, np.empty((0,), dtype="int64")
   1961             else:
-> 1962                 v = np.asarray(v)
   1963 
   1964                 if v.dtype.kind in "US":

Then the next issue is the multiple computes that happen when we pass a DataArray instead of a dask.array (ds.b.sel(x=ds.a))
In general, I guess we can only support this for unindexed dimensions since we wouldn't know what index labels to use without computing the indexer array. This seems to be the case in that colab notebook.

@alimanfoo
Copy link
Contributor

Thanks @dcherian.

Just to add that if we make progress with supporting indexing with dask arrays then at some point I think we'll hit a separate issue, which is that xarray will require that the chunk sizes of the indexed arrays are computed, but currently calling the dask array method compute_chunk_sizes() is inefficient for n-d arrays. Raised here: dask/dask#7416

In case anyone needs a workaround for indexing a dataset with a 1d boolean dask array, I'm currently using this hacked implementation of a compress() style function that operates on an xarray dataset, which includes more efficient computation of chunk sizes.

@dcherian
Copy link
Contributor

currently calling the dask array method compute_chunk_sizes() is inefficient for n-d arrays. Raised here: dask/dask#7416

ouch. thanks for raising that issue.

I'm currently using this hacked implementation of a compress() style function that operates on an xarray dataset, which includes more efficient computation of chunk sizes.

I think we'd be open to adding a .compress for Dataset, DataArray, & Variable without the chunks hack :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants