Skip to content
forked from pydata/xarray

Commit

Permalink
Merge branch 'main' into map-blocks-indexes-fix
Browse files Browse the repository at this point in the history
* main:
  Adapt map_blocks to use new Coordinates API (pydata#8560)
  add xeofs to ecosystem.rst (pydata#8561)
  Offer a fixture for unifying DataArray & Dataset tests (pydata#8533)
  Generalize cumulative reduction (scan) to non-dask types (pydata#8019)
  • Loading branch information
dcherian committed Dec 20, 2023
2 parents 84ba745 + b444438 commit bf06e12
Show file tree
Hide file tree
Showing 7 changed files with 170 additions and 68 deletions.
1 change: 1 addition & 0 deletions doc/ecosystem.rst
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ Extend xarray capabilities
- `xarray-dataclasses <https://github.com/astropenguin/xarray-dataclasses>`_: xarray extension for typed DataArray and Dataset creation.
- `xarray_einstats <https://xarray-einstats.readthedocs.io>`_: Statistics, linear algebra and einops for xarray
- `xarray_extras <https://github.com/crusaderky/xarray_extras>`_: Advanced algorithms for xarray objects (e.g. integrations/interpolations).
- `xeofs <https://github.com/nicrie/xeofs>`_: PCA/EOF analysis and related techniques, integrated with xarray and Dask for efficient handling of large-scale data.
- `xpublish <https://xpublish.readthedocs.io/>`_: Publish Xarray Datasets via a Zarr compatible REST API.
- `xrft <https://github.com/rabernat/xrft>`_: Fourier transforms for xarray data.
- `xr-scipy <https://xr-scipy.readthedocs.io>`_: A lightweight scipy wrapper for xarray.
Expand Down
4 changes: 4 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,10 @@ Internal Changes

- :py:func:`as_variable` now consistently includes the variable name in any exceptions
raised. (:pull:`7995`). By `Peter Hill <https://github.com/ZedThree>`_
- Redirect cumulative reduction functions internally through the :py:class:`ChunkManagerEntryPoint`,
potentially allowing :py:meth:`~xarray.DataArray.ffill` and :py:meth:`~xarray.DataArray.bfill` to
use non-dask chunked array types.
(:pull:`8019`) By `Tom Nicholas <https://github.com/TomNicholas>`_.
- :py:func:`encode_dataset_coordinates` now sorts coordinates automatically assigned to
`coordinates` attributes during serialization (:issue:`8026`, :pull:`8034`).
`By Ian Carroll <https://github.com/itcarroll>`_.
Expand Down
22 changes: 22 additions & 0 deletions xarray/core/daskmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,28 @@ def reduction(
keepdims=keepdims,
)

def scan(
self,
func: Callable,
binop: Callable,
ident: float,
arr: T_ChunkedArray,
axis: int | None = None,
dtype: np.dtype | None = None,
**kwargs,
) -> DaskArray:
from dask.array.reductions import cumreduction

return cumreduction(
func,
binop,
ident,
arr,
axis=axis,
dtype=dtype,
**kwargs,
)

def apply_gufunc(
self,
func: Callable,
Expand Down
64 changes: 38 additions & 26 deletions xarray/core/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
import itertools
import operator
from collections.abc import Hashable, Iterable, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Callable
from typing import TYPE_CHECKING, Any, Callable, Literal, TypedDict

import numpy as np

from xarray.core.alignment import align
from xarray.core.coordinates import Coordinates
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset
from xarray.core.indexes import Index
from xarray.core.merge import merge
from xarray.core.pycompat import is_dask_collection
from xarray.core.variable import Variable
Expand All @@ -20,6 +21,13 @@
from xarray.core.types import T_Xarray


class ExpectedDict(TypedDict):
shapes: dict[Hashable, int]
coords: set[Hashable]
data_vars: set[Hashable]
indexes: dict[Hashable, Index]


def unzip(iterable):
return zip(*iterable)

Expand All @@ -34,7 +42,9 @@ def assert_chunks_compatible(a: Dataset, b: Dataset):


def check_result_variables(
result: DataArray | Dataset, expected: Mapping[str, Any], kind: str
result: DataArray | Dataset,
expected: ExpectedDict,
kind: Literal["coords", "data_vars"],
):
if kind == "coords":
nice_str = "coordinate"
Expand Down Expand Up @@ -326,7 +336,7 @@ def _wrapper(
args: list,
kwargs: dict,
arg_is_array: Iterable[bool],
expected: dict,
expected: ExpectedDict,
):
"""
Wrapper function that receives datasets in args; converts to dataarrays when necessary;
Expand Down Expand Up @@ -429,6 +439,8 @@ def _wrapper(

merged_coordinates = merge([arg.coords for arg in aligned]).coords

merged_coordinates = merge([arg.coords for arg in aligned]).coords

_, npargs = unzip(
sorted(list(zip(xarray_indices, xarray_objs)) + others, key=lambda x: x[0])
)
Expand All @@ -444,11 +456,11 @@ def _wrapper(
# infer template by providing zero-shaped arrays
template = infer_template(func, aligned[0], *args, **kwargs)
template_coords = set(template.coords)
preserved_coord_names = template_coords & set(merged_coordinates)
new_indexes = set(template.xindexes) - set(merged_coordinates)
preserved_coord_vars = template_coords & set(merged_coordinates)
new_coord_vars = template_coords - set(merged_coordinates)

preserved_coords = merged_coordinates.to_dataset()[preserved_coord_names]
# preserved_coords contains all coordinate variables that share a dimension
preserved_coords = merged_coordinates.to_dataset()[preserved_coord_vars]
# preserved_coords contains all coordinates bariables that share a dimension
# with any index variable in preserved_indexes
# Drop any unneeded vars in a second pass, this is required for e.g.
# if the mapped function were to drop a non-dimension coordinate variable.
Expand All @@ -457,7 +469,7 @@ def _wrapper(
)

coordinates = merge(
(preserved_coords, template.coords.to_dataset()[new_indexes])
(preserved_coords, template.coords.to_dataset()[new_coord_vars])
).coords
output_chunks: Mapping[Hashable, tuple[int, ...]] = {
dim: input_chunks[dim] for dim in template.dims if dim in input_chunks
Expand Down Expand Up @@ -520,7 +532,7 @@ def _wrapper(
dim: np.cumsum((0,) + chunks_v) for dim, chunks_v in output_chunks.items()
}

include_variables = set(template.variables) - set(coordinates.indexes)
computed_variables = set(template.variables) - set(coordinates.indexes)
# iterate over all possible chunk combinations
for chunk_tuple in itertools.product(*ichunk.values()):
# mapping from dimension name to chunk index
Expand All @@ -533,31 +545,31 @@ def _wrapper(
for isxr, arg in zip(is_xarray, npargs)
]

# expected["shapes", "coords", "data_vars", "indexes"] are used to
# raise nice error messages in _wrapper
expected: dict[Hashable, dict] = {}
# input chunk 0 along a dimension maps to output chunk 0 along the same dimension
# even if length of dimension is changed by the applied function
expected["shapes"] = {
k: output_chunks[k][v] for k, v in chunk_index.items() if k in output_chunks
}
expected["data_vars"] = set(template.data_vars.keys()) # type: ignore[assignment]
expected["coords"] = set(template.coords.keys()) # type: ignore[assignment]
# Minimize duplication due to broadcasting by only including any new or modified indexes
# Others can be inferred by inputs to wrapper (GH8412)
expected["indexes"] = {
name: coordinates.xindexes[name][
_get_chunk_slicer(name, chunk_index, output_chunk_bounds)
]
for name in (new_indexes | modified_indexes)
expected: ExpectedDict = {
# input chunk 0 along a dimension maps to output chunk 0 along the same dimension
# even if length of dimension is changed by the applied function
"shapes": {
k: output_chunks[k][v]
for k, v in chunk_index.items()
if k in output_chunks
},
"data_vars": set(template.data_vars.keys()),
"coords": set(template.coords.keys()),
"indexes": {
dim: coordinates.xindexes[dim][
_get_chunk_slicer(dim, chunk_index, output_chunk_bounds)
]
for dim in (new_indexes | modified_indexes)
},
}

from_wrapper = (gname,) + chunk_tuple
graph[from_wrapper] = (_wrapper, func, blocked_args, kwargs, is_array, expected)

# mapping from variable name to dask graph key
var_key_map: dict[Hashable, str] = {}
for name in include_variables:
for name in computed_variables:
variable = template.variables[name]
gname_l = f"{name}-{gname}"
var_key_map[name] = gname_l
Expand Down
37 changes: 37 additions & 0 deletions xarray/core/parallelcompat.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,43 @@ def reduction(
"""
raise NotImplementedError()

def scan(
self,
func: Callable,
binop: Callable,
ident: float,
arr: T_ChunkedArray,
axis: int | None = None,
dtype: np.dtype | None = None,
**kwargs,
) -> T_ChunkedArray:
"""
General version of a 1D scan, also known as a cumulative array reduction.
Used in ``ffill`` and ``bfill`` in xarray.
Parameters
----------
func: callable
Cumulative function like np.cumsum or np.cumprod
binop: callable
Associated binary operator like ``np.cumsum->add`` or ``np.cumprod->mul``
ident: Number
Associated identity like ``np.cumsum->0`` or ``np.cumprod->1``
arr: dask Array
axis: int, optional
dtype: dtype
Returns
-------
Chunked array
See also
--------
dask.array.cumreduction
"""
raise NotImplementedError()

@abstractmethod
def apply_gufunc(
self,
Expand Down
43 changes: 43 additions & 0 deletions xarray/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import numpy as np
import pandas as pd
import pytest
Expand Down Expand Up @@ -77,3 +79,44 @@ def da(request, backend):
return da
else:
raise ValueError


@pytest.fixture(params=[Dataset, DataArray])
def type(request):
return request.param


@pytest.fixture(params=[1])
def d(request, backend, type) -> DataArray | Dataset:
"""
For tests which can test either a DataArray or a Dataset.
"""
result: DataArray | Dataset
if request.param == 1:
ds = Dataset(
dict(
a=(["x", "z"], np.arange(24).reshape(2, 12)),
b=(["y", "z"], np.arange(100, 136).reshape(3, 12).astype(np.float64)),
),
dict(
x=("x", np.linspace(0, 1.0, 2)),
y=range(3),
z=("z", pd.date_range("2000-01-01", periods=12)),
w=("x", ["a", "b"]),
),
)
if type == DataArray:
result = ds["a"].assign_coords(w=ds.coords["w"])
elif type == Dataset:
result = ds
else:
raise ValueError
else:
raise ValueError

if backend == "dask":
return result.chunk()
elif backend == "numpy":
return result
else:
raise ValueError
67 changes: 25 additions & 42 deletions xarray/tests/test_rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,31 @@ def compute_backend(request):
yield request.param


@pytest.mark.parametrize("func", ["mean", "sum"])
@pytest.mark.parametrize("min_periods", [1, 10])
def test_cumulative(d, func, min_periods) -> None:
# One dim
result = getattr(d.cumulative("z", min_periods=min_periods), func)()
expected = getattr(d.rolling(z=d["z"].size, min_periods=min_periods), func)()
assert_identical(result, expected)

# Multiple dim
result = getattr(d.cumulative(["z", "x"], min_periods=min_periods), func)()
expected = getattr(
d.rolling(z=d["z"].size, x=d["x"].size, min_periods=min_periods),
func,
)()
assert_identical(result, expected)


def test_cumulative_vs_cum(d) -> None:
result = d.cumulative("z").sum()
expected = d.cumsum("z")
# cumsum drops the coord of the dimension; cumulative doesn't
expected = expected.assign_coords(z=result["z"])
assert_identical(result, expected)


class TestDataArrayRolling:
@pytest.mark.parametrize("da", (1, 2), indirect=True)
@pytest.mark.parametrize("center", [True, False])
Expand Down Expand Up @@ -485,29 +510,6 @@ def test_rolling_exp_keep_attrs(self, da, func) -> None:
):
da.rolling_exp(time=10, keep_attrs=True)

@pytest.mark.parametrize("func", ["mean", "sum"])
@pytest.mark.parametrize("min_periods", [1, 20])
def test_cumulative(self, da, func, min_periods) -> None:
# One dim
result = getattr(da.cumulative("time", min_periods=min_periods), func)()
expected = getattr(
da.rolling(time=da.time.size, min_periods=min_periods), func
)()
assert_identical(result, expected)

# Multiple dim
result = getattr(da.cumulative(["time", "a"], min_periods=min_periods), func)()
expected = getattr(
da.rolling(time=da.time.size, a=da.a.size, min_periods=min_periods),
func,
)()
assert_identical(result, expected)

def test_cumulative_vs_cum(self, da) -> None:
result = da.cumulative("time").sum()
expected = da.cumsum("time")
assert_identical(result, expected)


class TestDatasetRolling:
@pytest.mark.parametrize(
Expand Down Expand Up @@ -832,25 +834,6 @@ def test_raise_no_warning_dask_rolling_assert_close(self, ds, name) -> None:
expected = getattr(getattr(ds.rolling(time=4), name)().rolling(x=3), name)()
assert_allclose(actual, expected)

@pytest.mark.parametrize("func", ["mean", "sum"])
@pytest.mark.parametrize("ds", (2,), indirect=True)
@pytest.mark.parametrize("min_periods", [1, 10])
def test_cumulative(self, ds, func, min_periods) -> None:
# One dim
result = getattr(ds.cumulative("time", min_periods=min_periods), func)()
expected = getattr(
ds.rolling(time=ds.time.size, min_periods=min_periods), func
)()
assert_identical(result, expected)

# Multiple dim
result = getattr(ds.cumulative(["time", "x"], min_periods=min_periods), func)()
expected = getattr(
ds.rolling(time=ds.time.size, x=ds.x.size, min_periods=min_periods),
func,
)()
assert_identical(result, expected)


@requires_numbagg
class TestDatasetRollingExp:
Expand Down

0 comments on commit bf06e12

Please sign in to comment.