Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update CoCiP-grid data access patterns to reduce duplicate chunk downloads #171

Merged
merged 9 commits into from
Apr 8, 2024
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ The Unterstrasser (2016) parameterization can be used in CoCiP by setting a new
- Adds a parameter to `CoCipParams`, `unterstrasser_ice_survival_fraction`, that activates the Unterstrasser (2016) survival parameterization when set to `True`. This is disabled by default, and only implemented for `CoCiP`. `CoCiPGrid` will produce an error if run with `unterstrasser_ice_surival_fraction=True`.
- Modifies `CoCiPGrid` so that setting `compute_atr_20` (defined in `CoCipParams`) to `True` adds `global_yearly_mean_rf` and `atr20` to CoCiP-grid output.
- Replaces `pycontrails.datalib.GOES` ash convention label "MIT" with "SEVIRI"
- Modifies meteorology time step selection logic in `CoCiPGrid` to reduce duplicate chunk downloads when reading from remote zarr stores.
- Updates unit tests for xarray v2024.03.0, which introduced changes to netCDF decoding that slightly alter decoded values. Note that some unit tests will fail for earlier xarray versions.
- Updates `RegularGridInterpolator` to fall back on legacy scipy implementations of tensor-product spline methods when using scipy versions 1.13.0 and later.

## v0.50.0

Expand Down
4 changes: 2 additions & 2 deletions pycontrails/core/flight.py
Original file line number Diff line number Diff line change
Expand Up @@ -1356,8 +1356,8 @@ def length_met(self, key: str, threshold: float = 1.0) -> float:
>>> # Intersect and attach
>>> fl["air_temperature"] = fl.intersect_met(met['air_temperature'])
>>> fl["air_temperature"]
array([235.94658, 235.95767, 235.96873, ..., 234.59918, 234.60388,
234.60846], dtype=float32)
array([235.94657007, 235.95766965, 235.96873412, ..., 234.59917962,
234.60387402, 234.60845312])

