From 766da3480f50d7672fe1a7c1cdf3aa32d8181fcf Mon Sep 17 00:00:00 2001 From: Tom Nicholas Date: Mon, 18 Dec 2023 14:30:18 -0500 Subject: [PATCH 1/4] Generalize cumulative reduction (scan) to non-dask types (#8019) * add scan to ChunkManager ABC * implement scan for dask using cumreduction * generalize push to work for non-dask chunked arrays * whatsnew * fix importerror * Allow arbitrary kwargs Co-authored-by: Deepak Cherian * Type hint return value of T_ChunkedArray Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> * Type hint return value of Dask array * ffill -> bfill in doc/whats-new.rst Co-authored-by: Deepak Cherian * hopefully fix docs warning --------- Co-authored-by: Deepak Cherian Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> --- doc/whats-new.rst | 4 ++++ xarray/core/daskmanager.py | 22 +++++++++++++++++++++ xarray/core/parallelcompat.py | 37 +++++++++++++++++++++++++++++++++++ 3 files changed, 63 insertions(+) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 4188af98e3f..c0917b7443b 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -589,6 +589,10 @@ Internal Changes - :py:func:`as_variable` now consistently includes the variable name in any exceptions raised. (:pull:`7995`). By `Peter Hill `_ +- Redirect cumulative reduction functions internally through the :py:class:`ChunkManagerEntryPoint`, + potentially allowing :py:meth:`~xarray.DataArray.ffill` and :py:meth:`~xarray.DataArray.bfill` to + use non-dask chunked array types. + (:pull:`8019`) By `Tom Nicholas `_. - :py:func:`encode_dataset_coordinates` now sorts coordinates automatically assigned to `coordinates` attributes during serialization (:issue:`8026`, :pull:`8034`). `By Ian Carroll `_. diff --git a/xarray/core/daskmanager.py b/xarray/core/daskmanager.py index 56d8dc9e23a..efa04bc3df2 100644 --- a/xarray/core/daskmanager.py +++ b/xarray/core/daskmanager.py @@ -97,6 +97,28 @@ def reduction( keepdims=keepdims, ) + def scan( + self, + func: Callable, + binop: Callable, + ident: float, + arr: T_ChunkedArray, + axis: int | None = None, + dtype: np.dtype | None = None, + **kwargs, + ) -> DaskArray: + from dask.array.reductions import cumreduction + + return cumreduction( + func, + binop, + ident, + arr, + axis=axis, + dtype=dtype, + **kwargs, + ) + def apply_gufunc( self, func: Callable, diff --git a/xarray/core/parallelcompat.py b/xarray/core/parallelcompat.py index 333059e00ae..37542925dde 100644 --- a/xarray/core/parallelcompat.py +++ b/xarray/core/parallelcompat.py @@ -403,6 +403,43 @@ def reduction( """ raise NotImplementedError() + def scan( + self, + func: Callable, + binop: Callable, + ident: float, + arr: T_ChunkedArray, + axis: int | None = None, + dtype: np.dtype | None = None, + **kwargs, + ) -> T_ChunkedArray: + """ + General version of a 1D scan, also known as a cumulative array reduction. + + Used in ``ffill`` and ``bfill`` in xarray. + + Parameters + ---------- + func: callable + Cumulative function like np.cumsum or np.cumprod + binop: callable + Associated binary operator like ``np.cumsum->add`` or ``np.cumprod->mul`` + ident: Number + Associated identity like ``np.cumsum->0`` or ``np.cumprod->1`` + arr: dask Array + axis: int, optional + dtype: dtype + + Returns + ------- + Chunked array + + See also + -------- + dask.array.cumreduction + """ + raise NotImplementedError() + @abstractmethod def apply_gufunc( self, From 219ef0ce5e5c38f6033b285c356085ea0cce61e5 Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Mon, 18 Dec 2023 13:30:40 -0800 Subject: [PATCH 2/4] Offer a fixture for unifying DataArray & Dataset tests (#8533) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add Cumulative aggregation Offer a fixture for unifying `DataArray` & `Dataset` tests (stacked on #8512, worth reviewing after that's merged) Some tests are literally copy & pasted between DataArray & Dataset tests. This change allows them to use a single test. Not everything will work — sometimes we want to check specifics — but sometimes they will... --- xarray/tests/conftest.py | 43 +++++++++++++++++++++++ xarray/tests/test_rolling.py | 67 ++++++++++++++---------------------- 2 files changed, 68 insertions(+), 42 deletions(-) diff --git a/xarray/tests/conftest.py b/xarray/tests/conftest.py index 6a8cf008f9f..f153c2f4dc0 100644 --- a/xarray/tests/conftest.py +++ b/xarray/tests/conftest.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import numpy as np import pandas as pd import pytest @@ -77,3 +79,44 @@ def da(request, backend): return da else: raise ValueError + + +@pytest.fixture(params=[Dataset, DataArray]) +def type(request): + return request.param + + +@pytest.fixture(params=[1]) +def d(request, backend, type) -> DataArray | Dataset: + """ + For tests which can test either a DataArray or a Dataset. + """ + result: DataArray | Dataset + if request.param == 1: + ds = Dataset( + dict( + a=(["x", "z"], np.arange(24).reshape(2, 12)), + b=(["y", "z"], np.arange(100, 136).reshape(3, 12).astype(np.float64)), + ), + dict( + x=("x", np.linspace(0, 1.0, 2)), + y=range(3), + z=("z", pd.date_range("2000-01-01", periods=12)), + w=("x", ["a", "b"]), + ), + ) + if type == DataArray: + result = ds["a"].assign_coords(w=ds.coords["w"]) + elif type == Dataset: + result = ds + else: + raise ValueError + else: + raise ValueError + + if backend == "dask": + return result.chunk() + elif backend == "numpy": + return result + else: + raise ValueError diff --git a/xarray/tests/test_rolling.py b/xarray/tests/test_rolling.py index 645ec1f85e6..7cb2cd70d29 100644 --- a/xarray/tests/test_rolling.py +++ b/xarray/tests/test_rolling.py @@ -36,6 +36,31 @@ def compute_backend(request): yield request.param +@pytest.mark.parametrize("func", ["mean", "sum"]) +@pytest.mark.parametrize("min_periods", [1, 10]) +def test_cumulative(d, func, min_periods) -> None: + # One dim + result = getattr(d.cumulative("z", min_periods=min_periods), func)() + expected = getattr(d.rolling(z=d["z"].size, min_periods=min_periods), func)() + assert_identical(result, expected) + + # Multiple dim + result = getattr(d.cumulative(["z", "x"], min_periods=min_periods), func)() + expected = getattr( + d.rolling(z=d["z"].size, x=d["x"].size, min_periods=min_periods), + func, + )() + assert_identical(result, expected) + + +def test_cumulative_vs_cum(d) -> None: + result = d.cumulative("z").sum() + expected = d.cumsum("z") + # cumsum drops the coord of the dimension; cumulative doesn't + expected = expected.assign_coords(z=result["z"]) + assert_identical(result, expected) + + class TestDataArrayRolling: @pytest.mark.parametrize("da", (1, 2), indirect=True) @pytest.mark.parametrize("center", [True, False]) @@ -485,29 +510,6 @@ def test_rolling_exp_keep_attrs(self, da, func) -> None: ): da.rolling_exp(time=10, keep_attrs=True) - @pytest.mark.parametrize("func", ["mean", "sum"]) - @pytest.mark.parametrize("min_periods", [1, 20]) - def test_cumulative(self, da, func, min_periods) -> None: - # One dim - result = getattr(da.cumulative("time", min_periods=min_periods), func)() - expected = getattr( - da.rolling(time=da.time.size, min_periods=min_periods), func - )() - assert_identical(result, expected) - - # Multiple dim - result = getattr(da.cumulative(["time", "a"], min_periods=min_periods), func)() - expected = getattr( - da.rolling(time=da.time.size, a=da.a.size, min_periods=min_periods), - func, - )() - assert_identical(result, expected) - - def test_cumulative_vs_cum(self, da) -> None: - result = da.cumulative("time").sum() - expected = da.cumsum("time") - assert_identical(result, expected) - class TestDatasetRolling: @pytest.mark.parametrize( @@ -832,25 +834,6 @@ def test_raise_no_warning_dask_rolling_assert_close(self, ds, name) -> None: expected = getattr(getattr(ds.rolling(time=4), name)().rolling(x=3), name)() assert_allclose(actual, expected) - @pytest.mark.parametrize("func", ["mean", "sum"]) - @pytest.mark.parametrize("ds", (2,), indirect=True) - @pytest.mark.parametrize("min_periods", [1, 10]) - def test_cumulative(self, ds, func, min_periods) -> None: - # One dim - result = getattr(ds.cumulative("time", min_periods=min_periods), func)() - expected = getattr( - ds.rolling(time=ds.time.size, min_periods=min_periods), func - )() - assert_identical(result, expected) - - # Multiple dim - result = getattr(ds.cumulative(["time", "x"], min_periods=min_periods), func)() - expected = getattr( - ds.rolling(time=ds.time.size, x=ds.x.size, min_periods=min_periods), - func, - )() - assert_identical(result, expected) - @requires_numbagg class TestDatasetRollingExp: From b3890a3859993dc53064ff14c2362bb0134b7c56 Mon Sep 17 00:00:00 2001 From: Niclas Rieger <45175997+nicrie@users.noreply.github.com> Date: Tue, 19 Dec 2023 15:39:37 +0100 Subject: [PATCH 3/4] add xeofs to ecosystem.rst (#8561) Suggestion to include [xeofs](https://github.com/nicrie/xeofs) in the xarray ecosystem documentation. xeofs enables fully multidimensional PCA / EOF analysis and related techniques with large datasets, thanks to the integration of xarray and dask. References: - [Github repository](https://github.com/nicrie/xeofs) - [Documentation](https://xeofs.readthedocs.io/en/latest/) - [JOSS review](https://github.com/openjournals/joss-reviews/issues/6060) --- doc/ecosystem.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/ecosystem.rst b/doc/ecosystem.rst index fc5ae963a1d..561e9cdb5b2 100644 --- a/doc/ecosystem.rst +++ b/doc/ecosystem.rst @@ -78,6 +78,7 @@ Extend xarray capabilities - `xarray-dataclasses `_: xarray extension for typed DataArray and Dataset creation. - `xarray_einstats `_: Statistics, linear algebra and einops for xarray - `xarray_extras `_: Advanced algorithms for xarray objects (e.g. integrations/interpolations). +- `xeofs `_: PCA/EOF analysis and related techniques, integrated with xarray and Dask for efficient handling of large-scale data. - `xpublish `_: Publish Xarray Datasets via a Zarr compatible REST API. - `xrft `_: Fourier transforms for xarray data. - `xr-scipy `_: A lightweight scipy wrapper for xarray. From b4444388cb0647c4375d6a364290e4fa5e5f94ba Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Wed, 20 Dec 2023 10:11:16 -0700 Subject: [PATCH 4/4] Adapt map_blocks to use new Coordinates API (#8560) * Adapt map_blocks to use new Coordinates API * cleanup * typing fixes * optimize * small cleanups * Typing fixes --- xarray/core/coordinates.py | 2 +- xarray/core/dataarray.py | 2 +- xarray/core/dataset.py | 2 +- xarray/core/parallel.py | 89 ++++++++++++++++++++++++-------------- xarray/tests/test_dask.py | 19 ++++++++ 5 files changed, 79 insertions(+), 35 deletions(-) diff --git a/xarray/core/coordinates.py b/xarray/core/coordinates.py index cdf1d354be6..c59c5deba16 100644 --- a/xarray/core/coordinates.py +++ b/xarray/core/coordinates.py @@ -213,7 +213,7 @@ class Coordinates(AbstractCoordinates): :py:class:`~xarray.Coordinates` object is passed, its indexes will be added to the new created object. indexes: dict-like, optional - Mapping of where keys are coordinate names and values are + Mapping where keys are coordinate names and values are :py:class:`~xarray.indexes.Index` objects. If None (default), pandas indexes will be created for each dimension coordinate. Passing an empty dictionary will skip this default behavior. diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 0335ad3bdda..0f245ff464b 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -80,7 +80,7 @@ try: from dask.dataframe import DataFrame as DaskDataFrame except ImportError: - DaskDataFrame = None # type: ignore + DaskDataFrame = None try: from dask.delayed import Delayed except ImportError: diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 9ec39e74ad1..a6fc0e2ca18 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -171,7 +171,7 @@ try: from dask.dataframe import DataFrame as DaskDataFrame except ImportError: - DaskDataFrame = None # type: ignore + DaskDataFrame = None # list of attributes of pd.DatetimeIndex that are ndarrays of time info diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index f971556b3f7..ef505b55345 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -4,19 +4,29 @@ import itertools import operator from collections.abc import Hashable, Iterable, Mapping, Sequence -from typing import TYPE_CHECKING, Any, Callable +from typing import TYPE_CHECKING, Any, Callable, Literal, TypedDict import numpy as np from xarray.core.alignment import align +from xarray.core.coordinates import Coordinates from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset +from xarray.core.indexes import Index +from xarray.core.merge import merge from xarray.core.pycompat import is_dask_collection if TYPE_CHECKING: from xarray.core.types import T_Xarray +class ExpectedDict(TypedDict): + shapes: dict[Hashable, int] + coords: set[Hashable] + data_vars: set[Hashable] + indexes: dict[Hashable, Index] + + def unzip(iterable): return zip(*iterable) @@ -31,7 +41,9 @@ def assert_chunks_compatible(a: Dataset, b: Dataset): def check_result_variables( - result: DataArray | Dataset, expected: Mapping[str, Any], kind: str + result: DataArray | Dataset, + expected: ExpectedDict, + kind: Literal["coords", "data_vars"], ): if kind == "coords": nice_str = "coordinate" @@ -254,7 +266,7 @@ def _wrapper( args: list, kwargs: dict, arg_is_array: Iterable[bool], - expected: dict, + expected: ExpectedDict, ): """ Wrapper function that receives datasets in args; converts to dataarrays when necessary; @@ -345,33 +357,45 @@ def _wrapper( for arg in aligned ) + merged_coordinates = merge([arg.coords for arg in aligned]).coords + _, npargs = unzip( sorted(list(zip(xarray_indices, xarray_objs)) + others, key=lambda x: x[0]) ) # check that chunk sizes are compatible input_chunks = dict(npargs[0].chunks) - input_indexes = dict(npargs[0]._indexes) for arg in xarray_objs[1:]: assert_chunks_compatible(npargs[0], arg) input_chunks.update(arg.chunks) - input_indexes.update(arg._indexes) + coordinates: Coordinates if template is None: # infer template by providing zero-shaped arrays template = infer_template(func, aligned[0], *args, **kwargs) - template_indexes = set(template._indexes) - preserved_indexes = template_indexes & set(input_indexes) - new_indexes = template_indexes - set(input_indexes) - indexes = {dim: input_indexes[dim] for dim in preserved_indexes} - indexes.update({k: template._indexes[k] for k in new_indexes}) + template_coords = set(template.coords) + preserved_coord_vars = template_coords & set(merged_coordinates) + new_coord_vars = template_coords - set(merged_coordinates) + + preserved_coords = merged_coordinates.to_dataset()[preserved_coord_vars] + # preserved_coords contains all coordinates bariables that share a dimension + # with any index variable in preserved_indexes + # Drop any unneeded vars in a second pass, this is required for e.g. + # if the mapped function were to drop a non-dimension coordinate variable. + preserved_coords = preserved_coords.drop_vars( + tuple(k for k in preserved_coords.variables if k not in template_coords) + ) + + coordinates = merge( + (preserved_coords, template.coords.to_dataset()[new_coord_vars]) + ).coords output_chunks: Mapping[Hashable, tuple[int, ...]] = { dim: input_chunks[dim] for dim in template.dims if dim in input_chunks } else: # template xarray object has been provided with proper sizes and chunk shapes - indexes = dict(template._indexes) + coordinates = template.coords output_chunks = template.chunksizes if not output_chunks: raise ValueError( @@ -473,6 +497,9 @@ def subset_dataset_to_block( return (Dataset, (dict, data_vars), (dict, coords), dataset.attrs) + # variable names that depend on the computation. Currently, indexes + # cannot be modified in the mapped function, so we exclude thos + computed_variables = set(template.variables) - set(coordinates.xindexes) # iterate over all possible chunk combinations for chunk_tuple in itertools.product(*ichunk.values()): # mapping from dimension name to chunk index @@ -485,19 +512,23 @@ def subset_dataset_to_block( for isxr, arg in zip(is_xarray, npargs) ] - # expected["shapes", "coords", "data_vars", "indexes"] are used to # raise nice error messages in _wrapper - expected = {} - # input chunk 0 along a dimension maps to output chunk 0 along the same dimension - # even if length of dimension is changed by the applied function - expected["shapes"] = { - k: output_chunks[k][v] for k, v in chunk_index.items() if k in output_chunks - } - expected["data_vars"] = set(template.data_vars.keys()) # type: ignore[assignment] - expected["coords"] = set(template.coords.keys()) # type: ignore[assignment] - expected["indexes"] = { - dim: indexes[dim][_get_chunk_slicer(dim, chunk_index, output_chunk_bounds)] - for dim in indexes + expected: ExpectedDict = { + # input chunk 0 along a dimension maps to output chunk 0 along the same dimension + # even if length of dimension is changed by the applied function + "shapes": { + k: output_chunks[k][v] + for k, v in chunk_index.items() + if k in output_chunks + }, + "data_vars": set(template.data_vars.keys()), + "coords": set(template.coords.keys()), + "indexes": { + dim: coordinates.xindexes[dim][ + _get_chunk_slicer(dim, chunk_index, output_chunk_bounds) + ] + for dim in coordinates.xindexes + }, } from_wrapper = (gname,) + chunk_tuple @@ -505,9 +536,8 @@ def subset_dataset_to_block( # mapping from variable name to dask graph key var_key_map: dict[Hashable, str] = {} - for name, variable in template.variables.items(): - if name in indexes: - continue + for name in computed_variables: + variable = template.variables[name] gname_l = f"{name}-{gname}" var_key_map[name] = gname_l @@ -543,12 +573,7 @@ def subset_dataset_to_block( }, ) - # TODO: benbovy - flexible indexes: make it work with custom indexes - # this will need to pass both indexes and coords to the Dataset constructor - result = Dataset( - coords={k: idx.to_pandas_index() for k, idx in indexes.items()}, - attrs=template.attrs, - ) + result = Dataset(coords=coordinates, attrs=template.attrs) for index in result._indexes: result[index].attrs = template[index].attrs diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index c2a77c97d85..137d6020829 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -1367,6 +1367,25 @@ def test_map_blocks_da_ds_with_template(obj): assert_identical(actual, template) +def test_map_blocks_roundtrip_string_index(): + ds = xr.Dataset( + {"data": (["label"], [1, 2, 3])}, coords={"label": ["foo", "bar", "baz"]} + ).chunk(label=1) + assert ds.label.dtype == np.dtype("