Skip to content

Commit

Permalink
map_blocks: allow user function to add new unindexed dimension. (#3817)
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian authored Mar 21, 2020
1 parent 5354679 commit b6409f0
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 1 deletion.
3 changes: 2 additions & 1 deletion doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ New Features
arguments and should instead pass a single list of dimensions.
(:pull:`3802`)
By `Maximilian Roos <https://github.com/max-sixty>`_
- :py:func:`map_blocks` can now apply functions that add new unindexed dimensions.
By `Deepak Cherian <https://github.com/dcherian>`_
- The new ``Dataset._repr_html_`` and ``DataArray._repr_html_`` (introduced
in 0.14.1) is now on by default. To disable, use
``xarray.set_options(display_style="text")``.
Expand All @@ -60,7 +62,6 @@ New Features
(:issue:`3843`, :pull:`3844`)
By `Aaron Spring <https://github.com/aaronspring>`_.


Bug fixes
~~~~~~~~~
- Fix :py:meth:`Dataset.interp` when indexing array shares coordinates with the
Expand Down
3 changes: 3 additions & 0 deletions xarray/core/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,9 @@ def _wrapper(func, obj, to_array, args, kwargs):
var_chunks.append(input_chunks[dim])
elif dim in indexes:
var_chunks.append((len(indexes[dim]),))
elif dim in template.dims:
# new unindexed dimension
var_chunks.append((template.sizes[dim],))

data = dask.array.Array(
hlg, name=gname_l, chunks=var_chunks, dtype=template[name].dtype
Expand Down
2 changes: 2 additions & 0 deletions xarray/tests/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -1147,6 +1147,7 @@ def test_map_blocks_to_array(map_ds):
lambda x: x.to_dataset(),
lambda x: x.drop_vars("x"),
lambda x: x.expand_dims(k=[1, 2, 3]),
lambda x: x.expand_dims(k=3),
lambda x: x.assign_coords(new_coord=("y", x.y * 2)),
lambda x: x.astype(np.int32),
# TODO: [lambda x: x.isel(x=1).drop_vars("x"), map_da],
Expand All @@ -1167,6 +1168,7 @@ def test_map_blocks_da_transformations(func, map_da):
lambda x: x.drop_vars("a"),
lambda x: x.drop_vars("x"),
lambda x: x.expand_dims(k=[1, 2, 3]),
lambda x: x.expand_dims(k=3),
lambda x: x.rename({"a": "new1", "b": "new2"}),
# TODO: [lambda x: x.isel(x=1)],
],
Expand Down

0 comments on commit b6409f0

Please sign in to comment.