Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

map_blocks: Allow passing dask-backed objects in args #3818

Merged
merged 23 commits into from
Jun 7, 2020

Conversation

dcherian
Copy link
Contributor

@dcherian dcherian commented Mar 2, 2020

  • Tests added
  • Passes isort -rc . && black . && mypy . && flake8
  • Fully documented, including whats-new.rst for all changes and api.rst for new API

It parses args and breaks any xarray objects into appropriate blocks before passing them on to the user function.

e.g.

da1 = xr.DataArray(
    np.ones((10, 20)), dims=["x", "y"], coords={"x": np.arange(10), "y": np.arange(20)}
).chunk({"x": 5, "y": 4})
da1

def sumda(da1, da2):
    #print(da1.shape)
    #print(da2.shape)
    return da1 - da2

da3 = (da1 + 1).isel(x=1, drop=True).rename({"y": "k"})
mapped = xr.map_blocks(sumda, da1, args=[da3])
xr.testing.assert_equal(da1-da3, mapped) # passes

@dcherian dcherian force-pushed the map-blocks-dask-args branch 2 times, most recently from 17b9936 to b962053 Compare March 9, 2020 09:35
@dcherian dcherian mentioned this pull request Mar 19, 2020
13 tasks
@jhamman
Copy link
Member

jhamman commented Mar 21, 2020

I've started testing this out and have run into one problem. Here's a simple example that uses template and Dask backed DataArrays in args

ds = xr.tutorial.load_dataset('air_temperature')


def func(X, y):
    ''' a simple reduction (assume the output can't be inferred automatically) '''
    return X.sum('time') + y.min('time')


ds = ds.chunk({'lat': 10, 'lon': 10})
X = ds['air']
y = ds['air'] **2
template = X.sum('time')

expected = func(X, y)
actual = xr.map_blocks(func, X, args=[y], template=template)
xr.testing.assert_identical(actual, expected)

This raises:

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
<ipython-input-9-5491345c8970> in <module>
      5 
      6 expected = func(X, y)
----> 7 actual = xr.map_blocks(func, X, args=[y], template=template)
      8 xr.testing.assert_identical(actual, expected)

/srv/conda/envs/notebook/lib/python3.7/site-packages/xarray/core/parallel.py in map_blocks(func, obj, args, kwargs, template)
    424         # even if length of dimension is changed by the applied function
    425         expected["shapes"] = {
--> 426             k: output_chunks[k][v] for k, v in input_chunk_index.items()
    427         }
    428         expected["data_vars"] = set(template.data_vars.keys())  # type: ignore

/srv/conda/envs/notebook/lib/python3.7/site-packages/xarray/core/parallel.py in <dictcomp>(.0)
    424         # even if length of dimension is changed by the applied function
    425         expected["shapes"] = {
--> 426             k: output_chunks[k][v] for k, v in input_chunk_index.items()
    427         }
    428         expected["data_vars"] = set(template.data_vars.keys())  # type: ignore

KeyError: 'time'

@dcherian dcherian force-pushed the map-blocks-dask-args branch from b962053 to e99033e Compare March 21, 2020 22:02
@pep8speaks
Copy link

pep8speaks commented Mar 21, 2020

Hello @dcherian! Thanks for updating this PR. We checked the lines you've touched for PEP 8 issues, and found:

There are currently no PEP 8 issues detected in this Pull Request. Cheers! 🍻

Comment last updated at 2020-05-29 17:44:28 UTC

@dcherian
Copy link
Contributor Author

I fixed that on the template branch. After a rebase, this example works with assert_equal

@dcherian dcherian force-pushed the map-blocks-dask-args branch from e99033e to a6838f8 Compare March 28, 2020 19:07
xarray/core/parallel.py Outdated Show resolved Hide resolved
aligned = align(*npargs[is_xarray], join="left")
# assigning to object arrays works better when RHS is object array
# https://stackoverflow.com/questions/43645135/boolean-indexing-assignment-of-a-numpy-array-to-a-numpy-array
npargs[is_xarray] = to_object_array(aligned)
Copy link
Contributor Author

@dcherian dcherian Mar 28, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a better way to do this assignment?

np.array(args) ends up computing things.

@jhamman
Copy link
Member

jhamman commented May 6, 2020

@dcherian - can you resolve conflicts here?

@dcherian dcherian force-pushed the map-blocks-dask-args branch from a11bf48 to 04ffa6c Compare May 6, 2020 19:36
xarray/core/parallel.py Outdated Show resolved Hide resolved
xarray/core/parallel.py Outdated Show resolved Hide resolved
xarray/core/parallel.py Outdated Show resolved Hide resolved
else:
dataset = obj
input_is_array = False
npargs = to_object_array([obj] + list(args))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

converting to object array so that we can use boolean indexing to pull out xarray objects

dcherian added 4 commits May 9, 2020 07:01
indexes should just have indexes for output variable. When template was
provided, I was initializing to indexes to contain all input indexes.
It should just have the indexes from template. Otherwise indexes for
any indexed dimensions removed by func will still be propagated.
@dcherian dcherian changed the title [WIP] map_blocks: Allow passing dask-backed objects in args map_blocks: Allow passing dask-backed objects in args May 25, 2020
@dcherian
Copy link
Contributor Author

This is ready for review.

I've minimized the diff. Once this is merged, I'll do some refactoring.

xarray/core/dataarray.py Outdated Show resolved Hide resolved
Copy link
Member

@jhamman jhamman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @dcherian, this is quite close. Would love to get either @shoyer or @TomAugspurger to look this over but everything seems good to me. Just a series of small comments.

xarray/core/parallel.py Outdated Show resolved Hide resolved
xarray/core/parallel.py Outdated Show resolved Hide resolved
xarray/core/parallel.py Outdated Show resolved Hide resolved
xarray/core/parallel.py Outdated Show resolved Hide resolved
@jhamman
Copy link
Member

jhamman commented Jun 2, 2020

Things seem to have died down here. I suggest we merge this as is. As a reminder, the map_blocks function is still marked as an experimental feature so I'm not too concerned about breaking things in the wild. Better to get some early feedback and iterate on the design.

@shoyer
Copy link
Member

shoyer commented Jun 4, 2020 via email

@jhamman jhamman merged commit 2a288f6 into pydata:master Jun 7, 2020
@dcherian dcherian deleted the map-blocks-dask-args branch June 11, 2020 18:22
dcherian added a commit to TomNicholas/xarray that referenced this pull request Jun 24, 2020
…o-combine

* 'master' of github.com:pydata/xarray: (81 commits)
  use builtin python types instead of the numpy alias (pydata#4170)
  Revise pull request template (pydata#4039)
  pint support for Dataset (pydata#3975)
  drop eccodes in docs (pydata#4162)
  Update issue templates inspired/based on dask (pydata#4154)
  Fix failing upstream-dev build & remove docs build (pydata#4160)
  Improve typehints of xr.Dataset.__getitem__ (pydata#4144)
  provide a error summary for assert_allclose (pydata#3847)
  built-in accessor documentation (pydata#3988)
  Recommend installing cftime when time decoding fails. (pydata#4134)
  parameter documentation for DataArray.sel (pydata#4150)
  speed up map_blocks (pydata#4149)
  Remove outdated note from datetime accessor docstring (pydata#4148)
  Fix the upstream-dev pandas build failure (pydata#4138)
  map_blocks: Allow passing dask-backed objects in args (pydata#3818)
  keep attrs in reset_index (pydata#4103)
  Fix open_rasterio() for WarpedVRT with specified src_crs (pydata#4104)
  Allow non-unique and non-monotonic coordinates in get_clean_interp_index and polyfit (pydata#4099)
  update numpy's intersphinx url (pydata#4117)
  xr.infer_freq (pydata#4033)
  ...
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants