Skip to content

Commit

Permalink
MVP for dask collections in args
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Mar 21, 2020
1 parent d934b3f commit 4d40da2
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 34 deletions.
92 changes: 61 additions & 31 deletions xarray/core/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def subset_dataset_to_chunk(graph, gname, dataset, input_chunks, chunk_tuple):
chunk_variable_task = (f"{gname}-{name}-{chunk[0]}",) + chunk_tuple
graph[chunk_variable_task] = (tuple, [variable.dims, chunk, variable.attrs])
else:
# non-dask array with possibly chunked dimensions
# non-dask array with dimensions that may be chunked on other variables
# index into variable appropriately
subsetter = {}
for dim in variable.dims:
Expand Down Expand Up @@ -121,6 +121,17 @@ def dataset_to_dataarray(obj: Dataset) -> DataArray:
return next(iter(obj.data_vars.values()))


def dataarray_to_dataset(obj: DataArray) -> Dataset:
# only using _to_temp_dataset would break
# func = lambda x: x.to_dataset()
# since that relies on preserving name.
if obj.name is None:
dataset = obj._to_temp_dataset()
else:
dataset = obj.to_dataset()
return dataset


def make_meta(obj):
"""If obj is a DataArray or Dataset, return a new object of the same type and with
the same variables and dtypes, but where all variables have size 0 and numpy
Expand Down Expand Up @@ -208,8 +219,8 @@ def map_blocks(
obj: DataArray, Dataset
Passed to the function as its first argument, one dask chunk at a time.
args: Sequence
Passed verbatim to func after unpacking, after the sliced obj. xarray objects,
if any, will not be split by chunks. Passing dask collections is not allowed.
Passed verbatim to func after unpacking, after the sliced obj.
Any xarray objects will also be split by chunks and then passed on.
kwargs: Mapping
Passed verbatim to func after unpacking. xarray objects, if any, will not be
split by chunks. Passing dask collections is not allowed.
Expand Down Expand Up @@ -282,14 +293,16 @@ def map_blocks(
* time (time) object 1990-01-31 00:00:00 ... 1991-12-31 00:00:00
"""

def _wrapper(func, obj, to_array, args, kwargs, expected):
check_shapes = dict(obj.dims)
def _wrapper(func, args, kwargs, arg_is_array, expected):
check_shapes = dict(args[0].dims)
check_shapes.update(expected["shapes"])

if to_array:
obj = dataset_to_dataarray(obj)
converted_args = [
dataset_to_dataarray(arg) if is_array else arg
for is_array, arg in zip(arg_is_array, args)
]

result = func(obj, *args, **kwargs)
result = func(*converted_args, **kwargs)

# check all dims are present
# this next bit allows lambda x: x.isel(x=1) work
Expand Down Expand Up @@ -323,38 +336,46 @@ def _wrapper(func, obj, to_array, args, kwargs, expected):
elif not isinstance(kwargs, Mapping):
raise TypeError("kwargs must be a mapping (for example, a dict)")

for value in list(args) + list(kwargs.values()):
for value in kwargs.values():
if dask.is_dask_collection(value):
raise TypeError(
"Cannot pass dask collections in args or kwargs yet. Please compute or "
"Cannot pass dask collections in kwargs yet. Please compute or "
"load values before passing to map_blocks."
)

if not dask.is_dask_collection(obj):
return func(obj, *args, **kwargs)

if isinstance(obj, DataArray):
# only using _to_temp_dataset would break
# func = lambda x: x.to_dataset()
# since that relies on preserving name.
if obj.name is None:
dataset = obj._to_temp_dataset()
else:
dataset = obj.to_dataset()
dataset = dataarray_to_dataset(obj)
input_is_array = True
else:
dataset = obj
input_is_array = False

input_chunks = dataset.chunks
dataset_indexes = set(dataset.indexes)
# TODO: align args and dataset here?
# TODO: unify_chunks for args and dataset
input_chunks = dict(dataset.chunks)
input_indexes = dict(dataset.indexes)
converted_args = []
arg_is_array = []
for arg in args:
arg_is_array.append(isinstance(arg, DataArray))
if isinstance(arg, (Dataset, DataArray)):
if isinstance(arg, DataArray):
converted_args.append(dataarray_to_dataset(arg))
input_chunks.update(converted_args[-1].chunks)
input_indexes.update(converted_args[-1].indexes)
else:
converted_args.append(arg)

if template is None:
# infer template by providing zero-shaped arrays
template = infer_template(func, obj, *args, **kwargs)
template_indexes = set(template.indexes)
preserved_indexes = template_indexes & dataset_indexes
new_indexes = template_indexes - dataset_indexes
indexes = {dim: dataset.indexes[dim] for dim in preserved_indexes}
preserved_indexes = template_indexes & set(input_indexes)
new_indexes = template_indexes - set(input_indexes)
indexes = {dim: input_indexes[dim] for dim in preserved_indexes}
indexes.update({k: template.indexes[k] for k in new_indexes})
output_chunks = {
dim: input_chunks[dim] for dim in template.dims if dim in input_chunks
Expand All @@ -363,7 +384,7 @@ def _wrapper(func, obj, to_array, args, kwargs, expected):
else:
# template xarray object has been provided with proper sizes and chunk shapes
template_indexes = set(template.indexes)
indexes = {dim: dataset.indexes[dim] for dim in dataset_indexes}
indexes = input_indexes
indexes.update({k: template.indexes[k] for k in template_indexes})
if isinstance(template, DataArray):
output_chunks = dict(zip(template.dims, template.chunks)) # type: ignore
Expand Down Expand Up @@ -399,10 +420,16 @@ def _wrapper(func, obj, to_array, args, kwargs, expected):
# mapping from dimension name to chunk index
input_chunk_index = dict(zip(input_chunks.keys(), chunk_tuple))

# make a Dataset creation task for one block
dataset_chunk_task = subset_dataset_to_chunk(
graph, gname, dataset, input_chunks, chunk_tuple
)
chunked_args = []
for arg in (dataset,) + tuple(converted_args):
if isinstance(arg, (DataArray, Dataset)):
chunked_args.append(
subset_dataset_to_chunk(
graph, gname, arg, input_chunks, chunk_tuple
)
)
else:
chunked_args.append(arg)

# expected["shapes", "coords", "data_vars"] are used to raise nice error messages in _wrapper
expected = {}
Expand All @@ -420,10 +447,9 @@ def _wrapper(func, obj, to_array, args, kwargs, expected):
graph[from_wrapper] = (
_wrapper,
func,
dataset_chunk_task,
input_is_array,
args,
chunked_args,
kwargs,
[input_is_array] + arg_is_array,
expected,
)

Expand Down Expand Up @@ -455,7 +481,11 @@ def _wrapper(func, obj, to_array, args, kwargs, expected):
# layer.
new_layers[gname_l][key] = (operator.getitem, from_wrapper, name)

hlg = HighLevelGraph.from_collections(gname, graph, dependencies=[dataset])
hlg = HighLevelGraph.from_collections(
gname,
graph,
dependencies=[dataset] + [arg for arg in args if dask.is_dask_collection(arg)],
)

for gname_l, layer in new_layers.items():
# This adds in the getitems for each variable in the dataset.
Expand Down
3 changes: 0 additions & 3 deletions xarray/tests/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -1066,9 +1066,6 @@ def really_bad_func(darray):
with raises_regex(ValueError, "inconsistent chunks"):
xr.map_blocks(bad_func, ds_copy)

with raises_regex(TypeError, "Cannot pass dask collections"):
xr.map_blocks(bad_func, map_da, args=[map_da.chunk()])

with raises_regex(TypeError, "Cannot pass dask collections"):
xr.map_blocks(bad_func, map_da, kwargs=dict(a=map_da.chunk()))

Expand Down

0 comments on commit 4d40da2

Please sign in to comment.