From c50100f34f4946ae0a69529a6699a4372bd8118b Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sun, 6 Oct 2024 17:43:49 +0900 Subject: [PATCH] Reimplement DataTree aggregations They now allow for dimensions that are missing on particular nodes, and use Xarray's standard generate_aggregations machinery, like aggregations for DataArray and Dataset. Fixes https://github.com/pydata/xarray/issues/8949, https://github.com/pydata/xarray/issues/8963 --- xarray/core/_aggregations.py | 1308 ++++++++++++++++++++++++++ xarray/core/dataset.py | 14 +- xarray/core/datatree.py | 100 +- xarray/core/utils.py | 20 + xarray/namedarray/_aggregations.py | 103 +- xarray/tests/test_datatree.py | 41 +- xarray/util/generate_aggregations.py | 44 +- 7 files changed, 1536 insertions(+), 94 deletions(-) diff --git a/xarray/core/_aggregations.py b/xarray/core/_aggregations.py index b557ad44a32..6b1029791ea 100644 --- a/xarray/core/_aggregations.py +++ b/xarray/core/_aggregations.py @@ -19,6 +19,1314 @@ flox_available = module_available("flox") +class DataTreeAggregations: + __slots__ = () + + def reduce( + self, + func: Callable[..., Any], + dim: Dims = None, + *, + axis: int | Sequence[int] | None = None, + keep_attrs: bool | None = None, + keepdims: bool = False, + **kwargs: Any, + ) -> Self: + raise NotImplementedError() + + def count( + self, + dim: Dims = None, + *, + keep_attrs: bool | None = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this DataTree's data by applying ``count`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``count``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + keep_attrs : bool or None, optional + If True, ``attrs`` will be copied from the original + object to the new one. If False, the new object will be + returned without attributes. + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``count`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : DataTree + New DataTree with ``count`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + pandas.DataFrame.count + dask.dataframe.DataFrame.count + Dataset.count + DataArray.count + :ref:`agg` + User guide on reduction or aggregation operations. + + Examples + -------- + >>> dt = xr.DataTree( + ... xr.Dataset( + ... data_vars=dict(foo=("time", np.array([1, 2, 3, 0, 2, np.nan]))), + ... coords=dict( + ... time=( + ... "time", + ... pd.date_range("2001-01-01", freq="ME", periods=6), + ... ), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... ), + ... ) + >>> dt + + Group: / + Dimensions: (time: 6) + Coordinates: + * time (time) datetime64[ns] 48B 2001-01-31 2001-02-28 ... 2001-06-30 + labels (time) >> dt.count() + + Group: / + Dimensions: () + Data variables: + foo int64 8B 5 + """ + return self.reduce( + duck_array_ops.count, + dim=dim, + numeric_only=False, + keep_attrs=keep_attrs, + **kwargs, + ) + + def all( + self, + dim: Dims = None, + *, + keep_attrs: bool | None = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this DataTree's data by applying ``all`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``all``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + keep_attrs : bool or None, optional + If True, ``attrs`` will be copied from the original + object to the new one. If False, the new object will be + returned without attributes. + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``all`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : DataTree + New DataTree with ``all`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.all + dask.array.all + Dataset.all + DataArray.all + :ref:`agg` + User guide on reduction or aggregation operations. + + Examples + -------- + >>> dt = xr.DataTree( + ... xr.Dataset( + ... data_vars=dict( + ... foo=( + ... "time", + ... np.array([True, True, True, True, True, False], dtype=bool), + ... ) + ... ), + ... coords=dict( + ... time=( + ... "time", + ... pd.date_range("2001-01-01", freq="ME", periods=6), + ... ), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... ), + ... ) + >>> dt + + Group: / + Dimensions: (time: 6) + Coordinates: + * time (time) datetime64[ns] 48B 2001-01-31 2001-02-28 ... 2001-06-30 + labels (time) >> dt.all() + + Group: / + Dimensions: () + Data variables: + foo bool 1B False + """ + return self.reduce( + duck_array_ops.array_all, + dim=dim, + numeric_only=False, + keep_attrs=keep_attrs, + **kwargs, + ) + + def any( + self, + dim: Dims = None, + *, + keep_attrs: bool | None = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this DataTree's data by applying ``any`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``any``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + keep_attrs : bool or None, optional + If True, ``attrs`` will be copied from the original + object to the new one. If False, the new object will be + returned without attributes. + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``any`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : DataTree + New DataTree with ``any`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.any + dask.array.any + Dataset.any + DataArray.any + :ref:`agg` + User guide on reduction or aggregation operations. + + Examples + -------- + >>> dt = xr.DataTree( + ... xr.Dataset( + ... data_vars=dict( + ... foo=( + ... "time", + ... np.array([True, True, True, True, True, False], dtype=bool), + ... ) + ... ), + ... coords=dict( + ... time=( + ... "time", + ... pd.date_range("2001-01-01", freq="ME", periods=6), + ... ), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... ), + ... ) + >>> dt + + Group: / + Dimensions: (time: 6) + Coordinates: + * time (time) datetime64[ns] 48B 2001-01-31 2001-02-28 ... 2001-06-30 + labels (time) >> dt.any() + + Group: / + Dimensions: () + Data variables: + foo bool 1B True + """ + return self.reduce( + duck_array_ops.array_any, + dim=dim, + numeric_only=False, + keep_attrs=keep_attrs, + **kwargs, + ) + + def max( + self, + dim: Dims = None, + *, + skipna: bool | None = None, + keep_attrs: bool | None = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this DataTree's data by applying ``max`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``max``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + skipna : bool or None, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + keep_attrs : bool or None, optional + If True, ``attrs`` will be copied from the original + object to the new one. If False, the new object will be + returned without attributes. + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``max`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : DataTree + New DataTree with ``max`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.max + dask.array.max + Dataset.max + DataArray.max + :ref:`agg` + User guide on reduction or aggregation operations. + + Examples + -------- + >>> dt = xr.DataTree( + ... xr.Dataset( + ... data_vars=dict(foo=("time", np.array([1, 2, 3, 0, 2, np.nan]))), + ... coords=dict( + ... time=( + ... "time", + ... pd.date_range("2001-01-01", freq="ME", periods=6), + ... ), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... ), + ... ) + >>> dt + + Group: / + Dimensions: (time: 6) + Coordinates: + * time (time) datetime64[ns] 48B 2001-01-31 2001-02-28 ... 2001-06-30 + labels (time) >> dt.max() + + Group: / + Dimensions: () + Data variables: + foo float64 8B 3.0 + + Use ``skipna`` to control whether NaNs are ignored. + + >>> dt.max(skipna=False) + + Group: / + Dimensions: () + Data variables: + foo float64 8B nan + """ + return self.reduce( + duck_array_ops.max, + dim=dim, + skipna=skipna, + numeric_only=False, + keep_attrs=keep_attrs, + **kwargs, + ) + + def min( + self, + dim: Dims = None, + *, + skipna: bool | None = None, + keep_attrs: bool | None = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this DataTree's data by applying ``min`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``min``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + skipna : bool or None, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + keep_attrs : bool or None, optional + If True, ``attrs`` will be copied from the original + object to the new one. If False, the new object will be + returned without attributes. + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``min`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : DataTree + New DataTree with ``min`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.min + dask.array.min + Dataset.min + DataArray.min + :ref:`agg` + User guide on reduction or aggregation operations. + + Examples + -------- + >>> dt = xr.DataTree( + ... xr.Dataset( + ... data_vars=dict(foo=("time", np.array([1, 2, 3, 0, 2, np.nan]))), + ... coords=dict( + ... time=( + ... "time", + ... pd.date_range("2001-01-01", freq="ME", periods=6), + ... ), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... ), + ... ) + >>> dt + + Group: / + Dimensions: (time: 6) + Coordinates: + * time (time) datetime64[ns] 48B 2001-01-31 2001-02-28 ... 2001-06-30 + labels (time) >> dt.min() + + Group: / + Dimensions: () + Data variables: + foo float64 8B 0.0 + + Use ``skipna`` to control whether NaNs are ignored. + + >>> dt.min(skipna=False) + + Group: / + Dimensions: () + Data variables: + foo float64 8B nan + """ + return self.reduce( + duck_array_ops.min, + dim=dim, + skipna=skipna, + numeric_only=False, + keep_attrs=keep_attrs, + **kwargs, + ) + + def mean( + self, + dim: Dims = None, + *, + skipna: bool | None = None, + keep_attrs: bool | None = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this DataTree's data by applying ``mean`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``mean``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + skipna : bool or None, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + keep_attrs : bool or None, optional + If True, ``attrs`` will be copied from the original + object to the new one. If False, the new object will be + returned without attributes. + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``mean`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : DataTree + New DataTree with ``mean`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.mean + dask.array.mean + Dataset.mean + DataArray.mean + :ref:`agg` + User guide on reduction or aggregation operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + + Examples + -------- + >>> dt = xr.DataTree( + ... xr.Dataset( + ... data_vars=dict(foo=("time", np.array([1, 2, 3, 0, 2, np.nan]))), + ... coords=dict( + ... time=( + ... "time", + ... pd.date_range("2001-01-01", freq="ME", periods=6), + ... ), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... ), + ... ) + >>> dt + + Group: / + Dimensions: (time: 6) + Coordinates: + * time (time) datetime64[ns] 48B 2001-01-31 2001-02-28 ... 2001-06-30 + labels (time) >> dt.mean() + + Group: / + Dimensions: () + Data variables: + foo float64 8B 1.6 + + Use ``skipna`` to control whether NaNs are ignored. + + >>> dt.mean(skipna=False) + + Group: / + Dimensions: () + Data variables: + foo float64 8B nan + """ + return self.reduce( + duck_array_ops.mean, + dim=dim, + skipna=skipna, + numeric_only=True, + keep_attrs=keep_attrs, + **kwargs, + ) + + def prod( + self, + dim: Dims = None, + *, + skipna: bool | None = None, + min_count: int | None = None, + keep_attrs: bool | None = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this DataTree's data by applying ``prod`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``prod``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + skipna : bool or None, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + min_count : int or None, optional + The required number of valid values to perform the operation. If + fewer than min_count non-NA values are present the result will be + NA. Only used if skipna is set to True or defaults to True for the + array's dtype. Changed in version 0.17.0: if specified on an integer + array and skipna=True, the result will be a float array. + keep_attrs : bool or None, optional + If True, ``attrs`` will be copied from the original + object to the new one. If False, the new object will be + returned without attributes. + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``prod`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : DataTree + New DataTree with ``prod`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.prod + dask.array.prod + Dataset.prod + DataArray.prod + :ref:`agg` + User guide on reduction or aggregation operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + + Examples + -------- + >>> dt = xr.DataTree( + ... xr.Dataset( + ... data_vars=dict(foo=("time", np.array([1, 2, 3, 0, 2, np.nan]))), + ... coords=dict( + ... time=( + ... "time", + ... pd.date_range("2001-01-01", freq="ME", periods=6), + ... ), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... ), + ... ) + >>> dt + + Group: / + Dimensions: (time: 6) + Coordinates: + * time (time) datetime64[ns] 48B 2001-01-31 2001-02-28 ... 2001-06-30 + labels (time) >> dt.prod() + + Group: / + Dimensions: () + Data variables: + foo float64 8B 0.0 + + Use ``skipna`` to control whether NaNs are ignored. + + >>> dt.prod(skipna=False) + + Group: / + Dimensions: () + Data variables: + foo float64 8B nan + + Specify ``min_count`` for finer control over when NaNs are ignored. + + >>> dt.prod(skipna=True, min_count=2) + + Group: / + Dimensions: () + Data variables: + foo float64 8B 0.0 + """ + return self.reduce( + duck_array_ops.prod, + dim=dim, + skipna=skipna, + min_count=min_count, + numeric_only=True, + keep_attrs=keep_attrs, + **kwargs, + ) + + def sum( + self, + dim: Dims = None, + *, + skipna: bool | None = None, + min_count: int | None = None, + keep_attrs: bool | None = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this DataTree's data by applying ``sum`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``sum``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + skipna : bool or None, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + min_count : int or None, optional + The required number of valid values to perform the operation. If + fewer than min_count non-NA values are present the result will be + NA. Only used if skipna is set to True or defaults to True for the + array's dtype. Changed in version 0.17.0: if specified on an integer + array and skipna=True, the result will be a float array. + keep_attrs : bool or None, optional + If True, ``attrs`` will be copied from the original + object to the new one. If False, the new object will be + returned without attributes. + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``sum`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : DataTree + New DataTree with ``sum`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.sum + dask.array.sum + Dataset.sum + DataArray.sum + :ref:`agg` + User guide on reduction or aggregation operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + + Examples + -------- + >>> dt = xr.DataTree( + ... xr.Dataset( + ... data_vars=dict(foo=("time", np.array([1, 2, 3, 0, 2, np.nan]))), + ... coords=dict( + ... time=( + ... "time", + ... pd.date_range("2001-01-01", freq="ME", periods=6), + ... ), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... ), + ... ) + >>> dt + + Group: / + Dimensions: (time: 6) + Coordinates: + * time (time) datetime64[ns] 48B 2001-01-31 2001-02-28 ... 2001-06-30 + labels (time) >> dt.sum() + + Group: / + Dimensions: () + Data variables: + foo float64 8B 8.0 + + Use ``skipna`` to control whether NaNs are ignored. + + >>> dt.sum(skipna=False) + + Group: / + Dimensions: () + Data variables: + foo float64 8B nan + + Specify ``min_count`` for finer control over when NaNs are ignored. + + >>> dt.sum(skipna=True, min_count=2) + + Group: / + Dimensions: () + Data variables: + foo float64 8B 8.0 + """ + return self.reduce( + duck_array_ops.sum, + dim=dim, + skipna=skipna, + min_count=min_count, + numeric_only=True, + keep_attrs=keep_attrs, + **kwargs, + ) + + def std( + self, + dim: Dims = None, + *, + skipna: bool | None = None, + ddof: int = 0, + keep_attrs: bool | None = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this DataTree's data by applying ``std`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``std``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + skipna : bool or None, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + ddof : int, default: 0 + “Delta Degrees of Freedom”: the divisor used in the calculation is ``N - ddof``, + where ``N`` represents the number of elements. + keep_attrs : bool or None, optional + If True, ``attrs`` will be copied from the original + object to the new one. If False, the new object will be + returned without attributes. + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``std`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : DataTree + New DataTree with ``std`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.std + dask.array.std + Dataset.std + DataArray.std + :ref:`agg` + User guide on reduction or aggregation operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + + Examples + -------- + >>> dt = xr.DataTree( + ... xr.Dataset( + ... data_vars=dict(foo=("time", np.array([1, 2, 3, 0, 2, np.nan]))), + ... coords=dict( + ... time=( + ... "time", + ... pd.date_range("2001-01-01", freq="ME", periods=6), + ... ), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... ), + ... ) + >>> dt + + Group: / + Dimensions: (time: 6) + Coordinates: + * time (time) datetime64[ns] 48B 2001-01-31 2001-02-28 ... 2001-06-30 + labels (time) >> dt.std() + + Group: / + Dimensions: () + Data variables: + foo float64 8B 1.02 + + Use ``skipna`` to control whether NaNs are ignored. + + >>> dt.std(skipna=False) + + Group: / + Dimensions: () + Data variables: + foo float64 8B nan + + Specify ``ddof=1`` for an unbiased estimate. + + >>> dt.std(skipna=True, ddof=1) + + Group: / + Dimensions: () + Data variables: + foo float64 8B 1.14 + """ + return self.reduce( + duck_array_ops.std, + dim=dim, + skipna=skipna, + ddof=ddof, + numeric_only=True, + keep_attrs=keep_attrs, + **kwargs, + ) + + def var( + self, + dim: Dims = None, + *, + skipna: bool | None = None, + ddof: int = 0, + keep_attrs: bool | None = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this DataTree's data by applying ``var`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``var``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + skipna : bool or None, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + ddof : int, default: 0 + “Delta Degrees of Freedom”: the divisor used in the calculation is ``N - ddof``, + where ``N`` represents the number of elements. + keep_attrs : bool or None, optional + If True, ``attrs`` will be copied from the original + object to the new one. If False, the new object will be + returned without attributes. + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``var`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : DataTree + New DataTree with ``var`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.var + dask.array.var + Dataset.var + DataArray.var + :ref:`agg` + User guide on reduction or aggregation operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + + Examples + -------- + >>> dt = xr.DataTree( + ... xr.Dataset( + ... data_vars=dict(foo=("time", np.array([1, 2, 3, 0, 2, np.nan]))), + ... coords=dict( + ... time=( + ... "time", + ... pd.date_range("2001-01-01", freq="ME", periods=6), + ... ), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... ), + ... ) + >>> dt + + Group: / + Dimensions: (time: 6) + Coordinates: + * time (time) datetime64[ns] 48B 2001-01-31 2001-02-28 ... 2001-06-30 + labels (time) >> dt.var() + + Group: / + Dimensions: () + Data variables: + foo float64 8B 1.04 + + Use ``skipna`` to control whether NaNs are ignored. + + >>> dt.var(skipna=False) + + Group: / + Dimensions: () + Data variables: + foo float64 8B nan + + Specify ``ddof=1`` for an unbiased estimate. + + >>> dt.var(skipna=True, ddof=1) + + Group: / + Dimensions: () + Data variables: + foo float64 8B 1.3 + """ + return self.reduce( + duck_array_ops.var, + dim=dim, + skipna=skipna, + ddof=ddof, + numeric_only=True, + keep_attrs=keep_attrs, + **kwargs, + ) + + def median( + self, + dim: Dims = None, + *, + skipna: bool | None = None, + keep_attrs: bool | None = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this DataTree's data by applying ``median`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``median``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + skipna : bool or None, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + keep_attrs : bool or None, optional + If True, ``attrs`` will be copied from the original + object to the new one. If False, the new object will be + returned without attributes. + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``median`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : DataTree + New DataTree with ``median`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.median + dask.array.median + Dataset.median + DataArray.median + :ref:`agg` + User guide on reduction or aggregation operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + + Examples + -------- + >>> dt = xr.DataTree( + ... xr.Dataset( + ... data_vars=dict(foo=("time", np.array([1, 2, 3, 0, 2, np.nan]))), + ... coords=dict( + ... time=( + ... "time", + ... pd.date_range("2001-01-01", freq="ME", periods=6), + ... ), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... ), + ... ) + >>> dt + + Group: / + Dimensions: (time: 6) + Coordinates: + * time (time) datetime64[ns] 48B 2001-01-31 2001-02-28 ... 2001-06-30 + labels (time) >> dt.median() + + Group: / + Dimensions: () + Data variables: + foo float64 8B 2.0 + + Use ``skipna`` to control whether NaNs are ignored. + + >>> dt.median(skipna=False) + + Group: / + Dimensions: () + Data variables: + foo float64 8B nan + """ + return self.reduce( + duck_array_ops.median, + dim=dim, + skipna=skipna, + numeric_only=True, + keep_attrs=keep_attrs, + **kwargs, + ) + + def cumsum( + self, + dim: Dims = None, + *, + skipna: bool | None = None, + keep_attrs: bool | None = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this DataTree's data by applying ``cumsum`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``cumsum``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + skipna : bool or None, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + keep_attrs : bool or None, optional + If True, ``attrs`` will be copied from the original + object to the new one. If False, the new object will be + returned without attributes. + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``cumsum`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : DataTree + New DataTree with ``cumsum`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.cumsum + dask.array.cumsum + Dataset.cumsum + DataArray.cumsum + DataTree.cumulative + :ref:`agg` + User guide on reduction or aggregation operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + + Note that the methods on the ``cumulative`` method are more performant (with numbagg installed) + and better supported. ``cumsum`` and ``cumprod`` may be deprecated + in the future. + + Examples + -------- + >>> dt = xr.DataTree( + ... xr.Dataset( + ... data_vars=dict(foo=("time", np.array([1, 2, 3, 0, 2, np.nan]))), + ... coords=dict( + ... time=( + ... "time", + ... pd.date_range("2001-01-01", freq="ME", periods=6), + ... ), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... ), + ... ) + >>> dt + + Group: / + Dimensions: (time: 6) + Coordinates: + * time (time) datetime64[ns] 48B 2001-01-31 2001-02-28 ... 2001-06-30 + labels (time) >> dt.cumsum() + + Group: / + Dimensions: (time: 6) + Dimensions without coordinates: time + Data variables: + foo (time) float64 48B 1.0 3.0 6.0 6.0 8.0 8.0 + + Use ``skipna`` to control whether NaNs are ignored. + + >>> dt.cumsum(skipna=False) + + Group: / + Dimensions: (time: 6) + Dimensions without coordinates: time + Data variables: + foo (time) float64 48B 1.0 3.0 6.0 6.0 8.0 nan + """ + return self.reduce( + duck_array_ops.cumsum, + dim=dim, + skipna=skipna, + numeric_only=True, + keep_attrs=keep_attrs, + **kwargs, + ) + + def cumprod( + self, + dim: Dims = None, + *, + skipna: bool | None = None, + keep_attrs: bool | None = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this DataTree's data by applying ``cumprod`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``cumprod``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + skipna : bool or None, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + keep_attrs : bool or None, optional + If True, ``attrs`` will be copied from the original + object to the new one. If False, the new object will be + returned without attributes. + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``cumprod`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : DataTree + New DataTree with ``cumprod`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.cumprod + dask.array.cumprod + Dataset.cumprod + DataArray.cumprod + DataTree.cumulative + :ref:`agg` + User guide on reduction or aggregation operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + + Note that the methods on the ``cumulative`` method are more performant (with numbagg installed) + and better supported. ``cumsum`` and ``cumprod`` may be deprecated + in the future. + + Examples + -------- + >>> dt = xr.DataTree( + ... xr.Dataset( + ... data_vars=dict(foo=("time", np.array([1, 2, 3, 0, 2, np.nan]))), + ... coords=dict( + ... time=( + ... "time", + ... pd.date_range("2001-01-01", freq="ME", periods=6), + ... ), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... ), + ... ) + >>> dt + + Group: / + Dimensions: (time: 6) + Coordinates: + * time (time) datetime64[ns] 48B 2001-01-31 2001-02-28 ... 2001-06-30 + labels (time) >> dt.cumprod() + + Group: / + Dimensions: (time: 6) + Dimensions without coordinates: time + Data variables: + foo (time) float64 48B 1.0 2.0 6.0 0.0 0.0 0.0 + + Use ``skipna`` to control whether NaNs are ignored. + + >>> dt.cumprod(skipna=False) + + Group: / + Dimensions: (time: 6) + Dimensions without coordinates: time + Data variables: + foo (time) float64 48B 1.0 2.0 6.0 0.0 0.0 nan + """ + return self.reduce( + duck_array_ops.cumprod, + dim=dim, + skipna=skipna, + numeric_only=True, + keep_attrs=keep_attrs, + **kwargs, + ) + + class DatasetAggregations: __slots__ = () diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index d57a6957553..f3c24be74b6 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -109,6 +109,7 @@ OrderedSet, _default, decode_numpy_dict_values, + dim_arg_to_dims_set, drop_dims_from_indexers, either_dict_or_kwargs, emit_user_level_warning, @@ -6986,18 +6987,7 @@ def reduce( " Please use 'dim' instead." ) - if dim is None or dim is ...: - dims = set(self.dims) - elif isinstance(dim, str) or not isinstance(dim, Iterable): - dims = {dim} - else: - dims = set(dim) - - missing_dimensions = tuple(d for d in dims if d not in self.dims) - if missing_dimensions: - raise ValueError( - f"Dimensions {missing_dimensions} not found in data dimensions {tuple(self.dims)}" - ) + dims = dim_arg_to_dims_set(dim, self.dims) if keep_attrs is None: keep_attrs = _get_keep_attrs(default=False) diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index f8da26c28b5..04f761fc3f9 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -11,9 +11,10 @@ Mapping, ) from html import escape -from typing import TYPE_CHECKING, Any, Literal, NoReturn, Union, overload +from typing import TYPE_CHECKING, Any, Literal, NoReturn, Self, Union, overload from xarray.core import utils +from xarray.core._aggregations import DataTreeAggregations from xarray.core.alignment import align from xarray.core.common import TreeAttrAccessMixin from xarray.core.coordinates import Coordinates, DataTreeCoordinates @@ -37,6 +38,8 @@ FilteredMapping, Frozen, _default, + dim_arg_to_dims_set, + drop_dims_from_indexers, either_dict_or_kwargs, maybe_wrap_array, ) @@ -54,7 +57,13 @@ from xarray.core.datatree_io import T_DataTreeNetcdfEngine, T_DataTreeNetcdfTypes from xarray.core.merge import CoercibleMapping, CoercibleValue - from xarray.core.types import ErrorOptions, NetcdfWriteModes, ZarrWriteModes + from xarray.core.types import ( + Dims, + ErrorOptions, + ErrorOptionsWithWarn, + NetcdfWriteModes, + ZarrWriteModes, + ) # """ # DEVELOPERS' NOTE @@ -398,6 +407,7 @@ def map( # type: ignore[override] class DataTree( NamedNode["DataTree"], + DataTreeAggregations, TreeAttrAccessMixin, Mapping[str, "DataArray | DataTree"], ): @@ -1607,3 +1617,89 @@ def to_zarr( compute=compute, **kwargs, ) + + def _get_all_dims(self) -> set: + all_dims = set() + for node in self.subtree: + all_dims.update(node._node_dims) + return all_dims + + def _selective_indexing( + self, + func: Callable[[Dataset, Mapping[Any, Any]]], + indexers: Mapping[Any, Any], + missing_dims: ErrorOptionsWithWarn = "raise", + ) -> Self: + indexers = drop_dims_from_indexers(indexers, self._get_all_dims(), missing_dims) + result = {} + for node in self.subtree: + node_indexers = {k: v for k, v in indexers.items() if k in node.dims} + node_result = func(node.dataset, node_indexers) + for k in node_indexers: + if k not in node.coords and k in node_result.coords: + del node_result.coords[k] + result[node.path] = node_result + return type(self).from_dict(result, name=self.name) # type: ignore + + def isel( + self, + indexers: Mapping[Any, Any] | None = None, + drop: bool = False, + missing_dims: ErrorOptionsWithWarn = "raise", + **indexers_kwargs: Any, + ) -> Self: + """Positional indexing.""" + + def apply_indexers(dataset, node_indexers): + return dataset.isel(node_indexers, drop=drop) + + indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "isel") + return self._selective_indexing( + apply_indexers, indexers, missing_dims=missing_dims + ) + + def sel( + self, + indexers: Mapping[Any, Any] | None = None, + method: str | None = None, + tolerance: int | float | Iterable[int | float] | None = None, + drop: bool = False, + **indexers_kwargs: Any, + ) -> Self: + """Label-based indexing.""" + + def apply_indexers(dataset, node_indexers): + # TODO: reimplement in terms of map_index_queries(), to avoid + # redundant index look-ups on child nodes + return dataset.sel( + node_indexers, method=method, tolerance=tolerance, drop=drop + ) + + indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "sel") + return self._selective_indexing(apply_indexers, indexers) + + def reduce( + self, + func: Callable, + dim: Dims = None, + *, + keep_attrs: bool | None = None, + keepdims: bool = False, + numeric_only: bool = False, + **kwargs: Any, + ) -> Self: + """Reduce this tree by applying `func` along some dimension(s).""" + dims = dim_arg_to_dims_set(dim, self._get_all_dims()) + result = {} + for node in self.subtree: + reduce_dims = [d for d in node._node_dims if d in dims] + node_result = node.dataset.reduce( + func, + reduce_dims, + keep_attrs=keep_attrs, + keepdims=keepdims, + numeric_only=numeric_only, + **kwargs, + ) + result[node.path] = node_result + return type(self).from_dict(result, name=self.name) # type: ignore diff --git a/xarray/core/utils.py b/xarray/core/utils.py index e5168342e1e..817586f0a0d 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -830,6 +830,26 @@ def drop_dims_from_indexers( ) +def dim_arg_to_dims_set(dim: Dims, all_dims: Collection) -> set: + """Convert a `dim` argument from Dataset/DataTree into a set of dimensions.""" + + if dim is None or dim is ...: + dims = set(all_dims) + elif isinstance(dim, str) or not isinstance(dim, Iterable): + # TODO: consider dropping `not isinstance(dim, Iterable)`, which is not + # allowed per the type signature + dims = {dim} + else: + dims = set(dim) + + missing_dimensions = tuple(d for d in dims if d not in all_dims) + if missing_dimensions: + raise ValueError( + f"Dimensions {missing_dimensions} not found in data dimensions {tuple(all_dims)}" + ) + return dims + + @overload def parse_dims( dim: Dims, diff --git a/xarray/namedarray/_aggregations.py b/xarray/namedarray/_aggregations.py index 48001422386..ff6bb06456f 100644 --- a/xarray/namedarray/_aggregations.py +++ b/xarray/namedarray/_aggregations.py @@ -61,17 +61,14 @@ def count( Examples -------- >>> from xarray.namedarray.core import NamedArray - >>> na = NamedArray( - ... "x", - ... np.array([1, 2, 3, 0, 2, np.nan]), - ... ) + >>> na = NamedArray("x", np.array([1, 2, 3, 0, 2, np.nan])) >>> na Size: 48B array([ 1., 2., 3., 0., 2., nan]) >>> na.count() Size: 8B - array(5) + np.int64(5) """ return self.reduce( duck_array_ops.count, @@ -116,8 +113,7 @@ def all( -------- >>> from xarray.namedarray.core import NamedArray >>> na = NamedArray( - ... "x", - ... np.array([True, True, True, True, True, False], dtype=bool), + ... "x", np.array([True, True, True, True, True, False], dtype=bool) ... ) >>> na Size: 6B @@ -125,7 +121,7 @@ def all( >>> na.all() Size: 1B - array(False) + np.False_ """ return self.reduce( duck_array_ops.array_all, @@ -170,8 +166,7 @@ def any( -------- >>> from xarray.namedarray.core import NamedArray >>> na = NamedArray( - ... "x", - ... np.array([True, True, True, True, True, False], dtype=bool), + ... "x", np.array([True, True, True, True, True, False], dtype=bool) ... ) >>> na Size: 6B @@ -179,7 +174,7 @@ def any( >>> na.any() Size: 1B - array(True) + np.True_ """ return self.reduce( duck_array_ops.array_any, @@ -230,23 +225,20 @@ def max( Examples -------- >>> from xarray.namedarray.core import NamedArray - >>> na = NamedArray( - ... "x", - ... np.array([1, 2, 3, 0, 2, np.nan]), - ... ) + >>> na = NamedArray("x", np.array([1, 2, 3, 0, 2, np.nan])) >>> na Size: 48B array([ 1., 2., 3., 0., 2., nan]) >>> na.max() Size: 8B - array(3.) + np.float64(3.0) Use ``skipna`` to control whether NaNs are ignored. >>> na.max(skipna=False) Size: 8B - array(nan) + np.float64(nan) """ return self.reduce( duck_array_ops.max, @@ -298,23 +290,20 @@ def min( Examples -------- >>> from xarray.namedarray.core import NamedArray - >>> na = NamedArray( - ... "x", - ... np.array([1, 2, 3, 0, 2, np.nan]), - ... ) + >>> na = NamedArray("x", np.array([1, 2, 3, 0, 2, np.nan])) >>> na Size: 48B array([ 1., 2., 3., 0., 2., nan]) >>> na.min() Size: 8B - array(0.) + np.float64(0.0) Use ``skipna`` to control whether NaNs are ignored. >>> na.min(skipna=False) Size: 8B - array(nan) + np.float64(nan) """ return self.reduce( duck_array_ops.min, @@ -370,23 +359,20 @@ def mean( Examples -------- >>> from xarray.namedarray.core import NamedArray - >>> na = NamedArray( - ... "x", - ... np.array([1, 2, 3, 0, 2, np.nan]), - ... ) + >>> na = NamedArray("x", np.array([1, 2, 3, 0, 2, np.nan])) >>> na Size: 48B array([ 1., 2., 3., 0., 2., nan]) >>> na.mean() Size: 8B - array(1.6) + np.float64(1.6) Use ``skipna`` to control whether NaNs are ignored. >>> na.mean(skipna=False) Size: 8B - array(nan) + np.float64(nan) """ return self.reduce( duck_array_ops.mean, @@ -449,23 +435,20 @@ def prod( Examples -------- >>> from xarray.namedarray.core import NamedArray - >>> na = NamedArray( - ... "x", - ... np.array([1, 2, 3, 0, 2, np.nan]), - ... ) + >>> na = NamedArray("x", np.array([1, 2, 3, 0, 2, np.nan])) >>> na Size: 48B array([ 1., 2., 3., 0., 2., nan]) >>> na.prod() Size: 8B - array(0.) + np.float64(0.0) Use ``skipna`` to control whether NaNs are ignored. >>> na.prod(skipna=False) Size: 8B - array(nan) + np.float64(nan) Specify ``min_count`` for finer control over when NaNs are ignored. @@ -535,23 +518,20 @@ def sum( Examples -------- >>> from xarray.namedarray.core import NamedArray - >>> na = NamedArray( - ... "x", - ... np.array([1, 2, 3, 0, 2, np.nan]), - ... ) + >>> na = NamedArray("x", np.array([1, 2, 3, 0, 2, np.nan])) >>> na Size: 48B array([ 1., 2., 3., 0., 2., nan]) >>> na.sum() Size: 8B - array(8.) + np.float64(8.0) Use ``skipna`` to control whether NaNs are ignored. >>> na.sum(skipna=False) Size: 8B - array(nan) + np.float64(nan) Specify ``min_count`` for finer control over when NaNs are ignored. @@ -618,29 +598,26 @@ def std( Examples -------- >>> from xarray.namedarray.core import NamedArray - >>> na = NamedArray( - ... "x", - ... np.array([1, 2, 3, 0, 2, np.nan]), - ... ) + >>> na = NamedArray("x", np.array([1, 2, 3, 0, 2, np.nan])) >>> na Size: 48B array([ 1., 2., 3., 0., 2., nan]) >>> na.std() Size: 8B - array(1.0198039) + np.float64(1.019803902718557) Use ``skipna`` to control whether NaNs are ignored. >>> na.std(skipna=False) Size: 8B - array(nan) + np.float64(nan) Specify ``ddof=1`` for an unbiased estimate. >>> na.std(skipna=True, ddof=1) Size: 8B - array(1.14017543) + np.float64(1.140175425099138) """ return self.reduce( duck_array_ops.std, @@ -701,29 +678,26 @@ def var( Examples -------- >>> from xarray.namedarray.core import NamedArray - >>> na = NamedArray( - ... "x", - ... np.array([1, 2, 3, 0, 2, np.nan]), - ... ) + >>> na = NamedArray("x", np.array([1, 2, 3, 0, 2, np.nan])) >>> na Size: 48B array([ 1., 2., 3., 0., 2., nan]) >>> na.var() Size: 8B - array(1.04) + np.float64(1.04) Use ``skipna`` to control whether NaNs are ignored. >>> na.var(skipna=False) Size: 8B - array(nan) + np.float64(nan) Specify ``ddof=1`` for an unbiased estimate. >>> na.var(skipna=True, ddof=1) Size: 8B - array(1.3) + np.float64(1.3) """ return self.reduce( duck_array_ops.var, @@ -780,23 +754,20 @@ def median( Examples -------- >>> from xarray.namedarray.core import NamedArray - >>> na = NamedArray( - ... "x", - ... np.array([1, 2, 3, 0, 2, np.nan]), - ... ) + >>> na = NamedArray("x", np.array([1, 2, 3, 0, 2, np.nan])) >>> na Size: 48B array([ 1., 2., 3., 0., 2., nan]) >>> na.median() Size: 8B - array(2.) + np.float64(2.0) Use ``skipna`` to control whether NaNs are ignored. >>> na.median(skipna=False) Size: 8B - array(nan) + np.float64(nan) """ return self.reduce( duck_array_ops.median, @@ -857,10 +828,7 @@ def cumsum( Examples -------- >>> from xarray.namedarray.core import NamedArray - >>> na = NamedArray( - ... "x", - ... np.array([1, 2, 3, 0, 2, np.nan]), - ... ) + >>> na = NamedArray("x", np.array([1, 2, 3, 0, 2, np.nan])) >>> na Size: 48B array([ 1., 2., 3., 0., 2., nan]) @@ -934,10 +902,7 @@ def cumprod( Examples -------- >>> from xarray.namedarray.core import NamedArray - >>> na = NamedArray( - ... "x", - ... np.array([1, 2, 3, 0, 2, np.nan]), - ... ) + >>> na = NamedArray("x", np.array([1, 2, 3, 0, 2, np.nan])) >>> na Size: 48B array([ 1., 2., 3., 0., 2., nan]) diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index 3a3afb0647a..9eb819be48d 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -1582,7 +1582,9 @@ def test_dataset_method(self): result = dt.isel(x=1) assert_equal(result, expected) - @pytest.mark.xfail(reason="reduce methods not implemented yet") + +class TestAggregations: + def test_reduce_method(self): ds = xr.Dataset({"a": ("x", [False, True, False])}) dt = DataTree.from_dict({"/": ds, "/results": ds}) @@ -1592,7 +1594,6 @@ def test_reduce_method(self): result = dt.any() assert_equal(result, expected) - @pytest.mark.xfail(reason="reduce methods not implemented yet") def test_nan_reduce_method(self): ds = xr.Dataset({"a": ("x", [1, 2, 3])}) dt = DataTree.from_dict({"/": ds, "/results": ds}) @@ -1602,7 +1603,6 @@ def test_nan_reduce_method(self): result = dt.mean() assert_equal(result, expected) - @pytest.mark.xfail(reason="cum methods not implemented yet") def test_cum_method(self): ds = xr.Dataset({"a": ("x", [1, 2, 3])}) dt = DataTree.from_dict({"/": ds, "/results": ds}) @@ -1617,6 +1617,41 @@ def test_cum_method(self): result = dt.cumsum() assert_equal(result, expected) + def test_dim_argument(self): + dt = DataTree.from_dict( + { + "/a": xr.Dataset({"A": ("x", [1, 2])}), + "/b": xr.Dataset({"B": ("y", [1, 2])}), + } + ) + + expected = DataTree.from_dict( + { + "/a": xr.Dataset({"A": 1.5}), + "/b": xr.Dataset({"B": 1.5}), + } + ) + actual = dt.mean() + assert_equal(expected, actual) + + actual = dt.mean(dim=...) + assert_equal(expected, actual) + + expected = DataTree.from_dict( + { + "/a": xr.Dataset({"A": 1.5}), + "/b": xr.Dataset({"B": ("y", [1.0, 2.0])}), + } + ) + actual = dt.mean("x") + assert_equal(expected, actual) + + with pytest.raises( + ValueError, + match=re.escape("Dimensions ('invalid',) not found in data dimensions"), + ): + dt.mean("invalid") + class TestOps: @pytest.mark.xfail(reason="arithmetic not implemented yet") diff --git a/xarray/util/generate_aggregations.py b/xarray/util/generate_aggregations.py index d2fc4f6d4e2..089ef558581 100644 --- a/xarray/util/generate_aggregations.py +++ b/xarray/util/generate_aggregations.py @@ -263,7 +263,7 @@ class DataStructure: create_example: str example_var_name: str numeric_only: bool = False - see_also_modules: tuple[str] = tuple + see_also_modules: tuple[str, ...] = tuple class Method: @@ -287,13 +287,13 @@ def __init__( self.additional_notes = additional_notes if bool_reduce: self.array_method = f"array_{name}" - self.np_example_array = """ - ... np.array([True, True, True, True, True, False], dtype=bool)""" + self.np_example_array = ( + """np.array([True, True, True, True, True, False], dtype=bool)""" + ) else: self.array_method = name - self.np_example_array = """ - ... np.array([1, 2, 3, 0, 2, np.nan])""" + self.np_example_array = """np.array([1, 2, 3, 0, 2, np.nan])""" @dataclass @@ -541,10 +541,27 @@ def generate_code(self, method, has_keep_attrs): ) +DATATREE_OBJECT = DataStructure( + name="DataTree", + create_example=""" + >>> dt = xr.DataTree( + ... xr.Dataset( + ... data_vars=dict(foo=("time", {example_array})), + ... coords=dict( + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... ), + ... )""", + example_var_name="dt", + numeric_only=True, + see_also_modules=("Dataset", "DataArray"), +) DATASET_OBJECT = DataStructure( name="Dataset", create_example=""" - >>> da = xr.DataArray({example_array}, + >>> da = xr.DataArray( + ... {example_array}, ... dims="time", ... coords=dict( ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), @@ -559,7 +576,8 @@ def generate_code(self, method, has_keep_attrs): DATAARRAY_OBJECT = DataStructure( name="DataArray", create_example=""" - >>> da = xr.DataArray({example_array}, + >>> da = xr.DataArray( + ... {example_array}, ... dims="time", ... coords=dict( ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), @@ -570,6 +588,15 @@ def generate_code(self, method, has_keep_attrs): numeric_only=False, see_also_modules=("Dataset",), ) +DATATREE_GENERATOR = GenericAggregationGenerator( + cls="", + datastructure=DATATREE_OBJECT, + methods=AGGREGATION_METHODS, + docref="agg", + docref_description="reduction or aggregation operations", + example_call_preamble="", + definition_preamble=AGGREGATIONS_PREAMBLE, +) DATASET_GENERATOR = GenericAggregationGenerator( cls="", datastructure=DATASET_OBJECT, @@ -634,7 +661,7 @@ def generate_code(self, method, has_keep_attrs): create_example=""" >>> from xarray.namedarray.core import NamedArray >>> na = NamedArray( - ... "x",{example_array}, + ... "x", {example_array} ... )""", example_var_name="na", numeric_only=False, @@ -670,6 +697,7 @@ def write_methods(filepath, generators, preamble): write_methods( filepath=p.parent / "xarray" / "xarray" / "core" / "_aggregations.py", generators=[ + DATATREE_GENERATOR, DATASET_GENERATOR, DATAARRAY_GENERATOR, DATASET_GROUPBY_GENERATOR,