>>> # Length (in meters) of waypoints whose temperature exceeds 236K
>>> fl.length_met("air_temperature", threshold=236)
Expand Down
40 changes: 39 additions & 1 deletion pycontrails/core/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@ def __init__(

self.grid = points
self.values = values
self.method = method
# TODO: consider supporting updated tensor-product spline methods
# see https://github.com/scipy/scipy/releases/tag/v1.13.0
self.method = _pick_method(scipy.__version__, method)
thabbott marked this conversation as resolved.
Show resolved Hide resolved
self.bounds_error = bounds_error
self.fill_value = fill_value

Expand Down Expand Up @@ -219,6 +221,42 @@ def _evaluate_linear(
raise ValueError(msg)


def _pick_method(scipy_version: str, method: str) -> str:
"""Select an interpolation method.

For scipy versions 1.13.0 and later, fall back on legacy implementations
of tensor-product spline methods. The default implementations in 1.13.0
and later are incompatible with this class.

Parameters
----------
scipy_version : str
scipy version (major.minor.patch)

method : str
Interpolation method. Passed into :class:`scipy.interpolate.RegularGridInterpolator`
as-is unless ``scipy_version`` is 1.13.0 or later and ``method`` is ``"slinear"``,
``"cubic"``, or ``"quintic"``. In this case, ``"_legacy"`` is appended to ``method``.

Returns
-------
str
Interpolation method adjusted for compatibility with this class.
"""
try:
version = scipy_version.split(".")
major = int(version[0])
minor = int(version[1])
except (IndexError, ValueError) as exc:
msg = f"Failed to parse major and minor version from {scipy_version}"
raise ValueError(msg) from exc

reimplemented_methods = ["slinear", "cubic", "quintic"]
if major > 1 or (major == 1 and minor >= 13) and method in reimplemented_methods:
return method + "_legacy"
return method


def _floatize_time(
time: npt.NDArray[np.datetime64], offset: np.datetime64
) -> npt.NDArray[np.float64]:
Expand Down
22 changes: 10 additions & 12 deletions pycontrails/core/met.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,13 +674,10 @@ class MetDataset(MetBase):
>>> da = mda.data # Underlying `xarray` object

>>> # Check out a few values
>>> da[5:10, 5:10, 1, 1].values
array([[224.0896 , 224.41374, 224.75946, 225.16237, 225.60507],
[224.09457, 224.42038, 224.76526, 225.16817, 225.61089],
[224.10037, 224.42618, 224.77106, 225.17314, 225.61586],
[224.10617, 224.43282, 224.7777 , 225.17812, 225.62166],
[224.11115, 224.44028, 224.7835 , 225.18393, 225.62663]],
dtype=float32)
>>> da[5:8, 5:8, 1, 1].values
array([[224.08959005, 224.41374427, 224.75945349],
[224.09456429, 224.42037658, 224.76525676],
[224.10036756, 224.42617985, 224.77106004]])

>>> # Mean temperature over entire array
>>> da.mean().load().item()
Expand Down Expand Up @@ -1618,24 +1615,25 @@ def interpolate(

>>> # Interpolation at a grid point agrees with value
>>> mda.interpolate(1, 2, 300, np.datetime64('2022-03-01T14:00'))
array([241.91974], dtype=float32)
array([241.91972984])

>>> da = mda.data
>>> da.sel(longitude=1, latitude=2, level=300, time=np.datetime64('2022-03-01T14')).item()
241.91974
241.9197298421629

>>> # Interpolation off grid
>>> mda.interpolate(1.1, 2.1, 290, np.datetime64('2022-03-01 13:10'))
array([239.83794], dtype=float32)
array([239.83793798])

>>> # Interpolate along path
>>> longitude = np.linspace(1, 2, 10)
>>> latitude = np.linspace(2, 3, 10)
>>> level = np.linspace(200, 300, 10)
>>> time = pd.date_range("2022-03-01T14", periods=10, freq="5min")
>>> mda.interpolate(longitude, latitude, level, time)
array([220.44348, 223.089 , 225.7434 , 228.41643, 231.10858, 233.54858,
235.71506, 237.86479, 239.99275, 242.10793], dtype=float32)
array([220.44347694, 223.08900738, 225.74338924, 228.41642088,
231.10858599, 233.54857391, 235.71504913, 237.86478872,
239.99274623, 242.10792167])
"""
# Load if necessary
if not self.in_memory:
Expand Down
18 changes: 10 additions & 8 deletions pycontrails/core/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -1694,12 +1694,14 @@ def intersect_met(

>>> # Intersect
>>> fl.intersect_met(met['air_temperature'], method='nearest')
array([231.6297 , 230.72604, 232.2432 , 231.88339, 231.0643 , 231.59073,
231.65126, 231.93065, 232.03345, 231.65955], dtype=float32)
array([231.62969892, 230.72604651, 232.24318771, 231.88338483,
231.06429438, 231.59073409, 231.65125393, 231.93064004,
232.03344087, 231.65954432])

>>> fl.intersect_met(met['air_temperature'], method='linear')
array([225.77794, 225.13908, 226.23122, 226.31831, 225.56102, 225.81192,
226.03194, 226.22057, 226.0377 , 225.63226], dtype=float32)
array([225.77794552, 225.13908414, 226.231218 , 226.31831528,
225.56102321, 225.81192149, 226.03192642, 226.22056121,
226.03770174, 225.63226188])

>>> # Interpolate and attach to `Flight` instance
>>> for key in met:
Expand All @@ -1708,11 +1710,11 @@ def intersect_met(
>>> # Show the final three columns of the dataframe
>>> fl.dataframe.iloc[:, -3:].head()
time air_temperature specific_humidity
0 2022-03-01 00:00:00 225.777939 0.000132
0 2022-03-01 00:00:00 225.777946 0.000132
1 2022-03-01 00:13:20 225.139084 0.000132
2 2022-03-01 00:26:40 226.231216 0.000107
3 2022-03-01 00:40:00 226.318314 0.000171
4 2022-03-01 00:53:20 225.561020 0.000109
2 2022-03-01 00:26:40 226.231218 0.000107
3 2022-03-01 00:40:00 226.318315 0.000171
4 2022-03-01 00:53:20 225.561022 0.000109

"""
# Override use_indices in certain situations
Expand Down
58 changes: 56 additions & 2 deletions pycontrails/models/cocipgrid/cocip_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import numpy as np
import numpy.typing as npt
import pandas as pd
import xarray as xr

import pycontrails
from pycontrails.core import models
Expand Down Expand Up @@ -311,23 +312,76 @@ def _maybe_downselect_met_rad(
``time_end``, new slices are selected from the larger ``self.met`` and
``self.rad`` data. The slicing only occurs in the time domain.

The end of currently-used ``met`` and ``rad`` will be used as the start
of newly-selected met slices when possible to avoid losing and re-loading
already-loaded met data.

If ``self.params["downselect_met"]`` is True, :func:`_downselect_met` has
already performed a spatial downselection of the met data.
"""
if met is None or time_end > met.indexes["time"].to_numpy()[-1]:

if met is None:
# idx is the first index at which self.met.variables["time"].to_numpy() >= time_end
idx = np.searchsorted(self.met.indexes["time"].to_numpy(), time_end)
sl = slice(max(0, idx - 1), idx + 1)
logger.debug("Select met slice %s", sl)
met = MetDataset(self.met.data.isel(time=sl), copy=False)

if rad is None or time_end > rad.indexes["time"].to_numpy()[-1]:
elif time_end > met.indexes["time"].to_numpy()[-1]:
current_times = met.indexes["time"].to_numpy()
all_times = self.met.indexes["time"].to_numpy()
# idx is the first index at which all_times >= time_end
idx = np.searchsorted(all_times, time_end)
sl = slice(max(0, idx - 1), idx + 1)

# case 1: cannot re-use end of current met as start of new met
if current_times[-1] != all_times[sl.start]:
logger.debug("Select met slice %s", sl)
met = MetDataset(self.met.data.isel(time=sl), copy=False)
# case 2: can re-use end of current met plus one step of new met
elif sl.start < all_times.size - 1:
sl = slice(sl.start + 1, sl.stop)
logger.debug("Reuse end of met and select met slice %s", sl)
met = MetDataset(
xr.concat((met.data.isel(time=[-1]), self.met.data.isel(time=sl)), dim="time"),
copy=False,
)
# case 3: can re-use end of current met and nothing else
else:
logger.debug("Reuse end of met")
met = MetDataset(met.data.isel(time=[-1]), copy=False)

if rad is None:
# idx is the first index at which self.rad.variables["time"].to_numpy() >= time_end
idx = np.searchsorted(self.rad.indexes["time"].to_numpy(), time_end)
sl = slice(max(0, idx - 1), idx + 1)
logger.debug("Select rad slice %s", sl)
rad = MetDataset(self.rad.data.isel(time=sl), copy=False)

elif time_end > rad.indexes["time"].to_numpy()[-1]:
current_times = rad.indexes["time"].to_numpy()
all_times = self.rad.indexes["time"].to_numpy()
# idx is the first index at which all_times >= time_end
idx = np.searchsorted(all_times, time_end)
sl = slice(max(0, idx - 1), idx + 1)

# case 1: cannot re-use end of current rad as start of new rad
if current_times[-1] != all_times[sl.start]:
logger.debug("Select rad slice %s", sl)
rad = MetDataset(self.rad.data.isel(time=sl), copy=False)
# case 2: can re-use end of current rad plus one step of new rad
elif sl.start < all_times.size - 1:
sl = slice(sl.start + 1, sl.stop)
logger.debug("Reuse end of rad and select rad slice %s", sl)
rad = MetDataset(
xr.concat((rad.data.isel(time=[-1]), self.rad.data.isel(time=sl)), dim="time"),
copy=False,
)
# case 3: can re-use end of current rad and nothing else
else:
logger.debug("Reuse end of rad")
rad = MetDataset(rad.data.isel(time=[-1]), copy=False)

return met, rad

def _attach_verbose_outputs_evolution(self, contrail_list: list[GeoVectorDataset]) -> None:
Expand Down
2 changes: 1 addition & 1 deletion pycontrails/models/issr.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class ISSR(Model):
>>> out1 = model.eval()
>>> issr1 = out1["issr"]
>>> issr1.proportion # Get proportion of values with ice supersaturation
0.114140
0.11414134603859523

>>> # Run with a lower threshold
>>> out2 = model.eval(rhi_threshold=0.95)
Expand Down
20 changes: 17 additions & 3 deletions tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@ def override_cache() -> DiskCacheStore:
def met_pcc_pl(met_ecmwf_pl_path: str, override_cache: DiskCacheStore) -> MetDataset:
"""Met data (pressure levels) for PCC algorithm testing.

The v2024.03.0 release of xarray changes the datatype of decoded variables
from float32 to float64. This breaks tests that expect variables as float32.
As a workaround, we manually convert variables to float32 after decoding.

Returns
-------
MetDataset
Expand All @@ -117,7 +121,9 @@ def met_pcc_pl(met_ecmwf_pl_path: str, override_cache: DiskCacheStore) -> MetDat
paths=met_ecmwf_pl_path,
cachestore=override_cache,
)
return era5.open_metdataset()
met = era5.open_metdataset()
met.data = met.data.astype("float32")
thabbott marked this conversation as resolved.
Show resolved Hide resolved
return met


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -200,11 +206,15 @@ def met_issr(met_ecmwf_pl_path: str) -> MetDataset:
the test `tests/unit/test_dtypes::test_issr_sac_grid_output[linear-ISSR]`
to hang. As a workaround, we open without dask parallelism.

The v2024.03.0 release of xarray changes the datatype of decoded variables
from float32 to float64. This breaks tests that expect variables as float32.
As a workaround, we manually convert variables to float32 after decoding.

Returns
-------
MetDataset
"""
ds = xr.open_dataset(met_ecmwf_pl_path)
ds = xr.open_dataset(met_ecmwf_pl_path).astype("float32")
ds = met.standardize_variables(ds, (met_var.AirTemperature, met_var.SpecificHumidity))
ds.attrs.update(provider="ECMWF", dataset="ERA5", product="reanalysis")
return MetDataset(ds)
Expand All @@ -217,9 +227,13 @@ def met_cocip1() -> MetDataset:
See ``met_cocip1_module_scope`` for a module-scoped version.

See tests/fixtures/cocip-met.py for domain and creation of the source data file.

The v2024.03.0 release of xarray changes the datatype of decoded variables
from float32 to float64. This breaks tests that expect variables as float32.
As a workaround, we manually convert variables to float32 after decoding.
"""
path = get_static_path("met-era5-cocip1.nc")
ds = xr.open_dataset(path)
ds = xr.open_dataset(path).astype("float32")
ds["air_pressure"] = ds["air_pressure"].astype("float32")
ds["altitude"] = ds["altitude"].astype("float32")
return MetDataset(ds, provider="ECMWF", dataset="ERA5", product="reanalysis")
Expand Down
Loading