Skip to content

Commit

Permalink
[PR]: CDAT Migration: Refactor aerosol_aeronet set (#788)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomvothecoder committed Dec 4, 2024
1 parent 10dd405 commit 8bf49f3
Show file tree
Hide file tree
Showing 14 changed files with 226 additions and 709 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from auxiliary_tools.cdat_regression_testing.base_run_script import run_set

SET_NAME = "aerosol_aeronet"
SET_DIR = "672-aerosol-aeronet"
CFG_PATH: str | None = None
MULTIPROCESSING = True

run_set(SET_NAME, SET_DIR, CFG_PATH, MULTIPROCESSING)
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
160 changes: 100 additions & 60 deletions e3sm_diags/driver/aerosol_aeronet_driver.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
from __future__ import annotations

import os
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING

import numpy as np
import pandas as pd
import xarray as xr
import xcdat as xc
from scipy import interpolate

import e3sm_diags
from e3sm_diags.driver import utils
from e3sm_diags.driver.utils.dataset_xr import Dataset
from e3sm_diags.logger import custom_logger
from e3sm_diags.plot.cartopy import aerosol_aeronet_plot
from e3sm_diags.metrics.metrics import spatial_avg
from e3sm_diags.plot import aerosol_aeronet_plot

if TYPE_CHECKING:
from cdms2.tvariable import TransientVariable

from e3sm_diags.parameter.core_parameter import CoreParameter


Expand All @@ -25,74 +27,110 @@


def run_diag(parameter: CoreParameter) -> CoreParameter:
"""Runs the aerosol aeronet diagnostic.
:param parameter: Parameters for the run
:type parameter: CoreParameter
:raises ValueError: Invalid run type
:return: Parameters for the run
:rtype: CoreParameter
"""Run the aerosol aeronet diagnostics.
Parameters
----------
parameter : CoreParameter
The parameter for the diagnostic.
Returns
-------
CoreParameter
The parameter for the diagnostic with the result (completed or failed).
Raises
------
ValueError
If the run type is not valid.
"""
variables = parameter.variables
run_type = parameter.run_type
seasons = parameter.seasons

for season in seasons:
test_data = utils.dataset.Dataset(parameter, test=True)
parameter.test_name_yrs = utils.general.get_name_and_yrs(
parameter, test_data, season
)
parameter.ref_name_yrs = "AERONET (2006-2015)"
test_ds = Dataset(parameter, data_type="test")

for var in variables:
logger.info("Variable: {}".format(var))
parameter.var_id = var
for var_key in variables:
logger.info("Variable: {}".format(var_key))
parameter.var_id = var_key

test = test_data.get_climo_variable(var, season)
test_site = interpolate_model_output_to_obs_sites(test, var)
for season in seasons:
ds_test = test_ds.get_climo_dataset(var_key, season)
da_test = ds_test[var_key]

if run_type == "model_vs_model":
ref_data = utils.dataset.Dataset(parameter, ref=True)
parameter.ref_name_yrs = utils.general.get_name_and_yrs(
parameter, ref_data, season
test_site_arr = interpolate_model_output_to_obs_sites(
ds_test[var_key], var_key
)
ref = ref_data.get_climo_variable(var, season)
ref_site = interpolate_model_output_to_obs_sites(ref, var)

elif run_type == "model_vs_obs":
ref_site = interpolate_model_output_to_obs_sites(None, var)
else:
raise ValueError("Invalid run_type={}".format(run_type))
parameter.test_name_yrs = test_ds.get_name_yrs_attr(season)
parameter.ref_name_yrs = "AERONET (2006-2015)"

parameter.output_file = (
f"{parameter.ref_name}-{parameter.var_id}-{season}-global"
)
aerosol_aeronet_plot.plot(test, test_site, ref_site, parameter)
if run_type == "model_vs_model":
ref_ds = Dataset(parameter, data_type="ref")

parameter.ref_name_yrs = utils.general.get_name_and_yrs(
parameter, ref_ds, season
)

ds_ref = ref_ds.get_climo_dataset(var_key, season)
ref_site_arr = interpolate_model_output_to_obs_sites(
ds_ref[var_key], var_key
)
elif run_type == "model_vs_obs":
ref_site_arr = interpolate_model_output_to_obs_sites(None, var_key)
else:
raise ValueError("Invalid run_type={}".format(run_type))

parameter.output_file = (
f"{parameter.ref_name}-{parameter.var_id}-{season}-global"
)

metrics_dict = {
"max": da_test.max().item(),
"min": da_test.min().item(),
"mean": spatial_avg(ds_test, var_key, axis=["X", "Y"]),
}
aerosol_aeronet_plot.plot(
parameter, da_test, test_site_arr, ref_site_arr, metrics_dict
)

return parameter


def interpolate_model_output_to_obs_sites(
var: Optional[TransientVariable], var_id: str
):
da_var: xr.DataArray | None, var_key: str
) -> np.ndarray:
"""Interpolate model outputs (on regular lat lon grids) to observational sites
:param var: Input model variable, var_id: name of the variable
:type var: TransientVariable or NoneType, var_id: str
:raises IOError: Invalid variable input
:return: interpolated values over all observational sites
:rtype: 1-D numpy.array
# TODO: Add test coverage for this function.
Parameters
----------
da_var : xr.DataArray | None
An optional input model variable dataarray.
var_key : str
The key of the variable.
Returns
-------
np.ndarray
The interpolated values over all observational sites.
Raises
------
IOError
If the variable key is invalid.
"""
logger.info(
"Interpolate model outputs (on regular lat lon grids) to observational sites"
)
if var_id == "AODABS":

if var_key == "AODABS":
aeronet_file = os.path.join(
e3sm_diags.INSTALL_PATH, "aerosol_aeronet/aaod550_AERONET_2006-2015.txt"
)
var_header = "aaod"
elif var_id == "AODVIS":
elif var_key == "AODVIS":
aeronet_file = os.path.join(
e3sm_diags.INSTALL_PATH, "aerosol_aeronet/aod550_AERONET_2006-2015.txt"
)
Expand All @@ -102,22 +140,24 @@ def interpolate_model_output_to_obs_sites(

data_obs = pd.read_csv(aeronet_file, dtype=object, sep=",")

lonloc = np.array(data_obs["lon"].astype(float))
latloc = np.array(data_obs["lat"].astype(float))
obsloc = np.array(data_obs[var_header].astype(float))
# sitename = np.array(data_obs["site"].astype(str))
nsite = len(obsloc)
lon_loc = np.array(data_obs["lon"].astype(float))
lat_loc = np.array(data_obs["lat"].astype(float))
obs_loc = np.array(data_obs[var_header].astype(float))

# express lonloc from 0 to 360
lonloc[lonloc < 0.0] = lonloc[lonloc < 0.0] + 360.0
num_sites = len(obs_loc)

if var is not None:
f_intp = interpolate.RectBivariateSpline(
var.getLatitude()[:], var.getLongitude()[:], var
)
var_intp = np.zeros(nsite)
for i in range(nsite):
var_intp[i] = f_intp(latloc[i], lonloc[i])
# Express lon_loc from 0 to 360.
lon_loc[lon_loc < 0.0] = lon_loc[lon_loc < 0.0] + 360.0

if da_var is not None:
lat = xc.get_dim_coords(da_var, axis="Y")
lon = xc.get_dim_coords(da_var, axis="X")
f_intp = interpolate.RectBivariateSpline(lat.values, lon.values, da_var.values)

var_intp = np.zeros(num_sites)
for i in range(num_sites):
var_intp[i] = f_intp(lat_loc[i], lon_loc[i])

return var_intp
return obsloc

return obs_loc
2 changes: 1 addition & 1 deletion e3sm_diags/driver/zonal_mean_2d_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
DEFAULT_PLEVS,
ZonalMean2dParameter,
)
from e3sm_diags.plot.cartopy.zonal_mean_2d_plot import plot as plot_func
from e3sm_diags.plot.zonal_mean_2d_plot import plot as plot_func

logger = custom_logger(__name__)

Expand Down
114 changes: 114 additions & 0 deletions e3sm_diags/plot/aerosol_aeronet_plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import matplotlib
import numpy as np
import xarray as xr

from e3sm_diags.driver.utils.type_annotations import MetricsDict
from e3sm_diags.logger import custom_logger
from e3sm_diags.parameter.core_parameter import CoreParameter
from e3sm_diags.plot.lat_lon_plot import _add_colormap
from e3sm_diags.plot.utils import _save_plot

matplotlib.use("Agg")
import matplotlib.pyplot as plt # isort:skip # noqa: E402

logger = custom_logger(__name__)

# Plot scatter plot
# Position and sizes of subplot axes in page coordinates (0 to 1)
# (left, bottom, width, height) in page coordinates
PANEL_CFG = [
(0.09, 0.40, 0.72, 0.30),
(0.19, 0.2, 0.62, 0.30),
]
# Border padding relative to subplot axes for saving individual panels
# (left, bottom, right, top) in page coordinates.
BORDER_PADDING = (-0.06, -0.03, 0.13, 0.03)


def plot(
parameter: CoreParameter,
da_test: xr.DataArray,
test_site_arr: np.ndarray,
ref_site_arr: np.ndarray,
metrics_dict: MetricsDict,
):
"""Plot the test variable's metrics generated for the aerosol_aeronet set.
Parameters
----------
parameter : CoreParameter
The CoreParameter object containing plot configurations.
da_test : xr.DataArray
The test data.
test_site : np.ndarray
The array containing values for the test site.
ref_site : np.ndarray
The array containing values for the ref site.
metrics_dict : MetricsDict
The metrics.
"""
fig = plt.figure(figsize=parameter.figsize, dpi=parameter.dpi)
fig.suptitle(parameter.var_id, x=0.5, y=0.97)

# Add the colormap subplot for test data.
min = metrics_dict["min"]
mean = metrics_dict["mean"]
max = metrics_dict["max"]

_add_colormap(
0,
da_test,
fig,
parameter,
parameter.test_colormap,
parameter.contour_levels,
title=(parameter.test_name_yrs, None, None), # type: ignore
metrics=(max, mean, min), # type: ignore
)

# Add the scatter plot.
ax = fig.add_axes(PANEL_CFG[1])
ax.set_title(f"{parameter.var_id} from AERONET sites")

# Define 1:1 line, and x, y axis limits.
if parameter.var_id == "AODVIS":
x1 = np.arange(0.01, 3.0, 0.1)
y1 = np.arange(0.01, 3.0, 0.1)
plt.xlim(0.03, 1)
plt.ylim(0.03, 1)
else:
x1 = np.arange(0.0001, 1.0, 0.01)
y1 = np.arange(0.0001, 1.0, 0.01)
plt.xlim(0.001, 0.3)
plt.ylim(0.001, 0.3)

plt.loglog(x1, y1, "-k", linewidth=0.5)
plt.loglog(x1, y1 * 0.5, "--k", linewidth=0.5)
plt.loglog(x1 * 0.5, y1, "--k", linewidth=0.5)

corr = np.corrcoef(ref_site_arr, test_site_arr)
xmean = np.mean(ref_site_arr)
ymean = np.mean(test_site_arr)
ax.text(
0.3,
0.9,
f"Mean (test): {ymean:.3f} \n Mean (ref): {xmean:.3f}\n Corr: {corr[0, 1]:.2f}",
horizontalalignment="right",
verticalalignment="top",
transform=ax.transAxes,
)

# Configure axis ticks.
plt.tick_params(axis="both", which="major")
plt.tick_params(axis="both", which="minor")

# Configure axis labels
plt.xlabel(f"ref: {parameter.ref_name_yrs}")
plt.ylabel(f"test: {parameter.test_name_yrs}")

plt.loglog(ref_site_arr, test_site_arr, "kx", markersize=3.0, mfc="none")

# Configure legend.
plt.legend(frameon=False, prop={"size": 5})

_save_plot(fig, parameter, PANEL_CFG, BORDER_PADDING)
Loading

0 comments on commit 8bf49f3

Please sign in to comment.