diff --git a/intake_esgf/catalog.py b/intake_esgf/catalog.py index 5d62875..3310277 100644 --- a/intake_esgf/catalog.py +++ b/intake_esgf/catalog.py @@ -6,7 +6,7 @@ from functools import partial from multiprocessing.pool import ThreadPool from pathlib import Path -from typing import Any, Callable, Literal, Union +from typing import Callable, Literal, Union import pandas as pd import requests @@ -601,7 +601,6 @@ def to_dataset_dict( add_measures: bool = True, globus_endpoint: Union[str, None] = None, globus_path: Union[Path, None] = None, - operators: list[Any] = [], ) -> dict[str, xr.Dataset]: """Return the current search as a dictionary of datasets. @@ -706,9 +705,6 @@ def to_dataset_dict( ): ds[key] = add_cell_measures(ds[key], self) - # If the user specifies operators, apply them now - for op in operators: - ds = op(ds) return ds def remove_incomplete(self, complete: Callable[[pd.DataFrame], bool]): diff --git a/intake_esgf/operators.py b/intake_esgf/operators.py deleted file mode 100644 index 1244fe6..0000000 --- a/intake_esgf/operators.py +++ /dev/null @@ -1,156 +0,0 @@ -"""A collection of common operators used in CMIP analysis.""" - -from typing import Union - -import pandas as pd -import xarray as xr - -from intake_esgf import IN_NOTEBOOK -from intake_esgf.base import bar_format, get_cell_measure, get_search_criteria -from intake_esgf.projects import get_likely_project, projects - -if IN_NOTEBOOK: - from tqdm import tqdm_notebook as tqdm -else: - from tqdm import tqdm - - -def global_sum( - dsd: Union[dict[str, xr.Dataset], xr.Dataset], quiet: bool = False -) -> Union[dict[str, xr.Dataset], xr.Dataset]: - """Integrate the datasets globally in place with the proper cell measures. - - Parameters - ---------- - dsd - The dataset or dictionary of datasets to integrate. - - """ - - def _global_sum(ds: xr.Dataset): - ds_sum = {} - ds_attrs = ds.attrs - for var, da in ds.items(): - measure = get_cell_measure(var, ds) - if measure is None: - continue - attrs = da.attrs # attributes get dropped, so we rem them - da = (da * measure).sum(dim=measure.dims) - da.attrs = attrs - da.attrs["units"] = f"({da.attrs['units']}) * ({measure.attrs['units']})" - ds_sum[var] = da - ds_sum = xr.Dataset(ds_sum) - ds_sum.attrs = ds_attrs - return ds_sum - - if isinstance(dsd, xr.Dataset): - return _global_sum(dsd) - for key, ds in tqdm( - dsd.items(), - disable=quiet, - bar_format=bar_format, - unit="dataset", - unit_scale=False, - desc="Global sum", - ascii=False, - total=len(dsd), - ): - dsd[key] = _global_sum(ds) - return dsd - - -def global_mean( - dsd: Union[dict[str, xr.Dataset], xr.Dataset], quiet: bool = False -) -> Union[dict[str, xr.Dataset], xr.Dataset]: - """Compute a area-weighted global mean of the datasets in place. - - Parameters - ---------- - dsd - The dataset or dictionary of datasets to average. - - """ - - def _global_mean(ds: xr.Dataset): - ds_mean = {} - ds_attrs = ds.attrs - for var, da in ds.items(): - measure = get_cell_measure(var, ds) - if measure is None: - continue - attrs = da.attrs # attributes get dropped, so we rem them - da = da.weighted(measure.fillna(0)).mean(dim=measure.dims) - da.attrs = attrs - ds_mean[var] = da - ds_mean = xr.Dataset(ds_mean) - ds_mean.attrs = ds_attrs - return ds_mean - - if isinstance(dsd, xr.Dataset): - return _global_mean(dsd) - for key, ds in tqdm( - dsd.items(), - disable=quiet, - bar_format=bar_format, - unit="dataset", - unit_scale=False, - desc="Global mean", - ascii=False, - total=len(dsd), - ): - dsd[key] = _global_mean(ds) - return dsd - - -def ensemble_mean( - dsd: dict[str, xr.Dataset], - include_std: bool = False, - quiet: bool = False, -) -> dict[str, xr.Dataset]: - """Compute the ensemble mean of the input dictionary of datasets. - - This routine intelligently combines the input data across where the only difference - in the facets is the `variant_label`. - - Parameters - ---------- - dsd - The dictionary of datasets. - include_std - Enable to include the standard deviation in the output. - quiet - Enable to silence the progress bar. - - """ - # parse facets out of the dataset attributes - df = [] - for key, ds in dsd.items(): - project_id = get_likely_project(ds.attrs) - project = projects[project_id] - df.append(get_search_criteria(ds, project_id)) - df[-1]["key"] = key - df = pd.DataFrame(df) - # now groupby everything but the variant_label and compute the mean/std - variant_facet = project.variant_facet() - grp_cols = [c for c in list(df.columns) if c not in [variant_facet, "key"]] - out = {} - for _, grp in tqdm( - df.groupby(grp_cols), - disable=quiet, - bar_format=bar_format, - unit="dataset", - unit_scale=False, - desc="Ensemble mean", - ascii=False, - ): - ds = xr.concat([dsd[key] for key in grp["key"].to_list()], dim="variant") - ds.attrs[variant_facet] = grp[variant_facet].to_list() - row = grp.iloc[0] - out[row["key"].replace(row[variant_facet], "mean")] = ds.mean( - dim="variant", keep_attrs=True - ) - if include_std: - out[row["key"].replace(row[variant_facet], "std")] = ds.std( - dim="variant", keep_attrs=True - ) - return out diff --git a/intake_esgf/tests/test_operators.py b/intake_esgf/tests/test_operators.py deleted file mode 100644 index 925dfba..0000000 --- a/intake_esgf/tests/test_operators.py +++ /dev/null @@ -1,52 +0,0 @@ -from functools import partial - -import pytest - -import intake_esgf.operators as ops -from intake_esgf import ESGFCatalog - - -# just to make these tests run faster -def trim_time(dsd): - for key, ds in dsd.items(): - dsd[key] = ds.sel(time=slice("1990-01-01", "2000-01-01")) - return dsd - - -def test_global_mean(): - cat = ESGFCatalog().search( - experiment_id=["historical"], - source_id="CanESM5", - variant_label="r1i1p1f1", - variable_id=["gpp", "fgco2"], - frequency="mon", - ) - dsd = cat.to_dataset_dict(ignore_facets=["table_id"]) - dsd = trim_time(dsd) - dsd = ops.global_mean(dsd) - assert set(["fgco2", "gpp"]) == set(dsd.keys()) - - -@pytest.mark.skip(reason="Temporary while we rework to_dataset_dict()") -def test_ensemble_mean(): - """Run a test on composition of operators. - - Operators may be locally defined, but we expect that the only argument taken is a - dictionary of datasets or a dataset. If a function has arguments, you can use - `partial` from `functools` to resolve the arugments and then pass the function into - the operators. - - """ - cat = ESGFCatalog().search( - experiment_id="historical", - source_id=["CanESM5"], - variant_label=["r1i1p1f1", "r2i1p1f1"], - variable_id=["gpp"], - frequency="mon", - ) - ensemble_mean = partial(ops.ensemble_mean, include_std=True) - dsd = cat.to_dataset_dict( - ignore_facets=["institution_id", "table_id"], - operators=[trim_time, ops.global_mean, ensemble_mean], - ) - assert set(["mean", "std"]) == set(dsd.keys())