Skip to content

Commit

Permalink
CDAT Migration Phase 2: Refactor qbo set (#826)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomvothecoder committed Nov 11, 2024
1 parent 60aa46c commit bff1359
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 0 deletions.
27 changes: 27 additions & 0 deletions e3sm_diags/driver/qbo_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@
import scipy.fftpack
import xarray as xr
import xcdat as xc
<<<<<<< HEAD
from scipy.signal import detrend
=======
>>>>>>> c7ef34e9 (CDAT Migration Phase 2: Refactor `qbo` set (#826))

from e3sm_diags.driver.utils.dataset_xr import Dataset
from e3sm_diags.driver.utils.io import _get_output_dir, _write_to_netcdf
Expand All @@ -26,9 +29,12 @@
# The region will always be 5S5N
REGION = "5S5N"

<<<<<<< HEAD
# Target power spectral vertical level for the wavelet diagnostic.
POW_SPEC_LEV = 20.0

=======
>>>>>>> c7ef34e9 (CDAT Migration Phase 2: Refactor `qbo` set (#826))

class MetricsDict(TypedDict):
qbo: xr.DataArray
Expand All @@ -37,8 +43,11 @@ class MetricsDict(TypedDict):
period_new: np.ndarray
psd_x_new: np.ndarray
amplitude_new: np.ndarray
<<<<<<< HEAD
wave_period: np.ndarray
wavelet: np.ndarray
=======
>>>>>>> c7ef34e9 (CDAT Migration Phase 2: Refactor `qbo` set (#826))
name: str


Expand Down Expand Up @@ -96,6 +105,7 @@ def run_diag(parameter: QboParameter) -> QboParameter:
x_ref, ref_dict["period_new"]
)

<<<<<<< HEAD
# Diagnostic 4: calculate the Wavelet
test_dict["wave_period"], test_dict["wavelet"] = _calculate_wavelet(
test_dict["qbo"]
Expand All @@ -104,6 +114,8 @@ def run_diag(parameter: QboParameter) -> QboParameter:
ref_dict["qbo"]
)

=======
>>>>>>> c7ef34e9 (CDAT Migration Phase 2: Refactor `qbo` set (#826))
parameter.var_id = var_key
parameter.output_file = "qbo_diags"
parameter.main_title = (
Expand All @@ -123,7 +135,15 @@ def run_diag(parameter: QboParameter) -> QboParameter:

# Write the metrics to .json files.
test_dict["name"] = test_ds._get_test_name()
<<<<<<< HEAD
ref_dict["name"] = ref_ds._get_ref_name()
=======

try:
ref_dict["name"] = ref_ds._get_ref_name()
except AttributeError:
ref_dict["name"] = parameter.ref_name
>>>>>>> c7ef34e9 (CDAT Migration Phase 2: Refactor `qbo` set (#826))

_save_metrics_to_json(parameter, test_dict, "test") # type: ignore
_save_metrics_to_json(parameter, ref_dict, "ref") # type: ignore
Expand Down Expand Up @@ -152,7 +172,11 @@ def _save_metrics_to_json(
metrics_dict[key] = metrics_dict[key].tolist() # type: ignore

with open(abs_path, "w") as outfile:
<<<<<<< HEAD
json.dump(metrics_dict, outfile, default=str)
=======
json.dump(metrics_dict, outfile)
>>>>>>> c7ef34e9 (CDAT Migration Phase 2: Refactor `qbo` set (#826))

logger.info("Metrics saved in: {}".format(abs_path))

Expand Down Expand Up @@ -355,6 +379,7 @@ def deseason(xraw):
# i.e., get the difference between this month's value and it's "usual" value
x_deseasoned[month_index] = xraw[month_index] - xclim[month]
return x_deseasoned
<<<<<<< HEAD


def _calculate_wavelet(var: xr.DataArray) -> Tuple[np.ndarray, np.ndarray]:
Expand Down Expand Up @@ -419,3 +444,5 @@ def _get_psd_from_wavelet(data: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
psd = np.mean(np.square(np.abs(cwtmatr)), axis=1)

return (period, psd)
=======
>>>>>>> c7ef34e9 (CDAT Migration Phase 2: Refactor `qbo` set (#826))
60 changes: 60 additions & 0 deletions tests/e3sm_diags/driver/utils/test_dataset_xr.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,9 +558,69 @@ def test_returns_climo_dataset_using_test_file_variable_ref_name_and_season_nest

xr.testing.assert_identical(result, expected)

<<<<<<< HEAD
@pytest.mark.xfail(
reason="Need to figure out why to create dummy incorrect time scalar variable with Xarray."
)
=======
def test_returns_climo_dataset_with_derived_variable(self):
# We will derive the "PRECT" variable using the "pr" variable.
ds_pr = xr.Dataset(
coords={
**spatial_coords,
"time": xr.DataArray(
dims="time",
data=np.array(
[
cftime.DatetimeGregorian(
2000, 1, 16, 12, 0, 0, 0, has_year_zero=False
),
],
dtype=object,
),
attrs={
"axis": "T",
"long_name": "time",
"standard_name": "time",
"bounds": "time_bnds",
},
),
},
data_vars={
**spatial_bounds,
"pr": xr.DataArray(
xr.DataArray(
data=np.array(
[
[[1.0, 1.0], [1.0, 1.0]],
]
),
dims=["time", "lat", "lon"],
attrs={"units": "mm/s"},
)
),
},
)

parameter = _create_parameter_object(
"ref", "climo", self.data_path, "2000", "2001"
)
parameter.ref_file = "pr_200001_200112.nc"
ds_pr.to_netcdf(f"{self.data_path}/{parameter.ref_file}")

ds = Dataset(parameter, data_type="ref")

result = ds.get_climo_dataset("PRECT", season="ANN")
expected = ds_pr.copy()
expected = expected.squeeze(dim="time").drop_vars("time")
expected["PRECT"] = expected["pr"] * 3600 * 24
expected["PRECT"].attrs["units"] = "mm/day"
expected = expected.drop_vars("pr")

xr.testing.assert_identical(result, expected)

@pytest.mark.xfail
>>>>>>> c7ef34e9 (CDAT Migration Phase 2: Refactor `qbo` set (#826))
def test_returns_climo_dataset_using_derived_var_directly_from_dataset_and_replaces_scalar_time_var(
self,
):
Expand Down

0 comments on commit bff1359

Please sign in to comment.