From 7eb0092ccca12dbc1ac36a774fab7b3e2fa1b919 Mon Sep 17 00:00:00 2001 From: Tom Vo Date: Tue, 30 Jul 2024 10:23:37 -0700 Subject: [PATCH] CDAT Migration Phase 2: Refactor `qbo` set (#826) --- .../664-qbo/debug/qa.py | 108 ++ .../664-qbo/regression_test.ipynb | 1696 +++++++++++++++++ .../664-qbo/regression_test_png.ipynb | 225 +++ .../cdat_regression_testing/664-qbo/run.cfg | 6 + .../664-qbo/run_script.py | 12 + e3sm_diags/driver/qbo_driver.py | 474 ++--- e3sm_diags/driver/utils/dataset_xr.py | 60 +- e3sm_diags/parameter/core_parameter.py | 8 +- e3sm_diags/plot/qbo_plot.py | 237 +++ .../driver/utils/test_dataset_xr.py | 3 + 10 files changed, 2581 insertions(+), 248 deletions(-) create mode 100644 auxiliary_tools/cdat_regression_testing/664-qbo/debug/qa.py create mode 100644 auxiliary_tools/cdat_regression_testing/664-qbo/regression_test.ipynb create mode 100644 auxiliary_tools/cdat_regression_testing/664-qbo/regression_test_png.ipynb create mode 100644 auxiliary_tools/cdat_regression_testing/664-qbo/run.cfg create mode 100644 auxiliary_tools/cdat_regression_testing/664-qbo/run_script.py create mode 100644 e3sm_diags/plot/qbo_plot.py diff --git a/auxiliary_tools/cdat_regression_testing/664-qbo/debug/qa.py b/auxiliary_tools/cdat_regression_testing/664-qbo/debug/qa.py new file mode 100644 index 000000000..3dd42e3a3 --- /dev/null +++ b/auxiliary_tools/cdat_regression_testing/664-qbo/debug/qa.py @@ -0,0 +1,108 @@ +""" +Issue - The slice is excluding points for the ref file. +""" +# %% +import xarray as xr +import xcdat as xc + +from e3sm_diags.derivations.default_regions_xr import REGION_SPECS + +REGION = "5S5N" +region_slice = (-5.0, 5.0) + + +test_file = "/global/cfs/cdirs/e3sm/e3sm_diags/postprocessed_e3sm_v2_data_for_e3sm_diags/20210528.v2rc3e.piControl.ne30pg2_EC30to60E2r2.chrysalis/time-series/rgr/U_005101_006012.nc" +ref_file = "/global/cfs/cdirs/e3sm/diagnostics/observations/Atm/time-series/ERA-Interim/ua_197901_201612.nc" + + +def _subset_on_region(ds: xr.Dataset) -> xr.Dataset: + """Subset the dataset by the region 5S5N (latitude). + + This function takes into account the CDAT subset flag, "ccb", which can + add new latitude coordinate points to the beginning and end. + + Parameters + ---------- + ds : xr.Dataset + The dataset. + + Returns + ------- + xr.Dataset + The dataset subsetted by the region. + """ + lat_slice = REGION_SPECS[REGION]["lat"] # type: ignore + + ds_new = ds.copy() + dim_key = xc.get_dim_keys(ds, "Y") + + # 1. Subset on the region_slice + # slice = -5.0, 5.0 + ds_new = ds_new.sel({dim_key: slice(*lat_slice)}) + + # 2. Add delta to first and last value + dim_bounds = ds_new.bounds.get_bounds(axis="Y") + # delta = 1.0 / 2 = 0.5 + delta = (dim_bounds[0][1].item() - dim_bounds[0][0].item()) / 2 + delta_slice = (lat_slice[0] - delta, lat_slice[1] + delta) + + # 3. Check if latitude slice value exists in original latitude. + # If it exists already, then don't add the coordinate point. + # If it does not exist, add the coordinate point. + # delta = 0.5 + # delta slice = -5.5, 5.5 + ds_list = [ds_new] + + try: + ds.sel({dim_key: delta_slice[0]}) + except KeyError: + ds_first_pt = ds_new.isel({dim_key: 0}) + ds_first_pt[dim_key] = ds_first_pt[dim_key] - delta + + ds_list.append(ds_first_pt) + + try: + ds.sel({dim_key: delta_slice[-1]}) + except KeyError: + ds_last_pt = ds_new.isel({dim_key: -1}) + ds_last_pt[dim_key] = ds_last_pt[dim_key] + delta + + ds_list.append(ds_last_pt) + + ds_new = xr.concat(ds_list, dim=dim_key, data_vars="minimal", coords="minimal") + ds_new.drop_vars(dim_bounds) + + return ds_new + + +ds_test = xc.open_dataset(test_file) +ds_ref = xc.open_dataset(ref_file) + +# %% +""" +"ccb" flag is adding the bounds delta / 2 to the end and beginning coordinates. + +CDAT Expected: 10 + array([-4.5, -3.5, -2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5, 4.5]) +xCDAT Result: 10 + array([-4.5, -3.5, -2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5, 4.5]) +""" +ds_test_reg = _subset_on_region(ds_test) +ds_test_reg.lat + +# %% +""" +"ccb" flag is adding the bounds delta / 2 to the end and beginning coordinates. + +CDAT Expected: 15 + array([-4.9375, -4.5 , -3.75 , -3. , -2.25 , -1.5 , -0.75 , + 0. , 0.75 , 1.5 , 2.25 , 3. , 3.75 , 4.5 , + 4.9375]) +xCDAT Result: 15 + array([-4.875, -4.5 , -3.75 , -3. , -2.25 , -1.5 , -0.75 , 0. , 0.75 , + 1.5 , 2.25 , 3. , 3.75 , 4.5 , 4.875]) +""" +ds_ref_reg = _subset_on_region(ds_ref) +ds_ref_reg.lat + +# %% diff --git a/auxiliary_tools/cdat_regression_testing/664-qbo/regression_test.ipynb b/auxiliary_tools/cdat_regression_testing/664-qbo/regression_test.ipynb new file mode 100644 index 000000000..7340b1394 --- /dev/null +++ b/auxiliary_tools/cdat_regression_testing/664-qbo/regression_test.ipynb @@ -0,0 +1,1696 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# CDAT Migration Regression Testing Notebook (`.nc` files)\n", + "\n", + "This notebook is used to perform regression testing between the development and\n", + "production versions of a diagnostic set.\n", + "\n", + "## How it works\n", + "\n", + "It compares the relative differences (%) between ref and test variables between\n", + "the dev and `main` branches.\n", + "\n", + "## How to use\n", + "\n", + "PREREQUISITE: The diagnostic set's netCDF stored in `.json` files in two directories\n", + "(dev and `main` branches).\n", + "\n", + "1. Make a copy of this notebook under `auxiliary_tools/cdat_regression_testing/`.\n", + "2. Run `mamba create -n cdat_regression_test -y -c conda-forge \"python<3.12\" xarray netcdf4 dask pandas matplotlib-base ipykernel`\n", + "3. Run `mamba activate cdat_regression_test`\n", + "4. Update `SET_DIR` and `SET_NAME` in the copy of your notebook.\n", + "5. Run all cells IN ORDER.\n", + "6. Review results for any outstanding differences (>=1e-5 relative tolerance).\n", + " - Debug these differences (e.g., bug in metrics functions, incorrect variable references, etc.)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup Code\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import glob\n", + "\n", + "import numpy as np\n", + "import xarray as xr\n", + "from e3sm_diags.derivations.derivations import DERIVED_VARIABLES\n", + "\n", + "SET_NAME = \"qbo\"\n", + "SET_DIR = \"664-qbo\"\n", + "\n", + "DEV_PATH = f\"/global/cfs/cdirs/e3sm/www/cdat-migration-fy24/{SET_DIR}/{SET_NAME}/**\"\n", + "DEV_GLOB = sorted(glob.glob(DEV_PATH + \"/*.nc\"))\n", + "DEV_NUM_FILES = len(DEV_GLOB)\n", + "\n", + "MAIN_PATH = f\"/global/cfs/cdirs/e3sm/www/cdat-migration-fy24/main/{SET_NAME}/**\"\n", + "MAIN_GLOB = sorted(glob.glob(MAIN_PATH + \"/*.nc\"))\n", + "MAIN_NUM_FILES = len(MAIN_GLOB)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "def _check_if_files_found():\n", + " if DEV_NUM_FILES == 0 or MAIN_NUM_FILES == 0:\n", + " raise IOError(\n", + " \"No files found at DEV_PATH and/or MAIN_PATH. \"\n", + " f\"Please check {DEV_PATH} and {MAIN_PATH}.\"\n", + " )\n", + "\n", + "\n", + "def _check_if_matching_filecount():\n", + " if DEV_NUM_FILES != MAIN_NUM_FILES:\n", + " raise IOError(\n", + " \"Number of files do not match at DEV_PATH and MAIN_PATH \"\n", + " f\"({DEV_NUM_FILES} vs. {MAIN_NUM_FILES}).\"\n", + " )\n", + "\n", + " print(f\"Matching file count ({DEV_NUM_FILES} and {MAIN_NUM_FILES}).\")\n", + "\n", + "\n", + "def _check_if_missing_files():\n", + " missing_count = 0\n", + "\n", + " for fp_main in MAIN_GLOB:\n", + " fp_dev = fp_main.replace(SET_DIR, \"main\")\n", + "\n", + " if fp_dev not in MAIN_GLOB:\n", + " print(f\"No production file found to compare with {fp_dev}!\")\n", + " missing_count += 1\n", + "\n", + " for fp_dev in DEV_GLOB:\n", + " fp_main = fp_main.replace(\"main\", SET_DIR)\n", + "\n", + " if fp_main not in DEV_GLOB:\n", + " print(f\"No development file found to compare with {fp_main}!\")\n", + " missing_count += 1\n", + "\n", + " print(f\"Number of files missing: {missing_count}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "def _get_relative_diffs():\n", + " # We are mainly focusing on relative tolerance here (in percentage terms).\n", + " atol = 0\n", + " rtol = 1e-5\n", + "\n", + " for fp_main in MAIN_GLOB:\n", + " if \"test.nc\" in fp_main or \"ref.nc\" in fp_main:\n", + " fp_dev = fp_main.replace(\"main\", SET_DIR)\n", + "\n", + " print(\"Comparing:\")\n", + " print(f\" * {fp_dev}\")\n", + " print(f\" * {fp_main}\")\n", + "\n", + " ds1 = xr.open_dataset(fp_dev)\n", + " ds2 = xr.open_dataset(fp_main)\n", + "\n", + " var_keys = [\"U\"]\n", + " for key in var_keys:\n", + " print(f\" * var_key: {key}\")\n", + "\n", + " dev_data = ds1[key].values\n", + " main_data = ds2[key].values\n", + "\n", + " if dev_data is None or main_data is None:\n", + " print(\" * Could not find variable key in the dataset(s)\")\n", + " continue\n", + "\n", + " try:\n", + " np.testing.assert_allclose(\n", + " dev_data,\n", + " main_data,\n", + " atol=atol,\n", + " rtol=rtol,\n", + " )\n", + " except (KeyError, AssertionError) as e:\n", + " print(f\" {e}\")\n", + " else:\n", + " print(f\" * All close and within relative tolerance ({rtol})\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Check for matching and equal number of files\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "_check_if_files_found()" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['/global/cfs/cdirs/e3sm/www/cdat-migration-fy24/664-qbo/qbo/QBO-ERA-Interim/_qbo_ref_unify.nc',\n", + " '/global/cfs/cdirs/e3sm/www/cdat-migration-fy24/664-qbo/qbo/QBO-ERA-Interim/_qbo_test_unify.nc',\n", + " '/global/cfs/cdirs/e3sm/www/cdat-migration-fy24/664-qbo/qbo/QBO-ERA-Interim/qbo_diags_qbo_ref.nc',\n", + " '/global/cfs/cdirs/e3sm/www/cdat-migration-fy24/664-qbo/qbo/QBO-ERA-Interim/qbo_diags_qbo_test.nc']" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "DEV_GLOB" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['/global/cfs/cdirs/e3sm/www/cdat-migration-fy24/main/qbo/QBO-ERA-Interim/qbo_diags_level_ref.nc',\n", + " '/global/cfs/cdirs/e3sm/www/cdat-migration-fy24/main/qbo/QBO-ERA-Interim/qbo_diags_level_test.nc',\n", + " '/global/cfs/cdirs/e3sm/www/cdat-migration-fy24/main/qbo/QBO-ERA-Interim/qbo_diags_qbo_ref.nc',\n", + " '/global/cfs/cdirs/e3sm/www/cdat-migration-fy24/main/qbo/QBO-ERA-Interim/qbo_diags_qbo_test.nc']" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "MAIN_GLOB" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Number of files missing: 0\n" + ] + } + ], + "source": [ + "_check_if_missing_files()" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Matching file count (4 and 4).\n" + ] + } + ], + "source": [ + "_check_if_matching_filecount()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Let's ignore `qbo_diags_level_ref.nc` and `qbo_diags_level_test.nc`.\n", + "\n", + "- Those files are just the Z dimension of the variable found in the `qbo_diags_qbo_ref.nc` and `qbo_diags_qbo_test.nc`.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "MAIN_GLOB = [filename for filename in MAIN_GLOB if \"_level\" not in filename]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2 Compare the netCDF files between branches\n", + "\n", + "- Compare \"ref\" and \"test\" files\n", + "- \"diff\" files are ignored because getting relative diffs for these does not make sense (relative diff will be above tolerance)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Comparing:\n", + " * /global/cfs/cdirs/e3sm/www/cdat-migration-fy24/664-qbo/qbo/QBO-ERA-Interim/qbo_diags_qbo_ref.nc\n", + " * /global/cfs/cdirs/e3sm/www/cdat-migration-fy24/main/qbo/QBO-ERA-Interim/qbo_diags_qbo_ref.nc\n", + " * var_key: U\n", + " \n", + "Not equal to tolerance rtol=1e-05, atol=0\n", + "\n", + "Mismatched elements: 4440 / 4440 (100%)\n", + "Max absolute difference: 64.15747321\n", + "Max relative difference: 1767.25842536\n", + " x: array([[-16.930712, -42.569729, -25.370665, ..., -2.711329, -2.530502,\n", + " -2.283684],\n", + " [ -2.24533 , -41.558418, -27.585657, ..., -2.62069 , -2.403724,...\n", + " y: array([[ -2.285837, -2.53099 , -2.710924, ..., -25.36748 , -42.5402 ,\n", + " -16.94262 ],\n", + " [ -2.126941, -2.409103, -2.624998, ..., -27.588021, -41.54002 ,...\n", + "Comparing:\n", + " * /global/cfs/cdirs/e3sm/www/cdat-migration-fy24/664-qbo/qbo/QBO-ERA-Interim/qbo_diags_qbo_test.nc\n", + " * /global/cfs/cdirs/e3sm/www/cdat-migration-fy24/main/qbo/QBO-ERA-Interim/qbo_diags_qbo_test.nc\n", + " * var_key: U\n", + " * All close and within relative tolerance (1e-05)\n" + ] + } + ], + "source": [ + "_get_relative_diffs()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Results\n", + "\n", + "- Reference file diffs are massive because the CDAT codebase does not correctly sort the data by the Z axis (`plev`). I opened an issue to address this on `main` here: https://github.com/E3SM-Project/e3sm_diags/issues/825\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Validation: Sorting the CDAT produced reference file by the Z axis in ascending fixes the issue. We can move forward with the changes in this PR.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "import xcdat as xc\n", + "import xarray as xr\n", + "\n", + "ds_xc = xc.open_dataset(\n", + " \"/global/cfs/cdirs/e3sm/www/cdat-migration-fy24/664-qbo/qbo/QBO-ERA-Interim/qbo_diags_qbo_ref.nc\"\n", + ")\n", + "ds_cdat = xc.open_dataset(\n", + " \"/global/cfs/cdirs/e3sm/www/cdat-migration-fy24/main/qbo/QBO-ERA-Interim/qbo_diags_qbo_ref.nc\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.DataArray 'plev' (plev: 37)> Size: 296B\n",
+       "array([   1.,    2.,    3.,    5.,    7.,   10.,   20.,   30.,   50.,   70.,\n",
+       "        100.,  125.,  150.,  175.,  200.,  225.,  250.,  300.,  350.,  400.,\n",
+       "        450.,  500.,  550.,  600.,  650.,  700.,  750.,  775.,  800.,  825.,\n",
+       "        850.,  875.,  900.,  925.,  950.,  975., 1000.])\n",
+       "Coordinates:\n",
+       "  * plev     (plev) float64 296B 1.0 2.0 3.0 5.0 7.0 ... 925.0 950.0 975.0 1e+03
" + ], + "text/plain": [ + " Size: 296B\n", + "array([ 1., 2., 3., 5., 7., 10., 20., 30., 50., 70.,\n", + " 100., 125., 150., 175., 200., 225., 250., 300., 350., 400.,\n", + " 450., 500., 550., 600., 650., 700., 750., 775., 800., 825.,\n", + " 850., 875., 900., 925., 950., 975., 1000.])\n", + "Coordinates:\n", + " * plev (plev) float64 296B 1.0 2.0 3.0 5.0 7.0 ... 925.0 950.0 975.0 1e+03" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ds_xc[\"plev\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.DataArray 'plev' (plev: 37)> Size: 296B\n",
+       "array([1000.,  975.,  950.,  925.,  900.,  875.,  850.,  825.,  800.,  775.,\n",
+       "        750.,  700.,  650.,  600.,  550.,  500.,  450.,  400.,  350.,  300.,\n",
+       "        250.,  225.,  200.,  175.,  150.,  125.,  100.,   70.,   50.,   30.,\n",
+       "         20.,   10.,    7.,    5.,    3.,    2.,    1.])\n",
+       "Coordinates:\n",
+       "  * plev     (plev) float64 296B 1e+03 975.0 950.0 925.0 ... 5.0 3.0 2.0 1.0\n",
+       "Attributes:\n",
+       "    axis:           Z\n",
+       "    units:          hPa\n",
+       "    standard_name:  air_pressure\n",
+       "    long_name:      pressure\n",
+       "    positive:       down\n",
+       "    realtopology:   linear
" + ], + "text/plain": [ + " Size: 296B\n", + "array([1000., 975., 950., 925., 900., 875., 850., 825., 800., 775.,\n", + " 750., 700., 650., 600., 550., 500., 450., 400., 350., 300.,\n", + " 250., 225., 200., 175., 150., 125., 100., 70., 50., 30.,\n", + " 20., 10., 7., 5., 3., 2., 1.])\n", + "Coordinates:\n", + " * plev (plev) float64 296B 1e+03 975.0 950.0 925.0 ... 5.0 3.0 2.0 1.0\n", + "Attributes:\n", + " axis: Z\n", + " units: hPa\n", + " standard_name: air_pressure\n", + " long_name: pressure\n", + " positive: down\n", + " realtopology: linear" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ds_cdat[\"plev\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "ds_cdat = ds_cdat.sortby(\"plev\", ascending=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.DataArray 'plev' (plev: 37)> Size: 296B\n",
+       "array([   1.,    2.,    3.,    5.,    7.,   10.,   20.,   30.,   50.,   70.,\n",
+       "        100.,  125.,  150.,  175.,  200.,  225.,  250.,  300.,  350.,  400.,\n",
+       "        450.,  500.,  550.,  600.,  650.,  700.,  750.,  775.,  800.,  825.,\n",
+       "        850.,  875.,  900.,  925.,  950.,  975., 1000.])\n",
+       "Coordinates:\n",
+       "  * plev     (plev) float64 296B 1.0 2.0 3.0 5.0 7.0 ... 925.0 950.0 975.0 1e+03
" + ], + "text/plain": [ + " Size: 296B\n", + "array([ 1., 2., 3., 5., 7., 10., 20., 30., 50., 70.,\n", + " 100., 125., 150., 175., 200., 225., 250., 300., 350., 400.,\n", + " 450., 500., 550., 600., 650., 700., 750., 775., 800., 825.,\n", + " 850., 875., 900., 925., 950., 975., 1000.])\n", + "Coordinates:\n", + " * plev (plev) float64 296B 1.0 2.0 3.0 5.0 7.0 ... 925.0 950.0 975.0 1e+03" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ds_xc.plev" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "ename": "AssertionError", + "evalue": "\nNot equal to tolerance rtol=1e-07, atol=0\n\nMismatched elements: 4440 / 4440 (100%)\nMax absolute difference: 0.18704352\nMax relative difference: 7.69507964\n x: array([[-16.930712, -42.569729, -25.370665, ..., -2.711329, -2.530502,\n -2.283684],\n [ -2.24533 , -41.558418, -27.585657, ..., -2.62069 , -2.403724,...\n y: array([[-16.94262 , -42.5402 , -25.36748 , ..., -2.710924, -2.53099 ,\n -2.285837],\n [ -2.284392, -41.54002 , -27.588021, ..., -2.624998, -2.409103,...", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[16], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mnumpy\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mnp\u001b[39;00m\n\u001b[0;32m----> 3\u001b[0m \u001b[43mnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtesting\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43massert_allclose\u001b[49m\u001b[43m(\u001b[49m\u001b[43mds_xc\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mU\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mds_cdat\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mU\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n", + " \u001b[0;31m[... skipping hidden 1 frame]\u001b[0m\n", + "File \u001b[0;32m~/mambaforge/envs/e3sm_diags_dev_669/lib/python3.10/contextlib.py:79\u001b[0m, in \u001b[0;36mContextDecorator.__call__..inner\u001b[0;34m(*args, **kwds)\u001b[0m\n\u001b[1;32m 76\u001b[0m \u001b[38;5;129m@wraps\u001b[39m(func)\n\u001b[1;32m 77\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21minner\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwds):\n\u001b[1;32m 78\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_recreate_cm():\n\u001b[0;32m---> 79\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwds\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/mambaforge/envs/e3sm_diags_dev_669/lib/python3.10/site-packages/numpy/testing/_private/utils.py:797\u001b[0m, in \u001b[0;36massert_array_compare\u001b[0;34m(comparison, x, y, err_msg, verbose, header, precision, equal_nan, equal_inf, strict)\u001b[0m\n\u001b[1;32m 793\u001b[0m err_msg \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m'\u001b[39m \u001b[38;5;241m+\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m.\u001b[39mjoin(remarks)\n\u001b[1;32m 794\u001b[0m msg \u001b[38;5;241m=\u001b[39m build_err_msg([ox, oy], err_msg,\n\u001b[1;32m 795\u001b[0m verbose\u001b[38;5;241m=\u001b[39mverbose, header\u001b[38;5;241m=\u001b[39mheader,\n\u001b[1;32m 796\u001b[0m names\u001b[38;5;241m=\u001b[39m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mx\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124my\u001b[39m\u001b[38;5;124m'\u001b[39m), precision\u001b[38;5;241m=\u001b[39mprecision)\n\u001b[0;32m--> 797\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mAssertionError\u001b[39;00m(msg)\n\u001b[1;32m 798\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m:\n\u001b[1;32m 799\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mtraceback\u001b[39;00m\n", + "\u001b[0;31mAssertionError\u001b[0m: \nNot equal to tolerance rtol=1e-07, atol=0\n\nMismatched elements: 4440 / 4440 (100%)\nMax absolute difference: 0.18704352\nMax relative difference: 7.69507964\n x: array([[-16.930712, -42.569729, -25.370665, ..., -2.711329, -2.530502,\n -2.283684],\n [ -2.24533 , -41.558418, -27.585657, ..., -2.62069 , -2.403724,...\n y: array([[-16.94262 , -42.5402 , -25.36748 , ..., -2.710924, -2.53099 ,\n -2.285837],\n [ -2.284392, -41.54002 , -27.588021, ..., -2.624998, -2.409103,..." + ] + } + ], + "source": [ + "import numpy as np\n", + "\n", + "np.testing.assert_allclose(ds_xc[\"U\"], ds_cdat[\"U\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Compare Maxes and Mins -- Really close\n" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "61.54721945135814 61.36017592984254\n", + "-66.54760399615296 -66.52449748057968\n" + ] + } + ], + "source": [ + "print(ds_xc[\"U\"].max().item(), ds_cdat[\"U\"].max().item())\n", + "print(ds_xc[\"U\"].min().item(), ds_cdat[\"U\"].min().item())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Compare Sum and Mean -- Really close\n" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "-3.739846878096383 -3.745529874323115\n", + "-16604.92013874794 -16630.15264199463\n" + ] + } + ], + "source": [ + "print(ds_xc[\"U\"].mean().item(), ds_cdat[\"U\"].mean().item())\n", + "print(ds_xc[\"U\"].sum().item(), ds_cdat[\"U\"].sum().item())" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "cdat_regression_test", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/auxiliary_tools/cdat_regression_testing/664-qbo/regression_test_png.ipynb b/auxiliary_tools/cdat_regression_testing/664-qbo/regression_test_png.ipynb new file mode 100644 index 000000000..8f2c4085d --- /dev/null +++ b/auxiliary_tools/cdat_regression_testing/664-qbo/regression_test_png.ipynb @@ -0,0 +1,225 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# CDAT Migration Regression Testing Notebook (`.png` files)\n", + "\n", + "This notebook is used to perform regression testing between the development and\n", + "production versions of a diagnostic set.\n", + "\n", + "## How to use\n", + "\n", + "PREREQUISITE: The diagnostic set's netCDF stored in `.json` files in two directories\n", + "(dev and `main` branches).\n", + "\n", + "1. Make a copy of this notebook under `auxiliary_tools/cdat_regression_testing/`.\n", + "2. Run `mamba create -n cdat_regression_test -y -c conda-forge \"python<3.12\" xarray netcdf4 dask pandas matplotlib-base ipykernel`\n", + "3. Run `mamba activate cdat_regression_test`\n", + "4. Update `SET_DIR` and `SET_NAME` in the copy of your notebook.\n", + "5. Run all cells IN ORDER.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup Code\n" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "import glob\n", + "\n", + "from auxiliary_tools.cdat_regression_testing.utils import get_image_diffs\n", + "\n", + "SET_NAME = \"qbo\"\n", + "SET_DIR = \"664-qbo\"\n", + "\n", + "DEV_PATH = f\"/global/cfs/cdirs/e3sm/www/cdat-migration-fy24/{SET_DIR}/{SET_NAME}/**\"\n", + "DEV_GLOB = sorted(glob.glob(DEV_PATH + \"/*.png\"))\n", + "DEV_NUM_FILES = len(DEV_GLOB)\n", + "\n", + "MAIN_PATH = f\"/global/cfs/cdirs/e3sm/www/cdat-migration-fy24/main/{SET_NAME}/**\"\n", + "MAIN_GLOB = sorted(glob.glob(MAIN_PATH + \"/*.png\"))\n", + "MAIN_NUM_FILES = len(MAIN_GLOB)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "def _check_if_files_found():\n", + " if DEV_NUM_FILES == 0 or MAIN_NUM_FILES == 0:\n", + " raise IOError(\n", + " \"No files found at DEV_PATH and/or MAIN_PATH. \"\n", + " f\"Please check {DEV_PATH} and {MAIN_PATH}.\"\n", + " )\n", + "\n", + "\n", + "def _check_if_matching_filecount():\n", + " if DEV_NUM_FILES != MAIN_NUM_FILES:\n", + " raise IOError(\n", + " \"Number of files do not match at DEV_PATH and MAIN_PATH \"\n", + " f\"({DEV_NUM_FILES} vs. {MAIN_NUM_FILES}).\"\n", + " )\n", + "\n", + " print(f\"Matching file count ({DEV_NUM_FILES} and {MAIN_NUM_FILES}).\")\n", + "\n", + "\n", + "def _check_if_missing_files():\n", + " missing_count = 0\n", + "\n", + " for fp_main in MAIN_GLOB:\n", + " fp_dev = fp_main.replace(SET_DIR, \"main-diurnal-cycle\")\n", + "\n", + " if fp_dev not in MAIN_GLOB:\n", + " print(f\"No production file found to compare with {fp_dev}!\")\n", + " missing_count += 1\n", + "\n", + " for fp_dev in DEV_GLOB:\n", + " fp_main = fp_main.replace(\"main-diurnal-cycle\", SET_DIR)\n", + "\n", + " if fp_main not in DEV_GLOB:\n", + " print(f\"No development file found to compare with {fp_main}!\")\n", + " missing_count += 1\n", + "\n", + " print(f\"Number of files missing: {missing_count}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Check for matching and equal number of files\n" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "_check_if_files_found()" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "No development file found to compare with /global/cfs/cdirs/e3sm/www/cdat-migration-fy24/main/qbo/QBO-ERA-Interim/qbo_diags.png!\n", + "No development file found to compare with /global/cfs/cdirs/e3sm/www/cdat-migration-fy24/main/qbo/QBO-ERA-Interim/qbo_diags.png!\n", + "Number of files missing: 2\n" + ] + } + ], + "source": [ + "_check_if_missing_files()" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "ename": "OSError", + "evalue": "Number of files do not match at DEV_PATH and MAIN_PATH (2 vs. 1).", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mOSError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[16], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43m_check_if_matching_filecount\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n", + "Cell \u001b[0;32mIn[13], line 11\u001b[0m, in \u001b[0;36m_check_if_matching_filecount\u001b[0;34m()\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_check_if_matching_filecount\u001b[39m():\n\u001b[1;32m 10\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m DEV_NUM_FILES \u001b[38;5;241m!=\u001b[39m MAIN_NUM_FILES:\n\u001b[0;32m---> 11\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mIOError\u001b[39;00m(\n\u001b[1;32m 12\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mNumber of files do not match at DEV_PATH and MAIN_PATH \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 13\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m(\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mDEV_NUM_FILES\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m vs. \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mMAIN_NUM_FILES\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m).\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 14\u001b[0m )\n\u001b[1;32m 16\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mMatching file count (\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mDEV_NUM_FILES\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m and \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mMAIN_NUM_FILES\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m).\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", + "\u001b[0;31mOSError\u001b[0m: Number of files do not match at DEV_PATH and MAIN_PATH (2 vs. 1)." + ] + } + ], + "source": [ + "_check_if_matching_filecount()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2 Compare the plots between branches\n", + "\n", + "- Compare \"ref\" and \"test\" files\n", + "- \"diff\" files are ignored because getting relative diffs for these does not make sense (relative diff will be above tolerance)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Comparing:\n", + " * /global/cfs/cdirs/e3sm/www/cdat-migration-fy24/main/qbo/QBO-ERA-Interim/qbo_diags.png\n", + " * /global/cfs/cdirs/e3sm/www/cdat-migration-fy24/664-qbo/qbo/QBO-ERA-Interim/qbo_diags.png\n", + " * Difference path /global/cfs/cdirs/e3sm/www/cdat-migration-fy24/664-qbo/qbo/QBO-ERA-Interim_diff/qbo_diags.png\n" + ] + } + ], + "source": [ + "for main_path, dev_path in zip(MAIN_GLOB, DEV_GLOB):\n", + " print(\"Comparing:\")\n", + " print(f\" * {main_path}\")\n", + " print(f\" * {dev_path}\")\n", + "\n", + " get_image_diffs(dev_path, main_path)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Results\n", + "\n", + "All plots are identical\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "cdat_regression_test", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/auxiliary_tools/cdat_regression_testing/664-qbo/run.cfg b/auxiliary_tools/cdat_regression_testing/664-qbo/run.cfg new file mode 100644 index 000000000..1d56fcb15 --- /dev/null +++ b/auxiliary_tools/cdat_regression_testing/664-qbo/run.cfg @@ -0,0 +1,6 @@ +[#] +sets = ["qbo"] +case_id = "QBO-ERA-Interim" +variables = ["U"] +ref_name = "ERA-Interim" +reference_name = "ERA-Interim" diff --git a/auxiliary_tools/cdat_regression_testing/664-qbo/run_script.py b/auxiliary_tools/cdat_regression_testing/664-qbo/run_script.py new file mode 100644 index 000000000..c9007cbee --- /dev/null +++ b/auxiliary_tools/cdat_regression_testing/664-qbo/run_script.py @@ -0,0 +1,12 @@ +# %% +# python -m auxiliary_tools.cdat_regression_testing.664-qbo.run_script +from auxiliary_tools.cdat_regression_testing.base_run_script import run_set + +SET_NAME = "qbo" +SET_DIR = "664-qbo" +CFG_PATH: str | None = None +# CFG_PATH: str | None = "/global/u2/v/vo13/E3SM-Project/e3sm_diags/auxiliary_tools/cdat_regression_testing/664-qbo/run.cfg" +MULTIPROCESSING = False + +# %% +run_set(SET_NAME, SET_DIR, CFG_PATH, MULTIPROCESSING) diff --git a/e3sm_diags/driver/qbo_driver.py b/e3sm_diags/driver/qbo_driver.py index 7891eb7cc..d05952691 100644 --- a/e3sm_diags/driver/qbo_driver.py +++ b/e3sm_diags/driver/qbo_driver.py @@ -2,111 +2,270 @@ import json import os -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Dict, Literal, Tuple, TypedDict -import cdutil import numpy as np import scipy.fftpack -from scipy.signal import detrend +import xarray as xr +import xcdat as xc -from e3sm_diags.derivations import default_regions -from e3sm_diags.driver import utils +from e3sm_diags.driver.utils.dataset_xr import Dataset +from e3sm_diags.driver.utils.io import _get_output_dir, _write_to_netcdf +from e3sm_diags.driver.utils.regrid import _subset_on_region from e3sm_diags.logger import custom_logger -from e3sm_diags.plot.cartopy.qbo_plot import plot +from e3sm_diags.metrics.metrics import spatial_avg +from e3sm_diags.parameter.core_parameter import CoreParameter +from e3sm_diags.plot.qbo_plot import plot logger = custom_logger(__name__) if TYPE_CHECKING: from e3sm_diags.parameter.qbo_parameter import QboParameter +# The region will always be 5S5N +REGION = "5S5N" -def unify_plev(var): - """ - Given a data set with a z-axis (plev), - convert to the plev with units: hPa and make sure plev is in ascending order(same as model data) + +class MetricsDict(TypedDict): + qbo: xr.DataArray + psd_sum: np.ndarray + amplitude: np.ndarray + period_new: np.ndarray + psd_x_new: np.ndarray + amplitude_new: np.ndarray + name: str + + +def run_diag(parameter: QboParameter) -> QboParameter: + variables = parameter.variables + + test_ds = Dataset(parameter, data_type="test") + ref_ds = Dataset(parameter, data_type="ref") + + for var_key in variables: + logger.info(f"Variable={var_key}") + + ds_test = test_ds.get_time_series_dataset(var_key) + ds_ref = ref_ds.get_time_series_dataset(var_key) + + ds_test_region = _subset_on_region(ds_test, var_key, REGION) + ds_ref_region = _subset_on_region(ds_ref, var_key, REGION) + + # Convert plevs of test and ref for unified units and direction + ds_test_region = _unify_plev(ds_test_region, var_key) + ds_ref_region = _unify_plev(ds_ref_region, var_key) + + # Dictionaries to store information on the variable including the name, + # the averaged variable, and metrics. + test_dict: MetricsDict = {} # type: ignore + ref_dict: MetricsDict = {} # type: ignore + + # Diagnostic 1: average over longitude & latitude to produce time-height + # array of u field. + test_dict["qbo"] = _spatial_avg(ds_test_region, var_key) + ref_dict["qbo"] = _spatial_avg(ds_ref_region, var_key) + + # Diagnostic 2: calculate and plot the amplitude of wind variations with a 20-40 month period + test_dict["psd_sum"], test_dict["amplitude"] = _get_20to40month_fft_amplitude( + test_dict["qbo"] + ) + ref_dict["psd_sum"], ref_dict["amplitude"] = _get_20to40month_fft_amplitude( + ref_dict["qbo"] + ) + + # Diagnostic 3: calculate the Power Spectral Density. + # Pre-process data to average over lat,lon,height + x_test = _get_power_spectral_density(test_dict["qbo"]) + x_ref = _get_power_spectral_density(ref_dict["qbo"]) + + # Calculate the PSD and interpolate to period_new. Specify periods to + # plot. + test_dict["period_new"] = ref_dict["period_new"] = np.concatenate( + (np.arange(2.0, 33.0), np.arange(34.0, 100.0, 2.0)), axis=0 + ) + test_dict["psd_x_new"], test_dict["amplitude_new"] = _get_psd_from_deseason( + x_test, test_dict["period_new"] + ) + ref_dict["psd_x_new"], ref_dict["amplitude_new"] = _get_psd_from_deseason( + x_ref, ref_dict["period_new"] + ) + + parameter.var_id = var_key + parameter.output_file = "qbo_diags" + parameter.main_title = ( + f"QBO index, amplitude, and power spectral density for {var_key}" + ) + # Get the years of the data. + parameter.viewer_descr[var_key] = parameter.main_title + + parameter.test_yrs = f"{test_ds.start_yr}-{test_ds.end_yr}" + parameter.ref_yrs = f"{ref_ds.start_yr}-{ref_ds.end_yr}" + + # Write the averaged variables to netCDF. Save with data type as + # `qbo_test` and `qbo_ref` to match CDAT codebase for regression + # testing of the `.nc` files. + _write_to_netcdf(parameter, test_dict["qbo"], var_key, "qbo_test") # type: ignore + _write_to_netcdf(parameter, ref_dict["qbo"], var_key, "qbo_ref") # type: ignore + + # Write the metrics to .json files. + test_dict["name"] = test_ds._get_test_name() + + try: + ref_dict["name"] = ref_ds._get_ref_name() + except AttributeError: + ref_dict["name"] = parameter.ref_name + + _save_metrics_to_json(parameter, test_dict, "test") # type: ignore + _save_metrics_to_json(parameter, ref_dict, "ref") # type: ignore + + # plot the results. + plot(parameter, test_dict, ref_dict) + + return parameter + + +def _save_metrics_to_json( + parameter: CoreParameter, + var_dict: Dict[str, str | np.ndarray], + dict_type: Literal["test", "ref"], +): + output_dir = _get_output_dir(parameter) + filename = parameter.output_file + f"_{dict_type}.json" + abs_path = os.path.join(output_dir, filename) + + # Convert all metrics from `np.ndarray` to a Python list for serialization + # to `.json`. + metrics_dict = {k: v for k, v in var_dict.items() if k != "qbo"} + + for key in metrics_dict.keys(): + if key != "name": + metrics_dict[key] = metrics_dict[key].tolist() # type: ignore + + with open(abs_path, "w") as outfile: + json.dump(metrics_dict, outfile) + + logger.info("Metrics saved in: {}".format(abs_path)) + + +def _unify_plev(ds_region: xr.Dataset, var_key: str) -> xr.Dataset: + """Convert the Z-axis (plev) with units Pa to hPa. + + This function also orders the data by plev in ascending order (same as model + data). + + Parameters + ---------- + ds_region : xr.Dataset + The dataset for the region. + + Returns + ------- + xr.Dataset + The dataset for the region with a converted Z-axis. """ - var_plv = var.getLevel() - if var_plv.units == "Pa": - var_plv[:] = var_plv[:] / 100.0 # convert Pa to mb - var_plv.units = "hPa" - var.setAxis(1, var_plv) + ds_region_new = ds_region.copy() + # The dataset can have multiple Z axes (e.g., "lev", "ilev"), so get the + # Z axis from the variable directly. + z_axis = xc.get_dim_coords(ds_region[var_key], axis="Z") + z_dim = z_axis.name - # Make plev in ascending order - if var.getLevel()[0] > var.getLevel()[-1]: - var = var(lev=slice(-1, None, -1)) + if z_axis.attrs["units"] == "Pa": + ds_region_new[z_dim] = z_axis / 100.0 + ds_region_new[z_dim].attrs["units"] = "hPa" + ds_region_new = ds_region_new.sortby(z_dim, ascending=True) -def process_u_for_time_height(data_region): - # Average over longitude (i.e., each latitude's average in data_region) - data_lon_average = cdutil.averager(data_region, axis="x") - # Average over latitude (i.e., average for entire data_region) - data_lon_lat_average = cdutil.averager(data_lon_average, axis="y") - # Get data by vertical level - level_data = data_lon_lat_average.getAxis(1) - return data_lon_lat_average, level_data + return ds_region_new -def deseason(xraw): - # Calculates the deseasonalized data - months_per_year = 12 - # Create array to hold climatological values and deseasonalized data - # Create months_per_year x 1 array of zeros - xclim = np.zeros((months_per_year, 1)) - # Create array with same shape as xraw - x_deseasoned = np.zeros(xraw.shape) - # Iterate through all 12 months. - for month in np.arange(months_per_year): - # `xraw[month::12]` will return the data for this month every year (12 months) - # (i.e., from month until the end of xraw, get every 12th month) - # Get the mean of this month, using data from every year, ignoring NaNs - xclim[month] = np.nanmean(xraw[month::months_per_year]) - num_years = int(np.floor(len(x_deseasoned) / months_per_year)) - # Iterate through all years in x_deseasoned (same number as in xraw) - for year in np.arange(num_years): - year_index = year * months_per_year - # Iterate through all months of the year - for month in np.arange(months_per_year): - month_index = year_index + month - # Subtract the month's mean over num_years from xraw's data for this month in this year - # i.e., get the difference between this month's value and it's "usual" value - x_deseasoned[month_index] = xraw[month_index] - xclim[month] - return x_deseasoned +def _spatial_avg(ds: xr.Dataset, var_key: str) -> xr.DataArray: + """Process the U variable for time and height by averaging of lat and lon. + + Diagnostic 1: average over longitude & latitude to produce time-height + array of u field. + Richter, J. H., Chen, C. C., Tang, Q., Xie, S., & Rasch, P. J. (2019). + Improved Simulation of the QBO in E3SMv1. Journal of Advances in Modeling + Earth Systems, 11(11), 3403-3418. -def get_20to40month_fft_amplitude(qboN, levelN): - # Calculates the amplitude of wind variations in the 20 - 40 month period + U = "Monthly mean zonal mean zonal wind averaged between 5S and 5N as a + function of pressure and time" (p. 3406) + + Source: https://agupubs.onlinelibrary.wiley.com/doi/epdf/10.1029/2019MS001763 + + Parameters + ---------- + ds_region : xr.Dataset + The dataset. + var_key : str + The key of the variable. + + Returns + ------- + xr.DataArray + The averaged variable. + """ + var_avg = spatial_avg(ds, var_key, axis=["X", "Y"], as_list=False) + + return var_avg # type: ignore + + +def _get_20to40month_fft_amplitude( + var_avg: xr.DataArray, +) -> Tuple[np.ndarray, np.ndarray]: + """Calculates the amplitude of wind variations in the 20 - 40 month period. + + Parameters + ---------- + var_avg : xr.DataArray + The spatially averaged variable. + + Returns + ------- + Tuple[np.ndarray, np.ndarray] + The psd and amplitude arrays. + """ + qboN_arr = np.squeeze(var_avg.values) + + levelN = xc.get_dim_coords(var_avg, axis="Z") psd_sumN = np.zeros(levelN.shape) amplitudeN = np.zeros(levelN.shape) for ilev in np.arange(len(levelN)): - # `qboN[:, ilev]` returns the entire 0th dimension for ilev in the 1st dimension of the array. - y_input = deseason(np.squeeze(qboN[:, ilev])) + # `qboN[:, ilev]` returns the entire 0th dimension for ilev in the 1st + # dimension of the array. + y_input = deseason(np.squeeze(qboN_arr[:, ilev])) y = scipy.fftpack.fft(y_input) n = len(y) + frequency = np.arange(n / 2) / n period = 1 / frequency values = y[0 : int(np.floor(n / 2))] fyy = values * np.conj(values) - # Choose the range 20 - 40 months that captures most QBOs (in nature) + + # Choose the range 20 - 40 months that captures most QBOs (in nature). psd_sumN[ilev] = 2 * np.nansum(fyy[(period <= 40) & (period >= 20)]) amplitudeN[ilev] = np.sqrt(2 * psd_sumN[ilev]) * (frequency[1] - frequency[0]) + return psd_sumN, amplitudeN -def process_u_for_power_spectral_density(data_region): +def _get_power_spectral_density(var_avg: xr.DataArray): # Average over vertical levels and horizontal area (units: hPa) level_bottom = 22 level_top = 18 - # Average over lat and lon - data_lat_lon_average = cdutil.averager(data_region, axis="xy") + + z_dim = xc.get_dim_keys(var_avg, axis="Z") # Average over vertical try: - average = data_lat_lon_average(level=(level_top, level_bottom)) + average = var_avg.sel({z_dim: slice(level_top, level_bottom)}) except Exception: raise Exception( "No levels found between {}hPa and {}hPa".format(level_top, level_bottom) ) - x0 = np.nanmean(np.array(average), axis=1) + + x0 = np.nanmean(average.values, axis=1) + # x0 should now be 1D return x0 @@ -122,7 +281,7 @@ def ceil_log2(x): return np.ceil(np.log2(x)).astype("int") -def get_psd_from_deseason(xraw, period_new): +def _get_psd_from_deseason(xraw, period_new): x_deseasoned = deseason(xraw) # Sampling frequency: assumes frequency of sampling = 1 month @@ -143,6 +302,7 @@ def get_psd_from_deseason(xraw, period_new): amplitude0 = 2 * abs(x0[0 : int(NFFT0 / 2 + 1)]) # Calculate power spectral density as a function of frequency psd_x0 = amplitude0**2 / L0 + # Total spectral power # In the next code block, we will perform an interpolation using the period # (interpolating values of amplitude0_flipped and psd_x0_flipped from period0_flipped to period_new). @@ -156,170 +316,32 @@ def get_psd_from_deseason(xraw, period_new): period_new, period0_flipped[:-1], amplitude0_flipped[:-1] ) psd_x_new0 = np.interp(period_new, period0_flipped[:-1], psd_x0_flipped[:-1]) - return psd_x_new0, amplitude_new0 + return psd_x_new0, amplitude_new0 -def get_psd_from_wavelet(data): - """ - Return power spectral density using a complex Morlet wavelet spectrum of degree 6 - """ - deg = 6 - period = np.arange(1, 55 + 1) - freq = 1 / period - widths = deg / (2 * np.pi * freq) - cwtmatr = scipy.signal.cwt(data, scipy.signal.morlet2, widths=widths, w=deg) - psd = np.mean(np.square(np.abs(cwtmatr)), axis=1) - return (period, psd) - - -def run_diag(parameter: QboParameter) -> QboParameter: - variables = parameter.variables - # The region will always be 5S5N - region = "5S5N" - test_data = utils.dataset.Dataset(parameter, test=True) - ref_data = utils.dataset.Dataset(parameter, ref=True) - # Get the years of the data. - parameter.test_yrs = utils.general.get_yrs(test_data) - parameter.ref_yrs = utils.general.get_yrs(ref_data) - for variable in variables: - if parameter.print_statements: - logger.info("Variable={}".format(variable)) - test_var = test_data.get_timeseries_variable(variable) - ref_var = ref_data.get_timeseries_variable(variable) - qbo_region = default_regions.regions_specs[region]["domain"] # type: ignore - - test_region = test_var(qbo_region) - ref_region = ref_var(qbo_region) - - # Convert plevs of test and ref for unified units and direction - unify_plev(test_region) - unify_plev(ref_region) - - test = {} - ref = {} - - # Diagnostic 1: average over longitude & latitude to produce time-height array of u field: - # Richter, J. H., Chen, C. C., Tang, Q., Xie, S., & Rasch, P. J. (2019). Improved Simulation of the QBO in E3SMv1. Journal of Advances in Modeling Earth Systems, 11(11), 3403-3418. - # https://agupubs.onlinelibrary.wiley.com/doi/epdf/10.1029/2019MS001763 - # U = "Monthly mean zonal mean zonal wind averaged between 5S and 5N as a function of pressure and time" (p. 3406) - test["qbo"], test["level"] = process_u_for_time_height(test_region) - ref["qbo"], ref["level"] = process_u_for_time_height(ref_region) - - # Diagnostic 2: calculate and plot the amplitude of wind variations with a 20-40 month period - test["psd_sum"], test["amplitude"] = get_20to40month_fft_amplitude( - np.squeeze(np.array(test["qbo"])), test["level"] - ) - ref["psd_sum"], ref["amplitude"] = get_20to40month_fft_amplitude( - np.squeeze(np.array(ref["qbo"])), ref["level"] - ) - - # Diagnostic 3: calculate the Power Spectral Density - # Pre-process data to average over lat,lon,height - x_test = process_u_for_power_spectral_density(test_region) - x_ref = process_u_for_power_spectral_density(ref_region) - # Calculate the PSD and interpolate to period_new. Specify periods to plot - period_new = np.concatenate( - (np.arange(2.0, 33.0), np.arange(34.0, 100.0, 2.0)), axis=0 - ) - test["psd_x_new"], test["amplitude_new"] = get_psd_from_deseason( - x_test, period_new - ) - test["period_new"] = period_new - ref["psd_x_new"], ref["amplitude_new"] = get_psd_from_deseason( - x_ref, period_new - ) - ref["period_new"] = period_new - - # Diagnostic 4: calculate the Wavelet - # Target vertical level - pow_spec_lev = 20.0 - - # Find the closest value for power spectral level in the list - # List of test case vertical levels - test_lev_list = list(test["level"]) - closest_lev = min(test_lev_list, key=lambda x: abs(x - pow_spec_lev)) - closest_index = test_lev_list.index(closest_lev) - # Grab target vertical level - test_data_avg = test["qbo"][:, closest_index] - - # List of reference case vertical levels - ref_lev_list = list(ref["level"]) - # Find the closest value for power spectral level in the list - closest_lev = min(ref_lev_list, key=lambda x: abs(x - pow_spec_lev)) - closest_index = ref_lev_list.index(closest_lev) - # Grab target vertical level - ref_data_avg = ref["qbo"][:, closest_index] - - # convert to anomalies - test_data_avg = test_data_avg - test_data_avg.mean() - ref_data_avg = ref_data_avg - ref_data_avg.mean() - - # Detrend the data - test_detrended_data = detrend(test_data_avg) - ref_detrended_data = detrend(ref_data_avg) - - test["wave_period"], test_wavelet = get_psd_from_wavelet(test_detrended_data) - ref["wave_period"], ref_wavelet = get_psd_from_wavelet(ref_detrended_data) - - # Get square root values of wavelet spectra - test["wavelet"] = np.sqrt(test_wavelet) - ref["wavelet"] = np.sqrt(ref_wavelet) - - parameter.var_id = variable - parameter.main_title = ( - "QBO index, amplitude, and power spectral density for {}".format(variable) - ) - parameter.viewer_descr[variable] = parameter.main_title - - test["name"] = utils.general.get_name(parameter, test_data) - ref["name"] = utils.general.get_name(parameter, ref_data) - - test_nc = {} - ref_nc = {} - for key in ["qbo", "level"]: - test_nc[key] = test[key] - ref_nc[key] = ref[key] - - test_json = {} - ref_json = {} - for key in test.keys(): - if key == "name": - test_json[key] = test[key] - ref_json[key] = ref[key] - elif key == "qbo": - continue - else: - test_json[key] = list(test[key]) - ref_json[key] = list(ref[key]) - - parameter.output_file = "qbo_diags" - # TODO: Check the below works properly by using ncdump on Cori - utils.general.save_transient_variables_to_netcdf( - parameter.current_set, test_nc, "test", parameter - ) - utils.general.save_transient_variables_to_netcdf( - parameter.current_set, ref_nc, "ref", parameter - ) - - # Saving the other data as json. - for dict_type in ["test", "ref"]: - json_output_file_name = os.path.join( - utils.general.get_output_dir(parameter.current_set, parameter), - parameter.output_file + "_{}.json".format(dict_type), - ) - with open(json_output_file_name, "w") as outfile: - if dict_type == "test": - json_dict = test_json - else: - json_dict = ref_json - json.dump(json_dict, outfile, default=str) - # Get the file name that the user has passed in and display that. - json_output_file_name = os.path.join( - utils.general.get_output_dir(parameter.current_set, parameter), - parameter.output_file + "_{}.json".format(dict_type), - ) - logger.info("Metrics saved in: {}".format(json_output_file_name)) - - plot(parameter, test, ref) - return parameter +def deseason(xraw): + # Calculates the deseasonalized data + months_per_year = 12 + # Create array to hold climatological values and deseasonalized data + # Create months_per_year x 1 array of zeros + xclim = np.zeros((months_per_year, 1)) + # Create array with same shape as xraw + x_deseasoned = np.zeros(xraw.shape) + # Iterate through all 12 months. + for month in np.arange(months_per_year): + # `xraw[month::12]` will return the data for this month every year (12 months) + # (i.e., from month until the end of xraw, get every 12th month) + # Get the mean of this month, using data from every year, ignoring NaNs + xclim[month] = np.nanmean(xraw[month::months_per_year]) + num_years = int(np.floor(len(x_deseasoned) / months_per_year)) + # Iterate through all years in x_deseasoned (same number as in xraw) + for year in np.arange(num_years): + year_index = year * months_per_year + # Iterate through all months of the year + for month in np.arange(months_per_year): + month_index = year_index + month + # Subtract the month's mean over num_years from xraw's data for this month in this year + # i.e., get the difference between this month's value and it's "usual" value + x_deseasoned[month_index] = xraw[month_index] - xclim[month] + return x_deseasoned diff --git a/e3sm_diags/driver/utils/dataset_xr.py b/e3sm_diags/driver/utils/dataset_xr.py index 0e42a177a..86a30c0a1 100644 --- a/e3sm_diags/driver/utils/dataset_xr.py +++ b/e3sm_diags/driver/utils/dataset_xr.py @@ -31,6 +31,7 @@ ) from e3sm_diags.driver import LAND_FRAC_KEY, LAND_OCEAN_MASK_PATH, OCEAN_FRAC_KEY from e3sm_diags.driver.utils.climo_xr import CLIMO_FREQS, ClimoFreq, climo +from e3sm_diags.driver.utils.regrid import HYBRID_SIGMA_KEYS from e3sm_diags.logger import custom_logger if TYPE_CHECKING: @@ -269,8 +270,6 @@ def _get_ref_name(self) -> str: "reference datasets." ) - return self.parameter.ref_name - def _get_global_attr_from_climo_dataset( self, attr: str, season: ClimoFreq ) -> str | None: @@ -440,23 +439,7 @@ def _get_climo_dataset(self, season: str) -> xr.Dataset: ) ds = squeeze_time_dim(ds) - - # slat and slon are lat lon pair for staggered FV grid included in - # remapped files. - if "slat" in ds.dims: - ds = ds.drop_dims(["slat", "slon"]) - - all_vars = list(ds.data_vars.keys()) - keep_bnds = [var for var in all_vars if "bnd" in var or "bounds" in var] - ds = ds[[self.var] + keep_bnds] - - # NOTE: There seems to be an issue with `open_mfdataset()` and - # using the multiprocessing scheduler defined in e3sm_diags, - # resulting in timeouts and resource locking. - # To avoid this, we load the multi-file dataset into memory before - # performing downstream operations. - # Related GH issue: https://github.com/pydata/xarray/issues/3781 - ds.load(scheduler="sync") + ds = self._subset_vars_and_load(ds) return ds @@ -792,6 +775,45 @@ def _get_matching_climo_src_vars( return None + def _subset_vars_and_load(self, ds: xr.Dataset) -> xr.Dataset: + """Subset for variables needed for processing and load into memory. + + Subsetting the dataset reduces its memory footprint. Loading is + necessary because there seems to be an issue with `open_mfdataset()` + and using the multiprocessing scheduler defined in e3sm_diags, + resulting in timeouts and resource locking. To avoid this, we load the + multi-file dataset into memory before performing downstream operations. + + Source: https://github.com/pydata/xarray/issues/3781 + + Parameters + ---------- + ds : xr.Dataset + The dataset. + + Returns + ------- + xr.Dataset + The dataset subsetted and loaded into memory. + """ + # slat and slon are lat lon pair for staggered FV grid included in + # remapped files. + if "slat" in ds.dims: + ds = ds.drop_dims(["slat", "slon"]) + + all_vars_keys = list(ds.data_vars.keys()) + hybrid_var_keys = set(list(sum(HYBRID_SIGMA_KEYS.values(), ()))) + keep_vars = [ + var + for var in all_vars_keys + if "bnd" in var or "bounds" in var or var in hybrid_var_keys + ] + ds = ds[[self.var] + keep_vars] + + ds.load(scheduler="sync") + + return ds + # -------------------------------------------------------------------------- # Time series related methods # -------------------------------------------------------------------------- diff --git a/e3sm_diags/parameter/core_parameter.py b/e3sm_diags/parameter/core_parameter.py index 7b876b3bf..5351a9cb5 100644 --- a/e3sm_diags/parameter/core_parameter.py +++ b/e3sm_diags/parameter/core_parameter.py @@ -277,7 +277,9 @@ def _set_param_output_attrs( self.output_file = output_file self.main_title = main_title - def _set_name_yrs_attrs(self, ds_test: Dataset, ds_ref: Dataset, season: ClimoFreq): + def _set_name_yrs_attrs( + self, ds_test: Dataset, ds_ref: Dataset, season: ClimoFreq | None + ): """Set the test_name_yrs and ref_name_yrs attributes. Parameters @@ -286,8 +288,8 @@ def _set_name_yrs_attrs(self, ds_test: Dataset, ds_ref: Dataset, season: ClimoFr The test dataset object used for setting ``self.test_name_yrs``. ds_ref : Dataset The ref dataset object used for setting ``self.ref_name_yrs``. - season : CLIMO_FREQ - The climatology frequency. + season : ClimoFreq | None + The optional climatology frequency. """ self.test_name_yrs = ds_test.get_name_yrs_attr(season) self.ref_name_yrs = ds_ref.get_name_yrs_attr(season) diff --git a/e3sm_diags/plot/qbo_plot.py b/e3sm_diags/plot/qbo_plot.py new file mode 100644 index 000000000..7847c42cb --- /dev/null +++ b/e3sm_diags/plot/qbo_plot.py @@ -0,0 +1,237 @@ +from __future__ import annotations + +from typing import List, Literal, TypedDict + +import matplotlib +import numpy as np +import xcdat as xc + +from e3sm_diags.logger import custom_logger +from e3sm_diags.parameter.qbo_parameter import QboParameter +from e3sm_diags.plot.utils import _save_plot + +matplotlib.use("Agg") +import matplotlib.pyplot as plt # isort:skip # noqa: E402 + +logger = custom_logger(__name__) + +PANEL_CFG = [ + (0.075, 0.70, 0.6, 0.225), + (0.075, 0.425, 0.6, 0.225), + (0.725, 0.425, 0.2, 0.5), + (0.075, 0.075, 0.85, 0.275), +] + +LABEL_SIZE = 14 +CMAP = plt.cm.RdBu_r + + +class XAxis(TypedDict): + axis_range: list[int] + axis_scale: Literal["linear"] + label: str + data: np.ndarray + data_label: str | None + data2: np.ndarray | None + data2_label: str | None + + +class YAxis(TypedDict): + axis_range: list[int] + axis_scale: Literal["linear", "log"] + label: str + data: np.ndarray + data_label: str | None + data2: np.ndarray | None + data2_label: str | None + + +class ZAxis(TypedDict): + data: np.ndarray + + +def plot(parameter: QboParameter, test_dict, ref_dict): + fig = plt.figure(figsize=(14, 14)) + + test_z_axis = xc.get_dim_coords(test_dict["qbo"], axis="Z") + ref_z_axis = xc.get_dim_coords(ref_dict["qbo"], axis="Z") + + months = np.minimum(ref_dict["qbo"].shape[0], test_dict["qbo"].shape[0]) + x_test, y_test = np.meshgrid(np.arange(0, months), test_z_axis) + x_ref, y_ref = np.meshgrid(np.arange(0, months), ref_z_axis) + + color_levels0 = np.arange(-50, 51, 100.0 / 20.0) + + # Panel 0 (Top Left) + x: XAxis = dict( + axis_range=[0, months], + axis_scale="linear", + label=" ", + data=x_test, + data_label=None, + data2=None, + data2_label=None, + ) + y: YAxis = dict( + axis_range=[100, 1], + axis_scale="log", + label="hPa", + data=y_test, + data_label=None, + data2=None, + data2_label=None, + ) + z: ZAxis = dict(data=test_dict["qbo"].T[:, :months]) + title = "{} U [{}] 5S-5N ({})".format(test_dict["name"], "m/s", parameter.test_yrs) + _add_color_map( + 0, + fig, + "contourf", + title, + x, + y, + z=z, + plot_colors=CMAP, + color_levels=color_levels0, + color_ticks=[-50, -25, -5, 5, 25, 50], + ) + + # Panel 1 (Middle Left) + x = dict( + axis_range=[0, months], + axis_scale="linear", + label="month", + data=x_ref, + data_label=None, + data2=None, + data2_label=None, + ) + y = dict( + axis_range=[100, 1], + axis_scale="log", + label="hPa", + data=y_ref, + data_label=None, + data2=None, + data2_label=None, + ) + z = dict(data=ref_dict["qbo"].T[:, :months]) + title = "{} U [{}] 5S-5N ({})".format(ref_dict["name"], "m/s", parameter.ref_yrs) + _add_color_map( + 1, + fig, + "contourf", + title, + x, + y, + z=z, + plot_colors=CMAP, + color_levels=color_levels0, + color_ticks=[-50, -25, -5, 5, 25, 50], + ) + + # Panel 2 (Top/Middle Right) + x = dict( + axis_range=[0, 30], + axis_scale="linear", + label="Amplitude (m/s)", + data=test_dict["amplitude"][:], + data_label=test_dict["name"], + data2=ref_dict["amplitude"][:], + data2_label=ref_dict["name"], + ) + y = dict( + axis_range=[100, 1], + axis_scale="log", + label="Pressure (hPa)", + data=test_z_axis[:], + data_label=None, + data2=ref_z_axis[:], + data2_label=None, + ) + title = "QBO Amplitude \n (period = 20-40 months)" + _add_color_map(2, fig, "line", title, x, y) + + # Panel 3 (Bottom) + x = dict( + axis_range=[0, 50], + axis_scale="linear", + label="Period (months)", + data=test_dict["period_new"], + data_label=test_dict["name"], + data2=ref_dict["period_new"], + data2_label=ref_dict["name"], + ) + y = dict( + axis_range=[-1, 25], + axis_scale="linear", + label="Amplitude (m/s)", + data=test_dict["amplitude_new"], + data_label=None, + data2=ref_dict["amplitude_new"], + data2_label=None, + ) + title = "QBO Spectral Density (Eq. 18-22 hPa zonal winds)" + _add_color_map(3, fig, "line", title, x, y) + + plt.tight_layout() + + # Figure title + fig.suptitle(parameter.main_title, x=0.5, y=0.97, fontsize=15) + + # Save figure + _save_plot(fig, parameter, PANEL_CFG) + + plt.close() + + +def _add_color_map( + subplot_num: int, + fig: plt.Figure, + plot_type: Literal["contourf", "line"], + title: str, + x: XAxis, + y: YAxis, + z: ZAxis | None = None, + plot_colors: plt.cm.ColormapRegistry | None = None, + color_levels: np.ndarray | None = None, + color_ticks: List[int] | None = None, +): + # x,y,z should be of the form: + # dict(axis_range=None, axis_scale=None, data=None, data_label=None, data2=None, data2_label=None, label=None) + + # Create new figure axis using dimensions from panel (hard coded) + ax = fig.add_axes(PANEL_CFG[subplot_num]) + # Plot either a contourf or line plot + if plot_type == "contourf": + if z is None: + raise RuntimeError(f"Must set `z` arg to use plot_type={plot_type}.") + + p1 = ax.contourf( + x["data"], y["data"], z["data"], color_levels, cmap=plot_colors + ) + cbar = plt.colorbar(p1, ticks=color_ticks) + cbar.ax.tick_params(labelsize=LABEL_SIZE) + + if plot_type == "line": + (p1,) = ax.plot(x["data"], y["data"], "-ok") + (p2,) = ax.plot(x["data2"], y["data2"], "--or") + + plt.grid("on") + ax.legend( + (p1, p2), + (x["data_label"], x["data2_label"]), + loc="upper right", + fontsize=LABEL_SIZE, + ) + + ax.set_title(title, size=LABEL_SIZE, weight="demi") + ax.set_xlabel(x["label"], size=LABEL_SIZE) + ax.set_ylabel(y["label"], size=LABEL_SIZE) + + plt.yscale(y["axis_scale"]) + plt.ylim([y["axis_range"][0], y["axis_range"][1]]) + plt.yticks(size=LABEL_SIZE) + plt.xscale(x["axis_scale"]) + plt.xlim([x["axis_range"][0], x["axis_range"][1]]) + plt.xticks(size=LABEL_SIZE) diff --git a/tests/e3sm_diags/driver/utils/test_dataset_xr.py b/tests/e3sm_diags/driver/utils/test_dataset_xr.py index 79371f18e..a329fef80 100644 --- a/tests/e3sm_diags/driver/utils/test_dataset_xr.py +++ b/tests/e3sm_diags/driver/utils/test_dataset_xr.py @@ -766,6 +766,7 @@ def test_returns_climo_dataset_with_derived_variable(self): expected = expected.squeeze(dim="time").drop_vars("time") expected["PRECT"] = expected["pr"] * 3600 * 24 expected["PRECT"].attrs["units"] = "mm/day" + expected = expected.drop_vars("pr") xr.testing.assert_identical(result, expected) @@ -924,6 +925,7 @@ def test_returns_climo_dataset_using_source_variable_with_wildcard(self): result = ds.get_climo_dataset("bc_DDF", season="ANN") expected = ds_precst.squeeze(dim="time").drop_vars("time") expected["bc_DDF"] = expected["bc_a?DDF"] + expected["bc_c?DDF"] + expected = expected.drop_vars(["bc_a?DDF", "bc_c?DDF"]) xr.testing.assert_identical(result, expected) @@ -1502,6 +1504,7 @@ def test_returns_land_sea_mask_if_matching_vars_in_dataset(self): result = ds._get_land_sea_mask("ANN") expected = ds_climo.copy() expected = expected.squeeze(dim="time").drop_vars("time") + expected = expected.drop_vars("ts") xr.testing.assert_identical(result, expected)