Skip to content

Commit

Permalink
Add tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Mar 3, 2020
1 parent ab2148a commit 17b9936
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 1 deletion.
2 changes: 1 addition & 1 deletion xarray/core/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ def _wrapper(func, args, kwargs, arg_is_array, expected_shapes):
if isinstance(template, DataArray):
output_chunks = dict(zip(template.dims, template.chunks)) # type: ignore
else:
output_chunks = template.chunks
output_chunks = dict(template.chunks)

if isinstance(template, DataArray):
result_is_array = True
Expand Down
30 changes: 30 additions & 0 deletions xarray/tests/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -1092,6 +1092,36 @@ def test_map_blocks_convert_args_to_list(obj):
assert_identical(actual, expected)


def test_map_blocks_dask_args():
da1 = xr.DataArray(
np.ones((10, 20)),
dims=["x", "y"],
coords={"x": np.arange(10), "y": np.arange(20)},
).chunk({"x": 5, "y": 4})

# check that block shapes are the same
def sumda(da1, da2):
assert da1.shape == da2.shape
return da1 + da2

da2 = da1 + 1
with raise_if_dask_computes():
mapped = xr.map_blocks(sumda, da1, args=[da2])
xr.testing.assert_equal(da1 + da2, mapped)

# one dimension in common
da2 = (da1 + 1).isel(x=1, drop=True)
with raise_if_dask_computes():
mapped = xr.map_blocks(operator.add, da1, args=[da2])
xr.testing.assert_equal(da1 + da2, mapped)

# test that everything works when dimension names are different
da2 = (da1 + 1).isel(x=1, drop=True).rename({"y": "k"})
with raise_if_dask_computes():
mapped = xr.map_blocks(operator.add, da1, args=[da2])
xr.testing.assert_equal(da1 + da2, mapped)


@pytest.mark.parametrize("obj", [make_da(), make_ds()])
def test_map_blocks_add_attrs(obj):
def add_attrs(obj):
Expand Down

0 comments on commit 17b9936

Please sign in to comment.