From 01f0b6ee46cd62954e5b5d1a11546104147aaa7b Mon Sep 17 00:00:00 2001 From: Stephen Po-Chedley Date: Fri, 28 Jun 2024 12:54:47 -0700 Subject: [PATCH] initial attempt at #531 (for spatial averaging) --- tests/test_spatial.py | 34 +++++++++++++++++++++++++++ xcdat/spatial.py | 53 +++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 85 insertions(+), 2 deletions(-) diff --git a/tests/test_spatial.py b/tests/test_spatial.py index fe0361cd..4f27b226 100644 --- a/tests/test_spatial.py +++ b/tests/test_spatial.py @@ -140,6 +140,17 @@ def test_raises_error_if_weights_lat_and_lon_dims_dont_align_with_data_var_dims( with pytest.raises(ValueError): self.ds.spatial.average("ts", axis=["X", "Y"], weights=weights) + def test_raises_error_if_required_weight_not_between_zero_and_one( + self, + ): + # ensure error if required_weight less than zero + with pytest.raises(ValueError): + self.ds.spatial.average("ts", axis=["X", "Y"], required_weight=-0.01) + + # ensure error if required_weight greater than 1 + with pytest.raises(ValueError): + self.ds.spatial.average("ts", axis=["X", "Y"], required_weight=1.01) + def test_spatial_average_for_lat_region_and_keep_weights(self): ds = self.ds.copy() @@ -254,6 +265,29 @@ def test_spatial_average_for_lat_and_lon_region_and_keep_weights(self): xr.testing.assert_allclose(result, expected) + def test_spatial_average_with_required_weight(self): + ds = self.ds.copy() + + # insert a nan + ds["ts"][0, :, 2] = np.nan + + result = ds.spatial.average( + "ts", + axis=["X", "Y"], + lat_bounds=(-5.0, 5), + lon_bounds=(-170, -120.1), + required_weight=1.0, + ) + + expected = self.ds.copy() + expected["ts"] = xr.DataArray( + data=np.array([np.nan, 1.0, 1.0]), + coords={"time": expected.time}, + dims="time", + ) + + xr.testing.assert_allclose(result, expected) + def test_spatial_average_for_lat_and_lon_region_with_custom_weights(self): ds = self.ds.copy() diff --git a/xcdat/spatial.py b/xcdat/spatial.py index 2c50595a..ad7bb8de 100644 --- a/xcdat/spatial.py +++ b/xcdat/spatial.py @@ -73,6 +73,7 @@ def average( keep_weights: bool = False, lat_bounds: Optional[RegionAxisBounds] = None, lon_bounds: Optional[RegionAxisBounds] = None, + required_weight: Optional[float] = 0.0, ) -> xr.Dataset: """ Calculates the spatial average for a rectilinear grid over an optionally @@ -122,6 +123,9 @@ def average( ignored if ``weights`` are supplied. The lower bound can be larger than the upper bound (e.g., across the prime meridian, dateline), by default None. + required_weight : optional, float + Fraction of data coverage (i..e, weight) needed to return a + spatial average value. Value must range from 0 to 1. Returns ------- @@ -193,7 +197,7 @@ def average( self._weights = weights self._validate_weights(dv, axis) - ds[dv.name] = self._averager(dv, axis) + ds[dv.name] = self._averager(dv, axis, required_weight=required_weight) if keep_weights: ds[self._weights.name] = self._weights @@ -698,7 +702,12 @@ def _validate_weights(self, data_var: xr.DataArray, axis: List[SpatialAxis]): f"and the data variable {dim_sizes} are misaligned." ) - def _averager(self, data_var: xr.DataArray, axis: List[SpatialAxis]): + def _averager( + self, + data_var: xr.DataArray, + axis: List[SpatialAxis], + required_weight: Optional[float] = 0.0, + ): """Perform a weighted average of a data variable. This method assumes all specified keys in ``axis`` exists in the data @@ -716,6 +725,9 @@ def _averager(self, data_var: xr.DataArray, axis: List[SpatialAxis]): Data variable inside a Dataset. axis : List[SpatialAxis] List of axis dimensions to average over. + required_weight : optional, float + Fraction of data coverage (i..e, weight) needed to return a + spatial average value. Value must range from 0 to 1. Returns ------- @@ -729,11 +741,48 @@ def _averager(self, data_var: xr.DataArray, axis: List[SpatialAxis]): """ weights = self._weights.fillna(0) + # ensure required weight is between 0 and 1 + if required_weight is None: + required_weight = 0.0 + + if required_weight < 0.0: + raise ValueError( + "required_weight argment is less than zero. " + "required_weight must be between 0 and 1." + ) + + if required_weight > 1.0: + raise ValueError( + "required_weight argment is greater than zero. " + "required_weight must be between 0 and 1." + ) + + # need weights to match data_var dimensionality + if required_weight > 0.0: + weights, data_var = xr.broadcast(weights, data_var) + + # get averaging dimensions dim = [] for key in axis: dim.append(get_dim_keys(data_var, key)) + # compute weighed mean with xr.set_options(keep_attrs=True): weighted_mean = data_var.cf.weighted(weights).mean(dim=dim) + # if weight thresholds applied, calculate fraction of data availability + # replace values that do not meet minimum weight with nan + if required_weight > 0.0: + # sum all weights (assuming no missing values exist) + print(dim) + weight_sum_all = weights.sum(dim=dim) # type: ignore + # zero out cells with missing values in data_var + weights = xr.where(~np.isnan(data_var), weights, 0) + # sum all weights (including zero for missing values) + weight_sum_masked = weights.sum(dim=dim) # type: ignore + # get fraction of weight available + frac = weight_sum_masked / weight_sum_all + # nan out values that don't meet specified weight threshold + weighted_mean = xr.where(frac >= required_weight, weighted_mean, np.nan) + return weighted_mean