Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

as_shared_dtype converts scalars to 0d numpy arrays if chunked cupy is involved #7721

Open
keewis opened this issue Apr 5, 2023 · 9 comments · May be fixed by #9212
Open

as_shared_dtype converts scalars to 0d numpy arrays if chunked cupy is involved #7721

keewis opened this issue Apr 5, 2023 · 9 comments · May be fixed by #9212
Labels
topic-arrays related to flexible array support

Comments

@keewis
Copy link
Collaborator

keewis commented Apr 5, 2023

I tried to run where with chunked cupy arrays:

In [1]: import xarray as xr
   ...: import cupy
   ...: import dask.array as da
   ...: 
   ...: arr = xr.DataArray(cupy.arange(4), dims="x")
   ...: mask = xr.DataArray(cupy.array([False, True, True, False]), dims="x")

this works:

In [2]: arr.where(mask)
Out[2]: 
<xarray.DataArray (x: 4)>
array([nan,  1.,  2., nan])
Dimensions without coordinates: x

this fails:

In [4]: arr.chunk().where(mask).compute()
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[4], line 1
----> 1 arr.chunk().where(mask).compute()

File ~/repos/xarray/xarray/core/dataarray.py:1095, in DataArray.compute(self, **kwargs)
   1076 """Manually trigger loading of this array's data from disk or a
   1077 remote source into memory and return a new array. The original is
   1078 left unaltered.
   (...)
   1092 dask.compute
   1093 """
   1094 new = self.copy(deep=False)
-> 1095 return new.load(**kwargs)

File ~/repos/xarray/xarray/core/dataarray.py:1069, in DataArray.load(self, **kwargs)
   1051 def load(self: T_DataArray, **kwargs) -> T_DataArray:
   1052     """Manually trigger loading of this array's data from disk or a
   1053     remote source into memory and return this array.
   1054 
   (...)
   1067     dask.compute
   1068     """
-> 1069     ds = self._to_temp_dataset().load(**kwargs)
   1070     new = self._from_temp_dataset(ds)
   1071     self._variable = new._variable

File ~/repos/xarray/xarray/core/dataset.py:752, in Dataset.load(self, **kwargs)
    749 import dask.array as da
    751 # evaluate all the dask arrays simultaneously
--> 752 evaluated_data = da.compute(*lazy_data.values(), **kwargs)
    754 for k, data in zip(lazy_data, evaluated_data):
    755     self.variables[k].data = data

File ~/.local/opt/mambaforge/envs/xarray/lib/python3.10/site-packages/dask/base.py:600, in compute(traverse, optimize_graph, scheduler, get, *args, **kwargs)
    597     keys.append(x.__dask_keys__())
    598     postcomputes.append(x.__dask_postcompute__())
--> 600 results = schedule(dsk, keys, **kwargs)
    601 return repack([f(r, *a) for r, (f, a) in zip(results, postcomputes)])

File ~/.local/opt/mambaforge/envs/xarray/lib/python3.10/site-packages/dask/threaded.py:89, in get(dsk, keys, cache, num_workers, pool, **kwargs)
     86     elif isinstance(pool, multiprocessing.pool.Pool):
     87         pool = MultiprocessingPoolExecutor(pool)
---> 89 results = get_async(
     90     pool.submit,
     91     pool._max_workers,
     92     dsk,
     93     keys,
     94     cache=cache,
     95     get_id=_thread_get_id,
     96     pack_exception=pack_exception,
     97     **kwargs,
     98 )
    100 # Cleanup pools associated to dead threads
    101 with pools_lock:

File ~/.local/opt/mambaforge/envs/xarray/lib/python3.10/site-packages/dask/local.py:511, in get_async(submit, num_workers, dsk, result, cache, get_id, rerun_exceptions_locally, pack_exception, raise_exception, callbacks, dumps, loads, chunksize, **kwargs)
    509         _execute_task(task, data)  # Re-execute locally
    510     else:
