Skip to content

Commit

Permalink
Fixing African easterly wave density plots in TC analysis (#851)
Browse files Browse the repository at this point in the history
* read in aew data without subsetting

* exclude 6hourly data cross year bounds

* use year range from stitch file name instead of data in file

* fixtests

* capture input data errors
  • Loading branch information
chengzhuzhang authored Oct 23, 2024
1 parent 6b0a245 commit 7b05425
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 24 deletions.
49 changes: 31 additions & 18 deletions e3sm_diags/driver/tc_analysis_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,7 @@ def run_diag(parameter: TCAnalysisParameter) -> TCAnalysisParameter:
test_data_path,
"aew_hist_{}_{}_{}.nc".format(test_name, test_start_yr, test_end_yr),
)
test_aew_hist = cdms2.open(test_aew_file)(
"density", lat=(0, 35, "ccb"), lon=(-180, 0, "ccb"), squeeze=1
)
test_aew_hist = cdms2.open(test_aew_file)("density", squeeze=1)

test_data = collections.OrderedDict()
ref_data = collections.OrderedDict()
Expand Down Expand Up @@ -134,9 +132,8 @@ def run_diag(parameter: TCAnalysisParameter) -> TCAnalysisParameter:
"density", lat=(-60, 60, "ccb"), squeeze=1
)
ref_aew_file = os.path.join(reference_data_path, "aew_hist_ERA5_2010_2014.nc")
ref_aew_hist = cdms2.open(ref_aew_file)(
"density", lat=(0, 35, "ccb"), lon=(180, 360, "ccb"), squeeze=1
)
ref_aew_hist = cdms2.open(ref_aew_file)("density", squeeze=1)

ref_data["cyclone_density"] = ref_cyclones_hist
ref_data["cyclone_num_years"] = 40 # type: ignore
ref_data["aew_density"] = ref_aew_hist
Expand All @@ -163,13 +160,40 @@ def generate_tc_metrics_from_te_stitch_file(te_stitch_file: str) -> Dict[str, An
"""
logger.info("\nGenerating TC Metrics from TE Stitch Files")
logger.info("============================================")
if not os.path.exists(te_stitch_file):
raise FileNotFoundError(f"The file {te_stitch_file} does not exist.")

with open(te_stitch_file) as f:
lines = f.readlines()
lines_orig = f.readlines()

if not lines_orig:
raise ValueError(f"The file {te_stitch_file} is empty.")

line_ind = []
data_start_year = int(te_stitch_file.split(".")[-2].split("_")[-2])
data_end_year = int(te_stitch_file.split(".")[-2].split("_")[-1])
for i in range(0, np.size(lines_orig)):
if lines_orig[i][0] == "s":
year = int(lines_orig[i].split("\t")[2])

if year <= data_end_year:
line_ind.append(i)

# Remove excessive time points cross year bounds from 6 hourly data
end_ind = line_ind[-1]
lines = lines_orig[0:end_ind]

# Calculate number of storms and max length
num_storms, max_len = _calc_num_storms_and_max_len(lines)
# Parse variables from TE stitch file
te_stitch_vars = _get_vars_from_te_stitch(lines, max_len, num_storms)
# Add year info
te_stitch_vars["year_start"] = data_start_year
te_stitch_vars["year_end"] = data_end_year
te_stitch_vars["num_years"] = data_end_year - data_start_year + 1
logger.info(
f"TE Start Year: {te_stitch_vars['year_start']}, TE End Year: {te_stitch_vars['year_end']}, Total Years: {te_stitch_vars['num_years']}"
)

# Use E3SM land-sea mask
mask_path = os.path.join(e3sm_diags.INSTALL_PATH, "acme_ne30_ocean_land_mask.nc")
Expand Down Expand Up @@ -246,15 +270,11 @@ def _get_vars_from_te_stitch(
vars_dict = {k: np.empty((max_len, num_storms)) * np.nan for k in keys}

index = 0
year_start = int(lines[0].split("\t")[2])
year_end = year_start

for line in lines:
line_split = line.split("\t")
if line[0] == "s":
index = index + 1
year = int(line_split[2])
year_end = max(year, year_start)
k = 0
else:
k = k + 1
Expand All @@ -265,13 +285,6 @@ def _get_vars_from_te_stitch(
vars_dict["yearmc"][k - 1, index - 1] = float(line_split[6])
vars_dict["monthmc"][k - 1, index - 1] = float(line_split[7])

vars_dict["year_start"] = year_start # type: ignore
vars_dict["year_end"] = year_end # type: ignore
vars_dict["num_years"] = year_end - year_start + 1 # type: ignore
logger.info(
f"TE Start Year: {vars_dict['year_start']}, TE End Year: {vars_dict['year_end']}, Total Years: {vars_dict['num_years']}"
)

return vars_dict


Expand Down
6 changes: 0 additions & 6 deletions tests/e3sm_diags/drivers/test_tc_analysis_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,6 @@ def test_correct_output(self):
"vsmc": np.array([[1.94, np.nan]]),
"yearmc": np.array([[1, np.nan]]),
"monthmc": np.array([[1, np.nan]]),
"year_start": 90,
"year_end": 90,
"num_years": 1,
}
result = _get_vars_from_te_stitch(lines, max_len, num_storms)

Expand All @@ -70,9 +67,6 @@ def test_correct_output(self):
np.array_equal(result["vsmc"], expected["vsmc"])
np.array_equal(result["yearmc"], expected["yearmc"])
np.array_equal(result["monthmc"], expected["monthmc"])
self.assertEqual(result["year_start"], expected["year_start"])
self.assertEqual(result["year_end"], expected["year_end"])
self.assertEqual(result["num_years"], expected["num_years"])


class TestDeriveMetricsPerBasin(TestCase):
Expand Down

0 comments on commit 7b05425

Please sign in to comment.