diff --git a/xarray/__init__.py b/xarray/__init__.py index 0c0d5995f72..10e09bbf734 100644 --- a/xarray/__init__.py +++ b/xarray/__init__.py @@ -14,6 +14,7 @@ from xarray.coding.cftimeindex import CFTimeIndex from xarray.coding.frequencies import infer_freq from xarray.conventions import SerializationWarning, decode_cf +from xarray.core import groupers from xarray.core.alignment import align, broadcast from xarray.core.combine import combine_by_coords, combine_nested from xarray.core.common import ALL_DIMS, full_like, ones_like, zeros_like @@ -55,6 +56,7 @@ # `mypy --strict` running in projects that import xarray. __all__ = ( # Sub-packages + "groupers", "testing", "tutorial", # Top-level functions diff --git a/xarray/core/common.py b/xarray/core/common.py index 7b9a049c662..3ddfa4fd0a5 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -1049,7 +1049,7 @@ def _resample( # TODO support non-string indexer after removing the old API. from xarray.core.dataarray import DataArray - from xarray.core.groupby import ResolvedGrouper, TimeResampler + from xarray.core.groupby import Resampler, ResolvedGrouper, TimeResampler from xarray.core.resample import RESAMPLE_DIM # note: the second argument (now 'skipna') use to be 'dim' @@ -1079,15 +1079,19 @@ def _resample( name=RESAMPLE_DIM, ) - grouper = TimeResampler( - freq=freq, - closed=closed, - label=label, - origin=origin, - offset=offset, - loffset=loffset, - base=base, - ) + if isinstance(freq, str): + grouper = TimeResampler( + freq=freq, + closed=closed, + label=label, + origin=origin, + offset=offset, + loffset=loffset, + base=base, + ) + else: + assert isinstance(freq, Resampler) + grouper = freq rgrouper = ResolvedGrouper(grouper, group, self) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 7a0bdbc4d4c..d0530d71fc2 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -6636,9 +6636,10 @@ def interp_calendar( def groupby( self, - group: Hashable | DataArray | IndexVariable, + group: Hashable | DataArray | IndexVariable = None, squeeze: bool | None = None, restore_coord_dims: bool = False, + **groupers, ) -> DataArrayGroupBy: """Returns a DataArrayGroupBy object for performing grouped operations. @@ -6710,7 +6711,19 @@ def groupby( ) _validate_groupby_squeeze(squeeze) - rgrouper = ResolvedGrouper(UniqueGrouper(), group, self) + + if group is not None: + assert not groupers + grouper = UniqueGrouper() + else: + if len(groupers) > 1: + raise ValueError("grouping by multiple variables is not supported yet.") + if not groupers: + raise ValueError + group, grouper = next(iter(groupers.items())) + + rgrouper = ResolvedGrouper(grouper, group, self) + return DataArrayGroupBy( self, (rgrouper,), diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index b4c00b66ed8..faaf5ee9be2 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -10134,9 +10134,10 @@ def interp_calendar( def groupby( self, - group: Hashable | DataArray | IndexVariable, + group: Hashable | DataArray | IndexVariable | None = None, squeeze: bool | None = None, restore_coord_dims: bool = False, + **groupers, ) -> DatasetGroupBy: """Returns a DatasetGroupBy object for performing grouped operations. @@ -10186,7 +10187,16 @@ def groupby( ) _validate_groupby_squeeze(squeeze) - rgrouper = ResolvedGrouper(UniqueGrouper(), group, self) + if group is not None: + assert not groupers + rgrouper = ResolvedGrouper(UniqueGrouper(), group, self) + else: + if len(groupers) > 1: + raise ValueError("grouping by multiple variables is not supported yet.") + if not groupers: + raise ValueError + for group, grouper in groupers.items(): + rgrouper = ResolvedGrouper(grouper, group, self) return DatasetGroupBy( self, diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 3fbfb74d985..1ef4255a0d9 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -254,7 +254,7 @@ def attrs(self) -> dict: def __getitem__(self, key): if isinstance(key, tuple): - key = key[0] + (key,) = key return self.values[key] def to_index(self) -> pd.Index: diff --git a/xarray/core/groupers.py b/xarray/core/groupers.py new file mode 100644 index 00000000000..eb6a23a47ca --- /dev/null +++ b/xarray/core/groupers.py @@ -0,0 +1,224 @@ +import itertools +from collections.abc import Mapping, Sequence +from dataclasses import dataclass, field + +import numpy as np +import pandas as pd + +from xarray.core.groupby import Grouper, Resampler +from xarray.core.variable import IndexVariable + + +## From toolz +## TODO: move to compat file, add license +def sliding_window(n, seq): + """A sequence of overlapping subsequences + + >>> list(sliding_window(2, [1, 2, 3, 4])) + [(1, 2), (2, 3), (3, 4)] + + This function creates a sliding window suitable for transformations like + sliding means / smoothing + + >>> mean = lambda seq: float(sum(seq)) / len(seq) + >>> list(map(mean, sliding_window(2, [1, 2, 3, 4]))) + [1.5, 2.5, 3.5] + """ + import collections + import itertools + + return zip( + *( + collections.deque(itertools.islice(it, i), 0) or it + for i, it in enumerate(itertools.tee(seq, n)) + ) + ) + + +def season_to_month_tuple(seasons: Sequence[str]) -> Sequence[Sequence[int]]: + easy = {"D": 12, "F": 2, "S": 9, "O": 10, "N": 11} + harder = {"DJF": 1, "FMA": 3, "MAM": 4, "AMJ": 5, "MJJ": 6, "JJA": 7, "JAS": 8} + + if len("".join(seasons)) != 12: + raise ValueError("SeasonGrouper requires exactly 12 months in total.") + + # Slide through with a window of 3. + # A 3 letter string is enough to unambiguously + # assign the right month number of the middle letter + WINDOW = 3 + + perseason = [seasons[-1], *seasons, seasons[0]] + + season_inds = [] + for sprev, sthis, snxt in sliding_window(WINDOW, perseason): + inds = [] + permonth = "".join([sprev[-1], *sthis, snxt[0]]) + for mprev, mthis, mnxt in sliding_window(WINDOW, permonth): + if mthis in easy: + inds.append(easy[mthis]) + else: + concatted = "".join([mprev, mthis, mnxt]) + # print(concatted) + inds.append(harder[concatted]) + + season_inds.append(inds) + return season_inds + + +@dataclass +class SeasonGrouper(Grouper): + """Allows grouping using a custom definition of seasons. + + Parameters + ---------- + seasons: sequence of str + List of strings representing seasons. E.g. ``"JF"`` or ``"JJA"`` etc. + drop_incomplete: bool + Whether to drop seasons that are not completely included in the data. + For example, if a time series starts in Jan-2001, and seasons includes `"DJF"` + then observations from Jan-2001, and Feb-2001 are ignored in the grouping + since Dec-2000 isn't present. + + Examples + -------- + >>> SeasonGrouper(["JF", "MAM", "JJAS", "OND"]) + >>> SeasonGrouper(["DJFM", "AM", "JJA", "SON"]) + """ + + seasons: Sequence[str] + season_inds: Sequence[Sequence[int]] = field(init=False) + drop_incomplete: bool = field(default=True) + + def __post_init__(self): + self.season_inds = season_to_month_tuple(self.seasons) + + def __repr__(self): + return f"SeasonGrouper over {self.grouper.seasons!r}" + + def factorize(self, group): + seasons = self.seasons + season_inds = self.season_inds + + months = group.dt.month + codes_ = np.full(group.shape, -1) + group_indices = [[]] * len(seasons) + + index = np.arange(group.size) + for idx, season in enumerate(season_inds): + mask = months.isin(season) + codes_[mask] = idx + group_indices[idx] = index[mask] + + if np.all(codes_ == -1): + raise ValueError( + "Failed to group data. Are you grouping by a variable that is all NaN?" + ) + codes = group.copy(data=codes_).rename("season") + unique_coord = IndexVariable("season", seasons, attrs=group.attrs) + full_index = unique_coord + return codes, group_indices, unique_coord, full_index + + +@dataclass +class SeasonResampler(Resampler): + """Allows grouping using a custom definition of seasons. + + Examples + -------- + >>> SeasonResampler(["JF", "MAM", "JJAS", "OND"]) + >>> SeasonResampler(["DJFM", "AM", "JJA", "SON"]) + """ + + seasons: Sequence[str] + # drop_incomplete: bool = field(default=True) # TODO: + season_inds: Sequence[Sequence[int]] = field(init=False) + season_tuples: Mapping[str, Sequence[int]] = field(init=False) + + def __post_init__(self): + self.season_inds = season_to_month_tuple(self.seasons) + self.season_tuples = dict(zip(self.seasons, self.season_inds)) + + def factorize(self, group): + assert group.ndim == 1 + + seasons = self.seasons + season_inds = self.season_inds + season_tuples = self.season_tuples + + nstr = max(len(s) for s in seasons) + year = group.dt.year.astype(int) + month = group.dt.month.astype(int) + season_label = np.full(group.shape, "", dtype=f"U{nstr}") + + # offset years for seasons with December and January + for season_str, season_ind in zip(seasons, season_inds): + season_label[month.isin(season_ind)] = season_str + if "DJ" in season_str: + after_dec = season_ind[season_str.index("D") + 1 :] + year[month.isin(after_dec)] -= 1 + + frame = pd.DataFrame( + data={"index": np.arange(group.size), "month": month}, + index=pd.MultiIndex.from_arrays( + [year.data, season_label], names=["year", "season"] + ), + ) + + series = frame["index"] + g = series.groupby(["year", "season"], sort=False) + first_items = g.first() + counts = g.count() + + # these are the seasons that are present + unique_coord = pd.DatetimeIndex( + [ + pd.Timestamp(year=year, month=season_tuples[season][0], day=1) + for year, season in first_items.index + ] + ) + + sbins = first_items.values.astype(int) + group_indices = [slice(i, j) for i, j in zip(sbins[:-1], sbins[1:])] + group_indices += [slice(sbins[-1], None)] + + # Make sure the first and last timestamps + # are for the correct months,if not we have incomplete seasons + unique_codes = np.arange(len(unique_coord)) + for idx, slicer in zip([0, -1], (slice(1, None), slice(-1))): + stamp_year, stamp_season = frame.index[idx] + code = seasons.index(stamp_season) + stamp_month = season_inds[code][idx] + if stamp_month != month[idx].item(): + # we have an incomplete season! + group_indices = group_indices[slicer] + unique_coord = unique_coord[slicer] + if idx == 0: + unique_codes -= 1 + unique_codes[idx] = -1 + + # all years and seasons + complete_index = pd.DatetimeIndex( + # This sorted call is a hack. It's hard to figure out how + # to start the iteration + sorted( + [ + pd.Timestamp(f"{y}-{m}-01") + for y, m in itertools.product( + range(year[0].item(), year[-1].item() + 1), + [s[0] for s in season_inds], + ) + ] + ) + ) + # only keep that included in data + range_ = complete_index.get_indexer(unique_coord[[0, -1]]) + full_index = complete_index[slice(range_[0], range_[-1] + 1)] + # check that there are no "missing" seasons in the middle + # print(full_index, unique_coord) + if not full_index.equals(unique_coord): + raise ValueError("Are there seasons missing in the middle of the dataset?") + + codes = group.copy(data=np.repeat(unique_codes, counts)) + unique_coord_var = IndexVariable(group.name, unique_coord, group.attrs) + + return codes, group_indices, unique_coord_var, full_index diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index d927550e424..2f63c518bf5 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -12,7 +12,11 @@ import xarray as xr from xarray import DataArray, Dataset, Variable -from xarray.core.groupby import _consolidate_slices +from xarray.core.groupby import ( + BinGrouper, + UniqueGrouper, + _consolidate_slices, +) from xarray.tests import ( InaccessibleArray, assert_allclose, @@ -112,8 +116,9 @@ def test_multi_index_groupby_map(dataset) -> None: assert_equal(expected, actual) -def test_reduce_numeric_only(dataset) -> None: - gb = dataset.groupby("x", squeeze=False) +@pytest.mark.parametrize("grouper", [dict(group="x"), dict(x=UniqueGrouper())]) +def test_reduce_numeric_only(dataset, grouper) -> None: + gb = dataset.groupby(**grouper, squeeze=False) with xr.set_options(use_flox=False): expected = gb.sum() with xr.set_options(use_flox=True): @@ -830,11 +835,12 @@ def test_groupby_dataset_reduce() -> None: expected = data.mean("y") expected["yonly"] = expected["yonly"].variable.set_dims({"x": 3}) - actual = data.groupby("x").mean(...) - assert_allclose(expected, actual) + for gb in [data.groupby("x"), data.groupby(x=UniqueGrouper())]: + actual = gb.mean(...) + assert_allclose(expected, actual) - actual = data.groupby("x").mean("y") - assert_allclose(expected, actual) + actual = gb.mean("y") + assert_allclose(expected, actual) letters = data["letters"] expected = Dataset( @@ -844,8 +850,9 @@ def test_groupby_dataset_reduce() -> None: "yonly": data["yonly"].groupby(letters).mean(), } ) - actual = data.groupby("letters").mean(...) - assert_allclose(expected, actual) + for gb in [data.groupby("letters"), data.groupby(letters=UniqueGrouper())]: + actual = gb.mean(...) + assert_allclose(expected, actual) @pytest.mark.parametrize("squeeze", [True, False]) @@ -975,6 +982,14 @@ def test_groupby_bins_cut_kwargs(use_flox: bool) -> None: ) assert_identical(expected, actual) + with xr.set_options(use_flox=use_flox): + actual = da.groupby( + x=BinGrouper( + bins=x_bins, cut_kwargs=dict(include_lowest=True, right=False) + ), + ).mean() + assert_identical(expected, actual) + @pytest.mark.parametrize("indexed_coord", [True, False]) def test_groupby_bins_math(indexed_coord) -> None: @@ -983,11 +998,17 @@ def test_groupby_bins_math(indexed_coord) -> None: if indexed_coord: da["x"] = np.arange(N) da["y"] = np.arange(N) - g = da.groupby_bins("x", np.arange(0, N + 1, 3)) - mean = g.mean() - expected = da.isel(x=slice(1, None)) - mean.isel(x_bins=("x", [0, 0, 0, 1, 1, 1])) - actual = g - mean - assert_identical(expected, actual) + + for g in [ + da.groupby_bins("x", np.arange(0, N + 1, 3)), + da.groupby(x=BinGrouper(bins=np.arange(0, N + 1, 3))), + ]: + mean = g.mean() + expected = da.isel(x=slice(1, None)) - mean.isel( + x_bins=("x", [0, 0, 0, 1, 1, 1]) + ) + actual = g - mean + assert_identical(expected, actual) def test_groupby_math_nD_group() -> None: @@ -2520,3 +2541,20 @@ def test_default_flox_method(): assert kwargs["method"] == "cohorts" else: assert "method" not in kwargs + + +def test_season_to_month_tuple(): + from xarray.core.groupers import season_to_month_tuple + + assert season_to_month_tuple(["JF", "MAM", "JJAS", "OND"]) == [ + [1, 2], + [3, 4, 5], + [6, 7, 8, 9], + [10, 11, 12], + ] + assert season_to_month_tuple(["DJFM", "AM", "JJAS", "ON"]) == [ + [12, 1, 2, 3], + [4, 5], + [6, 7, 8, 9], + [10, 11], + ]