diff --git a/tests/test_spatial.py b/tests/test_spatial.py index fe0361cd..244bcf31 100644 --- a/tests/test_spatial.py +++ b/tests/test_spatial.py @@ -140,6 +140,17 @@ def test_raises_error_if_weights_lat_and_lon_dims_dont_align_with_data_var_dims( with pytest.raises(ValueError): self.ds.spatial.average("ts", axis=["X", "Y"], weights=weights) + def test_raises_error_if_min_weight_not_between_zero_and_one( + self, + ): + # ensure error if min_weight less than zero + with pytest.raises(ValueError): + self.ds.spatial.average("ts", axis=["X", "Y"], min_weight=-0.01) + + # ensure error if min_weight greater than 1 + with pytest.raises(ValueError): + self.ds.spatial.average("ts", axis=["X", "Y"], min_weight=1.01) + def test_spatial_average_for_lat_region_and_keep_weights(self): ds = self.ds.copy() @@ -254,6 +265,49 @@ def test_spatial_average_for_lat_and_lon_region_and_keep_weights(self): xr.testing.assert_allclose(result, expected) + def test_spatial_average_with_min_weight(self): + ds = self.ds.copy() + + # insert a nan + ds["ts"][0, :, 2] = np.nan + + result = ds.spatial.average( + "ts", + axis=["X", "Y"], + lat_bounds=(-5.0, 5), + lon_bounds=(-170, -120.1), + min_weight=1.0, + ) + + expected = self.ds.copy() + expected["ts"] = xr.DataArray( + data=np.array([np.nan, 1.0, 1.0]), + coords={"time": expected.time}, + dims="time", + ) + + xr.testing.assert_allclose(result, expected) + + def test_spatial_average_with_min_weight_as_None(self): + ds = self.ds.copy() + + result = ds.spatial.average( + "ts", + axis=["X", "Y"], + lat_bounds=(-5.0, 5), + lon_bounds=(-170, -120.1), + min_weight=None, + ) + + expected = self.ds.copy() + expected["ts"] = xr.DataArray( + data=np.array([2.25, 1.0, 1.0]), + coords={"time": expected.time}, + dims="time", + ) + + xr.testing.assert_allclose(result, expected) + def test_spatial_average_for_lat_and_lon_region_with_custom_weights(self): ds = self.ds.copy() diff --git a/tests/test_utils.py b/tests/test_utils.py index 1d4dcbe8..30d3cbfb 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,7 +1,7 @@ import pytest import xarray as xr -from xcdat.utils import compare_datasets, str_to_bool +from xcdat.utils import _validate_min_weight, compare_datasets, str_to_bool class TestCompareDatasets: @@ -103,3 +103,23 @@ def test_raises_error_if_str_is_not_a_python_bool(self): with pytest.raises(ValueError): str_to_bool("1") + + +class TestValidateMinWeight: + def test_pass_None_returns_0(self): + result = _validate_min_weight(None) + + assert result == 0 + + def test_returns_error_if_less_than_0(self): + with pytest.raises(ValueError): + _validate_min_weight(-1) + + def test_returns_error_if_greater_than_1(self): + with pytest.raises(ValueError): + _validate_min_weight(1.1) + + def test_returns_valid_min_weight(self): + result = _validate_min_weight(1) + + assert result == 1 diff --git a/xcdat/spatial.py b/xcdat/spatial.py index 15bec956..d105666e 100644 --- a/xcdat/spatial.py +++ b/xcdat/spatial.py @@ -27,7 +27,11 @@ get_dim_keys, ) from xcdat.dataset import _get_data_var -from xcdat.utils import _if_multidim_dask_array_then_load +from xcdat.utils import ( + _get_masked_weights, + _if_multidim_dask_array_then_load, + _validate_min_weight, +) #: Type alias for a dictionary of axis keys mapped to their bounds. AxisWeights = Dict[Hashable, xr.DataArray] @@ -74,8 +78,9 @@ def average( axis: List[SpatialAxis] | Tuple[SpatialAxis, ...] = ("X", "Y"), weights: Union[Literal["generate"], xr.DataArray] = "generate", keep_weights: bool = False, - lat_bounds: Optional[RegionAxisBounds] = None, - lon_bounds: Optional[RegionAxisBounds] = None, + lat_bounds: RegionAxisBounds | None = None, + lon_bounds: RegionAxisBounds | None = None, + min_weight: float | None = None, ) -> xr.Dataset: """ Calculates the spatial average for a rectilinear grid over an optionally @@ -114,17 +119,21 @@ def average( keep_weights : bool, optional If calculating averages using weights, keep the weights in the final dataset output, by default False. - lat_bounds : Optional[RegionAxisBounds], optional + lat_bounds : RegionAxisBounds | None, optional A tuple of floats/ints for the regional latitude lower and upper boundaries. This arg is used when calculating axis weights, but is ignored if ``weights`` are supplied. The lower bound cannot be larger than the upper bound, by default None. - lon_bounds : Optional[RegionAxisBounds], optional + lon_bounds : RegionAxisBounds | None, optional A tuple of floats/ints for the regional longitude lower and upper boundaries. This arg is used when calculating axis weights, but is ignored if ``weights`` are supplied. The lower bound can be larger than the upper bound (e.g., across the prime meridian, dateline), by default None. + min_weight : optional, float + Fraction of data coverage (i.e, weight) needed to return a + spatial average value. Value must range from 0 to 1, by default None + (equivalent to ``min_weight=0.0``). Returns ------- @@ -184,7 +193,9 @@ def average( """ ds = self._dataset.copy() dv = _get_data_var(ds, data_var) + self._validate_axis_arg(axis) + min_weight = _validate_min_weight(min_weight) if isinstance(weights, str) and weights == "generate": if lat_bounds is not None: @@ -196,7 +207,7 @@ def average( self._weights = weights self._validate_weights(dv, axis) - ds[dv.name] = self._averager(dv, axis) + ds[dv.name] = self._averager(dv, axis, min_weight=min_weight) if keep_weights: ds[self._weights.name] = self._weights @@ -206,9 +217,9 @@ def average( def get_weights( self, axis: List[SpatialAxis] | Tuple[SpatialAxis, ...], - lat_bounds: Optional[RegionAxisBounds] = None, - lon_bounds: Optional[RegionAxisBounds] = None, - data_var: Optional[str] = None, + lat_bounds: RegionAxisBounds | None = None, + lon_bounds: RegionAxisBounds | None = None, + data_var: str | None = None, ) -> xr.DataArray: """ Get area weights for specified axis keys and an optional target domain. @@ -227,13 +238,13 @@ def get_weights( ---------- axis : List[SpatialAxis] | Tuple[SpatialAxis, ...] List of axis dimensions to average over. - lat_bounds : Optional[RegionAxisBounds] + lat_bounds : RegionAxisBounds | None Tuple of latitude boundaries for regional selection, by default None. - lon_bounds : Optional[RegionAxisBounds] + lon_bounds : RegionAxisBounds | None Tuple of longitude boundaries for regional selection, by default None. - data_var: Optional[str] + data_var: str | None The key of the data variable, by default None. Pass this argument when the dataset has more than one bounds per axis (e.g., "lon" and "zlon_bnds" for the "X" axis), or you want weights for a @@ -377,7 +388,7 @@ def _validate_region_bounds(self, axis: SpatialAxis, bounds: RegionAxisBounds): ) def _get_longitude_weights( - self, domain_bounds: xr.DataArray, region_bounds: Optional[np.ndarray] + self, domain_bounds: xr.DataArray, region_bounds: np.ndarray | None ) -> xr.DataArray: """Gets weights for the longitude axis. @@ -404,7 +415,7 @@ def _get_longitude_weights( ---------- domain_bounds : xr.DataArray The array of bounds for the longitude domain. - region_bounds : Optional[np.ndarray] + region_bounds : np.ndarray | None The array of bounds for longitude regional selection. Returns @@ -418,7 +429,7 @@ def _get_longitude_weights( If the there are multiple instances in which the domain_bounds[:, 0] > domain_bounds[:, 1] """ - p_meridian_index: Optional[np.ndarray] = None + p_meridian_index: np.ndarray | None = None d_bounds = domain_bounds.copy() pm_cells = np.where(domain_bounds[:, 1] - domain_bounds[:, 0] < 0)[0] @@ -450,7 +461,7 @@ def _get_longitude_weights( return weights def _get_latitude_weights( - self, domain_bounds: xr.DataArray, region_bounds: Optional[np.ndarray] + self, domain_bounds: xr.DataArray, region_bounds: np.ndarray | None ) -> xr.DataArray: """Gets weights for the latitude axis. @@ -462,7 +473,7 @@ def _get_latitude_weights( ---------- domain_bounds : xr.DataArray The array of bounds for the latitude domain. - region_bounds : Optional[np.ndarray] + region_bounds : np.ndarray | None The array of bounds for latitude regional selection. Returns @@ -702,7 +713,10 @@ def _validate_weights( ) def _averager( - self, data_var: xr.DataArray, axis: List[SpatialAxis] | Tuple[SpatialAxis, ...] + self, + data_var: xr.DataArray, + axis: List[SpatialAxis] | Tuple[SpatialAxis, ...], + min_weight: float, ): """Perform a weighted average of a data variable. @@ -721,6 +735,9 @@ def _averager( Data variable inside a Dataset. axis : List[SpatialAxis] | Tuple[SpatialAxis, ...] List of axis dimensions to average over. + min_weight : float + Fraction of data coverage (i.e, weight) needed to return a + spatial average value. Value must range from 0 to 1. Returns ------- @@ -734,11 +751,68 @@ def _averager( """ weights = self._weights.fillna(0) - dim = [] + # TODO: This conditional might not be needed because Xarray will + # automatically broadcast the weights to the data variable for + # operations such as .mean() and .where(). + if min_weight > 0.0: + weights, data_var = xr.broadcast(weights, data_var) + + dim: List[str] = [] for key in axis: - dim.append(get_dim_keys(data_var, key)) + dim.append(get_dim_keys(data_var, key)) # type: ignore with xr.set_options(keep_attrs=True): - weighted_mean = data_var.cf.weighted(weights).mean(dim=dim) + dv_mean = data_var.cf.weighted(weights).mean(dim=dim) + + if min_weight > 0.0: + dv_mean = self._mask_var_with_weight_threshold( + dv_mean, dim, weights, min_weight + ) + + return dv_mean + + def _mask_var_with_weight_threshold( + self, dv: xr.DataArray, dim: List[str], weights: xr.DataArray, min_weight: float + ) -> xr.DataArray: + """Mask values that do not meet the minimum weight threshold with np.nan. + + This function is useful for cases where the weighting of data might be + skewed based on the availability of data. For example, if a portion of + cells in a region has significantly more missing data than other other + regions, it can result in inaccurate calculations of spatial averaging. + Masking values that do not meet the minimum weight threshold ensures + more accurate calculations. + + Parameters + ---------- + dv : xr.DataArray + The weighted variable. + dim: List[str]: + List of axis dimensions to average over. + weights : xr.DataArray + A DataArray containing either the regional weights used for weighted + averaging. ``weights`` must include the same axis dimensions and + dimensional sizes as the data variable. + min_weight : float + Fraction of data coverage (i.e, weight) needed to return a + spatial average value. Value must range from 0 to 1. + + Returns + ------- + xr.DataArray + The variable with the minimum weight threshold applied. + """ + # Sum all weights, including zero for missing values. + weight_sum_all = weights.sum(dim=dim) + + masked_weights = _get_masked_weights(dv, weights) + weight_sum_masked = masked_weights.sum(dim=dim) + + # Get fraction of the available weight. + frac = weight_sum_masked / weight_sum_all + + # Nan out values that don't meet specified weight threshold. + dv_new = xr.where(frac >= min_weight, dv, np.nan, keep_attrs=True) + dv_new.name = dv.name - return weighted_mean + return dv_new diff --git a/xcdat/utils.py b/xcdat/utils.py index 83596561..a2f674fa 100644 --- a/xcdat/utils.py +++ b/xcdat/utils.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import importlib import json from typing import Dict, List, Optional, Union @@ -132,3 +134,61 @@ def _if_multidim_dask_array_then_load( return obj.load() return None + + +def _get_masked_weights(dv: xr.DataArray, weights: xr.DataArray) -> xr.DataArray: + """Get weights with missing data (`np.nan`) receiving no weight (zero). + + Parameters + ---------- + dv : xr.DataArray + The variable. + weights : xr.DataArray + A DataArray containing either the regional or temporal weights used for + weighted averaging. ``weights`` must include the same axis dimensions + and dimensional sizes as the data variable. + + Returns + ------- + xr.DataArray + The masked weights. + """ + masked_weights = xr.where(dv.copy().isnull(), 0.0, weights) + + return masked_weights + + +def _validate_min_weight(min_weight: float | None) -> float: + """Validate the ``min_weight`` value. + + Parameters + ---------- + min_weight : float | None + Fraction of data coverage (i..e, weight) needed to return a + spatial average value. Value must range from 0 to 1. + + Returns + ------- + float + The required weight percentage. + + Raises + ------ + ValueError + If the `min_weight` argument is less than 0. + ValueError + If the `min_weight` argument is greater than 1. + """ + if min_weight is None: + return 0.0 + elif min_weight < 0.0: + raise ValueError( + "min_weight argument is less than 0. " "min_weight must be between 0 and 1." + ) + elif min_weight > 1.0: + raise ValueError( + "min_weight argument is greater than 1. " + "min_weight must be between 0 and 1." + ) + + return min_weight