--> 511         raise_exception(exc, tb)
    512 res, worker_id = loads(res_info)
    513 state["cache"][key] = res

File ~/.local/opt/mambaforge/envs/xarray/lib/python3.10/site-packages/dask/local.py:319, in reraise(exc, tb)
    317 if exc.__traceback__ is not tb:
    318     raise exc.with_traceback(tb)
--> 319 raise exc

File ~/.local/opt/mambaforge/envs/xarray/lib/python3.10/site-packages/dask/local.py:224, in execute_task(key, task_info, dumps, loads, get_id, pack_exception)
    222 try:
    223     task, data = loads(task_info)
--> 224     result = _execute_task(task, data)
    225     id = get_id()
    226     result = dumps((result, id))

File ~/.local/opt/mambaforge/envs/xarray/lib/python3.10/site-packages/dask/core.py:119, in _execute_task(arg, cache, dsk)
    115     func, args = arg[0], arg[1:]
    116     # Note: Don't assign the subtask results to a variable. numpy detects
    117     # temporaries by their reference count and can execute certain
    118     # operations in-place.
--> 119     return func(*(_execute_task(a, cache) for a in args))
    120 elif not ishashable(arg):
    121     return arg

File ~/.local/opt/mambaforge/envs/xarray/lib/python3.10/site-packages/dask/optimization.py:990, in SubgraphCallable.__call__(self, *args)
    988 if not len(args) == len(self.inkeys):
    989     raise ValueError("Expected %d args, got %d" % (len(self.inkeys), len(args)))
--> 990 return core.get(self.dsk, self.outkey, dict(zip(self.inkeys, args)))

File ~/.local/opt/mambaforge/envs/xarray/lib/python3.10/site-packages/dask/core.py:149, in get(dsk, out, cache)
    147 for key in toposort(dsk):
    148     task = dsk[key]
--> 149     result = _execute_task(task, cache)
    150     cache[key] = result
    151 result = _execute_task(out, cache)

File ~/.local/opt/mambaforge/envs/xarray/lib/python3.10/site-packages/dask/core.py:119, in _execute_task(arg, cache, dsk)
    115     func, args = arg[0], arg[1:]
    116     # Note: Don't assign the subtask results to a variable. numpy detects
    117     # temporaries by their reference count and can execute certain
    118     # operations in-place.
--> 119     return func(*(_execute_task(a, cache) for a in args))
    120 elif not ishashable(arg):
    121     return arg

File <__array_function__ internals>:180, in where(*args, **kwargs)

File cupy/_core/core.pyx:1723, in cupy._core.core._ndarray_base.__array_function__()

File ~/.local/opt/mambaforge/envs/xarray/lib/python3.10/site-packages/cupy/_sorting/search.py:211, in where(condition, x, y)
    209 if fusion._is_fusing():
    210     return fusion._call_ufunc(_where_ufunc, condition, x, y)
--> 211 return _where_ufunc(condition.astype('?'), x, y)

File cupy/_core/_kernel.pyx:1287, in cupy._core._kernel.ufunc.__call__()

File cupy/_core/_kernel.pyx:160, in cupy._core._kernel._preprocess_args()

File cupy/_core/_kernel.pyx:146, in cupy._core._kernel._preprocess_arg()

TypeError: Unsupported type <class 'numpy.ndarray'>

this works again:

In [7]: arr.chunk().where(mask.chunk(), cupy.array(cupy.nan)).compute()
Out[7]: 
<xarray.DataArray (x: 4)>
array([nan,  1.,  2., nan])
Dimensions without coordinates: x

And other methods like fillna show similar behavior.

I think the reason is that this:

if any(isinstance(x, array_type("cupy")) for x in scalars_or_arrays):
is not sufficient to detect cupy beneath other layers of duckarrays (most commonly dask, pint, or both). In this specific case we could extend the condition to also match chunked cupy arrays (like arr.cupy.is_cupy does, but using is_duck_dask_array), but this will still break for other duckarray layers or if dask is not involved, and we're also in the process of moving away from special-casing dask. So short of asking cupy to treat 0d arrays like scalars I'm not sure how to fix this.

