Skip to content

Commit

Permalink
Add initial refactored aerosol_aeronet_plot.py
Browse files Browse the repository at this point in the history
  • Loading branch information
tomvothecoder committed Feb 16, 2024
1 parent 97da49f commit c03e727
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 9 deletions.
15 changes: 13 additions & 2 deletions e3sm_diags/driver/aerosol_aeronet_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
from e3sm_diags.driver import utils
from e3sm_diags.driver.utils.dataset_xr import Dataset
from e3sm_diags.logger import custom_logger
from e3sm_diags.plot.cartopy import aerosol_aeronet_plot
from e3sm_diags.metrics.metrics import spatial_avg
from e3sm_diags.plot import aerosol_aeronet_plot

if TYPE_CHECKING:
from e3sm_diags.parameter.core_parameter import CoreParameter
Expand Down Expand Up @@ -55,6 +56,8 @@ def run_diag(parameter: CoreParameter) -> CoreParameter:

for season in seasons:
ds_test = test_ds.get_climo_dataset(var_key, season)
da_test = ds_test[var_key]

test_site_arr = interpolate_model_output_to_obs_sites(
ds_test[var_key], var_key
)
Expand All @@ -81,7 +84,15 @@ def run_diag(parameter: CoreParameter) -> CoreParameter:
parameter.output_file = (
f"{parameter.ref_name}-{parameter.var_id}-{season}-global"
)
aerosol_aeronet_plot.plot(ds_test, test_site_arr, ref_site_arr, parameter)

metrics_dict = {
"max": da_test.max().item(),
"min": da_test.min().item(),
"mean": spatial_avg(ds_test, var_key, axis=["X", "Y"]),
}
aerosol_aeronet_plot.plot(
parameter, da_test, test_site_arr, ref_site_arr, metrics_dict
)

return parameter

Expand Down
117 changes: 117 additions & 0 deletions e3sm_diags/plot/aerosol_aeronet_plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import matplotlib
import numpy as np
import xarray as xr

from e3sm_diags.driver.utils.type_annotations import MetricsDict
from e3sm_diags.logger import custom_logger
from e3sm_diags.parameter.core_parameter import CoreParameter
from e3sm_diags.plot.lat_lon_plot import _add_colormap
from e3sm_diags.plot.utils import _save_plot

matplotlib.use("Agg")
import matplotlib.pyplot as plt # isort:skip # noqa: E402

logger = custom_logger(__name__)

MAIN_FONTSIZE = {"fontsize": 11.5}
SECONDARY_FONTSIZE = {"fontsize": 9.5}

# Plot scatter plot
# Position and sizes of subplot axes in page coordinates (0 to 1)
# (left, bottom, width, height) in page coordinates
PANEL_CFG = [
(0.09, 0.40, 0.72, 0.30),
(0.19, 0.2, 0.62, 0.30),
]
# Border padding relative to subplot axes for saving individual panels
# (left, bottom, right, top) in page coordinatesz
BORDER_PADDING = (-0.06, -0.03, 0.13, 0.03)


def plot(
parameter: CoreParameter,
da_test: xr.DataArray,
test_site_arr: np.ndarray,
ref_site_arr: np.ndarray,
metrics_dict: MetricsDict,
):
"""Plot the test variable's metrics generated for the aerosol_aeronet set.
Parameters
----------
parameter : CoreParameter
The CoreParameter object containing plot configurations.
da_test : xr.DataArray
The test data.
test_site : np.ndarray
The array containing values for the test site.
ref_site : np.ndarray
The array containing values for the ref site.
metrics_dict : MetricsDict
The metrics.
"""
fig = plt.figure(figsize=parameter.figsize, dpi=parameter.dpi)
fig.suptitle(parameter.var_id, x=0.5, y=0.97)

# Add the first subplot for test data.
min = metrics_dict["min"]
mean = metrics_dict["mean"]
max = metrics_dict["max"]

_add_colormap(
0,
da_test,
fig,
parameter,
parameter.test_colormap,
parameter.contour_levels,
title=(parameter.test_name_yrs, None, None), # type: ignore
metrics=(max, mean, min), # type: ignore
panel_configs=PANEL_CFG,
)

ax = fig.add_axes(PANEL_CFG[1])
ax.set_title(f"{parameter.var_id} from AERONET sites")

# Define 1:1 line, and x, y axis limits.
if parameter.var_id == "AODVIS":
x1 = np.arange(0.01, 3.0, 0.1)
y1 = np.arange(0.01, 3.0, 0.1)
plt.xlim(0.03, 1)
plt.ylim(0.03, 1)
else:
x1 = np.arange(0.0001, 1.0, 0.01)
y1 = np.arange(0.0001, 1.0, 0.01)
plt.xlim(0.001, 0.3)
plt.ylim(0.001, 0.3)

