diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 54c01be84d2..5d4a0661288 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -389,7 +389,7 @@ def _wrapper(func, args, kwargs, arg_is_array, expected): if isinstance(template, DataArray): output_chunks = dict(zip(template.dims, template.chunks)) # type: ignore else: - output_chunks = template.chunks # type: ignore + output_chunks = dict(template.chunks) if isinstance(template, DataArray): result_is_array = True diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 8d3f9092cd1..9f381b2c766 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -1092,6 +1092,36 @@ def test_map_blocks_convert_args_to_list(obj): assert_identical(actual, expected) +def test_map_blocks_dask_args(): + da1 = xr.DataArray( + np.ones((10, 20)), + dims=["x", "y"], + coords={"x": np.arange(10), "y": np.arange(20)}, + ).chunk({"x": 5, "y": 4}) + + # check that block shapes are the same + def sumda(da1, da2): + assert da1.shape == da2.shape + return da1 + da2 + + da2 = da1 + 1 + with raise_if_dask_computes(): + mapped = xr.map_blocks(sumda, da1, args=[da2]) + xr.testing.assert_equal(da1 + da2, mapped) + + # one dimension in common + da2 = (da1 + 1).isel(x=1, drop=True) + with raise_if_dask_computes(): + mapped = xr.map_blocks(operator.add, da1, args=[da2]) + xr.testing.assert_equal(da1 + da2, mapped) + + # test that everything works when dimension names are different + da2 = (da1 + 1).isel(x=1, drop=True).rename({"y": "k"}) + with raise_if_dask_computes(): + mapped = xr.map_blocks(operator.add, da1, args=[da2]) + xr.testing.assert_equal(da1 + da2, mapped) + + @pytest.mark.parametrize("obj", [make_da(), make_ds()]) def test_map_blocks_add_attrs(obj): def add_attrs(obj):