diff --git a/pcmdi_metrics/mjo/lib/__init__.py b/pcmdi_metrics/mjo/lib/__init__.py index 5f4352ff6..c1587b174 100644 --- a/pcmdi_metrics/mjo/lib/__init__.py +++ b/pcmdi_metrics/mjo/lib/__init__.py @@ -2,9 +2,7 @@ from .debug_chk_plot import debug_chk_plot # noqa from .dict_merge import dict_merge # noqa from .lib_mjo import ( # noqa - Remove_dailySeasonalCycle, calculate_ewr, - decorate_2d_array_axes, generate_axes_and_decorate, get_daily_ano_segment, interp2commonGrid, @@ -13,7 +11,6 @@ space_time_spectrum, subSliceSegment, taper, - unit_conversion, write_netcdf_output, ) from .mjo_metric_calc import mjo_metric_ewr_calculation # noqa diff --git a/pcmdi_metrics/mjo/lib/debug_chk_plot.py b/pcmdi_metrics/mjo/lib/debug_chk_plot.py index 656f545e9..75db53cd1 100644 --- a/pcmdi_metrics/mjo/lib/debug_chk_plot.py +++ b/pcmdi_metrics/mjo/lib/debug_chk_plot.py @@ -4,29 +4,15 @@ import matplotlib.pyplot as plt from cartopy.mpl.ticker import LatitudeFormatter, LongitudeFormatter +from pcmdi_metrics.io import get_latitude, get_longitude + def debug_chk_plot(d_seg_x_ano, Power, OEE, segment_year, daSeaCyc, segment_ano_year): os.makedirs("debug", exist_ok=True) - """ FIX ME --- - x = vcs.init() - x.plot(d_seg_x_ano) - x.png('debug/d_seg_x_ano.png') - - x.clear() - x.plot(Power) - x.png('debug/power.png') - - x.clear() - x.plot(OEE) - x.png('debug/OEE.png') - """ - print("type(segment_year)", type(segment_year)) print("segment_year.shape:", segment_year.shape) - print(segment_year.getAxis(0)) - print(segment_year.getAxis(1)) - print(segment_year.getAxis(2)) + plot_map(segment_year[0], "debug/segment.png") print("type(daSeaCyc)", type(daSeaCyc)) @@ -35,16 +21,14 @@ def debug_chk_plot(d_seg_x_ano, Power, OEE, segment_year, daSeaCyc, segment_ano_ print("type(segment_ano_year)", type(segment_ano_year)) print("segment_ano_year.shape:", segment_ano_year.shape) - print(segment_ano_year.getAxis(0)) - print(segment_ano_year.getAxis(1)) - print(segment_ano_year.getAxis(2)) + plot_map(segment_ano_year[0], "debug/segment_ano.png") def plot_map(data, filename): fig = plt.figure(figsize=(10, 6)) - lons = data.getLongitude() - lats = data.getLatitude() + lons = get_longitude(data) + lats = get_latitude(data) ax = plt.axes(projection=ccrs.PlateCarree(central_longitude=180)) im = ax.contourf(lons, lats, data, transform=ccrs.PlateCarree(), cmap="viridis") ax.coastlines() diff --git a/pcmdi_metrics/mjo/lib/lib_mjo.py b/pcmdi_metrics/mjo/lib/lib_mjo.py index 7e3cb58dd..53a2678b4 100644 --- a/pcmdi_metrics/mjo/lib/lib_mjo.py +++ b/pcmdi_metrics/mjo/lib/lib_mjo.py @@ -2,116 +2,113 @@ Code written by Jiwoo Lee, LLNL. Feb. 2019 Inspired by Daehyun Kim and Min-Seop Ahn's MJO metrics. +Code update history +2024-05 converted to use xcdat as base building block (Jiwoo Lee) + Reference: Ahn, MS., Kim, D., Sperber, K.R. et al. Clim Dyn (2017) 49: 4023. https://doi.org/10.1007/s00382-017-3558-4 """ -import cdms2 -import cdtime -import cdutil -import MV2 +from typing import Union + import numpy as np +import xarray as xr from scipy import signal -import pcmdi_metrics +from pcmdi_metrics.io import base, get_time_key, select_subset +from pcmdi_metrics.utils import create_target_grid, regrid -def interp2commonGrid(d, dlat, debug=False): - """ - input - - d: cdms array - - dlat: resolution (i.e. grid distance) in degree - output - - d2: d interpolated to dlat resolution grid - """ - nlat = int(180 / dlat) - grid = cdms2.createGaussianGrid(nlat, xorigin=0.0, order="yx") - d2 = d.regrid(grid, regridTool="regrid2", mkCyclic=True) - d2 = d2(latitude=(-10, 10)) +def interp2commonGrid(ds, data_var, dlat, dlon=None, debug=False): + if dlon is None: + dlon = dlat + + # Generate grid + grid = create_target_grid( + target_grid_resolution=f"{dlat}x{dlon}", grid_type="uniform" + ) + + # Regrid + ds_regrid = regrid(ds, data_var, grid) + ds_regrid_subset = select_subset(ds_regrid, lat=(-10, 10)) + if debug: - print("debug: d2.shape:", d2.shape) - return d2 + print( + "debug: ds_regrid_subset[data_var] shape:", ds_regrid_subset[data_var].shape + ) + return ds_regrid_subset -def subSliceSegment(d, year, mon, day, length): + +def subSliceSegment( + ds: Union[xr.Dataset, xr.DataArray], year: int, mon: int, day: int, length: int +) -> Union[xr.Dataset, xr.DataArray]: """ - Note: From given cdms array (3D: time and spatial 2D) + Note: From given array (3D: time and spatial 2D) Subslice to get segment with given length starting from given time. input - - d: cdms array + - ds: xarray dataset or dataArray - year: segment starting year (integer) - mon: segment starting month (integer) - day: segement starting day (integer) - length: segment length (integer) """ - tim = d.getTime() - comTim = tim.asComponentTime() - h = comTim[0].hour - m = comTim[0].minute - s = comTim[0].second - cptime = cdtime.comptime(year, mon, day, h, m, s) # start date of segment - n = comTim.index(cptime) # time dimension index of above start date - d2 = d.subSlice((n, n + length)) # slie 180 time steps starting from above index - return d2 - - -def Remove_dailySeasonalCycle(d, d_cyc): - """ - Note: Remove daily seasonal cycle - input - - d: cdms array - - d_cyc: numpy array - output - - d2: cdms array - """ - d2 = MV2.subtract(d, d_cyc) - # Preserve Axes - for i in range(len(d.shape)): - d2.setAxis(i, d.getAxis(i)) - # Preserve variable id (How to preserve others?) - d2.id = d.id - return d2 + time_key = get_time_key(ds) + n = list(ds[time_key].values).index( + ds.sel(time=f"{year:04}-{mon:02}-{day:02}")[time_key] + ) + + return ds.isel( + time=slice(n, n + length) + ) # slie 180 time steps starting from above index -def get_daily_ano_segment(d_seg): + +def get_daily_ano_segment(d_seg: xr.Dataset, data_var: str) -> xr.Dataset: """ Note: 1. Get daily time series (3D: time and spatial 2D) 2. Meridionally average (2D: time and spatial, i.e., longitude) 3. Get anomaly by removing time mean of the segment input - - d_seg: cdms2 data + - d_seg: xarray dataset + - data_var: name of variable output - - d_seg_x_ano: 2d array + - d_seg_x_ano: xarray dataset that contains 2d output array """ - cdms2.setAutoBounds("on") # sub region - d_seg = d_seg(latitude=(-10, 10)) + d_seg = select_subset(d_seg, lat=(-10, 10)) + # Get meridional average (3d (t, y, x) to 2d (t, y)) - d_seg_x = cdutil.averager(d_seg, axis="y", weights="weighted") + d_seg_x = d_seg.spatial.average(data_var, axis=["Y"]) + # Get time-average in the segment on each longitude grid - d_seg_x_ave = cdutil.averager(d_seg_x, axis="t") + d_seg_x_ave = d_seg_x.temporal.average(data_var) + # Remove time mean for each segment - d_seg_x_ano = MV2.subtract(d_seg_x, d_seg_x_ave) + d_seg_x_ano = d_seg.copy() + d_seg_x_ano[data_var] = d_seg_x[data_var] - d_seg_x_ave[data_var] + return d_seg_x_ano -def space_time_spectrum(d_seg_x_ano): +def space_time_spectrum(d_seg_x_ano: xr.Dataset, data_var: str) -> np.ndarray: """ input - - d: 2d cdms MV2 array (t (time), n (space)) + - d: xarray dataset that contains 2d DataArray (t (time), n (space)) named as `data_var` + - data_var: name of the 2d DataArray output - - p: 2d array for power + - p: 2d numpy array for power NOTE: Below code taken from https://github.com/CDAT/wk/blob/2b953281c7a4c5d0ac2d79fcc3523113e31613d5/WK/process.py#L188 """ # Number of grid in longitude axis, and timestep for each segment - NTSub = d_seg_x_ano.shape[0] # NTSub - NL = d_seg_x_ano.shape[1] # NL + NTSub = d_seg_x_ano[data_var].shape[0] # NTSub + NL = d_seg_x_ano[data_var].shape[1] # NL # Tapering - d_seg_x_ano = taper(d_seg_x_ano) + d_seg_x_ano[data_var] = taper(d_seg_x_ano[data_var]) # Power sepctrum analysis - EE = np.fft.fft2(d_seg_x_ano, axes=(1, 0)) / float(NL) / float(NTSub) + EE = np.fft.fft2(d_seg_x_ano[data_var], axes=(1, 0)) / float(NL) / float(NTSub) # Now the array EE(n,t) contains the (complex) space-time spectrum. """ Create array PEE(NL+1,NT/2+1) which contains the (real) power spectrum. @@ -136,48 +133,18 @@ def taper(data): """ Note: taper first and last 45 days with cosine window, using scipy.signal function input - - data: cdms 2d array (t, n) t: time, n: space (meridionally averaged) + - data: 2d array (t, n) t: time, n: space (meridionally averaged) output: - data: tapered data """ - window = signal.tukey(len(data)) + window = signal.windows.tukey(len(data)) data2 = data.copy() for i in range(0, len(data)): - data2[i] = MV2.multiply(data[i][:], window[i]) + data2[i] = np.multiply(data[i][:], window[i]) return data2 -def decorate_2d_array_axes( - a, y, x, a_id=None, y_id=None, x_id=None, y_units=None, x_units=None -): - """ - Note: Decorate array with given axes - input - - a: 2d cdms MV2 or numpy array to decorate axes - - y: list of numbers to be axis 0 - - x: list of numbers to be axis 1 - - a_id: id of variable a - - y_id, x_id: id of axes, string - - y_units, x_units: units of axes - output - - return the given array, a, with axes attached - """ - y = MV2.array(y) - x = MV2.array(x) - # Create the frequencies axis - Y = cdms2.createAxis(y) - Y.id = y_id - Y.units = y_units - # Create the wave numbers axis - X = cdms2.createAxis(x) - X.id = x_id - X.units = x_units - # Makes it an MV2 with axis and id (id come sfrom orignal data id) - a = MV2.array(a, axes=(Y, X), id=a_id) - return a - - -def generate_axes_and_decorate(Power, NT, NL): +def generate_axes_and_decorate(Power, NT: int, NL: int) -> xr.DataArray: """ Note: Generates axes for the decoration input @@ -185,44 +152,49 @@ def generate_axes_and_decorate(Power, NT, NL): - NT: integer, number of time step - NL: integer, number of spatial grid output - - Power: decorated 2d cdms array - - ff: frequency axis - - ss: wavenumber axis + - xr.DataArray that contains Power 2d DataArray that has frequency and zonalwavenumber axes """ # frequency ff = [] for t in range(0, NT + 1): ff.append(float(t - NT / 2) / float(NT)) - ff = MV2.array(ff) - ff.id = "frequency" - ff.units = "cycles per day" + ff = np.array(ff) + # wave number ss = [] for n in range(0, NL + 1): ss.append(float(n) - float(NL / 2)) - ss = MV2.array(ss) - ss.id = "zonalwavenumber" - ss.units = "-" - # Decoration - Power = decorate_2d_array_axes( + ss = np.array(ss) + + # Add name attributes to x and y coordinates + x_coords = xr.IndexVariable( + "zonalwavenumber", ss, attrs={"name": "zonalwavenumber", "units": "-"} + ) + y_coords = xr.IndexVariable( + "frequency", ff, attrs={"name": "frequency", "units": "cycles per day"} + ) + + # Create an xarray DataArray + da = xr.DataArray( Power, - ff, - ss, - a_id="power", - y_id=ff.id, - x_id=ss.id, - y_units=ff.units, - x_units=ss.units, + coords={"frequency": y_coords, "zonalwavenumber": x_coords}, + dims=["frequency", "zonalwavenumber"], + name="power", ) - return Power, ff, ss + + return da -def output_power_spectra(NL, NT, Power, ff, ss): +def output_power_spectra(NL: int, NT: int, Power, debug: bool = False) -> xr.DataArray: """ Below code taken and modified from Daehyun Kim's Fortran code (MSD/level_2/sample/stps/stps.sea.f.sample) """ # The corresponding frequencies, ff, and wavenumbers, ss, are:- PEE = Power + + ff = Power.frequency + ss = Power.zonalwavenumber + OEE = np.zeros((21, 11)) for n in range(int(NL / 2), int(NL / 2) + 1 + 10): nn = n - int(NL / 2) @@ -231,35 +203,49 @@ def output_power_spectra(NL, NT, Power, ff, ss): OEE[tt, nn] = PEE[t, n] a = list((ff[i] for i in range(int(NT / 2) - 10, int(NT / 2) + 1 + 10))) b = list((ss[i] for i in range(int(NL / 2), int(NL / 2) + 1 + 10))) - a = MV2.array(a) - b = MV2.array(b) - # Decoration - OEE = decorate_2d_array_axes( + a = np.array(a) + b = np.array(b) + + # Add name attributes to x and y coordinates + x_coords = xr.IndexVariable( + "zonalwavenumber", b, attrs={"name": "zonalwavenumber", "units": "-"} + ) + y_coords = xr.IndexVariable( + "frequency", a, attrs={"name": "frequency", "units": "cycles per day"} + ) + + # Create an xarray DataArray + OEE = xr.DataArray( OEE, - a, - b, - a_id="power", - y_id=ff.id, - x_id=ss.id, - y_units=ff.units, - x_units=ss.units, + coords={"frequency": y_coords, "zonalwavenumber": x_coords}, + dims=["frequency", "zonalwavenumber"], + name="power", ) + # Transpose for visualization - OEE = MV2.transpose(OEE, (1, 0)) - OEE.id = "power" - return OEE + if debug: + print("before transpose, OEE.shape:", OEE.shape) + transposed_OEE = OEE.transpose() + + if debug: + print("after transpose, transposed_OEE.shape:", transposed_OEE.shape) + return transposed_OEE -def write_netcdf_output(d, fname): + # return OEE + + +def write_netcdf_output(da: xr.DataArray, fname): """ Note: write array in a netcdf file input - - d: array - - fname: string. directory path and name of the netcd file, without .nc + - d: xr.DataArray object + - fname: string of filename. Directory path that includes file name without .nc + output + - None """ - fo = cdms2.open(fname + ".nc", "w") - fo.write(d) - fo.close() + ds = xr.Dataset({da.name: da}) + ds.to_netcdf(fname + ".nc") def calculate_ewr(OEE): @@ -270,34 +256,23 @@ def calculate_ewr(OEE): where x for frequency and y for wavenumber. Actual ranges of frequency and wavenumber have been checked and applied. """ - east_power_domain = OEE(zonalwavenumber=(1, 3), frequency=(0.0166667, 0.0333333)) - west_power_domain = OEE(zonalwavenumber=(1, 3), frequency=(-0.0333333, -0.0166667)) - eastPower = np.average(np.array(east_power_domain)) - westPower = np.average(np.array(west_power_domain)) + east_power_domain = OEE.sel( + zonalwavenumber=slice(1, 3), frequency=slice(0.016, 0.034) + ) + west_power_domain = OEE.sel( + zonalwavenumber=slice(1, 3), frequency=slice(-0.034, -0.016) + ) + eastPower = np.average(east_power_domain) + westPower = np.average(west_power_domain) ewr = eastPower / westPower return ewr, eastPower, westPower -def unit_conversion(data, UnitsAdjust): - """ - Convert unit following given tuple using MV2 - input: - - data: cdms array - - UnitsAdjust: tuple with 4 elements - e.g.: (True, 'multiply', 86400., 'mm d-1'): e.g., kg m-2 s-1 to mm d-1 - (False, 0, 0, 0): no unit conversion - """ - if UnitsAdjust[0]: - data = getattr(MV2, UnitsAdjust[1])(data, UnitsAdjust[2]) - data.units = UnitsAdjust[3] - return data - - def mjo_metrics_to_json( outdir, json_filename, result_dict, model=None, run=None, cmec_flag=False ): # Open JSON - JSON = pcmdi_metrics.io.base.Base(outdir, json_filename) + JSON = base.Base(outdir, json_filename) # Dict for JSON if model is None and run is None: result_dict_to_json = result_dict diff --git a/pcmdi_metrics/mjo/lib/mjo_metric_calc.py b/pcmdi_metrics/mjo/lib/mjo_metric_calc.py index 88b1a16e2..0b38f28b6 100644 --- a/pcmdi_metrics/mjo/lib/mjo_metric_calc.py +++ b/pcmdi_metrics/mjo/lib/mjo_metric_calc.py @@ -1,12 +1,11 @@ import os +from datetime import datetime -import cdms2 -import cdtime -import MV2 import numpy as np +import xarray as xr +from pcmdi_metrics.io import get_latitude, get_longitude, get_time_key, xcdat_open from pcmdi_metrics.mjo.lib import ( - Remove_dailySeasonalCycle, calculate_ewr, generate_axes_and_decorate, get_daily_ano_segment, @@ -14,9 +13,9 @@ output_power_spectra, space_time_spectrum, subSliceSegment, - unit_conversion, write_netcdf_output, ) +from pcmdi_metrics.utils import adjust_units from .debug_chk_plot import debug_chk_plot from .plot_wavenumber_frequency_power import plot_power @@ -34,38 +33,49 @@ def mjo_metric_ewr_calculation( degX, UnitsAdjust, inputfile, - var, - startYear, - endYear, - segmentLength, - dir_paths, - season="NDJFMA", + data_var: str, + startYear: int, + endYear: int, + segmentLength: int, + dir_paths: str, + season: str = "NDJFMA", ): # Open file to read daily dataset if debug: - print("debug: open file") - f = cdms2.open(inputfile) - d = f[var] - tim = d.getTime() - comTim = tim.asComponentTime() - lat = d.getLatitude() - lon = d.getLongitude() + print(f"debug: open file: {inputfile}") + + ds = xcdat_open(inputfile) + + lat = get_latitude(ds) + lon = get_longitude(ds) # Get starting and ending year and month if debug: print("debug: check time") - first_time = comTim[0] - last_time = comTim[-1] + + time_key = get_time_key(ds) + + # Get first time step date + first_time_year = ds[time_key][0].item().year + first_time_month = ds[time_key][0].item().month + first_time_day = ds[time_key][0].item().day + first_time = datetime(first_time_year, first_time_month, first_time_day) + + # Get last time step date + last_time_year = ds[time_key][-1].item().year + last_time_month = ds[time_key][-1].item().month + last_time_day = ds[time_key][-1].item().day + last_time = datetime(last_time_year, last_time_month, last_time_day) if season == "NDJFMA": # Adjust years to consider only when continuous NDJFMA is available - if first_time > cdtime.comptime(startYear, 11, 1): + if first_time > datetime(startYear, 11, 1): startYear += 1 - if last_time < cdtime.comptime(endYear, 4, 30): + if last_time < datetime(endYear, 4, 30): endYear -= 1 # Number of grids for 2d fft input - NL = len(d.getLongitude()) # number of grid in x-axis (longitude) + NL = len(lon.values) # number of grid in x-axis (longitude) if cmmGrid: NL = int(360 / degX) NT = segmentLength # number of time step for each segment (need to be an even number) @@ -84,39 +94,67 @@ def mjo_metric_ewr_calculation( elif season == "MJJASO": mon = 5 numYear = endYear - startYear + 1 + day = 1 + # Store each year's segment in a dictionary: segment[year] segment = {} segment_ano = {} - daSeaCyc = MV2.zeros((NT, d.shape[1], d.shape[2])) + + daSeaCyc = xr.DataArray( + np.zeros((NT, ds[data_var].shape[1], ds[data_var].shape[2])), + dims=["day", "lat", "lon"], + coords={"day": np.arange(NT), "lat": lat, "lon": lon}, + ) + daSeaCyc_values = daSeaCyc.values.copy() + + if debug: + print("debug: before year loop: daSeaCyc.shape:", daSeaCyc.shape) + + # Loop over years for year in range(startYear, endYear): print(year) - segment[year] = subSliceSegment(d, year, mon, day, NT) + segment[year] = subSliceSegment(ds, year, mon, day, NT) # units conversion - segment[year] = unit_conversion(segment[year], UnitsAdjust) + segment[year][data_var] = adjust_units(segment[year][data_var], UnitsAdjust) + if debug: + print( + "debug: year, segment[year][data_var].shape:", + year, + segment[year][data_var].shape, + ) # Get climatology of daily seasonal cycle - daSeaCyc = MV2.add(MV2.divide(segment[year], float(numYear)), daSeaCyc) + daSeaCyc_values = ( + segment[year][data_var].values / float(numYear) + ) + daSeaCyc_values + + daSeaCyc.values = daSeaCyc_values + + if debug: + print("debug: after year loop: daSeaCyc.shape:", daSeaCyc.shape) + # Remove daily seasonal cycle from each segment if numYear > 1: + # Loop over years for year in range(startYear, endYear): - segment_ano[year] = Remove_dailySeasonalCycle(segment[year], daSeaCyc) + # Remove daily Seasonal Cycle + segment_ano[year] = segment[year].copy() + segment_ano[year][data_var].values = ( + segment[year][data_var].values - daSeaCyc.values + ) else: segment_ano[year] = segment[year] - # Assign lat/lon to arrays - daSeaCyc.setAxis(1, lat) - daSeaCyc.setAxis(2, lon) - segment_ano[year].setAxis(1, lat) - segment_ano[year].setAxis(2, lon) - - """ Space-time power spectra - - Handle each segment (i.e. each year) separately. - 1. Get daily time series (3D: time and spatial 2D) - 2. Meridionally average (2D: time and spatial, i.e., longitude) - 3. Get anomaly by removing time mean of the segment - 4. Proceed 2-D FFT to get power. - Then get multi-year averaged power after the year loop. - """ + + # ----------------------------------------------------------------- + # Space-time power spectra + # ----------------------------------------------------------------- + # Handle each segment (i.e. each year) separately. + # 1. Get daily time series (3D: time and spatial 2D) + # 2. Meridionally average (2D: time and spatial, i.e., longitude) + # 3. Get anomaly by removing time mean of the segment + # 4. Proceed 2-D FFT to get power. + # Then get multi-year averaged power after the year loop. + # ----------------------------------------------------------------- # Define array for archiving power from each year segment Power = np.zeros((numYear, NT + 1, NL + 1), np.float) @@ -129,23 +167,30 @@ def mjo_metric_ewr_calculation( d_seg = segment_ano[year] # Regrid: interpolation to common grid if cmmGrid: - d_seg = interp2commonGrid(d_seg, degX, debug=debug) + d_seg = interp2commonGrid(d_seg, data_var, degX, debug=debug) # Subregion, meridional average, and remove segment time mean - d_seg_x_ano = get_daily_ano_segment(d_seg) + d_seg_x_ano = get_daily_ano_segment(d_seg, data_var) # Compute space-time spectrum if debug: print("debug: compute space-time spectrum") - Power[n, :, :] = space_time_spectrum(d_seg_x_ano) + Power[n, :, :] = space_time_spectrum(d_seg_x_ano, data_var) # Multi-year averaged power Power = np.average(Power, axis=0) + # Generates axes for the decoration - Power, ff, ss = generate_axes_and_decorate(Power, NT, NL) + Power = generate_axes_and_decorate(Power, NT, NL) + # Output for wavenumber-frequency power spectra - OEE = output_power_spectra(NL, NT, Power, ff, ss) + OEE = output_power_spectra(NL, NT, Power) + + if debug: + print("OEE:", OEE) + print("OEE.shape:", OEE.shape) # E/W ratio ewr, eastPower, westPower = calculate_ewr(OEE) + print("ewr: ", ewr) print("east power: ", eastPower) print("west power: ", westPower) @@ -166,9 +211,11 @@ def mjo_metric_ewr_calculation( os.makedirs(dir_paths["graphics"], exist_ok=True) fout = os.path.join(dir_paths["graphics"], output_filename) if model == "obs": - title = f"OBS ({run})\n{var.capitalize()}, {season} {startYear}-{endYear}" + title = ( + f"OBS ({run})\n{data_var.capitalize()}, {season} {startYear}-{endYear}" + ) else: - title = f"{mip.upper()}: {model} ({run})\n{var.capitalize()}, {season} {startYear}-{endYear}" + title = f"{mip.upper()}: {model} ({run})\n{data_var.capitalize()}, {season} {startYear}-{endYear}" if cmmGrid: title += ", common grid (2.5x2.5deg)" @@ -186,8 +233,13 @@ def mjo_metric_ewr_calculation( # Debug checking plot if debug and plot: debug_chk_plot( - d_seg_x_ano, Power, OEE, segment[year], daSeaCyc, segment_ano[year] + d_seg_x_ano, + Power, + OEE, + segment[year][data_var], + daSeaCyc, + segment_ano[year][data_var], ) - f.close() + ds.close() return metrics_result diff --git a/pcmdi_metrics/mjo/lib/plot_wavenumber_frequency_power.py b/pcmdi_metrics/mjo/lib/plot_wavenumber_frequency_power.py index d60683fd3..564021507 100644 --- a/pcmdi_metrics/mjo/lib/plot_wavenumber_frequency_power.py +++ b/pcmdi_metrics/mjo/lib/plot_wavenumber_frequency_power.py @@ -1,15 +1,15 @@ import copy import os -import cdms2 import matplotlib.cm import matplotlib.pyplot as plt +import xarray as xr from matplotlib.patches import Rectangle -def plot_power(d, title, fout, ewr=None): - y = d.getAxis(0)[:] - x = d.getAxis(1)[:] +def plot_power(d: xr.DataArray, title: str, fout: str, ewr=None): + x = d["frequency"] + y = d["zonalwavenumber"] # adjust font size SMALL_SIZE = 8 @@ -87,8 +87,8 @@ def plot_power(d, title, fout, ewr=None): currentAxis = plt.gca() currentAxis.add_patch( Rectangle( - (0.0166667, 1), - 0.0333333 - 0.0166667, + (0.016, 1), + 0.034 - 0.016, 2, edgecolor="black", ls="--", @@ -97,8 +97,8 @@ def plot_power(d, title, fout, ewr=None): ) currentAxis.add_patch( Rectangle( - (-0.0333333, 1), - 0.0333333 - 0.0166667, + (-0.034, 1), + 0.034 - 0.016, 2, edgecolor="black", ls="--", @@ -132,8 +132,9 @@ def plot_power(d, title, fout, ewr=None): imgdir = "." - f = cdms2.open(os.path.join(datadir, ncfile)) - d = f("power") + ds = xr.open_dataset(os.path.join(datadir, ncfile)) + d = ds["power"] + fout = os.path.join(imgdir, pngfilename) plot_power(d, title, fout, ewr=ewr) diff --git a/pcmdi_metrics/mjo/lib/post_process_plot.py b/pcmdi_metrics/mjo/scripts/post_process_plot.py similarity index 93% rename from pcmdi_metrics/mjo/lib/post_process_plot.py rename to pcmdi_metrics/mjo/scripts/post_process_plot.py index bea7ca873..44aced650 100644 --- a/pcmdi_metrics/mjo/lib/post_process_plot.py +++ b/pcmdi_metrics/mjo/scripts/post_process_plot.py @@ -1,7 +1,7 @@ import glob import os -import cdms2 +import xarray as xr from lib_mjo import calculate_ewr from plot_wavenumber_frequency_power import plot_power @@ -48,10 +48,9 @@ def main(): ncfile = ( "_".join([mip, model, exp, run, "mjo", period, "cmmGrid"]) + ".nc" ) - f = cdms2.open(os.path.join(datadir, ncfile)) - d = f("power") + ds = xr.open_dataset(os.path.join(datadir, ncfile)) + d = ds["power"] d_runs.append(d) - f.close() title = ( mip.upper() + ": " @@ -69,6 +68,7 @@ def main(): fout = os.path.join(imgdir, pngfilename) # plot plot_power(d, title, fout, ewr) + ds.close() except Exception: print(model, run, "cannnot load") pass diff --git a/pcmdi_metrics/mjo/lib/post_process_plot_ensemble_mean.py b/pcmdi_metrics/mjo/scripts/post_process_plot_ensemble_mean.py similarity index 91% rename from pcmdi_metrics/mjo/lib/post_process_plot_ensemble_mean.py rename to pcmdi_metrics/mjo/scripts/post_process_plot_ensemble_mean.py index 83f9012c0..df9b432bd 100644 --- a/pcmdi_metrics/mjo/lib/post_process_plot_ensemble_mean.py +++ b/pcmdi_metrics/mjo/scripts/post_process_plot_ensemble_mean.py @@ -1,8 +1,8 @@ import glob import os -import cdms2 -import MV2 +import numpy as np +import xarray as xr from lib_mjo import calculate_ewr from plot_wavenumber_frequency_power import plot_power @@ -62,18 +62,21 @@ def main(): ) + ".nc" ) - f = cdms2.open(os.path.join(datadir, ncfile)) - d = f("power") + + ds = xr.open_dataset(os.path.join(datadir, ncfile)) + d = ds["power"] + d_runs.append(d) - f.close() + except Exception as err: print(model, run, "cannnot load:", err) pass + if run == runs_list[-1]: num_runs = len(d_runs) # ensemble mean - d_avg = MV2.average(d_runs, axis=0) - d_avg.setAxisList(d.getAxisList()) + d_avg = np.average(d_runs, axis=0) + # d_avg.setAxisList(d.getAxisList()) title = ( mip.upper() + ": " diff --git a/pcmdi_metrics/sea_ice/sea_ice_driver.py b/pcmdi_metrics/sea_ice/sea_ice_driver.py index 64471a6c7..b62da14c6 100644 --- a/pcmdi_metrics/sea_ice/sea_ice_driver.py +++ b/pcmdi_metrics/sea_ice/sea_ice_driver.py @@ -228,14 +228,14 @@ end_year = meyear real_clim = { - "arctic": {"model_mean": None}, - "ca": {"model_mean": None}, - "na": {"model_mean": None}, - "np": {"model_mean": None}, - "antarctic": {"model_mean": None}, - "sp": {"model_mean": None}, - "sa": {"model_mean": None}, - "io": {"model_mean": None}, + "arctic": {"model_mean": {}}, + "ca": {"model_mean": {}}, + "na": {"model_mean": {}}, + "np": {"model_mean": {}}, + "antarctic": {"model_mean": {}}, + "sp": {"model_mean": {}}, + "sa": {"model_mean": {}}, + "io": {"model_mean": {}}, } real_mean = { "arctic": {"model_mean": 0}, @@ -309,7 +309,15 @@ "%(model_version)": model, "%(realization)": run, } - test_data_full_path = os.path.join(test_data_path, filename_template) + test_data_tmp = lib.replace_multi(test_data_path, tags) + if "*" in test_data_tmp: + # Get the most recent version for last wildcard + ind = test_data_tmp.split("/")[::-1].index("*") + tmp1 = "/".join(test_data_tmp.split("/")[0:-ind]) + globbed = glob.glob(tmp1) + globbed.sort() + test_data_tmp = globbed[-1] + test_data_full_path = os.path.join(test_data_tmp, filename_template) test_data_full_path = lib.replace_multi(test_data_full_path, tags) test_data_full_path = glob.glob(test_data_full_path) test_data_full_path.sort() @@ -365,16 +373,7 @@ # Running sum of all realizations for rgn in clims: real_clim[rgn][run] = clims[rgn] - if real_clim[rgn]["model_mean"] is None: - real_clim[rgn]["model_mean"] = clims[rgn] - else: - real_clim[rgn]["model_mean"][var] = ( - real_clim[rgn]["model_mean"][var] + clims[rgn][var] - ) real_mean[rgn][run] = means[rgn] - real_mean[rgn]["model_mean"] = ( - real_mean[rgn]["model_mean"] + means[rgn] - ) print("\n-------------------------------------------") print("Calculating model regional average metrics \nfor ", model) @@ -382,12 +381,12 @@ for rgn in real_clim: print(rgn) # Get model mean - real_clim[rgn]["model_mean"][var] = real_clim[rgn]["model_mean"][ - var - ] / len(list_of_runs) - real_mean[rgn]["model_mean"] = real_mean[rgn]["model_mean"] / len( - list_of_runs + datalist = [real_clim[rgn][r][var].data for r in list_of_runs] + real_clim[rgn]["model_mean"][var] = np.nanmean( + np.array(datalist), axis=0 ) + datalist = [real_mean[rgn][r] for r in list_of_runs] + real_mean[rgn]["model_mean"] = np.nanmean(np.array(datalist)) for run in real_clim[rgn]: # Set up metrics dictionary diff --git a/pcmdi_metrics/utils/__init__.py b/pcmdi_metrics/utils/__init__.py index d870b03c1..dc4a935f0 100644 --- a/pcmdi_metrics/utils/__init__.py +++ b/pcmdi_metrics/utils/__init__.py @@ -1,3 +1,4 @@ +from .adjust_units import adjust_units from .custom_season import ( custom_season_average, custom_season_departure, diff --git a/pcmdi_metrics/utils/adjust_units.py b/pcmdi_metrics/utils/adjust_units.py new file mode 100644 index 000000000..ee88d3d4a --- /dev/null +++ b/pcmdi_metrics/utils/adjust_units.py @@ -0,0 +1,27 @@ +import xarray as xr + + +def adjust_units(da: xr.DataArray, adjust_tuple: tuple) -> xr.DataArray: + """Convert unit following information in the given tuple + + Parameters + ---------- + da : xr.DataArray + input data array + adjust_tuple : tuple with at least 3 elements (4th element is optional for units) + e.g.: (True, 'multiply', 86400., 'mm d-1'): e.g., kg m-2 s-1 to mm d-1 + (False, 0, 0, 0): no unit conversion + + Returns + ------- + xr.DataArray + data array that contains converted values and attributes + """ + action_dict = {"multiply": "*", "divide": "/", "add": "+", "subtract": "-"} + if adjust_tuple[0]: + print("Converting units by ", adjust_tuple[1], adjust_tuple[2]) + cmd = " ".join(["da", str(action_dict[adjust_tuple[1]]), str(adjust_tuple[2])]) + da = eval(cmd) + if len(adjust_tuple) > 3: + da.assign_attrs(units=adjust_tuple[3]) + return da diff --git a/pcmdi_metrics/utils/grid.py b/pcmdi_metrics/utils/grid.py index 4de4d677a..968cbd055 100644 --- a/pcmdi_metrics/utils/grid.py +++ b/pcmdi_metrics/utils/grid.py @@ -17,6 +17,7 @@ def create_target_grid( lon1: float = 0.0, lon2: float = 360.0, target_grid_resolution: str = "2.5x2.5", + grid_type: str = "uniform", ) -> xr.Dataset: """Generate a uniform grid for given latitude/longitude ranges and resolution @@ -32,6 +33,8 @@ def create_target_grid( Starting latitude, by default 360. target_grid_resolution : str, optional grid resolution in degree for lat and lon, by default "2.5x2.5" + grid_type : str, optional + type of the grid ('uniform' or 'gaussian'), by default "uniform" Returns ------- @@ -46,11 +49,11 @@ def create_target_grid( Global uniform grid: - >>> t_grid = create_target_grid(-90, 90, 0, 360, target_grid="5x5") + >>> grid = create_target_grid(-90, 90, 0, 360, target_grid="5x5") Regional uniform grid: - >>> t_grid = create_target_grid(30, 50, 100, 150, target_grid="0.5x0.5") + >>> grid = create_target_grid(30, 50, 100, 150, target_grid="0.5x0.5") """ # generate target grid res = target_grid_resolution.split("x") @@ -60,10 +63,33 @@ def create_target_grid( start_lon = lon1 + lon_res / 2.0 end_lat = lat2 - lat_res / 2 end_lon = lon2 - lon_res / 2 - t_grid = xc.create_uniform_grid( - start_lat, end_lat, lat_res, start_lon, end_lon, lon_res - ) - return t_grid + + if grid_type == "uniform": + grid = xc.create_uniform_grid( + start_lat, end_lat, lat_res, start_lon, end_lon, lon_res + ) + elif grid_type == "gaussian": + nlat = int(180 / lat_res) + grid = xc.create_gaussian_grid(nlat) + + # If the longitude values include 0 and 360, then remove 360 to avoid having repeating grid + if 0 in grid.lon.values and 360 in grid.lon.values: + min_lon = grid.lon.values[0] # 0 + # max_lon = grid.lon.values[-1] # 360 + second_max_lon = grid.lon.values[-2] # 360-dlat + grid = grid.sel(lon=slice(min_lon, second_max_lon)) + + # Reverse latitude if needed + if grid.lat.values[0] > grid.lat.values[-1]: + grid = grid.isel(lat=slice(None, None, -1)) + + grid = grid.sel(lat=slice(start_lat, end_lat), lon=slice(start_lon, end_lon)) + else: + raise ValueError( + f"grid_type {grid_type} is undefined. Please use either 'uniform' or 'gaussian'" + ) + + return grid def __haversine(lat1, lon1, lat2, lon2):