plt.loglog(x1, y1, "-k", linewidth=0.5)
plt.loglog(x1, y1 * 0.5, "--k", linewidth=0.5)
plt.loglog(x1 * 0.5, y1, "--k", linewidth=0.5)

corr = np.corrcoef(ref_site_arr, test_site_arr)
xmean = np.mean(ref_site_arr)
ymean = np.mean(test_site_arr)
ax.text(
0.3,
0.9,
f"Mean (test): {ymean:.3f} \n Mean (ref): {xmean:.3f}\n Corr: {corr[0, 1]:.2f}",
horizontalalignment="right",
verticalalignment="top",
transform=ax.transAxes,
)

# Configure axis ticks.
plt.tick_params(axis="both", which="major")
plt.tick_params(axis="both", which="minor")

# Configure axis labels
plt.xlabel(f"ref: {parameter.ref_name_yrs}")
plt.ylabel(f"test: {parameter.test_name_yrs}")

plt.loglog(ref_site_arr, test_site_arr, "kx", markersize=3.0, mfc="none")

# Configure legend.
plt.legend(frameon=False, prop={"size": 5})

_save_plot(fig, parameter, PANEL_CFG, BORDER_PADDING)
12 changes: 7 additions & 5 deletions e3sm_diags/plot/lat_lon_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from e3sm_diags.parameter.core_parameter import CoreParameter
from e3sm_diags.plot.utils import (
DEFAULT_PANEL_CFG,
PanelConfig,
_add_colorbar,
_add_contour_plot,
_add_grid_res_info,
Expand Down Expand Up @@ -131,6 +132,7 @@ def _add_colormap(
contour_levels: List[float],
title: Tuple[str | None, str, str],
metrics: Tuple[float, ...],
panel_configs: PanelConfig = DEFAULT_PANEL_CFG,
):
"""Adds a colormap containing the variable data and metrics to the figure.
Expand Down Expand Up @@ -198,7 +200,7 @@ def _add_colormap(

# Get the figure Axes object using the projection above.
# --------------------------------------------------------------------------
ax = fig.add_axes(DEFAULT_PANEL_CFG[subplot_num], projection=projection)
ax = fig.add_axes(panel_configs[subplot_num], projection=projection)
ax.set_extent([lon_west, lon_east, lat_south, lat_north], crs=projection)
contour_plot = _add_contour_plot(
ax, parameter, var, lon, lat, color_map, ccrs.PlateCarree(), norm, c_levels
Expand Down Expand Up @@ -226,13 +228,13 @@ def _add_colormap(
_configure_x_and_y_axes(
ax, x_ticks, y_ticks, ccrs.PlateCarree(), parameter.current_set
)
_add_colorbar(fig, subplot_num, DEFAULT_PANEL_CFG, contour_plot, c_levels)
_add_colorbar(fig, subplot_num, panel_configs, contour_plot, c_levels)

# Add metrics text to the figure.
# --------------------------------------------------------------------------
_add_min_mean_max_text(fig, subplot_num, DEFAULT_PANEL_CFG, metrics)
_add_min_mean_max_text(fig, subplot_num, panel_configs, metrics)

if len(metrics) == 5:
_add_rmse_corr_text(fig, subplot_num, DEFAULT_PANEL_CFG, metrics)
_add_rmse_corr_text(fig, subplot_num, panel_configs, metrics)

_add_grid_res_info(fig, subplot_num, region_key, lat, lon, DEFAULT_PANEL_CFG)
_add_grid_res_info(fig, subplot_num, region_key, lat, lon, panel_configs)
5 changes: 3 additions & 2 deletions e3sm_diags/plot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@

# Border padding relative to subplot axes for saving individual panels
# (left, bottom, right, top) in page coordinates
DEFAULT_BORDER_PADDING = (-0.06, -0.03, 0.13, 0.03)
BorderPadding = Tuple[float, float, float, float]
DEFAULT_BORDER_PADDING: BorderPadding = (-0.06, -0.03, 0.13, 0.03)

# Sets that use the lat_lon formatter to configure the X and Y axes of the plot.
SETS_USING_LAT_LON_FORMATTER = [
Expand All @@ -56,7 +57,7 @@ def _save_plot(
fig: plt.Figure,
parameter: CoreParameter,
panel_configs: PanelConfig = DEFAULT_PANEL_CFG,
border_padding: Tuple[float, float, float, float] = DEFAULT_BORDER_PADDING,
border_padding: BorderPadding = DEFAULT_BORDER_PADDING,
):
"""Save the plot using the figure object and parameter configs.
Expand Down

0 comments on commit c03e727

Please sign in to comment.