Skip to content

Commit

Permalink
Align all xarray objects
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Mar 28, 2020
1 parent 162133e commit a6838f8
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 35 deletions.
69 changes: 34 additions & 35 deletions xarray/core/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

import numpy as np

from .alignment import align
from .dataarray import DataArray
from .dataset import Dataset

Expand All @@ -35,6 +36,13 @@ def get_index_vars(obj):
return {dim: obj[dim] for dim in obj.indexes}


def to_object_array(iterable):
npargs = np.empty((len(iterable),), dtype=np.object)
for idx, item in enumerate(iterable):
npargs[idx] = item
return npargs


def assert_chunks_compatible(a: Dataset, b: Dataset):
a = a.unify_chunks()
b = b.unify_chunks()
Expand Down Expand Up @@ -358,32 +366,30 @@ def _wrapper(func, args, kwargs, arg_is_array, expected):
if not dask.is_dask_collection(obj):
return func(obj, *args, **kwargs)

if isinstance(obj, DataArray):
dataset = dataarray_to_dataset(obj)
input_is_array = True
else:
dataset = obj
input_is_array = False

# TODO: align args and dataset here?
input_chunks = dict(dataset.chunks)
input_indexes = get_index_vars(dataset)
converted_args = []
arg_is_array = []
for arg in args:
arg_is_array.append(isinstance(arg, DataArray))
if isinstance(arg, (Dataset, DataArray)):
if isinstance(arg, DataArray):
converted_args.append(dataarray_to_dataset(arg))
assert_chunks_compatible(dataset, converted_args[-1])
input_chunks.update(converted_args[-1].chunks)
input_indexes.update(converted_args[-1].indexes)
else:
converted_args.append(arg)
npargs = to_object_array([obj] + list(args))
is_xarray = [isinstance(arg, (Dataset, DataArray)) for arg in npargs]
is_array = [isinstance(arg, DataArray) for arg in npargs]

# align all xarray objects
# TODO: should we allow join as a kwarg or force everything to be aligned to the first object?
aligned = align(*npargs[is_xarray], join="left")
# assigning to object arrays works better when RHS is object array
# https://stackoverflow.com/questions/43645135/boolean-indexing-assignment-of-a-numpy-array-to-a-numpy-array
npargs[is_xarray] = to_object_array(aligned)
npargs[is_array] = to_object_array(
[dataarray_to_dataset(da) for da in npargs[is_array]]
)

input_chunks = dict(npargs[0].chunks)
input_indexes = get_index_vars(npargs[0])
for arg in npargs[1:][is_xarray[1:]]:
assert_chunks_compatible(npargs[0], arg)
input_chunks.update(arg.chunks)
input_indexes.update(arg.indexes)

if template is None:
# infer template by providing zero-shaped arrays
template = infer_template(func, obj, *args, **kwargs)
template = infer_template(func, aligned[0], *args, **kwargs)
template_indexes = set(template.indexes)
preserved_indexes = template_indexes & set(input_indexes)
new_indexes = template_indexes - set(input_indexes)
Expand Down Expand Up @@ -420,7 +426,7 @@ def _wrapper(func, args, kwargs, arg_is_array, expected):
graph: Dict[Any, Any] = {}
new_layers: DefaultDict[str, Dict[Any, Any]] = collections.defaultdict(dict)
gname = "{}-{}".format(
dask.utils.funcname(func), dask.base.tokenize(dataset, args, kwargs)
dask.utils.funcname(func), dask.base.tokenize(npargs[0], args, kwargs)
)

# map dims to list of chunk indexes
Expand All @@ -433,9 +439,9 @@ def _wrapper(func, args, kwargs, arg_is_array, expected):

blocked_args = [
subset_dataset_to_block(graph, gname, arg, input_chunks, chunk_tuple)
if isinstance(arg, (DataArray, Dataset))
if isxr
else arg
for arg in (dataset,) + tuple(converted_args)
for isxr, arg in zip(is_xarray, npargs)
]

# expected["shapes", "coords", "data_vars"] are used to raise nice error messages in _wrapper
Expand All @@ -451,14 +457,7 @@ def _wrapper(func, args, kwargs, arg_is_array, expected):
expected["coords"] = set(template.coords.keys()) # type: ignore

from_wrapper = (gname,) + chunk_tuple
graph[from_wrapper] = (
_wrapper,
func,
blocked_args,
kwargs,
[input_is_array] + arg_is_array,
expected,
)
graph[from_wrapper] = (_wrapper, func, blocked_args, kwargs, is_array, expected)

# mapping from variable name to dask graph key
var_key_map: Dict[Hashable, str] = {}
Expand Down Expand Up @@ -491,7 +490,7 @@ def _wrapper(func, args, kwargs, arg_is_array, expected):
hlg = HighLevelGraph.from_collections(
gname,
graph,
dependencies=[dataset] + [arg for arg in args if dask.is_dask_collection(arg)],
dependencies=[arg for arg in npargs if dask.is_dask_collection(arg)],
)

for gname_l, layer in new_layers.items():
Expand Down
4 changes: 4 additions & 0 deletions xarray/tests/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -1124,6 +1124,10 @@ def sumda(da1, da2):
with raises_regex(ValueError, "Chunk sizes along dimension 'x'"):
xr.map_blocks(operator.add, da1, args=[da1.chunk({"x": 1})])

with raise_if_dask_computes():
mapped = xr.map_blocks(operator.add, da1, args=[da1.reindex(x=np.arange(20))])
xr.testing.assert_equal(da1 + da1, mapped)


@pytest.mark.parametrize("obj", [make_da(), make_ds()])
def test_map_blocks_add_attrs(obj):
Expand Down

0 comments on commit a6838f8

Please sign in to comment.