Skip to content

Commit

Permalink
Merge pull request #185 from chrishavlin/dask_unyt
Browse files Browse the repository at this point in the history
daskified unyt arrays
  • Loading branch information
neutrinoceros authored Sep 10, 2022
2 parents cc3a737 + 3ba5eef commit de9a5e0
Show file tree
Hide file tree
Showing 6 changed files with 1,153 additions and 9 deletions.
49 changes: 49 additions & 0 deletions docs/usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1377,3 +1377,52 @@ There are three ways to use the context manager:
>>> import unyt
>>> unyt.matplotlib_support()
>>> import matplotlib.pyplot as plt

Working with Dask arrays
++++++++++++++++++++++++

:mod:`unyt` provides the ability to wrap dask arrays with :mod:`unyt`
behavior. The main access point is the :mod:`unyt.dask_array.unyt_from_dask`
function, which allows you to build a :mod:`unyt_dask_array` from a plain dask array
analogous to the creation of a :mod:`unyt_array` from a plain :mod:`numpy.ndarray`:

>>> from unyt import dask_array as uda
>>> import dask.array as da
>>> x = da.arange(10000, chunks=(1000,))
>>> x_da = uda.unyt_from_dask(x, 'm')

Methods that hang off of a :mod:`unyt_dask_array` object and operations on
:mod:`unyt_dask_array` objects will generally preserve units:

>>> x_da.sum().compute()
unyt_quantity(49995000, 'm')
>>> (x_da[:5000] * x_da[5000:]).compute()[:5]
unyt_array([ 0, 5001, 10004, 15009, 20016], 'm**2')

One important caveat is that using Dask array functions may strip units:

>>> da.sum(x_da).compute()
49995000

For simple reductions, you can use the :mod:`reduce_with_units` function:

>>> result = uda.reduce_with_units(da.sum, x_da)
>>> result.compute()
unyt_quantity(49995000, 'm')

But more complex operations may require more careful management of units. Note
that :mod:`reduce_with_units` will accept any of the positional or keyword
arguments for the array function:

>>> import numpy as np
>>> x = da.ones((10000, 3), chunks=(1000, 1000))
>>> x[:,0] = np.nan
>>> x_da = uda.unyt_from_dask(x, 'm')
>>> result = uda.reduce_with_units(da.nansum, x_da, axis=1)
>>> result.compute()[:5]
unyt_array([2., 2., 2., 2., 2.], 'm')

As a final note: the initial Dask array provided to :mod:`dask_array.unyt_from_dask` can be
constructed in any of the usual ways of constructing Dask arrays -- from :mod:`numpy`-like
array instantiation as in the above examples to reading from file or delayed operations.
For more on creating arrays, check out the `Dask documentation <https://docs.dask.org/en/stable/array-creation.html>`_.
4 changes: 4 additions & 0 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ deps =
setuptools
matplotlib!=3.5.0
docutils
dask[array,diagnostics]>=2021.04.1
commands =
pytest --cov=unyt --cov-append --doctest-modules --doctest-plus --doctest-rst --basetemp={envtmpdir} -W once
coverage report --omit='.tox/*'
Expand All @@ -38,6 +39,7 @@ deps =
coverage>=5.0
pytest-cov
pytest-doctestplus
dask[array,diagnostics]>=2021.04.1
commands =
# don't do doctests on old numpy versions
pytest --cov=unyt --cov-append --basetemp={envtmpdir} -W once
Expand All @@ -52,6 +54,7 @@ deps =
coverage>=5.0
pytest-cov
pytest-doctestplus
dask[array,diagnostics]>=2021.04.1
depends = begin
commands =
# don't do doctests in rst files due to lack of way to specify optional
Expand All @@ -69,6 +72,7 @@ deps =
numpy
sympy
matplotlib!=3.5.0
dask[array,diagnostics]>=2021.04.1
commands =
make clean
python -m sphinx -M html "." "_build" -W
Expand Down
34 changes: 34 additions & 0 deletions unyt/_on_demand_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,3 +176,37 @@ def use(self):


_matplotlib = matplotlib_imports()


class dask_imports(object):
_name = "dask"
_array = None
_version = None

def __init__(self):
self._available = not isinstance(self.array, NotAModule)

@property
def array(self):
if self._array is None:
try:
from dask import array
except ImportError:
array = NotAModule(self._name)
self._array = array
return self._array

@property
def __version__(self):
if self._version is None:
try:
import dask

version = dask.__version__
except ImportError:
version = NotAModule(self._name)
self._version = version
return self._version


_dask = dask_imports()
31 changes: 22 additions & 9 deletions unyt/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@

from sympy import Rational

from unyt._on_demand_imports import _astropy, _pint
from unyt._on_demand_imports import _astropy, _dask, _pint
from unyt._pint_conversions import convert_pint_units
from unyt._unit_lookup_table import default_unit_symbol_lut
from unyt.dimensions import angle, temperature
Expand Down Expand Up @@ -280,6 +280,18 @@ def _sanitize_units_convert(possible_units, registry):
return unit


def _apply_power_mapping(ufunc, in_unit, in_size, in_shape, input_kwarg_dict):
# a reduction of a multiply or divide corresponds to
# a repeated product which we implement as an exponent
mul = 1
power_map = POWER_MAPPING[ufunc]
if input_kwarg_dict.get("axis", None) is not None:
unit = in_unit ** (power_map(in_shape[input_kwarg_dict["axis"]]))
else:
unit = in_unit ** (power_map(in_size))
return mul, unit


unary_operators = (
negative,
absolute,
Expand Down Expand Up @@ -1757,14 +1769,7 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
# evaluate the ufunc
out_arr = func(np.asarray(inp), out=out_func, **kwargs)
if ufunc in (multiply, divide) and method == "reduce":
# a reduction of a multiply or divide corresponds to
# a repeated product which we implement as an exponent
mul = 1
power_map = POWER_MAPPING[ufunc]
if "axis" in kwargs and kwargs["axis"] is not None:
unit = u ** (power_map(inp.shape[kwargs["axis"]]))
else:
unit = u ** (power_map(inp.size))
mul, unit = _apply_power_mapping(ufunc, u, inp.size, inp.shape, kwargs)
else:
# get unit of result
mul, unit = self._ufunc_registry[ufunc](u)
Expand All @@ -1775,6 +1780,14 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
# binary ufuncs
i0 = inputs[0]
i1 = inputs[1]

if _dask._available and isinstance(i1, _dask.array.core.Array):
# need to short circuit all this to handle binary operations
# like unyt_quantity(2,'m') / unyt_dask_array_instance
# only need to check the second argument as if the first arg
# is a unyt_dask_array, it won't end up here.
return i1.__array_ufunc__(ufunc, method, *inputs, **kwargs)

# coerce inputs to be ndarrays if they aren't already
inp0 = _coerce_iterable_units(i0)
inp1 = _coerce_iterable_units(i1)
Expand Down
Loading

0 comments on commit de9a5e0

Please sign in to comment.