From 7b52e8314da6cea3bea5bec51c8ee8c3dc5778de Mon Sep 17 00:00:00 2001 From: tomvothecoder Date: Mon, 9 Sep 2024 14:41:02 -0700 Subject: [PATCH] Add `.load()` to `_get_dataset_with_source_vars()` to improve performance --- .../667-arm_diags/debug/667_debug_perf.py | 104 ++++++++++++++++++ e3sm_diags/driver/utils/dataset_xr.py | 4 +- 2 files changed, 107 insertions(+), 1 deletion(-) create mode 100644 auxiliary_tools/cdat_regression_testing/667-arm_diags/debug/667_debug_perf.py diff --git a/auxiliary_tools/cdat_regression_testing/667-arm_diags/debug/667_debug_perf.py b/auxiliary_tools/cdat_regression_testing/667-arm_diags/debug/667_debug_perf.py new file mode 100644 index 0000000000..d9dbdbc15d --- /dev/null +++ b/auxiliary_tools/cdat_regression_testing/667-arm_diags/debug/667_debug_perf.py @@ -0,0 +1,104 @@ +# %% +import timeit + +setup_code = """ +import xarray as xr + +AIR_DENS = 1.225 # standard air density 1.225kg/m3 + +a1 = xr.open_dataarray("qa/667-arms-diags/a1.nc") +a2 = xr.open_dataarray("qa/667-arms-diags/a2.nc") +a3 = xr.open_dataarray("qa/667-arms-diags/a3.nc") +""" + +setup_code2 = """ +import xarray as xr + +AIR_DENS = 1.225 # standard air density 1.225kg/m3 + +a1 = xr.open_dataarray("qa/667-arms-diags/a1.nc") +a2 = xr.open_dataarray("qa/667-arms-diags/a2.nc") +a3 = xr.open_dataarray("qa/667-arms-diags/a3.nc") + +a1.load(scheduler="sync") +a2.load(scheduler="sync") +a3.load(scheduler="sync") +""" + +setup_code3 = """ +import xarray as xr + +AIR_DENS = 1.225 # standard air density 1.225kg/m3 + +a1_chunked = xr.open_dataarray("qa/667-arms-diags/a1.nc", chunks={"time": "auto"}) +a2_chunked = xr.open_dataarray("qa/667-arms-diags/a2.nc", chunks={"time": "auto"}) +a3_chunked = xr.open_dataarray("qa/667-arms-diags/a3.nc", chunks={"time": "auto"}) +""" + +code_statement1 = """ +with xr.set_options(keep_attrs=True): + var = (a1 + a2 + a3) * AIR_DENS / 1e6 +""" + + +code_statement2 = """ +with xr.set_options(keep_attrs=True): + var = (a1_chunked + a2_chunked + a3_chunked) * AIR_DENS / 1e6 +""" + +code_statement3 = """ +var_data = (a1.values + a2.values + a3.values) * AIR_DENS / 1e6 +var_new = xr.DataArray( + var_data, + dims=a1.dims, + coords=a1.coords, + name="a_num", + attrs={"units": "/cm3", "long_name": "aerosol number concentration"}, +) +""" + +code_statement4 = """ +var_data2 = (a1.data + a2.data + a3.data) * AIR_DENS / 1e6 +var_new2 = xr.DataArray( + name="a_num", data=var_data2, dims=a1.dims, coords=a1.coords, attrs=a1.attrs +) +var_new2.attrs.update( + {"units": "/cm3", "long_name": "aerosol number concentration"} +) +""" + + +def run_timeit(code_statement: str, setup_code: str) -> float: + elapsed_time = timeit.repeat( + code_statement, setup=setup_code, globals=globals(), repeat=3, number=1 + ) + + return min(elapsed_time) + + +elapsed_time_xarray = run_timeit(code_statement1, setup_code) +print(f"1. Elapsed time (Xarray non-chunked): {elapsed_time_xarray} seconds") + +elapsed_time_xarray_load = run_timeit(code_statement1, setup_code2) +print( + f"2. Elapsed time (Xarray non-chunked with .load()): {elapsed_time_xarray_load} seconds" +) +elapsed_time_xarray_chunked = run_timeit(code_statement2, setup_code3) +print(f"3. Elapsed time (Xarray chunked): {elapsed_time_xarray_chunked} seconds") + +elapsed_time_numpy_1 = run_timeit(code_statement3, setup_code) +print(f"4. Elapsed time (numpy .values): {elapsed_time_numpy_1} seconds") + +elapsed_time_numpy_2 = run_timeit(code_statement4, setup_code) +print(f"5. Elapsed time (numpy .data): {elapsed_time_numpy_2} seconds") + + +""" +Results +---------- +1. Elapsed time (Xarray non-chunked): 6.540755605790764 seconds +2. Elapsed time (Xarray non-chunked with .load()): 0.17097265785560012 seconds +3. Elapsed time (Xarray chunked): 0.1452920027077198 seconds +4. Elapsed time (numpy .values): 6.418793010059744 seconds +5. Elapsed time (numpy .data): 7.334999438840896 seconds +""" diff --git a/e3sm_diags/driver/utils/dataset_xr.py b/e3sm_diags/driver/utils/dataset_xr.py index d75ca0d61f..10588378fb 100644 --- a/e3sm_diags/driver/utils/dataset_xr.py +++ b/e3sm_diags/driver/utils/dataset_xr.py @@ -417,7 +417,8 @@ def get_climo_dataset(self, var: str, season: ClimoFreq) -> xr.Dataset: elif self.is_time_series: ds = self.get_time_series_dataset(var) ds_climo = climo(ds, self.var, season).to_dataset() - return ds_climo + + return ds_climo def _get_climo_dataset(self, season: str) -> xr.Dataset: """Get the climatology dataset for the variable and season. @@ -1054,6 +1055,7 @@ def _get_dataset_with_source_vars(self, vars_to_get: Tuple[str, ...]) -> xr.Data ds = xr.merge(datasets) ds = squeeze_time_dim(ds) + ds.load(scheduler="sync") return ds