Skip to content

Commit

Permalink
Fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed May 6, 2020
1 parent 32a37c7 commit a0e699f
Showing 1 changed file with 14 additions and 15 deletions.
29 changes: 14 additions & 15 deletions xarray/core/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,24 +79,24 @@ def check_result_variables(


def subset_dataset_to_block(
graph: dict,
gname: str,
dataset: Dataset,
input_chunks: dict,
input_chunk_bounds,
chunk_tuple: tuple,
graph: dict, gname: str, dataset: Dataset, input_chunk_bounds, chunk_index
):
"""
Creates a task that creates a subsets xarray dataset to a block determined by chunk_index;
whose extents are determined by input_chunk_bounds.
There are subtasks that create subsets of constituent variables.
# mapping from dimension name to chunk index
chunk_index = dict(zip(dataset.dims, chunk_tuple))
TODO: This is modifying graph in-place!
"""

# this will become [[name1, variable1],
# [name2, variable2],
# ...]
# [name2, variable2],
# ...]
# which is passed to dict and then to Dataset
data_vars = []
coords = []

chunk_tuple = tuple(chunk_index.values())
for name, variable in dataset.variables.items():
# make a task that creates tuple of (dims, chunk)
if dask.is_dask_collection(variable.data):
Expand Down Expand Up @@ -410,12 +410,13 @@ def map_blocks(
[dataarray_to_dataset(da) for da in npargs[is_array]]
)

# check that chunk sizes are compatible
input_chunks = dict(npargs[0].chunks)
input_indexes = get_index_vars(npargs[0])
for arg in npargs[1:][is_xarray[1:]]:
assert_chunks_compatible(npargs[0], arg)
input_chunks.update(arg.chunks)
input_indexes.update(arg.indexes)
input_indexes.update(get_index_vars(arg))

if template is None:
# infer template by providing zero-shaped arrays
Expand Down Expand Up @@ -481,12 +482,10 @@ def map_blocks(
# iterate over all possible chunk combinations
for chunk_tuple in itertools.product(*ichunk.values()):
# mapping from dimension name to chunk index
chunk_index = dict(zip(input_chunks.keys(), chunk_tuple))
chunk_index = dict(zip(ichunk.keys(), chunk_tuple))

blocked_args = [
subset_dataset_to_block(
graph, gname, arg, input_chunks, input_chunk_bounds, chunk_tuple
)
subset_dataset_to_block(graph, gname, arg, input_chunk_bounds, chunk_index)
if isxr
else arg
for isxr, arg in zip(is_xarray, npargs)
Expand Down

0 comments on commit a0e699f

Please sign in to comment.