From a6838f859ade76265d11edc02272a190734b66cc Mon Sep 17 00:00:00 2001 From: dcherian Date: Sat, 28 Mar 2020 09:26:13 -0600 Subject: [PATCH] Align all xarray objects --- xarray/core/parallel.py | 69 +++++++++++++++++++-------------------- xarray/tests/test_dask.py | 4 +++ 2 files changed, 38 insertions(+), 35 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 2d4c01963ab..77aaf352092 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -25,6 +25,7 @@ import numpy as np +from .alignment import align from .dataarray import DataArray from .dataset import Dataset @@ -35,6 +36,13 @@ def get_index_vars(obj): return {dim: obj[dim] for dim in obj.indexes} +def to_object_array(iterable): + npargs = np.empty((len(iterable),), dtype=np.object) + for idx, item in enumerate(iterable): + npargs[idx] = item + return npargs + + def assert_chunks_compatible(a: Dataset, b: Dataset): a = a.unify_chunks() b = b.unify_chunks() @@ -358,32 +366,30 @@ def _wrapper(func, args, kwargs, arg_is_array, expected): if not dask.is_dask_collection(obj): return func(obj, *args, **kwargs) - if isinstance(obj, DataArray): - dataset = dataarray_to_dataset(obj) - input_is_array = True - else: - dataset = obj - input_is_array = False - - # TODO: align args and dataset here? - input_chunks = dict(dataset.chunks) - input_indexes = get_index_vars(dataset) - converted_args = [] - arg_is_array = [] - for arg in args: - arg_is_array.append(isinstance(arg, DataArray)) - if isinstance(arg, (Dataset, DataArray)): - if isinstance(arg, DataArray): - converted_args.append(dataarray_to_dataset(arg)) - assert_chunks_compatible(dataset, converted_args[-1]) - input_chunks.update(converted_args[-1].chunks) - input_indexes.update(converted_args[-1].indexes) - else: - converted_args.append(arg) + npargs = to_object_array([obj] + list(args)) + is_xarray = [isinstance(arg, (Dataset, DataArray)) for arg in npargs] + is_array = [isinstance(arg, DataArray) for arg in npargs] + + # align all xarray objects + # TODO: should we allow join as a kwarg or force everything to be aligned to the first object? + aligned = align(*npargs[is_xarray], join="left") + # assigning to object arrays works better when RHS is object array + # https://stackoverflow.com/questions/43645135/boolean-indexing-assignment-of-a-numpy-array-to-a-numpy-array + npargs[is_xarray] = to_object_array(aligned) + npargs[is_array] = to_object_array( + [dataarray_to_dataset(da) for da in npargs[is_array]] + ) + + input_chunks = dict(npargs[0].chunks) + input_indexes = get_index_vars(npargs[0]) + for arg in npargs[1:][is_xarray[1:]]: + assert_chunks_compatible(npargs[0], arg) + input_chunks.update(arg.chunks) + input_indexes.update(arg.indexes) if template is None: # infer template by providing zero-shaped arrays - template = infer_template(func, obj, *args, **kwargs) + template = infer_template(func, aligned[0], *args, **kwargs) template_indexes = set(template.indexes) preserved_indexes = template_indexes & set(input_indexes) new_indexes = template_indexes - set(input_indexes) @@ -420,7 +426,7 @@ def _wrapper(func, args, kwargs, arg_is_array, expected): graph: Dict[Any, Any] = {} new_layers: DefaultDict[str, Dict[Any, Any]] = collections.defaultdict(dict) gname = "{}-{}".format( - dask.utils.funcname(func), dask.base.tokenize(dataset, args, kwargs) + dask.utils.funcname(func), dask.base.tokenize(npargs[0], args, kwargs) ) # map dims to list of chunk indexes @@ -433,9 +439,9 @@ def _wrapper(func, args, kwargs, arg_is_array, expected): blocked_args = [ subset_dataset_to_block(graph, gname, arg, input_chunks, chunk_tuple) - if isinstance(arg, (DataArray, Dataset)) + if isxr else arg - for arg in (dataset,) + tuple(converted_args) + for isxr, arg in zip(is_xarray, npargs) ] # expected["shapes", "coords", "data_vars"] are used to raise nice error messages in _wrapper @@ -451,14 +457,7 @@ def _wrapper(func, args, kwargs, arg_is_array, expected): expected["coords"] = set(template.coords.keys()) # type: ignore from_wrapper = (gname,) + chunk_tuple - graph[from_wrapper] = ( - _wrapper, - func, - blocked_args, - kwargs, - [input_is_array] + arg_is_array, - expected, - ) + graph[from_wrapper] = (_wrapper, func, blocked_args, kwargs, is_array, expected) # mapping from variable name to dask graph key var_key_map: Dict[Hashable, str] = {} @@ -491,7 +490,7 @@ def _wrapper(func, args, kwargs, arg_is_array, expected): hlg = HighLevelGraph.from_collections( gname, graph, - dependencies=[dataset] + [arg for arg in args if dask.is_dask_collection(arg)], + dependencies=[arg for arg in npargs if dask.is_dask_collection(arg)], ) for gname_l, layer in new_layers.items(): diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index e054c86797b..7f1a9d8a2e1 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -1124,6 +1124,10 @@ def sumda(da1, da2): with raises_regex(ValueError, "Chunk sizes along dimension 'x'"): xr.map_blocks(operator.add, da1, args=[da1.chunk({"x": 1})]) + with raise_if_dask_computes(): + mapped = xr.map_blocks(operator.add, da1, args=[da1.reindex(x=np.arange(20))]) + xr.testing.assert_equal(da1 + da1, mapped) + @pytest.mark.parametrize("obj", [make_da(), make_ds()]) def test_map_blocks_add_attrs(obj):