cc @jacobtomlinson

@keewis keewis added needs triage Issue that has not been reviewed by xarray team member topic-arrays related to flexible array support and removed needs triage Issue that has not been reviewed by xarray team member labels Apr 5, 2023
@keewis keewis changed the title where on a chunked cupy array raises a TypeError as_shared_dtype converts scalars to 0d numpy arrays if chunked cupy is involved Apr 5, 2023
@jacobtomlinson
Copy link
Contributor

Ping @leofang in case you have thoughts?

@leofang
Copy link

leofang commented Apr 16, 2023

Sorry that I missed the ping, Jacob, but I'd need more context for making any suggestions/answers 😅 Is the question about why CuPy wouldn't return scalars?

@keewis
Copy link
Collaborator Author

keewis commented Apr 16, 2023

The issue is that here:

def as_shared_dtype(scalars_or_arrays, xp=np):
"""Cast a arrays to a shared dtype using xarray's type promotion rules."""
if any(isinstance(x, array_type("cupy")) for x in scalars_or_arrays):
import cupy as cp
arrays = [asarray(x, xp=cp) for x in scalars_or_arrays]
else:
arrays = [asarray(x, xp=xp) for x in scalars_or_arrays]
# Pass arrays directly instead of dtypes to result_type so scalars
# get handled properly.
# Note that result_type() safely gets the dtype from dask arrays without
# evaluating them.
out_type = dtypes.result_type(*arrays)
return [astype(x, out_type, copy=False) for x in arrays]
we try to convert everything to the same dtype, casting numpy and python scalars to an array. The latter is important, because e.g. numpy.array_api.where only accepts arrays as input.

However, detecting cupy beneath (multiple) layers of duckarrays is not easy, which means that for example passing a pint(dask(cupy)) array together with scalars will currently cast the scalars to 0-d numpy arrays, while passing a cupy array instead will result in 0-d cupy arrays.

My naive suggestion was to treat np.int64(0) and np.array(0, dtype="int64") the same, where at the moment the latter would fail for the same reason as np.array([0], dtype="int64").

@leofang
Copy link

leofang commented Apr 20, 2023

Thanks, Justus, for expanding on this. It sounds to me the question is "how do we cast dtypes when multiple array libraries are participating in the same computation?" and I am not sure I am knowledgable enough to make any comment.

From the array API point of view, long long ago we decided that this is UB (undefined behavior), meaning it's completely up to each library to decide what to do. You can raise or come up with a special rule that you can make sense of.

It sounds like Xarray has some machinery to deal with this situation, but you'd rather prefer to not keep special-casing for a certain array library? Am I understanding it right?

@keewis
Copy link
Collaborator Author

keewis commented Apr 20, 2023

there's two things that happen in as_shared_dtype (which may not be good design, and we should probably consider splitting it into as_shared_dtype and as_compatible_arrays or something): first, we cast everything to an array, then decide on a common dtype and cast everything to that.

The latter could easily be done by using numpy scalars, which as far as I can tell would be supported by most array libraries, including cupy. However, the reason we need to cast to arrays is that the array API (i.e. __array_namespace__) does not allow using scalars of any type, e.g. np.array_api.where (this is important for libraries that don't implement __array_ufunc__ / __array_function__). To clarify, what we're trying to support is something like

import numpy.array_api as np
np.where(cond, cupy_array, python_scalar)

which (intentionally?) does not work.

At the moment, as_shared_dtype (or, really, the hypothetical as_compatible_arrays) correctly casts python_scalar to a 0-d cupy.array for the example above, but if we were to replace cupy_array with chunked_cupy_array or chunked_cupy_array_with_units, the special casing for cupy stops to work and scalars will be cast to 0-d numpy.array. Conceptually, I tend to think of 0-d arrays as equivalent to scalars, hence the suggestion to have cupy treat numpy scalars and 0-d numpy.array the same way (I don't follow the array api closely enough to know whether that was already discussed and rejected).

So really, my question is: how do we support python scalars for libraries that only implement __array_namespace__, given that stopping to do so would be a major breaking change?

Of course, I would prefer removing the special casing for specific libraries, but I wouldn't be opposed to keeping the existing one. I guess as a short-term fix we could just pull _meta out of duck dask arrays and determine the common array type for that (the downside is that we'd add another special case for dask, which in another PR we're actually trying to remove).

As a long-term fix I guess we'd need to revive the stalled nested duck array discussion.

@rgommers
Copy link

So really, my question is: how do we support python scalars for libraries that only implement __array_namespace__, given that stopping to do so would be a major breaking change?

I was considering this question for SciPy (xref scipy#18286) this week, and I think I'm happy with this strategy:

  1. Cast all "array-like" inputs like Python scalars, lists/sequences, and generators, to numpy.ndarray.
  2. Require "same array type" input, forbid mixing numpy-cupy, numpy-pytorch, cupy-pytorch, etc. - this will raise an exception
  3. As a result, cupy-pyscalar and pytorch-pyscalar will also raise an exception.

What that results in is an API that's backwards-compatible for numpy and array-like usage, and much stricter when using other array libraries. That strictness to me is a good thing, because:

  • that's what CuPy, PyTorch & co themselves do, and it works well there
  • it avoids the complexity raised by arbitrary mixing, which results in questions like the one raised in this issue.
    • in case you do need to use a scalar from within a function inside your own library, just convert it explicitly to the desired array type with xp.asarray(a_scalar) giving you a 0-D array of the correct type (add dtype=x.dtype to make sure dtypes match if that matters)

@keewis
Copy link
Collaborator Author

keewis commented Dec 4, 2023

So, after thinking about this for (quite) some time, it appears that one way or another we need to figure out the appropriate base array type of the nested array (regardless of whether or not we disallow passing python scalars to the xarray API... though since it is a breaking change I don't think we will do that).

I've come up with a (recursive) way of extracting the nesting structure in keewis/nested-duck-arrays, which we should be able to use to figure out the leaf array type and keep the current hack until we figure out how to resolve the issue without it.

@yt87
Copy link

yt87 commented Jun 30, 2024

Would this be an acceptable, if temporary, fix for #9195? Modified code in as_shared_dtype:

     array_type_cupy = array_type("cupy")
     # temporary fix
     import nested_duck_arrays.dask

     def _maybe_cupy(seq):
         return any(isinstance(x, array_type_cupy) or 
             is_duck_dask_array(x) and x.__duck_arrays__()[-1].__module__ == 'cupy'
             for x in seq)
        
     # if any(isinstance(x, array_type_cupy) for x in scalars_or_arrays):
     if _maybe_cupy(scalars_or_arrays):
         # end of fix
         import cupy as cp

@keewis
Copy link
Collaborator Author

keewis commented Jul 1, 2024

I'd go with something like

import nested_duck_arrays.dask
import nested_duck_arrays
...

if any(nested_duck_arrays.first_layer(x) is array_type_cupy for x in scalars_or_arrays):
    import cupy as cp

and add nested_duck_arrays.first_layer (with maybe a better name?) which would have a fallback of returning a 1-tuple containing type of x in case x is not a duck array (I'd be happy to relatively quickly release that to PyPI / conda-forge).

We'll need to think about what to do if nested_duck_arrays is not installed, though... something like this, maybe?

try:
    from nested_duck_arrays import first_layer
except ImportError:
    def first_layer(x):
        return type(x)

Also, we'll probably want to push the contents of nested_duck_arrays.dask to dask.array.

@keewis keewis linked a pull request Jul 7, 2024 that will close this issue
3 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
topic-arrays related to flexible array support
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants