Skip to content

Commit

Permalink
Fix map_blocks example (#4305)
Browse files Browse the repository at this point in the history
  • Loading branch information
TomAugspurger authored Aug 4, 2020
1 parent 8cea798 commit e1dafe6
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 43 deletions.
31 changes: 17 additions & 14 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -3358,9 +3358,12 @@ def map_blocks(
... clim = gb.mean(dim="time")
... return gb - clim
>>> time = xr.cftime_range("1990-01", "1992-01", freq="M")
>>> month = xr.DataArray(time.month, coords={"time": time}, dims=["time"])
>>> np.random.seed(123)
>>> array = xr.DataArray(
... np.random.rand(len(time)), dims="time", coords=[time]
... np.random.rand(len(time)),
... dims=["time"],
... coords={"time": time, "month": month},
... ).chunk()
>>> array.map_blocks(calculate_anomaly, template=array).compute()
<xarray.DataArray (time: 24)>
Expand All @@ -3371,21 +3374,19 @@ def map_blocks(
0.07673453, 0.22865714, 0.19063865, -0.0590131 ])
Coordinates:
* time (time) object 1990-01-31 00:00:00 ... 1991-12-31 00:00:00
month (time) int64 1 2 3 4 5 6 7 8 9 10 11 12 1 2 3 4 5 6 7 8 9 10 11 12
Note that one must explicitly use ``args=[]`` and ``kwargs={}`` to pass arguments
to the function being applied in ``xr.map_blocks()``:
>>> array.map_blocks(
... calculate_anomaly, kwargs={"groupby_type": "time.year"}, template=array,
... )
... ) # doctest: +ELLIPSIS
<xarray.DataArray (time: 24)>
array([ 0.15361741, -0.25671244, -0.31600032, 0.008463 , 0.1766172 ,
-0.11974531, 0.43791243, 0.14197797, -0.06191987, -0.15073425,
-0.19967375, 0.18619794, -0.05100474, -0.42989909, -0.09153273,
0.24841842, -0.30708526, -0.31412523, 0.04197439, 0.0422506 ,
0.14482397, 0.35985481, 0.23487834, 0.12144652])
dask.array<calculate_anomaly-...-<this, shape=(24,), dtype=float64, chunksize=(24,), chunktype=numpy.ndarray>
Coordinates:
* time (time) object 1990-01-31 00:00:00 ... 1991-12-31 00:00:00
* time (time) object 1990-01-31 00:00:00 ... 1991-12-31 00:00:00
month (time) int64 dask.array<chunksize=(24,), meta=np.ndarray>
"""
from .parallel import map_blocks

Expand Down Expand Up @@ -3875,9 +3876,10 @@ def argmin(
>>> array.isel(array.argmin(...))
array(-1)
>>> array = xr.DataArray([[[3, 2, 1], [3, 1, 2], [2, 1, 3]],
... [[1, 3, 2], [2, -5, 1], [2, 3, 1]]],
... dims=("x", "y", "z"))
>>> array = xr.DataArray(
... [[[3, 2, 1], [3, 1, 2], [2, 1, 3]], [[1, 3, 2], [2, -5, 1], [2, 3, 1]]],
... dims=("x", "y", "z"),
... )
>>> array.min(dim="x")
<xarray.DataArray (y: 3, z: 3)>
array([[ 1, 2, 1],
Expand Down Expand Up @@ -3977,9 +3979,10 @@ def argmax(
<xarray.DataArray ()>
array(3)
>>> array = xr.DataArray([[[3, 2, 1], [3, 1, 2], [2, 1, 3]],
... [[1, 3, 2], [2, 5, 1], [2, 3, 1]]],
... dims=("x", "y", "z"))
>>> array = xr.DataArray(
... [[[3, 2, 1], [3, 1, 2], [2, 1, 3]], [[1, 3, 2], [2, 5, 1], [2, 3, 1]]],
... dims=("x", "y", "z"),
... )
>>> array.max(dim="x")
<xarray.DataArray (y: 3, z: 3)>
array([[3, 3, 2],
Expand Down
29 changes: 15 additions & 14 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5817,35 +5817,36 @@ def map_blocks(
... clim = gb.mean(dim="time")
... return gb - clim
>>> time = xr.cftime_range("1990-01", "1992-01", freq="M")
>>> month = xr.DataArray(time.month, coords={"time": time}, dims=["time"])
>>> np.random.seed(123)
>>> array = xr.DataArray(
... np.random.rand(len(time)), dims="time", coords=[time]
... np.random.rand(len(time)),
... dims=["time"],
... coords={"time": time, "month": month},
... ).chunk()
>>> ds = xr.Dataset({"a": array})
>>> ds.map_blocks(calculate_anomaly, template=ds).compute()
<xarray.DataArray (time: 24)>
array([ 0.12894847, 0.11323072, -0.0855964 , -0.09334032, 0.26848862,
0.12382735, 0.22460641, 0.07650108, -0.07673453, -0.22865714,
-0.19063865, 0.0590131 , -0.12894847, -0.11323072, 0.0855964 ,
0.09334032, -0.26848862, -0.12382735, -0.22460641, -0.07650108,
0.07673453, 0.22865714, 0.19063865, -0.0590131 ])
<xarray.Dataset>
Dimensions: (time: 24)
Coordinates:
* time (time) object 1990-01-31 00:00:00 ... 1991-12-31 00:00:00
month (time) int64 1 2 3 4 5 6 7 8 9 10 11 12 1 2 3 4 5 6 7 8 9 10 11 12
Data variables:
a (time) float64 0.1289 0.1132 -0.0856 ... 0.2287 0.1906 -0.05901
Note that one must explicitly use ``args=[]`` and ``kwargs={}`` to pass arguments
to the function being applied in ``xr.map_blocks()``:
>>> ds.map_blocks(
... calculate_anomaly, kwargs={"groupby_type": "time.year"}, template=ds,
... )
<xarray.DataArray (time: 24)>
array([ 0.15361741, -0.25671244, -0.31600032, 0.008463 , 0.1766172 ,
-0.11974531, 0.43791243, 0.14197797, -0.06191987, -0.15073425,
-0.19967375, 0.18619794, -0.05100474, -0.42989909, -0.09153273,
0.24841842, -0.30708526, -0.31412523, 0.04197439, 0.0422506 ,
0.14482397, 0.35985481, 0.23487834, 0.12144652])
<xarray.Dataset>
Dimensions: (time: 24)
Coordinates:
* time (time) object 1990-01-31 00:00:00 ... 1991-12-31 00:00:00
* time (time) object 1990-01-31 00:00:00 ... 1991-12-31 00:00:00
month (time) int64 dask.array<chunksize=(24,), meta=np.ndarray>
Data variables:
a (time) float64 dask.array<chunksize=(24,), meta=np.ndarray>
"""
from .parallel import map_blocks

Expand Down
28 changes: 13 additions & 15 deletions xarray/core/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,11 +235,14 @@ def map_blocks(
... clim = gb.mean(dim="time")
... return gb - clim
>>> time = xr.cftime_range("1990-01", "1992-01", freq="M")
>>> month = xr.DataArray(time.month, coords={"time": time}, dims=["time"])
>>> np.random.seed(123)
>>> array = xr.DataArray(
... np.random.rand(len(time)), dims="time", coords=[time]
... np.random.rand(len(time)),
... dims=["time"],
... coords={"time": time, "month": month},
... ).chunk()
>>> xr.map_blocks(calculate_anomaly, array, template=array).compute()
>>> array.map_blocks(calculate_anomaly, template=array).compute()
<xarray.DataArray (time: 24)>
array([ 0.12894847, 0.11323072, -0.0855964 , -0.09334032, 0.26848862,
0.12382735, 0.22460641, 0.07650108, -0.07673453, -0.22865714,
Expand All @@ -248,25 +251,20 @@ def map_blocks(
0.07673453, 0.22865714, 0.19063865, -0.0590131 ])
Coordinates:
* time (time) object 1990-01-31 00:00:00 ... 1991-12-31 00:00:00
month (time) int64 1 2 3 4 5 6 7 8 9 10 11 12 1 2 3 4 5 6 7 8 9 10 11 12
Note that one must explicitly use ``args=[]`` and ``kwargs={}`` to pass arguments
to the function being applied in ``xr.map_blocks()``:
>>> xr.map_blocks(
... calculate_anomaly,
... array,
... kwargs={"groupby_type": "time.year"},
... template=array,
... )
>>> array.map_blocks(
... calculate_anomaly, kwargs={"groupby_type": "time.year"}, template=array,
... ) # doctest: +ELLIPSIS
<xarray.DataArray (time: 24)>
array([ 0.15361741, -0.25671244, -0.31600032, 0.008463 , 0.1766172 ,
-0.11974531, 0.43791243, 0.14197797, -0.06191987, -0.15073425,
-0.19967375, 0.18619794, -0.05100474, -0.42989909, -0.09153273,
0.24841842, -0.30708526, -0.31412523, 0.04197439, 0.0422506 ,
0.14482397, 0.35985481, 0.23487834, 0.12144652])
dask.array<calculate_anomaly-...-<this, shape=(24,), dtype=float64, chunksize=(24,), chunktype=numpy.ndarray>
Coordinates:
* time (time) object 1990-01-31 00:00:00 ... 1991-12-31 00:00:00
"""
* time (time) object 1990-01-31 00:00:00 ... 1991-12-31 00:00:00
month (time) int64 dask.array<chunksize=(24,), meta=np.ndarray>
"""

def _wrapper(
func: Callable,
Expand Down

0 comments on commit e1dafe6

Please sign in to comment.