From e45fd526d3d48d1dc4cd511bc45a0fa9b849bcd3 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Mon, 18 Dec 2023 15:39:15 -0700 Subject: [PATCH 01/21] Adapt map_blocks to use new Coordinates API --- xarray/core/coordinates.py | 2 +- xarray/core/parallel.py | 44 ++++++++++++++++++++++++-------------- xarray/tests/test_dask.py | 19 ++++++++++++++++ 3 files changed, 48 insertions(+), 17 deletions(-) diff --git a/xarray/core/coordinates.py b/xarray/core/coordinates.py index cdf1d354be6..c59c5deba16 100644 --- a/xarray/core/coordinates.py +++ b/xarray/core/coordinates.py @@ -213,7 +213,7 @@ class Coordinates(AbstractCoordinates): :py:class:`~xarray.Coordinates` object is passed, its indexes will be added to the new created object. indexes: dict-like, optional - Mapping of where keys are coordinate names and values are + Mapping where keys are coordinate names and values are :py:class:`~xarray.indexes.Index` objects. If None (default), pandas indexes will be created for each dimension coordinate. Passing an empty dictionary will skip this default behavior. diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index f971556b3f7..afe045965f5 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -11,6 +11,7 @@ from xarray.core.alignment import align from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset +from xarray.core.merge import merge from xarray.core.pycompat import is_dask_collection if TYPE_CHECKING: @@ -345,13 +346,16 @@ def _wrapper( for arg in aligned ) + merged_coordinates = merge([arg.coords for arg in aligned]).coords + _, npargs = unzip( sorted(list(zip(xarray_indices, xarray_objs)) + others, key=lambda x: x[0]) ) + # TODO (dcherian) cleanup to just use a Coordinates object + input_indexes = dict(npargs[0]._indexes) # check that chunk sizes are compatible input_chunks = dict(npargs[0].chunks) - input_indexes = dict(npargs[0]._indexes) for arg in xarray_objs[1:]: assert_chunks_compatible(npargs[0], arg) input_chunks.update(arg.chunks) @@ -360,18 +364,29 @@ def _wrapper( if template is None: # infer template by providing zero-shaped arrays template = infer_template(func, aligned[0], *args, **kwargs) - template_indexes = set(template._indexes) - preserved_indexes = template_indexes & set(input_indexes) - new_indexes = template_indexes - set(input_indexes) - indexes = {dim: input_indexes[dim] for dim in preserved_indexes} - indexes.update({k: template._indexes[k] for k in new_indexes}) + template_coords = set(template.coords) + preserved_indexes = template_coords & set(merged_coordinates) + new_indexes = template_coords - set(merged_coordinates) + + preserved_coords = merged_coordinates.to_dataset()[preserved_indexes] + # preserved_coords contains all coordinates bariables that share a dimension + # with any index variable in preserved_indexes + # Drop any unneeded vars in a second pass, this is required for e.g. + # if the mapped function were to drop a non-dimension coordinate variable. + preserved_coords = preserved_coords.drop_vars( + tuple(k for k in preserved_coords.variables if k not in template_coords) + ) + + coordinates = merge( + (preserved_coords, template.coords.to_dataset()[new_indexes]) + ).coords output_chunks: Mapping[Hashable, tuple[int, ...]] = { dim: input_chunks[dim] for dim in template.dims if dim in input_chunks } else: # template xarray object has been provided with proper sizes and chunk shapes - indexes = dict(template._indexes) + coordinates = template.coords output_chunks = template.chunksizes if not output_chunks: raise ValueError( @@ -496,8 +511,10 @@ def subset_dataset_to_block( expected["data_vars"] = set(template.data_vars.keys()) # type: ignore[assignment] expected["coords"] = set(template.coords.keys()) # type: ignore[assignment] expected["indexes"] = { - dim: indexes[dim][_get_chunk_slicer(dim, chunk_index, output_chunk_bounds)] - for dim in indexes + dim: coordinates.xindexes[dim][ + _get_chunk_slicer(dim, chunk_index, output_chunk_bounds) + ] + for dim in coordinates.xindexes } from_wrapper = (gname,) + chunk_tuple @@ -506,7 +523,7 @@ def subset_dataset_to_block( # mapping from variable name to dask graph key var_key_map: dict[Hashable, str] = {} for name, variable in template.variables.items(): - if name in indexes: + if name in coordinates.indexes: continue gname_l = f"{name}-{gname}" var_key_map[name] = gname_l @@ -543,12 +560,7 @@ def subset_dataset_to_block( }, ) - # TODO: benbovy - flexible indexes: make it work with custom indexes - # this will need to pass both indexes and coords to the Dataset constructor - result = Dataset( - coords={k: idx.to_pandas_index() for k, idx in indexes.items()}, - attrs=template.attrs, - ) + result = Dataset(coords=coordinates, attrs=template.attrs) for index in result._indexes: result[index].attrs = template[index].attrs diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index c2a77c97d85..137d6020829 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -1367,6 +1367,25 @@ def test_map_blocks_da_ds_with_template(obj): assert_identical(actual, template) +def test_map_blocks_roundtrip_string_index(): + ds = xr.Dataset( + {"data": (["label"], [1, 2, 3])}, coords={"label": ["foo", "bar", "baz"]} + ).chunk(label=1) + assert ds.label.dtype == np.dtype(" Date: Mon, 18 Dec 2023 16:13:47 -0700 Subject: [PATCH 02/21] cleanup --- xarray/core/parallel.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index afe045965f5..1d707c31c14 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -352,14 +352,11 @@ def _wrapper( sorted(list(zip(xarray_indices, xarray_objs)) + others, key=lambda x: x[0]) ) - # TODO (dcherian) cleanup to just use a Coordinates object - input_indexes = dict(npargs[0]._indexes) # check that chunk sizes are compatible input_chunks = dict(npargs[0].chunks) for arg in xarray_objs[1:]: assert_chunks_compatible(npargs[0], arg) input_chunks.update(arg.chunks) - input_indexes.update(arg._indexes) if template is None: # infer template by providing zero-shaped arrays From b5f763d223c3c7c1312b944d7bf1441418151ebd Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Mon, 18 Dec 2023 16:28:57 -0700 Subject: [PATCH 03/21] typing fixes --- xarray/core/dataarray.py | 2 +- xarray/core/dataset.py | 2 +- xarray/core/parallel.py | 4 +++- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 0335ad3bdda..0f245ff464b 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -80,7 +80,7 @@ try: from dask.dataframe import DataFrame as DaskDataFrame except ImportError: - DaskDataFrame = None # type: ignore + DaskDataFrame = None try: from dask.delayed import Delayed except ImportError: diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 9ec39e74ad1..a6fc0e2ca18 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -171,7 +171,7 @@ try: from dask.dataframe import DataFrame as DaskDataFrame except ImportError: - DaskDataFrame = None # type: ignore + DaskDataFrame = None # list of attributes of pd.DatetimeIndex that are ndarrays of time info diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 1d707c31c14..6b4797af841 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -9,6 +9,7 @@ import numpy as np from xarray.core.alignment import align +from xarray.core.coordinates import Coordinates from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset from xarray.core.merge import merge @@ -358,6 +359,7 @@ def _wrapper( assert_chunks_compatible(npargs[0], arg) input_chunks.update(arg.chunks) + coordinates: Coordinates if template is None: # infer template by providing zero-shaped arrays template = infer_template(func, aligned[0], *args, **kwargs) @@ -499,7 +501,7 @@ def subset_dataset_to_block( # expected["shapes", "coords", "data_vars", "indexes"] are used to # raise nice error messages in _wrapper - expected = {} + expected: dict[Hashable, dict] = {} # input chunk 0 along a dimension maps to output chunk 0 along the same dimension # even if length of dimension is changed by the applied function expected["shapes"] = { From d903384f4f34767f25ba7a42813a665ae732a870 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Fri, 3 Nov 2023 12:25:11 -0600 Subject: [PATCH 04/21] Minimize duplication in `map_blocks` task graph Closes #8409 --- xarray/core/parallel.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 6b4797af841..1c76833983b 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -471,13 +471,21 @@ def subset_dataset_to_block( for dim in variable.dims } subset = variable.isel(subsetter) + if name in chunk_index: + # We are including a dimension coordinate, + # minimize duplication by not copying it in the graph for every chunk. + this_var_chunk_tuple = (chunk_index[name],) + else: + this_var_chunk_tuple = chunk_tuple + chunk_variable_task = ( f"{name}-{gname}-{dask.base.tokenize(subset)}", - ) + chunk_tuple - graph[chunk_variable_task] = ( - tuple, - [subset.dims, subset, subset.attrs], - ) + ) + this_var_chunk_tuple + if chunk_variable_task not in graph: + graph[chunk_variable_task] = ( + tuple, + [subset.dims, subset, subset.attrs], + ) # this task creates dict mapping variable name to above tuple if name in dataset._coord_names: From 9790d3bda2def13bbabb6ab321b8c44ce4540eb2 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Fri, 3 Nov 2023 12:43:12 -0600 Subject: [PATCH 05/21] Some more optimization --- xarray/core/parallel.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 1c76833983b..45c6fdffb82 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -453,10 +453,11 @@ def subset_dataset_to_block( for name, variable in dataset.variables.items(): # make a task that creates tuple of (dims, chunk) if dask.is_dask_collection(variable.data): - # recursively index into dask_keys nested list to get chunk - chunk = variable.__dask_keys__() - for dim in variable.dims: - chunk = chunk[chunk_index[dim]] + # get task name for chunk + chunk = ( + variable.data.name, + *tuple(chunk_index[dim] for dim in variable.dims), + ) chunk_variable_task = (f"{name}-{gname}-{chunk[0]!r}",) + chunk_tuple graph[chunk_variable_task] = ( @@ -484,7 +485,7 @@ def subset_dataset_to_block( if chunk_variable_task not in graph: graph[chunk_variable_task] = ( tuple, - [subset.dims, subset, subset.attrs], + [subset.dims, subset._data, subset.attrs], ) # this task creates dict mapping variable name to above tuple From 7af206a48763e6e20c86cd76f840e6da08119e15 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Fri, 3 Nov 2023 14:24:09 -0600 Subject: [PATCH 06/21] Refactor inserting of in memory data --- xarray/core/parallel.py | 59 ++++++++++++++++++++++++++--------------- 1 file changed, 37 insertions(+), 22 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 45c6fdffb82..a57858006ca 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -146,6 +146,35 @@ def _get_chunk_slicer(dim: Hashable, chunk_index: Mapping, chunk_bounds: Mapping return slice(None) +def _insert_in_memory_data_in_graph( + graph, gname, name, variable, chunk_index, chunk_bounds +): + import dask + + # non-dask array possibly with dimensions chunked on other variables + # index into variable appropriately + subsetter = { + dim: _get_chunk_slicer(dim, chunk_index, chunk_bounds) for dim in variable.dims + } + subset = variable.isel(subsetter) + if name in chunk_index: + # We are including a dimension coordinate, + # minimize duplication by not copying it in the graph for every chunk. + chunk_tuple = (chunk_index[name],) + else: + chunk_tuple = tuple(chunk_index.values()) + + chunk_variable_task = ( + f"{name}-{gname}-{dask.base.tokenize(subset)}", + ) + chunk_tuple + if chunk_variable_task not in graph: + graph[chunk_variable_task] = ( + tuple, + [subset.dims, subset._data, subset.attrs], + ) + return chunk_variable_task + + def map_blocks( func: Callable[..., T_Xarray], obj: DataArray | Dataset, @@ -465,28 +494,14 @@ def subset_dataset_to_block( [variable.dims, chunk, variable.attrs], ) else: - # non-dask array possibly with dimensions chunked on other variables - # index into variable appropriately - subsetter = { - dim: _get_chunk_slicer(dim, chunk_index, input_chunk_bounds) - for dim in variable.dims - } - subset = variable.isel(subsetter) - if name in chunk_index: - # We are including a dimension coordinate, - # minimize duplication by not copying it in the graph for every chunk. - this_var_chunk_tuple = (chunk_index[name],) - else: - this_var_chunk_tuple = chunk_tuple - - chunk_variable_task = ( - f"{name}-{gname}-{dask.base.tokenize(subset)}", - ) + this_var_chunk_tuple - if chunk_variable_task not in graph: - graph[chunk_variable_task] = ( - tuple, - [subset.dims, subset._data, subset.attrs], - ) + chunk_variable_task = _insert_in_memory_data_in_graph( + graph, + gname, + name, + variable, + chunk_index, + input_chunk_bounds, + ) # this task creates dict mapping variable name to above tuple if name in dataset._coord_names: From 62c7e528f696619ed8e44f505c0363fb84e63b63 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Fri, 3 Nov 2023 14:25:26 -0600 Subject: [PATCH 07/21] [WIP] De-duplicate in expected["indexes"] --- xarray/core/parallel.py | 56 ++++++++++++++++++++++++++++------------- 1 file changed, 39 insertions(+), 17 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index a57858006ca..e3a61fbee7a 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -14,6 +14,7 @@ from xarray.core.dataset import Dataset from xarray.core.merge import merge from xarray.core.pycompat import is_dask_collection +from xarray.core.variable import IndexVariable if TYPE_CHECKING: from xarray.core.types import T_Xarray @@ -169,12 +170,18 @@ def _insert_in_memory_data_in_graph( ) + chunk_tuple if chunk_variable_task not in graph: graph[chunk_variable_task] = ( - tuple, - [subset.dims, subset._data, subset.attrs], + type(variable), + subset.dims, + subset._data, + subset.attrs, ) return chunk_variable_task +def assemble_dict(*args): + return {k: v for k, v in zip(args[::2], args[1::2])} + + def map_blocks( func: Callable[..., T_Xarray], obj: DataArray | Dataset, @@ -286,6 +293,7 @@ def _wrapper( kwargs: dict, arg_is_array: Iterable[bool], expected: dict, + expected_indexes: dict, ): """ Wrapper function that receives datasets in args; converts to dataarrays when necessary; @@ -314,8 +322,8 @@ def _wrapper( f"Received dimension {name!r} of length {result.sizes[name]}. " f"Expected length {expected['shapes'][name]}." ) - if name in expected["indexes"]: - expected_index = expected["indexes"][name] + if name in expected_indexes: + expected_index = expected_indexes[name] if not index.equals(expected_index): raise ValueError( f"Expected index {name!r} to be {expected_index!r}. Received {index!r} instead." @@ -495,12 +503,7 @@ def subset_dataset_to_block( ) else: chunk_variable_task = _insert_in_memory_data_in_graph( - graph, - gname, - name, - variable, - chunk_index, - input_chunk_bounds, + graph, gname, name, variable, chunk_index, input_chunk_bounds ) # this task creates dict mapping variable name to above tuple @@ -533,15 +536,34 @@ def subset_dataset_to_block( } expected["data_vars"] = set(template.data_vars.keys()) # type: ignore[assignment] expected["coords"] = set(template.coords.keys()) # type: ignore[assignment] - expected["indexes"] = { - dim: coordinates.xindexes[dim][ - _get_chunk_slicer(dim, chunk_index, output_chunk_bounds) - ] - for dim in coordinates.xindexes - } + + # dask does not replace task names present in the values of a dictionary + # So, we have to be very indirect about it. + expect_ = [] + for dim, index in coordinates.indexes.items(): + expect_ += [dim] + expect_.append( + _insert_in_memory_data_in_graph( + graph, + gname + "-expected", + dim, + IndexVariable(dim, index.index), + chunk_index, + output_chunk_bounds, + ) + ) + expected_indexes = (assemble_dict, *expect_) from_wrapper = (gname,) + chunk_tuple - graph[from_wrapper] = (_wrapper, func, blocked_args, kwargs, is_array, expected) + graph[from_wrapper] = ( + _wrapper, + func, + blocked_args, + kwargs, + is_array, + expected, + expected_indexes, + ) # mapping from variable name to dask graph key var_key_map: dict[Hashable, str] = {} From a7a63c036a75bd832820c873a3085cb7a262012e Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Sun, 5 Nov 2023 13:00:33 -0700 Subject: [PATCH 08/21] Revert "[WIP] De-duplicate in expected["indexes"]" This reverts commit 7276cbf0d198dea94538425556676233e27b6a6d. --- xarray/core/parallel.py | 54 ++++++++++++----------------------------- 1 file changed, 15 insertions(+), 39 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index e3a61fbee7a..e2c95eea03d 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -14,7 +14,6 @@ from xarray.core.dataset import Dataset from xarray.core.merge import merge from xarray.core.pycompat import is_dask_collection -from xarray.core.variable import IndexVariable if TYPE_CHECKING: from xarray.core.types import T_Xarray @@ -170,18 +169,12 @@ def _insert_in_memory_data_in_graph( ) + chunk_tuple if chunk_variable_task not in graph: graph[chunk_variable_task] = ( - type(variable), - subset.dims, - subset._data, - subset.attrs, + tuple, + [subset.dims, subset._data, subset.attrs], ) return chunk_variable_task -def assemble_dict(*args): - return {k: v for k, v in zip(args[::2], args[1::2])} - - def map_blocks( func: Callable[..., T_Xarray], obj: DataArray | Dataset, @@ -293,7 +286,6 @@ def _wrapper( kwargs: dict, arg_is_array: Iterable[bool], expected: dict, - expected_indexes: dict, ): """ Wrapper function that receives datasets in args; converts to dataarrays when necessary; @@ -322,8 +314,8 @@ def _wrapper( f"Received dimension {name!r} of length {result.sizes[name]}. " f"Expected length {expected['shapes'][name]}." ) - if name in expected_indexes: - expected_index = expected_indexes[name] + if name in expected["indexes"]: + expected_index = expected["indexes"][name] if not index.equals(expected_index): raise ValueError( f"Expected index {name!r} to be {expected_index!r}. Received {index!r} instead." @@ -503,7 +495,12 @@ def subset_dataset_to_block( ) else: chunk_variable_task = _insert_in_memory_data_in_graph( - graph, gname, name, variable, chunk_index, input_chunk_bounds + graph, + gname, + name, + variable, + chunk_index, + input_chunk_bounds, ) # this task creates dict mapping variable name to above tuple @@ -536,34 +533,13 @@ def subset_dataset_to_block( } expected["data_vars"] = set(template.data_vars.keys()) # type: ignore[assignment] expected["coords"] = set(template.coords.keys()) # type: ignore[assignment] - - # dask does not replace task names present in the values of a dictionary - # So, we have to be very indirect about it. - expect_ = [] - for dim, index in coordinates.indexes.items(): - expect_ += [dim] - expect_.append( - _insert_in_memory_data_in_graph( - graph, - gname + "-expected", - dim, - IndexVariable(dim, index.index), - chunk_index, - output_chunk_bounds, - ) - ) - expected_indexes = (assemble_dict, *expect_) + expected["indexes"] = { + dim: coordinates.xindexes[dim][_get_chunk_slicer(dim, chunk_index, output_chunk_bounds)] + for dim in coordinates.xindexes + } from_wrapper = (gname,) + chunk_tuple - graph[from_wrapper] = ( - _wrapper, - func, - blocked_args, - kwargs, - is_array, - expected, - expected_indexes, - ) + graph[from_wrapper] = (_wrapper, func, blocked_args, kwargs, is_array, expected) # mapping from variable name to dask graph key var_key_map: dict[Hashable, str] = {} From da95bc48ac13f52ac8b942442dd574ffd08707c2 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Sun, 5 Nov 2023 13:00:44 -0700 Subject: [PATCH 09/21] Revert "Refactor inserting of in memory data" This reverts commit f6557f70eeb34b4f5a6303a9a5a4a8ab7c512bb1. --- xarray/core/parallel.py | 59 +++++++++++++++-------------------------- 1 file changed, 22 insertions(+), 37 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index e2c95eea03d..557a7682c72 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -146,35 +146,6 @@ def _get_chunk_slicer(dim: Hashable, chunk_index: Mapping, chunk_bounds: Mapping return slice(None) -def _insert_in_memory_data_in_graph( - graph, gname, name, variable, chunk_index, chunk_bounds -): - import dask - - # non-dask array possibly with dimensions chunked on other variables - # index into variable appropriately - subsetter = { - dim: _get_chunk_slicer(dim, chunk_index, chunk_bounds) for dim in variable.dims - } - subset = variable.isel(subsetter) - if name in chunk_index: - # We are including a dimension coordinate, - # minimize duplication by not copying it in the graph for every chunk. - chunk_tuple = (chunk_index[name],) - else: - chunk_tuple = tuple(chunk_index.values()) - - chunk_variable_task = ( - f"{name}-{gname}-{dask.base.tokenize(subset)}", - ) + chunk_tuple - if chunk_variable_task not in graph: - graph[chunk_variable_task] = ( - tuple, - [subset.dims, subset._data, subset.attrs], - ) - return chunk_variable_task - - def map_blocks( func: Callable[..., T_Xarray], obj: DataArray | Dataset, @@ -494,14 +465,28 @@ def subset_dataset_to_block( [variable.dims, chunk, variable.attrs], ) else: - chunk_variable_task = _insert_in_memory_data_in_graph( - graph, - gname, - name, - variable, - chunk_index, - input_chunk_bounds, - ) + # non-dask array possibly with dimensions chunked on other variables + # index into variable appropriately + subsetter = { + dim: _get_chunk_slicer(dim, chunk_index, input_chunk_bounds) + for dim in variable.dims + } + subset = variable.isel(subsetter) + if name in chunk_index: + # We are including a dimension coordinate, + # minimize duplication by not copying it in the graph for every chunk. + this_var_chunk_tuple = (chunk_index[name],) + else: + this_var_chunk_tuple = chunk_tuple + + chunk_variable_task = ( + f"{name}-{gname}-{dask.base.tokenize(subset)}", + ) + this_var_chunk_tuple + if chunk_variable_task not in graph: + graph[chunk_variable_task] = ( + tuple, + [subset.dims, subset._data, subset.attrs], + ) # this task creates dict mapping variable name to above tuple if name in dataset._coord_names: From b9974b4fc1345cf88a8307c83fbcd3bb45e5acdb Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Mon, 18 Dec 2023 14:54:54 -0700 Subject: [PATCH 10/21] Be more clever about scalar broadcasting --- xarray/core/parallel.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 557a7682c72..25d05bf64cc 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -450,6 +450,7 @@ def subset_dataset_to_block( coords = [] chunk_tuple = tuple(chunk_index.values()) + chunk_dims_set = set(chunk_index) for name, variable in dataset.variables.items(): # make a task that creates tuple of (dims, chunk) if dask.is_dask_collection(variable.data): @@ -472,15 +473,17 @@ def subset_dataset_to_block( for dim in variable.dims } subset = variable.isel(subsetter) - if name in chunk_index: + if set(variable.dims) < chunk_dims_set: # We are including a dimension coordinate, # minimize duplication by not copying it in the graph for every chunk. - this_var_chunk_tuple = (chunk_index[name],) + this_var_chunk_tuple = tuple( + chunk_index[dim] for dim in variable.dims + ) else: this_var_chunk_tuple = chunk_tuple chunk_variable_task = ( - f"{name}-{gname}-{dask.base.tokenize(subset)}", + f"{name}-{gname}-{dask.base.tokenize(subsetter)}", ) + this_var_chunk_tuple if chunk_variable_task not in graph: graph[chunk_variable_task] = ( From 9b5e4cc5146891f7cd322bae86173e540bdb5879 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Mon, 18 Dec 2023 19:23:51 -0700 Subject: [PATCH 11/21] Small speedup --- xarray/core/parallel.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 25d05bf64cc..67fa9f3b806 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -472,7 +472,6 @@ def subset_dataset_to_block( dim: _get_chunk_slicer(dim, chunk_index, input_chunk_bounds) for dim in variable.dims } - subset = variable.isel(subsetter) if set(variable.dims) < chunk_dims_set: # We are including a dimension coordinate, # minimize duplication by not copying it in the graph for every chunk. @@ -486,6 +485,7 @@ def subset_dataset_to_block( f"{name}-{gname}-{dask.base.tokenize(subsetter)}", ) + this_var_chunk_tuple if chunk_variable_task not in graph: + subset = variable.isel(subsetter) graph[chunk_variable_task] = ( tuple, [subset.dims, subset._data, subset.attrs], @@ -522,7 +522,9 @@ def subset_dataset_to_block( expected["data_vars"] = set(template.data_vars.keys()) # type: ignore[assignment] expected["coords"] = set(template.coords.keys()) # type: ignore[assignment] expected["indexes"] = { - dim: coordinates.xindexes[dim][_get_chunk_slicer(dim, chunk_index, output_chunk_bounds)] + dim: coordinates.xindexes[dim][ + _get_chunk_slicer(dim, chunk_index, output_chunk_bounds) + ] for dim in coordinates.xindexes } From a106569783bb81d4f5e51cf0ea8fd68859596216 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Mon, 18 Dec 2023 19:42:15 -0700 Subject: [PATCH 12/21] Small improvement --- xarray/core/parallel.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 67fa9f3b806..6c5295ff0b2 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -522,10 +522,8 @@ def subset_dataset_to_block( expected["data_vars"] = set(template.data_vars.keys()) # type: ignore[assignment] expected["coords"] = set(template.coords.keys()) # type: ignore[assignment] expected["indexes"] = { - dim: coordinates.xindexes[dim][ - _get_chunk_slicer(dim, chunk_index, output_chunk_bounds) - ] - for dim in coordinates.xindexes + dim: index[_get_chunk_slicer(dim, chunk_index, output_chunk_bounds)] + for dim, index in coordinates.xindexes.items() } from_wrapper = (gname,) + chunk_tuple From 1334009d4bd781d1116e8a075a68927961be51f8 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Mon, 18 Dec 2023 19:56:20 -0700 Subject: [PATCH 13/21] Trim some more. --- xarray/core/parallel.py | 43 +++++++++++++++++++++++++++++------------ 1 file changed, 31 insertions(+), 12 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 6c5295ff0b2..d58d9052c57 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -270,6 +270,10 @@ def _wrapper( result = func(*converted_args, **kwargs) + merged_coordinates = merge( + [arg.coords for arg in args if isinstance(arg, (Dataset, DataArray))] + ).coords + # check all dims are present missing_dimensions = set(expected["shapes"]) - set(result.sizes) if missing_dimensions: @@ -285,12 +289,15 @@ def _wrapper( f"Received dimension {name!r} of length {result.sizes[name]}. " f"Expected length {expected['shapes'][name]}." ) - if name in expected["indexes"]: - expected_index = expected["indexes"][name] - if not index.equals(expected_index): - raise ValueError( - f"Expected index {name!r} to be {expected_index!r}. Received {index!r} instead." - ) + + merged_indexes = collections.ChainMap( + expected["indexes"], merged_coordinates.xindexes + ) + expected_index = merged_indexes.get(name, None) + if expected_index is not None and not index.equals(expected_index): + raise ValueError( + f"Expected index {name!r} to be {expected_index!r}. Received {index!r} instead." + ) # check that all expected variables were returned check_result_variables(result, expected, "coords") @@ -364,11 +371,11 @@ def _wrapper( # infer template by providing zero-shaped arrays template = infer_template(func, aligned[0], *args, **kwargs) template_coords = set(template.coords) - preserved_indexes = template_coords & set(merged_coordinates) - new_indexes = template_coords - set(merged_coordinates) + preserved_coord_names = template_coords & set(merged_coordinates) + new_indexes = set(template.xindexes) - set(merged_coordinates) - preserved_coords = merged_coordinates.to_dataset()[preserved_indexes] - # preserved_coords contains all coordinates bariables that share a dimension + preserved_coords = merged_coordinates.to_dataset()[preserved_coord_names] + # preserved_coords contains all coordinate variables that share a dimension # with any index variable in preserved_indexes # Drop any unneeded vars in a second pass, this is required for e.g. # if the mapped function were to drop a non-dimension coordinate variable. @@ -393,6 +400,13 @@ def _wrapper( " Please construct a template with appropriately chunked dask arrays." ) + new_indexes = set(template.xindexes) - set(merged_coordinates) + modified_indexes = set( + name + for name, xindex in coordinates.xindexes.items() + if not xindex.equals(merged_coordinates.xindexes.get(name, None)) + ) + for dim in output_chunks: if dim in input_chunks and len(input_chunks[dim]) != len(output_chunks[dim]): raise ValueError( @@ -521,9 +535,14 @@ def subset_dataset_to_block( } expected["data_vars"] = set(template.data_vars.keys()) # type: ignore[assignment] expected["coords"] = set(template.coords.keys()) # type: ignore[assignment] + + # Minimize duplication due to broadcasting by only including any new or modified indexes + # Others can be inferred by inputs to wrapper (GH8412) expected["indexes"] = { - dim: index[_get_chunk_slicer(dim, chunk_index, output_chunk_bounds)] - for dim, index in coordinates.xindexes.items() + name: coordinates.xindexes[name][ + _get_chunk_slicer(name, chunk_index, output_chunk_bounds) + ] + for name in (new_indexes | modified_indexes) } from_wrapper = (gname,) + chunk_tuple From 21b09498fb9d177a59c423a2b0060c9246a29217 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Mon, 18 Dec 2023 20:09:27 -0700 Subject: [PATCH 14/21] Restrict numpy code path only for scalars and indexes --- xarray/core/parallel.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index d58d9052c57..f40e0e0f3a8 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -353,6 +353,8 @@ def _wrapper( dataarray_to_dataset(arg) if isinstance(arg, DataArray) else arg for arg in aligned ) + # rechunk any numpy variables appropriately + xarray_objs = tuple(arg.chunk(arg.chunksizes) for arg in xarray_objs) merged_coordinates = merge([arg.coords for arg in aligned]).coords @@ -480,6 +482,8 @@ def subset_dataset_to_block( [variable.dims, chunk, variable.attrs], ) else: + assert name in dataset.dims or variable.ndim == 0 + # non-dask array possibly with dimensions chunked on other variables # index into variable appropriately subsetter = { @@ -498,7 +502,7 @@ def subset_dataset_to_block( chunk_variable_task = ( f"{name}-{gname}-{dask.base.tokenize(subsetter)}", ) + this_var_chunk_tuple - if chunk_variable_task not in graph: + if variable.ndim == 0 or chunk_variable_task not in graph: subset = variable.isel(subsetter) graph[chunk_variable_task] = ( tuple, From 1933e0b9befbb696ef4829f239f4de8164cf83d8 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Mon, 18 Dec 2023 20:13:33 -0700 Subject: [PATCH 15/21] Small cleanup --- xarray/core/parallel.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index f40e0e0f3a8..7635d77d8fd 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -491,8 +491,6 @@ def subset_dataset_to_block( for dim in variable.dims } if set(variable.dims) < chunk_dims_set: - # We are including a dimension coordinate, - # minimize duplication by not copying it in the graph for every chunk. this_var_chunk_tuple = tuple( chunk_index[dim] for dim in variable.dims ) @@ -502,6 +500,8 @@ def subset_dataset_to_block( chunk_variable_task = ( f"{name}-{gname}-{dask.base.tokenize(subsetter)}", ) + this_var_chunk_tuple + # We are including a dimension coordinate, + # minimize duplication by not copying it in the graph for every chunk. if variable.ndim == 0 or chunk_variable_task not in graph: subset = variable.isel(subsetter) graph[chunk_variable_task] = ( @@ -539,7 +539,6 @@ def subset_dataset_to_block( } expected["data_vars"] = set(template.data_vars.keys()) # type: ignore[assignment] expected["coords"] = set(template.coords.keys()) # type: ignore[assignment] - # Minimize duplication due to broadcasting by only including any new or modified indexes # Others can be inferred by inputs to wrapper (GH8412) expected["indexes"] = { @@ -560,14 +559,11 @@ def subset_dataset_to_block( gname_l = f"{name}-{gname}" var_key_map[name] = gname_l - key: tuple[Any, ...] = (gname_l,) - for dim in variable.dims: - if dim in chunk_index: - key += (chunk_index[dim],) - else: - # unchunked dimensions in the input have one chunk in the result - # output can have new dimensions with exactly one chunk - key += (0,) + # unchunked dimensions in the input have one chunk in the result + # output can have new dimensions with exactly one chunk + key: tuple[Any, ...] = (gname_l,) + tuple( + chunk_index[dim] if dim in chunk_index else 0 for dim in variable.dims + ) # We're adding multiple new layers to the graph: # The first new layer is the result of the computation on From a4bda14fae54a33b6ee15095e27b5f89b6fc935f Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Mon, 18 Dec 2023 20:30:00 -0700 Subject: [PATCH 16/21] Add test --- xarray/tests/test_dask.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 137d6020829..8043cfa81d6 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -1746,3 +1746,20 @@ def test_new_index_var_computes_once(): data = dask.array.from_array(np.array([100, 200])) with raise_if_dask_computes(max_computes=1): Dataset(coords={"z": ("z", data)}) + + +def test_minimize_graph_size(): + import cloudpickle + + # regression test for https://github.com/pydata/xarray/issues/8409 + ds = Dataset( + {"foo": (("x", "y", "z"), dask.array.ones((120, 120, 120)))}, + coords={"x": np.arange(120), "y": np.arange(120), "z": np.arange(120)}, + ).chunk(x=20, y=20, z=1) + + actual = len(cloudpickle.dumps(ds.map_blocks(lambda x: x))) + expected = len(cloudpickle.dumps(ds.drop_vars(ds.dims).map_blocks(lambda x: x))) + + # prior to https://github.com/pydata/xarray/pull/8412 + # actual is ~5x expected + assert actual < 2 * expected From 79f15ef2c9f55175e54c5950be98aca5d4e2acf4 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Mon, 18 Dec 2023 20:34:10 -0700 Subject: [PATCH 17/21] typing fixes --- pyproject.toml | 1 + xarray/core/parallel.py | 5 ++++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 3975468d50e..014dec7a6e7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -91,6 +91,7 @@ module = [ "cf_units.*", "cfgrib.*", "cftime.*", + "cloudpickle.*", "cubed.*", "cupy.*", "dask.types.*", diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 7635d77d8fd..10017398c29 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -14,6 +14,7 @@ from xarray.core.dataset import Dataset from xarray.core.merge import merge from xarray.core.pycompat import is_dask_collection +from xarray.core.variable import Variable if TYPE_CHECKING: from xarray.core.types import T_Xarray @@ -290,8 +291,9 @@ def _wrapper( f"Expected length {expected['shapes'][name]}." ) + # ChainMap wants MutableMapping, but xindexes is Mapping merged_indexes = collections.ChainMap( - expected["indexes"], merged_coordinates.xindexes + expected["indexes"], merged_coordinates.xindexes # type: ignore[arg-type] ) expected_index = merged_indexes.get(name, None) if expected_index is not None and not index.equals(expected_index): @@ -467,6 +469,7 @@ def subset_dataset_to_block( chunk_tuple = tuple(chunk_index.values()) chunk_dims_set = set(chunk_index) + variable: Variable for name, variable in dataset.variables.items(): # make a task that creates tuple of (dims, chunk) if dask.is_dask_collection(variable.data): From 9befd55e1d83ac820832f3ca2cba872f11ab1bb2 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 19 Dec 2023 09:32:01 -0700 Subject: [PATCH 18/21] optimize --- xarray/core/parallel.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 10017398c29..622cb7063ea 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -520,6 +520,7 @@ def subset_dataset_to_block( return (Dataset, (dict, data_vars), (dict, coords), dataset.attrs) + include_variables = set(template.variables) - set(coordinates.indexes) # iterate over all possible chunk combinations for chunk_tuple in itertools.product(*ichunk.values()): # mapping from dimension name to chunk index @@ -556,9 +557,8 @@ def subset_dataset_to_block( # mapping from variable name to dask graph key var_key_map: dict[Hashable, str] = {} - for name, variable in template.variables.items(): - if name in coordinates.indexes: - continue + for name in include_variables: + variable = template.variables[name] gname_l = f"{name}-{gname}" var_key_map[name] = gname_l From 84ba7453b0d6f68044e00309c1f785348a13369f Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 19 Dec 2023 09:32:18 -0700 Subject: [PATCH 19/21] reorder --- xarray/core/parallel.py | 138 ++++++++++++++++++++-------------------- 1 file changed, 69 insertions(+), 69 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 622cb7063ea..24ceb089c89 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -147,6 +147,75 @@ def _get_chunk_slicer(dim: Hashable, chunk_index: Mapping, chunk_bounds: Mapping return slice(None) +def subset_dataset_to_block( + graph: dict, gname: str, dataset: Dataset, input_chunk_bounds, chunk_index +): + """ + Creates a task that subsets an xarray dataset to a block determined by chunk_index. + Block extents are determined by input_chunk_bounds. + Also subtasks that subset the constituent variables of a dataset. + """ + import dask + + # this will become [[name1, variable1], + # [name2, variable2], + # ...] + # which is passed to dict and then to Dataset + data_vars = [] + coords = [] + + chunk_tuple = tuple(chunk_index.values()) + chunk_dims_set = set(chunk_index) + variable: Variable + for name, variable in dataset.variables.items(): + # make a task that creates tuple of (dims, chunk) + if dask.is_dask_collection(variable.data): + # get task name for chunk + chunk = ( + variable.data.name, + *tuple(chunk_index[dim] for dim in variable.dims), + ) + + chunk_variable_task = (f"{name}-{gname}-{chunk[0]!r}",) + chunk_tuple + graph[chunk_variable_task] = ( + tuple, + [variable.dims, chunk, variable.attrs], + ) + else: + assert name in dataset.dims or variable.ndim == 0 + + # non-dask array possibly with dimensions chunked on other variables + # index into variable appropriately + subsetter = { + dim: _get_chunk_slicer(dim, chunk_index, input_chunk_bounds) + for dim in variable.dims + } + if set(variable.dims) < chunk_dims_set: + this_var_chunk_tuple = tuple(chunk_index[dim] for dim in variable.dims) + else: + this_var_chunk_tuple = chunk_tuple + + chunk_variable_task = ( + f"{name}-{gname}-{dask.base.tokenize(subsetter)}", + ) + this_var_chunk_tuple + # We are including a dimension coordinate, + # minimize duplication by not copying it in the graph for every chunk. + if variable.ndim == 0 or chunk_variable_task not in graph: + subset = variable.isel(subsetter) + graph[chunk_variable_task] = ( + tuple, + [subset.dims, subset._data, subset.attrs], + ) + + # this task creates dict mapping variable name to above tuple + if name in dataset._coord_names: + coords.append([name, chunk_variable_task]) + else: + data_vars.append([name, chunk_variable_task]) + + return (Dataset, (dict, data_vars), (dict, coords), dataset.attrs) + + def map_blocks( func: Callable[..., T_Xarray], obj: DataArray | Dataset, @@ -451,75 +520,6 @@ def _wrapper( dim: np.cumsum((0,) + chunks_v) for dim, chunks_v in output_chunks.items() } - def subset_dataset_to_block( - graph: dict, gname: str, dataset: Dataset, input_chunk_bounds, chunk_index - ): - """ - Creates a task that subsets an xarray dataset to a block determined by chunk_index. - Block extents are determined by input_chunk_bounds. - Also subtasks that subset the constituent variables of a dataset. - """ - - # this will become [[name1, variable1], - # [name2, variable2], - # ...] - # which is passed to dict and then to Dataset - data_vars = [] - coords = [] - - chunk_tuple = tuple(chunk_index.values()) - chunk_dims_set = set(chunk_index) - variable: Variable - for name, variable in dataset.variables.items(): - # make a task that creates tuple of (dims, chunk) - if dask.is_dask_collection(variable.data): - # get task name for chunk - chunk = ( - variable.data.name, - *tuple(chunk_index[dim] for dim in variable.dims), - ) - - chunk_variable_task = (f"{name}-{gname}-{chunk[0]!r}",) + chunk_tuple - graph[chunk_variable_task] = ( - tuple, - [variable.dims, chunk, variable.attrs], - ) - else: - assert name in dataset.dims or variable.ndim == 0 - - # non-dask array possibly with dimensions chunked on other variables - # index into variable appropriately - subsetter = { - dim: _get_chunk_slicer(dim, chunk_index, input_chunk_bounds) - for dim in variable.dims - } - if set(variable.dims) < chunk_dims_set: - this_var_chunk_tuple = tuple( - chunk_index[dim] for dim in variable.dims - ) - else: - this_var_chunk_tuple = chunk_tuple - - chunk_variable_task = ( - f"{name}-{gname}-{dask.base.tokenize(subsetter)}", - ) + this_var_chunk_tuple - # We are including a dimension coordinate, - # minimize duplication by not copying it in the graph for every chunk. - if variable.ndim == 0 or chunk_variable_task not in graph: - subset = variable.isel(subsetter) - graph[chunk_variable_task] = ( - tuple, - [subset.dims, subset._data, subset.attrs], - ) - - # this task creates dict mapping variable name to above tuple - if name in dataset._coord_names: - coords.append([name, chunk_variable_task]) - else: - data_vars.append([name, chunk_variable_task]) - - return (Dataset, (dict, data_vars), (dict, coords), dataset.attrs) - include_variables = set(template.variables) - set(coordinates.indexes) # iterate over all possible chunk combinations for chunk_tuple in itertools.product(*ichunk.values()): From 9b8c0b36e109ba7dc35b739ceeaa34d5ec986d0c Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Wed, 20 Dec 2023 10:34:43 -0700 Subject: [PATCH 20/21] better test --- xarray/tests/test_dask.py | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 8043cfa81d6..386f1479c26 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -1749,17 +1749,25 @@ def test_new_index_var_computes_once(): def test_minimize_graph_size(): - import cloudpickle - # regression test for https://github.com/pydata/xarray/issues/8409 ds = Dataset( - {"foo": (("x", "y", "z"), dask.array.ones((120, 120, 120)))}, + { + "foo": ( + ("x", "y", "z"), + dask.array.ones((120, 120, 120), chunks=(20, 20, 1)), + ) + }, coords={"x": np.arange(120), "y": np.arange(120), "z": np.arange(120)}, - ).chunk(x=20, y=20, z=1) - - actual = len(cloudpickle.dumps(ds.map_blocks(lambda x: x))) - expected = len(cloudpickle.dumps(ds.drop_vars(ds.dims).map_blocks(lambda x: x))) + ) - # prior to https://github.com/pydata/xarray/pull/8412 - # actual is ~5x expected - assert actual < 2 * expected + mapped = ds.map_blocks(lambda x: x) + graph = dict(mapped.__dask_graph__()) + + numchunks = {k: len(v) for k, v in ds.chunksizes.items()} + for var in "xyz": + actual = len([key for key in graph if var in key[0]]) + # assert that we only include each chunk of an index variable + # is only included once, not the product of number of chunks of + # all the other dimenions. + # e.g. previously for 'x', actual == numchunks['y'] * numchunks['z'] + assert actual == numchunks[var], (actual, numchunks[var]) From 7dd26bba9d1342fb8e0d03d8d464d2561dcb14eb Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Wed, 20 Dec 2023 10:39:57 -0700 Subject: [PATCH 21/21] cleanup + whats-new --- doc/whats-new.rst | 5 ++++- xarray/core/parallel.py | 5 ++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index c0917b7443b..5361e5c0c5e 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -45,7 +45,10 @@ Documentation Internal Changes ~~~~~~~~~~~~~~~~ - +- The implementation of :py:func:`map_blocks` has changed to minimize graph size and duplication of data. + This should be a strict improvement even though the graphs are not always embarassingly parallel any more. + Please open an issue if you spot a regression. (:pull:`8412`, :issue:`8409`). + By `Deepak Cherian `_. - Remove null values before plotting. (:pull:`8535`). By `Jimmy Westling `_. diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index f7b9fe85dfa..3b47520a78c 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -439,8 +439,6 @@ def _wrapper( merged_coordinates = merge([arg.coords for arg in aligned]).coords - merged_coordinates = merge([arg.coords for arg in aligned]).coords - _, npargs = unzip( sorted(list(zip(xarray_indices, xarray_objs)) + others, key=lambda x: x[0]) ) @@ -460,7 +458,7 @@ def _wrapper( new_coord_vars = template_coords - set(merged_coordinates) preserved_coords = merged_coordinates.to_dataset()[preserved_coord_vars] - # preserved_coords contains all coordinates bariables that share a dimension + # preserved_coords contains all coordinates variables that share a dimension # with any index variable in preserved_indexes # Drop any unneeded vars in a second pass, this is required for e.g. # if the mapped function were to drop a non-dimension coordinate variable. @@ -556,6 +554,7 @@ def _wrapper( }, "data_vars": set(template.data_vars.keys()), "coords": set(template.coords.keys()), + # only include new or modified indexes to minimize duplication of data, and graph size. "indexes": { dim: coordinates.xindexes[dim][ _get_chunk_slicer(dim, chunk_index, output_chunk_bounds)