Skip to content

Commit

Permalink
Add subset_and_align_datasets() to regrid.py (#776)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomvothecoder committed Aug 21, 2024
1 parent 4b23a9b commit cb070e2
Show file tree
Hide file tree
Showing 7 changed files with 177 additions and 140 deletions.
117 changes: 27 additions & 90 deletions e3sm_diags/driver/lat_lon_driver.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
from __future__ import annotations

from typing import TYPE_CHECKING, List, Tuple
from typing import TYPE_CHECKING, List

import xarray as xr

from e3sm_diags.driver.utils.dataset_xr import Dataset
from e3sm_diags.driver.utils.io import _save_data_metrics_and_plots
from e3sm_diags.driver.utils.regrid import (
_apply_land_sea_mask,
_subset_on_region,
align_grids_to_lower_res,
get_z_axis,
has_z_axis,
regrid_z_axis_to_plevs,
subset_and_align_datasets,
)
from e3sm_diags.driver.utils.type_annotations import MetricsDict
from e3sm_diags.logger import custom_logger
Expand Down Expand Up @@ -151,18 +149,29 @@ def _run_diags_2d(
parameter._set_param_output_attrs(var_key, season, region, ref_name, ilev=None)

(
metrics_dict,
ds_test_region,
ds_test_region_regrid,
ds_ref_region,
ds_ref_region_regrid,
ds_diff_region,
) = _get_metrics_by_region(
) = subset_and_align_datasets(
parameter,
ds_test,
ds_ref,
ds_land_sea_mask,
var_key,
region,
)

metrics_dict = _create_metrics_dict(
var_key,
ds_test_region,
ds_test_region_regrid,
ds_ref_region,
ds_ref_region_regrid,
ds_diff_region,
)

_save_data_metrics_and_plots(
parameter,
plot_func,
Expand Down Expand Up @@ -223,11 +232,12 @@ def _run_diags_3d(

for region in regions:
(
metrics_dict,
ds_test_region,
ds_test_region_regrid,
ds_ref_region,
ds_ref_region_regrid,
ds_diff_region,
) = _get_metrics_by_region(
) = subset_and_align_datasets(
parameter,
ds_test_ilev,
ds_ref_ilev,
Expand All @@ -236,6 +246,15 @@ def _run_diags_3d(
region,
)

metrics_dict = _create_metrics_dict(
var_key,
ds_test_region,
ds_test_region_regrid,
ds_ref_region,
ds_ref_region_regrid,
ds_diff_region,
)

parameter._set_param_output_attrs(var_key, season, region, ref_name, ilev)
_save_data_metrics_and_plots(
parameter,
Expand All @@ -248,88 +267,6 @@ def _run_diags_3d(
)


def _get_metrics_by_region(
parameter: CoreParameter,
ds_test: xr.Dataset,
ds_ref: xr.Dataset,
ds_land_sea_mask: xr.Dataset,
var_key: str,
region: str,
) -> Tuple[MetricsDict, xr.Dataset, xr.Dataset | None, xr.Dataset | None]:
"""Get metrics by region and save data (optional), metrics, and plots
Parameters
----------
parameter : CoreParameter
The parameter for the diagnostic.
ds_test : xr.Dataset
The dataset containing the test variable.
ds_ref : xr.Dataset
The dataset containing the ref variable. If this is a model-only run
then it will be the same dataset as ``ds_test``.
ds_land_sea_mask : xr.Dataset
The land sea mask dataset, which is only used for masking if the region
is "land" or "ocean".
var_key : str
The key of the variable.
region : str
The region.
Returns
-------
Tuple[MetricsDict, xr.Dataset, xr.Dataset | None, xr.Dataset | None]
A tuple containing the metrics dictionary, the test dataset, the ref
dataset (optional), and the diffs dataset (optional).
"""
logger.info(f"Selected region: {region}")
parameter.var_region = region

# Apply a land sea mask or subset on a specific region.
if region == "land" or region == "ocean":
ds_test = _apply_land_sea_mask(
ds_test,
ds_land_sea_mask,
var_key,
region, # type: ignore
parameter.regrid_tool,
parameter.regrid_method,
)
ds_ref = _apply_land_sea_mask(
ds_ref,
ds_land_sea_mask,
var_key,
region, # type: ignore
parameter.regrid_tool,
parameter.regrid_method,
)
elif region != "global":
ds_test = _subset_on_region(ds_test, var_key, region)
ds_ref = _subset_on_region(ds_ref, var_key, region)

# Align the grid resolutions if the diagnostic is not model only.
if not parameter.model_only:
ds_test_regrid, ds_ref_regrid = align_grids_to_lower_res(
ds_test,
ds_ref,
var_key,
parameter.regrid_tool,
parameter.regrid_method,
)
ds_diff = ds_test_regrid.copy()
ds_diff[var_key] = ds_test_regrid[var_key] - ds_ref_regrid[var_key]
else:
ds_test_regrid = ds_test
ds_ref = None # type: ignore
ds_ref_regrid = None
ds_diff = None

metrics_dict = _create_metrics_dict(
var_key, ds_test, ds_test_regrid, ds_ref, ds_ref_regrid, ds_diff
)

return metrics_dict, ds_test, ds_ref, ds_diff


def _create_metrics_dict(
var_key: str,
ds_test: xr.Dataset,
Expand Down
28 changes: 18 additions & 10 deletions e3sm_diags/driver/utils/dataset_xr.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,7 +643,7 @@ def _get_matching_climo_src_vars(
return {tuple(var_list): target_variable_map[var_tuple]}

raise IOError(
f"The dataset file has no matching souce variables for {target_var}"
f"The dataset file has no matching source variables for {target_var}"
)

# --------------------------------------------------------------------------
Expand Down Expand Up @@ -825,16 +825,16 @@ def _get_time_series_dataset_obj(self, var) -> xr.Dataset:
xr.Dataset
The dataset for the variable.
"""
filename = self._get_timeseries_filepath(self.root_path, var)
filepath = self._get_timeseries_filepath(self.root_path, var)

if filename == "":
if filepath == "":
raise IOError(
f"No time series `.nc` file was found for '{var}' in '{self.root_path}'"
)

time_slice = self._get_time_slice(filename)
time_slice = self._get_time_slice(filepath)

ds = xr.open_dataset(filename, decode_times=True, use_cftime=True)
ds = xr.open_dataset(filepath, decode_times=True, use_cftime=True)
ds_subset = ds.sel(time=time_slice).squeeze()

return ds_subset
Expand Down Expand Up @@ -1043,18 +1043,26 @@ def _squeeze_time_dim(self, ds: xr.Dataset) -> xr.Dataset:
"""Squeeze single coordinate climatology time dimensions.
For example, "ANN" averages over the year and collapses the time dim.
Time bounds are also dropped if they exist.
Parameters
----------
ds : xr.Dataset
_description_
The dataset with a time dimension
Returns
-------
xr.Dataset
_description_
The dataset with a time dimension.
"""
dim = xc.get_dim_keys(ds[self.var], axis="T")
ds = ds.squeeze(dim=dim)
ds = ds.drop_vars(dim)
time_dim = xc.get_dim_coords(ds, axis="T")

if len(time_dim) == 1:
ds = ds.squeeze(dim=time_dim.name)
ds = ds.drop_vars(time_dim.name)

bnds_key = time_dim.attrs.get("bounds")
if bnds_key is not None and bnds_key in ds.data_vars.keys():
ds = ds.drop_vars(bnds_key)

return ds
16 changes: 11 additions & 5 deletions e3sm_diags/driver/utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,10 @@ def _save_data_metrics_and_plots(
ds_diff : xr.Dataset | None
The optional difference dataset. If the diagnostic is a model-only run,
then it will be None.
metrics_dict : Metrics
The dictionary containing metrics for the variable.
metrics_dict : Metrics | None
The optional dictionary containing metrics for the variable. Some sets
such as cosp_histogram only calculate spatial average and do not
use ``metrics_dict``.
"""
if parameter.save_netcdf:
_write_vars_to_netcdf(
Expand All @@ -68,13 +70,17 @@ def _save_data_metrics_and_plots(
"long_name", "No long_name attr in test data"
)

plot_func(
# Get the function arguments and pass to the set's plotting function.
args = [
parameter,
ds_test[var_key],
ds_ref[var_key] if ds_ref is not None else None,
ds_diff[var_key] if ds_diff is not None else None,
metrics_dict,
)
]
if metrics_dict is not None:
args = args + [metrics_dict]

plot_func(*args)


def _write_vars_to_netcdf(
Expand Down
81 changes: 80 additions & 1 deletion e3sm_diags/driver/utils/regrid.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
from __future__ import annotations

from typing import List, Literal, Tuple
from typing import TYPE_CHECKING, List, Literal, Tuple

import xarray as xr
import xcdat as xc

from e3sm_diags.derivations.default_regions_xr import REGION_SPECS
from e3sm_diags.driver import MASK_REGION_TO_VAR_KEY
from e3sm_diags.logger import custom_logger

if TYPE_CHECKING:
from e3sm_diags.parameter.core_parameter import CoreParameter

logger = custom_logger(__name__)


# Valid hybrid-sigma levels keys that can be found in datasets.
HYBRID_SIGMA_KEYS = {
Expand All @@ -19,6 +26,78 @@
REGRID_TOOLS = Literal["esmf", "xesmf", "regrid2"]


def subset_and_align_datasets(
parameter: CoreParameter,
ds_test: xr.Dataset,
ds_ref: xr.Dataset,
ds_land_sea_mask: xr.Dataset,
var_key: str,
region: str,
) -> Tuple[xr.Dataset, xr.Dataset, xr.Dataset, xr.Dataset, xr.Dataset]:
"""Subset ref and test datasets on a region and regrid to align them.
Parameters
----------
parameter : CoreParameter
The parameter for the diagnostic.
ds_test : xr.Dataset
The dataset containing the test variable.
ds_ref : xr.Dataset
The dataset containing the ref variable.
ds_land_sea_mask : xr.Dataset
The land sea mask dataset, which is only used for masking if the region
is "land" or "ocean".
var_key : str
The key of the variable.
region : str
The region.
Returns
-------
Tuple[xr.Dataset, xr.Dataset, xr.Dataset, xr.Dataset, xr.Dataset]
A tuple containing the test dataset, the regridded test
dataset, the ref dataset, the regridded ref dataset, and the difference
between regridded datasets.
"""
logger.info(f"Selected region: {region}")
parameter.var_region = region

# Apply a land sea mask or subset on a specific region.
if region == "land" or region == "ocean":
ds_test = _apply_land_sea_mask(
ds_test,
ds_land_sea_mask,
var_key,
region, # type: ignore
parameter.regrid_tool,
parameter.regrid_method,
)
ds_ref = _apply_land_sea_mask(
ds_ref,
ds_land_sea_mask,
var_key,
region, # type: ignore
parameter.regrid_tool,
parameter.regrid_method,
)
elif region != "global":
ds_test = _subset_on_region(ds_test, var_key, region)
ds_ref = _subset_on_region(ds_ref, var_key, region)

ds_test_regrid, ds_ref_regrid = align_grids_to_lower_res(
ds_test,
ds_ref,
var_key,
parameter.regrid_tool,
parameter.regrid_method,
)

ds_diff = ds_test_regrid.copy()
ds_diff[var_key] = ds_test_regrid[var_key] - ds_ref_regrid[var_key]

return ds_test, ds_test_regrid, ds_ref, ds_ref_regrid, ds_diff


def has_z_axis(data_var: xr.DataArray) -> bool:
"""Checks whether the data variable has a Z axis.
Expand Down
4 changes: 2 additions & 2 deletions e3sm_diags/parameter/core_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import copy
import importlib
import sys
from typing import TYPE_CHECKING, Any, Dict, List
from typing import TYPE_CHECKING, Any, Dict, List, Tuple

from e3sm_diags.derivations.derivations import DerivedVariablesMap
from e3sm_diags.driver.utils.climo_xr import CLIMO_FREQ
Expand Down Expand Up @@ -130,7 +130,7 @@ def __init__(self):
self.output_format_subplot: List[str] = []
self.canvas_size_w: int = 1212
self.canvas_size_h: int = 1628
self.figsize: List[float] = [8.5, 11.0]
self.figsize: Tuple[float, float] = (8.5, 11.0)
self.dpi: int = 150
self.arrows: bool = True
self.logo: bool = False
Expand Down
Loading

0 comments on commit cb070e2

Please sign in to comment.