diff --git a/e3sm_diags/driver/lat_lon_driver.py b/e3sm_diags/driver/lat_lon_driver.py index bea50fb94..3961be59e 100755 --- a/e3sm_diags/driver/lat_lon_driver.py +++ b/e3sm_diags/driver/lat_lon_driver.py @@ -1,18 +1,16 @@ from __future__ import annotations -from typing import TYPE_CHECKING, List, Tuple +from typing import TYPE_CHECKING, List import xarray as xr from e3sm_diags.driver.utils.dataset_xr import Dataset from e3sm_diags.driver.utils.io import _save_data_metrics_and_plots from e3sm_diags.driver.utils.regrid import ( - _apply_land_sea_mask, - _subset_on_region, - align_grids_to_lower_res, get_z_axis, has_z_axis, regrid_z_axis_to_plevs, + subset_and_align_datasets, ) from e3sm_diags.driver.utils.type_annotations import MetricsDict from e3sm_diags.logger import custom_logger @@ -151,11 +149,12 @@ def _run_diags_2d( parameter._set_param_output_attrs(var_key, season, region, ref_name, ilev=None) ( - metrics_dict, ds_test_region, + ds_test_region_regrid, ds_ref_region, + ds_ref_region_regrid, ds_diff_region, - ) = _get_metrics_by_region( + ) = subset_and_align_datasets( parameter, ds_test, ds_ref, @@ -163,6 +162,16 @@ def _run_diags_2d( var_key, region, ) + + metrics_dict = _create_metrics_dict( + var_key, + ds_test_region, + ds_test_region_regrid, + ds_ref_region, + ds_ref_region_regrid, + ds_diff_region, + ) + _save_data_metrics_and_plots( parameter, plot_func, @@ -223,11 +232,12 @@ def _run_diags_3d( for region in regions: ( - metrics_dict, ds_test_region, + ds_test_region_regrid, ds_ref_region, + ds_ref_region_regrid, ds_diff_region, - ) = _get_metrics_by_region( + ) = subset_and_align_datasets( parameter, ds_test_ilev, ds_ref_ilev, @@ -236,6 +246,15 @@ def _run_diags_3d( region, ) + metrics_dict = _create_metrics_dict( + var_key, + ds_test_region, + ds_test_region_regrid, + ds_ref_region, + ds_ref_region_regrid, + ds_diff_region, + ) + parameter._set_param_output_attrs(var_key, season, region, ref_name, ilev) _save_data_metrics_and_plots( parameter, @@ -248,88 +267,6 @@ def _run_diags_3d( ) -def _get_metrics_by_region( - parameter: CoreParameter, - ds_test: xr.Dataset, - ds_ref: xr.Dataset, - ds_land_sea_mask: xr.Dataset, - var_key: str, - region: str, -) -> Tuple[MetricsDict, xr.Dataset, xr.Dataset | None, xr.Dataset | None]: - """Get metrics by region and save data (optional), metrics, and plots - - Parameters - ---------- - parameter : CoreParameter - The parameter for the diagnostic. - ds_test : xr.Dataset - The dataset containing the test variable. - ds_ref : xr.Dataset - The dataset containing the ref variable. If this is a model-only run - then it will be the same dataset as ``ds_test``. - ds_land_sea_mask : xr.Dataset - The land sea mask dataset, which is only used for masking if the region - is "land" or "ocean". - var_key : str - The key of the variable. - region : str - The region. - - Returns - ------- - Tuple[MetricsDict, xr.Dataset, xr.Dataset | None, xr.Dataset | None] - A tuple containing the metrics dictionary, the test dataset, the ref - dataset (optional), and the diffs dataset (optional). - """ - logger.info(f"Selected region: {region}") - parameter.var_region = region - - # Apply a land sea mask or subset on a specific region. - if region == "land" or region == "ocean": - ds_test = _apply_land_sea_mask( - ds_test, - ds_land_sea_mask, - var_key, - region, # type: ignore - parameter.regrid_tool, - parameter.regrid_method, - ) - ds_ref = _apply_land_sea_mask( - ds_ref, - ds_land_sea_mask, - var_key, - region, # type: ignore - parameter.regrid_tool, - parameter.regrid_method, - ) - elif region != "global": - ds_test = _subset_on_region(ds_test, var_key, region) - ds_ref = _subset_on_region(ds_ref, var_key, region) - - # Align the grid resolutions if the diagnostic is not model only. - if not parameter.model_only: - ds_test_regrid, ds_ref_regrid = align_grids_to_lower_res( - ds_test, - ds_ref, - var_key, - parameter.regrid_tool, - parameter.regrid_method, - ) - ds_diff = ds_test_regrid.copy() - ds_diff[var_key] = ds_test_regrid[var_key] - ds_ref_regrid[var_key] - else: - ds_test_regrid = ds_test - ds_ref = None # type: ignore - ds_ref_regrid = None - ds_diff = None - - metrics_dict = _create_metrics_dict( - var_key, ds_test, ds_test_regrid, ds_ref, ds_ref_regrid, ds_diff - ) - - return metrics_dict, ds_test, ds_ref, ds_diff - - def _create_metrics_dict( var_key: str, ds_test: xr.Dataset, diff --git a/e3sm_diags/driver/utils/dataset_xr.py b/e3sm_diags/driver/utils/dataset_xr.py index 251e56c49..e77ad938a 100644 --- a/e3sm_diags/driver/utils/dataset_xr.py +++ b/e3sm_diags/driver/utils/dataset_xr.py @@ -643,7 +643,7 @@ def _get_matching_climo_src_vars( return {tuple(var_list): target_variable_map[var_tuple]} raise IOError( - f"The dataset file has no matching souce variables for {target_var}" + f"The dataset file has no matching source variables for {target_var}" ) # -------------------------------------------------------------------------- @@ -825,16 +825,16 @@ def _get_time_series_dataset_obj(self, var) -> xr.Dataset: xr.Dataset The dataset for the variable. """ - filename = self._get_timeseries_filepath(self.root_path, var) + filepath = self._get_timeseries_filepath(self.root_path, var) - if filename == "": + if filepath == "": raise IOError( f"No time series `.nc` file was found for '{var}' in '{self.root_path}'" ) - time_slice = self._get_time_slice(filename) + time_slice = self._get_time_slice(filepath) - ds = xr.open_dataset(filename, decode_times=True, use_cftime=True) + ds = xr.open_dataset(filepath, decode_times=True, use_cftime=True) ds_subset = ds.sel(time=time_slice).squeeze() return ds_subset @@ -1043,18 +1043,26 @@ def _squeeze_time_dim(self, ds: xr.Dataset) -> xr.Dataset: """Squeeze single coordinate climatology time dimensions. For example, "ANN" averages over the year and collapses the time dim. + Time bounds are also dropped if they exist. + Parameters ---------- ds : xr.Dataset - _description_ + The dataset with a time dimension Returns ------- xr.Dataset - _description_ + The dataset with a time dimension. """ - dim = xc.get_dim_keys(ds[self.var], axis="T") - ds = ds.squeeze(dim=dim) - ds = ds.drop_vars(dim) + time_dim = xc.get_dim_coords(ds, axis="T") + + if len(time_dim) == 1: + ds = ds.squeeze(dim=time_dim.name) + ds = ds.drop_vars(time_dim.name) + + bnds_key = time_dim.attrs.get("bounds") + if bnds_key is not None and bnds_key in ds.data_vars.keys(): + ds = ds.drop_vars(bnds_key) return ds diff --git a/e3sm_diags/driver/utils/io.py b/e3sm_diags/driver/utils/io.py index 09e4794da..9b559a065 100644 --- a/e3sm_diags/driver/utils/io.py +++ b/e3sm_diags/driver/utils/io.py @@ -41,8 +41,10 @@ def _save_data_metrics_and_plots( ds_diff : xr.Dataset | None The optional difference dataset. If the diagnostic is a model-only run, then it will be None. - metrics_dict : Metrics - The dictionary containing metrics for the variable. + metrics_dict : Metrics | None + The optional dictionary containing metrics for the variable. Some sets + such as cosp_histogram only calculate spatial average and do not + use ``metrics_dict``. """ if parameter.save_netcdf: _write_vars_to_netcdf( @@ -68,13 +70,17 @@ def _save_data_metrics_and_plots( "long_name", "No long_name attr in test data" ) - plot_func( + # Get the function arguments and pass to the set's plotting function. + args = [ parameter, ds_test[var_key], ds_ref[var_key] if ds_ref is not None else None, ds_diff[var_key] if ds_diff is not None else None, - metrics_dict, - ) + ] + if metrics_dict is not None: + args = args + [metrics_dict] + + plot_func(*args) def _write_vars_to_netcdf( diff --git a/e3sm_diags/driver/utils/regrid.py b/e3sm_diags/driver/utils/regrid.py index e61476e95..0d8f39372 100644 --- a/e3sm_diags/driver/utils/regrid.py +++ b/e3sm_diags/driver/utils/regrid.py @@ -1,12 +1,19 @@ from __future__ import annotations -from typing import List, Literal, Tuple +from typing import TYPE_CHECKING, List, Literal, Tuple import xarray as xr import xcdat as xc from e3sm_diags.derivations.default_regions_xr import REGION_SPECS from e3sm_diags.driver import MASK_REGION_TO_VAR_KEY +from e3sm_diags.logger import custom_logger + +if TYPE_CHECKING: + from e3sm_diags.parameter.core_parameter import CoreParameter + +logger = custom_logger(__name__) + # Valid hybrid-sigma levels keys that can be found in datasets. HYBRID_SIGMA_KEYS = { @@ -19,6 +26,78 @@ REGRID_TOOLS = Literal["esmf", "xesmf", "regrid2"] +def subset_and_align_datasets( + parameter: CoreParameter, + ds_test: xr.Dataset, + ds_ref: xr.Dataset, + ds_land_sea_mask: xr.Dataset, + var_key: str, + region: str, +) -> Tuple[xr.Dataset, xr.Dataset, xr.Dataset, xr.Dataset, xr.Dataset]: + """Subset ref and test datasets on a region and regrid to align them. + + Parameters + ---------- + parameter : CoreParameter + The parameter for the diagnostic. + ds_test : xr.Dataset + The dataset containing the test variable. + ds_ref : xr.Dataset + The dataset containing the ref variable. + ds_land_sea_mask : xr.Dataset + The land sea mask dataset, which is only used for masking if the region + is "land" or "ocean". + var_key : str + The key of the variable. + region : str + The region. + + Returns + ------- + Tuple[xr.Dataset, xr.Dataset, xr.Dataset, xr.Dataset, xr.Dataset] + A tuple containing the test dataset, the regridded test + dataset, the ref dataset, the regridded ref dataset, and the difference + between regridded datasets. + """ + logger.info(f"Selected region: {region}") + parameter.var_region = region + + # Apply a land sea mask or subset on a specific region. + if region == "land" or region == "ocean": + ds_test = _apply_land_sea_mask( + ds_test, + ds_land_sea_mask, + var_key, + region, # type: ignore + parameter.regrid_tool, + parameter.regrid_method, + ) + ds_ref = _apply_land_sea_mask( + ds_ref, + ds_land_sea_mask, + var_key, + region, # type: ignore + parameter.regrid_tool, + parameter.regrid_method, + ) + elif region != "global": + ds_test = _subset_on_region(ds_test, var_key, region) + ds_ref = _subset_on_region(ds_ref, var_key, region) + + ds_test_regrid, ds_ref_regrid = align_grids_to_lower_res( + ds_test, + ds_ref, + var_key, + parameter.regrid_tool, + parameter.regrid_method, + ) + + ds_diff = ds_test_regrid.copy() + ds_diff[var_key] = ds_test_regrid[var_key] - ds_ref_regrid[var_key] + + return ds_test, ds_test_regrid, ds_ref, ds_ref_regrid, ds_diff + + def has_z_axis(data_var: xr.DataArray) -> bool: """Checks whether the data variable has a Z axis. diff --git a/e3sm_diags/parameter/core_parameter.py b/e3sm_diags/parameter/core_parameter.py index 961a050c6..7b0da9e2c 100644 --- a/e3sm_diags/parameter/core_parameter.py +++ b/e3sm_diags/parameter/core_parameter.py @@ -3,7 +3,7 @@ import copy import importlib import sys -from typing import TYPE_CHECKING, Any, Dict, List +from typing import TYPE_CHECKING, Any, Dict, List, Tuple from e3sm_diags.derivations.derivations import DerivedVariablesMap from e3sm_diags.driver.utils.climo_xr import CLIMO_FREQ @@ -130,7 +130,7 @@ def __init__(self): self.output_format_subplot: List[str] = [] self.canvas_size_w: int = 1212 self.canvas_size_h: int = 1628 - self.figsize: List[float] = [8.5, 11.0] + self.figsize: Tuple[float, float] = (8.5, 11.0) self.dpi: int = 150 self.arrows: bool = True self.logo: bool = False diff --git a/e3sm_diags/plot/utils.py b/e3sm_diags/plot/utils.py index c4057cbbe..e2d9de5cb 100644 --- a/e3sm_diags/plot/utils.py +++ b/e3sm_diags/plot/utils.py @@ -40,7 +40,7 @@ BORDER_PADDING = (-0.06, -0.03, 0.13, 0.03) -def _save_plot(fig: plt.figure, parameter: CoreParameter): +def _save_plot(fig: plt.Figure, parameter: CoreParameter): """Save the plot using the figure object and parameter configs. This function creates the output filename to save the plot. It also @@ -48,7 +48,7 @@ def _save_plot(fig: plt.figure, parameter: CoreParameter): Parameters ---------- - fig : plt.figure + fig : plt.Figure The plot figure. parameter : CoreParameter The CoreParameter with file configurations. @@ -98,7 +98,7 @@ def _save_plot(fig: plt.figure, parameter: CoreParameter): def _add_colormap( subplot_num: int, var: xr.DataArray, - fig: plt.figure, + fig: plt.Figure, parameter: CoreParameter, color_map: str, contour_levels: List[float], @@ -117,7 +117,7 @@ def _add_colormap( The subplot number. var : xr.DataArray The variable to plot. - fig : plt.figure + fig : plt.Figure The figure object to add the subplot to. parameter : CoreParameter The CoreParameter object containing plot configurations. diff --git a/tests/e3sm_diags/driver/utils/test_dataset_xr.py b/tests/e3sm_diags/driver/utils/test_dataset_xr.py index 1fdf6de3e..75eed59e7 100644 --- a/tests/e3sm_diags/driver/utils/test_dataset_xr.py +++ b/tests/e3sm_diags/driver/utils/test_dataset_xr.py @@ -1,3 +1,4 @@ +import copy import logging from collections import OrderedDict from typing import Literal @@ -23,7 +24,9 @@ def _create_parameter_object( start_yr: str, end_yr: str, ): - parameter = CoreParameter() + # NOTE: Make sure to create deep copies to avoid references in memory to + # the same object. + parameter = copy.deepcopy(CoreParameter()) if dataset_type == "ref": if data_type == "time_series": @@ -83,7 +86,7 @@ def test_raises_error_if_type_attr_is_invalid(self): def test_sets_start_yr_and_end_yr_for_area_mean_time_series_set(self): parameter = AreaMeanTimeSeriesParameter() - parameter.sets[0] = "area_mean_time_series" + parameter.sets = ["area_mean_time_series"] parameter.start_yr = "2000" parameter.end_yr = "2001" @@ -96,7 +99,7 @@ def test_sets_sub_monthly_if_diurnal_cycle_or_arms_diags_set(self): parameter = _create_parameter_object( "ref", "time_series", self.data_path, "2000", "2001" ) - parameter.sets[0] = "diurnal_cycle" + parameter.sets = ["diurnal_cycle"] ds = Dataset(parameter, data_type="ref") @@ -363,7 +366,7 @@ def test_returns_reference_climo_dataset_from_file(self): result = ds.get_ref_climo_dataset("ts", "ANN", self.ds_climo.copy()) expected = self.ds_climo.squeeze(dim="time").drop_vars("time") - assert result.identical(expected) + xr.testing.assert_identical(result, expected) assert not ds.model_only def test_returns_test_dataset_as_default_value_if_climo_dataset_not_found(self): @@ -554,7 +557,7 @@ def test_returns_climo_dataset_using_ref_file_variable(self): result = ds.get_climo_dataset("ts", "ANN") expected = self.ds_climo.squeeze(dim="time").drop_vars("time") - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_returns_climo_dataset_using_test_file_variable(self): parameter = _create_parameter_object( @@ -568,9 +571,11 @@ def test_returns_climo_dataset_using_test_file_variable(self): result = ds.get_climo_dataset("ts", "ANN") expected = self.ds_climo.squeeze(dim="time").drop_vars("time") - assert result.identical(expected) + xr.testing.assert_identical(result, expected) - def test_returns_climo_dataset_using_ref_file_variable_test_name_and_season(self): + def test_returns_climo_dataset_using_ref_file_variable_test_name_and_season( + self, + ): # Example: {test_data_path}/{test_name}_{season}.nc parameter = _create_parameter_object( "ref", "climo", self.data_path, "2000", "2001" @@ -582,9 +587,11 @@ def test_returns_climo_dataset_using_ref_file_variable_test_name_and_season(self result = ds.get_climo_dataset("ts", "ANN") expected = self.ds_climo.squeeze(dim="time").drop_vars("time") - assert result.identical(expected) + xr.testing.assert_identical(result, expected) - def test_returns_climo_dataset_using_test_file_variable_test_name_and_season(self): + def test_returns_climo_dataset_using_test_file_variable_test_name_and_season( + self, + ): # Example: {test_data_path}/{test_name}_{season}.nc parameter = _create_parameter_object( "test", "climo", self.data_path, "2000", "2001" @@ -596,7 +603,7 @@ def test_returns_climo_dataset_using_test_file_variable_test_name_and_season(sel result = ds.get_climo_dataset("ts", "ANN") expected = self.ds_climo.squeeze(dim="time").drop_vars("time") - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_returns_climo_dataset_using_test_file_variable_ref_name_and_season_nested_pattern_1( self, @@ -616,7 +623,7 @@ def test_returns_climo_dataset_using_test_file_variable_ref_name_and_season_nest result = ds.get_climo_dataset("ts", "ANN") expected = self.ds_climo.squeeze(dim="time").drop_vars("time") - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_returns_climo_dataset_using_test_file_variable_ref_name_and_season_nested_pattern_2( self, @@ -638,7 +645,7 @@ def test_returns_climo_dataset_using_test_file_variable_ref_name_and_season_nest result = ds.get_climo_dataset("ts", "ANN") expected = self.ds_climo.squeeze(dim="time").drop_vars("time") - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_returns_climo_dataset_with_derived_variable(self): # We will derive the "PRECT" variable using the "pr" variable. @@ -693,7 +700,7 @@ def test_returns_climo_dataset_with_derived_variable(self): expected["PRECT"] = expected["pr"] * 3600 * 24 expected["PRECT"].attrs["units"] = "mm/day" - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_returns_climo_dataset_using_derived_var_directly_from_dataset(self): ds_precst = xr.Dataset( @@ -744,7 +751,7 @@ def test_returns_climo_dataset_using_derived_var_directly_from_dataset(self): result = ds.get_climo_dataset("PRECST", season="ANN") expected = ds_precst.squeeze(dim="time").drop_vars("time") - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_returns_climo_dataset_using_source_variable_with_wildcard(self): ds_precst = xr.Dataset( @@ -805,7 +812,7 @@ def test_returns_climo_dataset_using_source_variable_with_wildcard(self): expected = ds_precst.squeeze(dim="time").drop_vars("time") expected["bc_DDF"] = expected["bc_a?DDF"] + expected["bc_c?DDF"] - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_returns_climo_dataset_using_climo_of_time_series_files(self): parameter = _create_parameter_object( @@ -826,7 +833,7 @@ def test_returns_climo_dataset_using_climo_of_time_series_files(self): name="ts", data=np.array([[1.0, 1.0], [1.0, 1.0]]), dims=["lat", "lon"] ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_raises_error_if_no_filepath_found_for_variable(self): parameter = _create_parameter_object( @@ -1059,7 +1066,7 @@ def test_returns_time_series_dataset_using_file(self): # is dropped when subsetting with the middle of the month (2000-01-15). expected = self.ds_ts.isel(time=slice(1, 4)) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_returns_time_series_dataset_using_sub_monthly_sets(self): parameter = _create_parameter_object( @@ -1078,7 +1085,7 @@ def test_returns_time_series_dataset_using_sub_monthly_sets(self): result = ds.get_time_series_dataset("ts") expected = self.ds_ts.copy() - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_returns_time_series_dataset_using_derived_var(self): # We will derive the "PRECT" variable using the "pr" variable. @@ -1143,10 +1150,9 @@ def test_returns_time_series_dataset_using_derived_var(self): expected["PRECT"] = expected["pr"] * 3600 * 24 expected["PRECT"].attrs["units"] = "mm/day" - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_returns_time_series_dataset_using_derived_var_directly_from_dataset(self): - # We will derive the "PRECT" variable using the "pr" variable. ds_precst = xr.Dataset( coords={ "lat": [-90, 90], @@ -1206,7 +1212,7 @@ def test_returns_time_series_dataset_using_derived_var_directly_from_dataset(sel result = ds.get_time_series_dataset("PRECST") expected = ds_precst.copy() - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_raises_error_if_no_datasets_found_to_derive_variable(self): # In this test, we don't create a dataset and write it out to `.nc`. @@ -1241,9 +1247,10 @@ def test_returns_time_series_dataset_with_centered_time_if_single_point(self): dtype=object, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_returns_time_series_dataset_using_file_with_ref_name_prepended(self): + ds_ts = self.ds_ts.copy() parameter = _create_parameter_object( "ref", "time_series", self.data_path, "2000", "2001" ) @@ -1251,16 +1258,16 @@ def test_returns_time_series_dataset_using_file_with_ref_name_prepended(self): ref_data_path = self.data_path / parameter.ref_name ref_data_path.mkdir() - self.ds_ts.to_netcdf(f"{ref_data_path}/ts_200001_200112.nc") + ds_ts.to_netcdf(f"{ref_data_path}/ts_200001_200112.nc") ds = Dataset(parameter, data_type="ref") result = ds.get_time_series_dataset("ts") # Since the data is not sub-monthly, the first time coord (2001-01-01) # is dropped when subsetting with the middle of the month (2000-01-15). - expected = self.ds_ts.isel(time=slice(1, 4)) + expected = ds_ts.isel(time=slice(1, 4)) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_raises_error_if_time_series_dataset_could_not_be_found(self): self.ds_ts.to_netcdf(self.ts_path) @@ -1385,7 +1392,7 @@ def test_returns_land_sea_mask_if_matching_vars_in_dataset(self): expected = ds_climo.copy() expected = expected.squeeze(dim="time").drop_vars("time") - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_returns_default_land_sea_mask_if_one_or_no_matching_vars_in_dataset( self, caplog @@ -1405,9 +1412,9 @@ def test_returns_default_land_sea_mask_if_one_or_no_matching_vars_in_dataset( result = ds._get_land_sea_mask("ANN") expected = xr.open_dataset(LAND_OCEAN_MASK_PATH) - expected = expected.squeeze(dim="time").drop_vars("time") + expected = expected.squeeze(dim="time").drop_vars(["time", "time_bnds"]) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) class TestGetNameAndYearsAttr: