-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PR]: Add Z axis support for spatial averaging #606
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,7 +3,7 @@ | |
import xarray as xr | ||
|
||
from tests import requires_dask | ||
from tests.fixtures import generate_dataset | ||
from tests.fixtures import generate_dataset, generate_lev_dataset | ||
from xcdat.spatial import SpatialAccessor | ||
|
||
|
||
|
@@ -45,6 +45,35 @@ def test_raises_error_if_data_var_not_in_dataset(self): | |
with pytest.raises(KeyError): | ||
self.ds.spatial.average("not_a_data_var", axis=["Y", "incorrect_axis"]) | ||
|
||
def test_vertical_average_with_weights(self): | ||
# check that vertical averaging returns the correct answer | ||
# get dataset with vertical levels | ||
ds = generate_lev_dataset() | ||
# subset to one column for testing (and shake up data) | ||
ds = ds.isel(time=[0], lat=[0], lon=[0]).squeeze() | ||
so = ds["so"] | ||
so[:] = np.array([1, 2, 3, 4]) | ||
ds["so"] = so | ||
result = ds.spatial.average( | ||
"so", lev_bounds=(4000, 10000), axis=["Z"], keep_weights=True | ||
) | ||
# specify expected result | ||
expected = xr.DataArray( | ||
data=np.array(1.8), coords={"time": ds.time, "lat": ds.lat, "lon": ds.lon} | ||
) | ||
# compare | ||
xr.testing.assert_allclose(result["so"], expected) | ||
|
||
# check that vertical averaging returns the correct weights | ||
expected = xr.DataArray( | ||
data=np.array([2000, 2000, 1000, 0.0]), | ||
coords={"time": ds.time, "lev": ds.lev, "lat": ds.lat, "lon": ds.lon}, | ||
dims=["lev"], | ||
attrs={"xcdat_bounds": True}, | ||
) | ||
|
||
xr.testing.assert_allclose(result["lev_wts"], expected) | ||
Comment on lines
+65
to
+75
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are two-in-one tests permitted? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah there's no problem with that here. As long as the tests are relatively easy to maintain. |
||
|
||
def test_raises_error_if_axis_list_contains_unsupported_axis(self): | ||
with pytest.raises(ValueError): | ||
self.ds.spatial.average("ts", axis=["Y", "incorrect_axis"]) | ||
|
@@ -313,6 +342,23 @@ def test_raises_error_if_dataset_has_multiple_bounds_variables_for_an_axis(self) | |
with pytest.raises(TypeError): | ||
ds.spatial.get_weights(axis=["Y", "X"]) | ||
|
||
def test_vertical_weighting(self): | ||
# get dataset with vertical coordinate | ||
ds = generate_lev_dataset() | ||
# call _get_vertical_weights | ||
result = ds.spatial._get_vertical_weights( | ||
domain_bounds=ds.lev_bnds, region_bounds=np.array([4000, 10000]) | ||
) | ||
# specify expected result | ||
expected = xr.DataArray( | ||
data=np.array([2000, 2000, 1000, 0.0]), | ||
coords={"lev": ds.lev}, | ||
dims=["lev"], | ||
attrs={"units": "m", "positive": "down", "axis": "Z", "bounds": "lev_bnds"}, | ||
) | ||
# compare | ||
xr.testing.assert_allclose(result, expected) | ||
|
||
def test_data_var_weights_for_region_in_lat_and_lon_domains(self): | ||
ds = self.ds.copy() | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,7 +29,7 @@ | |
#: Type alias for a dictionary of axis keys mapped to their bounds. | ||
AxisWeights = Dict[Hashable, xr.DataArray] | ||
#: Type alias for supported spatial axis keys. | ||
SpatialAxis = Literal["X", "Y"] | ||
SpatialAxis = Literal["X", "Y", "Z"] | ||
SPATIAL_AXES: Tuple[SpatialAxis, ...] = get_args(SpatialAxis) | ||
#: Type alias for a tuple of floats/ints for the regional selection bounds. | ||
RegionAxisBounds = Tuple[float, float] | ||
|
@@ -73,10 +73,12 @@ def average( | |
keep_weights: bool = False, | ||
lat_bounds: Optional[RegionAxisBounds] = None, | ||
lon_bounds: Optional[RegionAxisBounds] = None, | ||
lev_bounds: Optional[RegionAxisBounds] = None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds good to me |
||
) -> xr.Dataset: | ||
""" | ||
Calculates the spatial average for a rectilinear grid over an optionally | ||
specified regional domain. | ||
Calculates the weighted spatial and/or vertical average for a | ||
rectilinear grid over an optionally specified regional and/or vertical | ||
domain. | ||
tomvothecoder marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Operations include: | ||
|
||
|
@@ -101,7 +103,7 @@ def average( | |
average. | ||
axis : List[SpatialAxis] | ||
List of axis dimensions to average over, by default ["X", "Y"]. | ||
Valid axis keys include "X" and "Y". | ||
Valid axis keys include "X", "Y", and "Z". | ||
weights : {"generate", xr.DataArray}, optional | ||
If "generate", then weights are generated. Otherwise, pass a | ||
DataArray containing the regional weights used for weighted | ||
|
@@ -122,6 +124,10 @@ def average( | |
ignored if ``weights`` are supplied. The lower bound can be larger | ||
than the upper bound (e.g., across the prime meridian, dateline), by | ||
default None. | ||
lev_bounds : Optional[RegionAxisBounds], optional | ||
A tuple of floats/ints for the regional lower and upper level | ||
boundaries. This arg is used when calculating axis weights, but is | ||
ignored if ``weights`` are supplied. The default is None. | ||
|
||
Returns | ||
------- | ||
|
@@ -143,11 +149,15 @@ def average( | |
>>> | ||
>>> ds.lon.attrs["axis"] | ||
>>> X | ||
>>> | ||
>>> ds.level.attrs["axis"] | ||
>>> Z | ||
|
||
Set the 'axis' attribute for the required coordinates if it isn't: | ||
|
||
>>> ds.lat.attrs["axis"] = "Y" | ||
>>> ds.lon.attrs["axis"] = "X" | ||
>>> ds.level.attrs["axis"] = "Z" | ||
|
||
Call spatial averaging method: | ||
|
||
|
@@ -167,6 +177,10 @@ def average( | |
|
||
>>> ts_zonal = ds.spatial.average("tas", axis=["X"])["tas"] | ||
|
||
Get the vertical average (between 100 and 1000 hPa): | ||
|
||
>>> ta_column = ds.spatial.average("ta", axis=["Z"], lev_bounds=(100, 1000))["ta"] | ||
|
||
Using custom weights for averaging: | ||
|
||
>>> # The shape of the weights must align with the data var. | ||
|
@@ -178,6 +192,13 @@ def average( | |
>>> | ||
>>> ts_global = ds.spatial.average("tas", axis=["X", "Y"], | ||
>>> weights=weights)["tas"] | ||
|
||
Notes: | ||
------ | ||
Weights are generally computed as the difference between the bounds. If | ||
sub-selecting a region, the units must match the axis units (e.g., | ||
Pa/hPa or m/km). The sub-selected region must be in numerical order | ||
(e.g., (100, 1000) and not (1000, 100)). | ||
""" | ||
ds = self._dataset.copy() | ||
dv = _get_data_var(ds, data_var) | ||
|
@@ -188,7 +209,11 @@ def average( | |
self._validate_region_bounds("Y", lat_bounds) | ||
if lon_bounds is not None: | ||
self._validate_region_bounds("X", lon_bounds) | ||
self._weights = self.get_weights(axis, lat_bounds, lon_bounds, data_var) | ||
if lev_bounds is not None: | ||
self._validate_region_bounds("Z", lev_bounds) | ||
self._weights = self.get_weights( | ||
axis, lat_bounds, lon_bounds, lev_bounds, data_var | ||
) | ||
elif isinstance(weights, xr.DataArray): | ||
self._weights = weights | ||
|
||
|
@@ -205,6 +230,7 @@ def get_weights( | |
axis: List[SpatialAxis], | ||
lat_bounds: Optional[RegionAxisBounds] = None, | ||
lon_bounds: Optional[RegionAxisBounds] = None, | ||
lev_bounds: Optional[RegionAxisBounds] = None, | ||
data_var: Optional[str] = None, | ||
) -> xr.DataArray: | ||
""" | ||
|
@@ -216,9 +242,9 @@ def get_weights( | |
weights are then combined to form a DataArray of weights that can be | ||
used to perform a weighted (spatial) average. | ||
|
||
If ``lat_bounds`` or ``lon_bounds`` are supplied, then grid cells | ||
outside this selected regional domain are given zero weight. Grid cells | ||
that are partially in this domain are given partial weight. | ||
If ``lat_bounds``, ``lon_bounds``, or ``lev_bounds`` are supplied, then | ||
grid cells outside this selected regional domain are given zero weight. | ||
Grid cells that are partially in this domain are given partial weight. | ||
|
||
Parameters | ||
---------- | ||
|
@@ -230,6 +256,9 @@ def get_weights( | |
lon_bounds : Optional[RegionAxisBounds] | ||
Tuple of longitude boundaries for regional selection, by default | ||
None. | ||
lev_bounds : Optional[RegionAxisBounds] | ||
Tuple of level boundaries for vertical selection, by default | ||
None. | ||
data_var: Optional[str] | ||
The key of the data variable, by default None. Pass this argument | ||
when the dataset has more than one bounds per axis (e.g., "lon" | ||
|
@@ -246,9 +275,7 @@ def get_weights( | |
Notes | ||
----- | ||
This method was developed for rectilinear grids only. ``get_weights()`` | ||
recognizes and operate on latitude and longitude, but could be extended | ||
to work with other standard geophysical dimensions (e.g., time, depth, | ||
and pressure). | ||
recognizes and operate on latitude, longitude, and vertical levels. | ||
""" | ||
Bounds = TypedDict( | ||
"Bounds", {"weights_method": Callable, "region": Optional[np.ndarray]} | ||
|
@@ -267,6 +294,12 @@ def get_weights( | |
if lat_bounds is not None | ||
else None, | ||
}, | ||
"Z": { | ||
"weights_method": self._get_vertical_weights, | ||
"region": np.array(lev_bounds, dtype="float") | ||
if lev_bounds is not None | ||
else None, | ||
}, | ||
} | ||
|
||
axis_weights: AxisWeights = {} | ||
|
@@ -476,6 +509,32 @@ def _get_latitude_weights( | |
weights = self._calculate_weights(d_bounds) | ||
return weights | ||
|
||
def _get_vertical_weights( | ||
self, domain_bounds: xr.DataArray, region_bounds: Optional[np.ndarray] | ||
) -> xr.DataArray: | ||
"""Gets weights for the vertical axis. | ||
|
||
This method scales the domain to a region (if selected) and returns weights | ||
proportional to the difference between each pair of level bounds. | ||
|
||
Parameters | ||
---------- | ||
domain_bounds : xr.DataArray | ||
The array of bounds for the vertical domain. | ||
region_bounds : Optional[np.ndarray] | ||
The array of bounds for vertical selection. | ||
|
||
Returns | ||
------- | ||
xr.DataArray | ||
The vertical axis weights. | ||
""" | ||
if region_bounds is not None: | ||
domain_bounds = self._scale_domain_to_region(domain_bounds, region_bounds) | ||
|
||
weights = self._calculate_weights(domain_bounds) | ||
return weights | ||
|
||
Comment on lines
+512
to
+537
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this is sufficiently different from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree |
||
def _calculate_weights(self, domain_bounds: xr.DataArray): | ||
"""Calculate weights for the domain. | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Assigning the numpy array directly to the DataArray will work too