Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable file locks in MetDataSource.open_dataset #213

Merged
merged 11 commits into from
Jul 5, 2024
13 changes: 13 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,18 @@
# Changelog

## v0.52.1

### Breaking changes

- Remove `lock=False` as a default keyword argument to `xr.open_mfdataset` in the `MetDataSource.open_dataset` method. This reverts a change from [v0.44.1](https://github.com/contrailcirrus/pycontrails/releases/tag/v0.44.1) and prevents segmentation faults when using recent versions of [netCDF4](https://pypi.org/project/netCDF4/) (1.7.0 and above).
- GOES imagery is now loaded from a temporary file on disk rather than directly from memory when using `GOES.get` without a cachestore.

### Internals

- Remove upper limits on netCDF4 and numpy versions.
- Remove h5netcdf dependency.
- Update doctests with numpy 2 scalar representation (see [NEP 51](https://numpy.org/neps/nep-0051-scalar-representation.html)).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So will the doctests now fail with numpy<2?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I added a check in the Makefile (d39d70e) to make this requirement explicit.

Trying to keep doctests compatible with numpy 1 and 2 seems tricky.... You can get the same output for scalars by using print, but that doesn't work for scalars wrapped in some other object:

>>> np.float64(1.0)
1.0              # numpy 1
np.float64(1.0)  # numpy 2
>>> print(np.float64(1.0)
1.0  # numpy 1
1.0  # numpy 2
>>> print((np.float64(1.0), np.float64(2.0))
(1.0, 2.0)                          # numpy 1
(np.float64(1.0), np.float64(2.0))  # numpy 2

Happy to discuss this more if you don't think this seems like the right solution!


## v0.52.0

### Breaking changes
Expand Down
10 changes: 5 additions & 5 deletions pycontrails/core/coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,27 +59,27 @@ def slice_domain(
>>> # Call with request as np.array
>>> request = np.linspace(-20, 20, 100)
>>> slice_domain(domain, request)
slice(640, 801, None)
slice(np.int64(640), np.int64(801), None)

>>> # Call with request as tuple
>>> request = -20, 20
>>> slice_domain(domain, request)
slice(640, 801, None)
slice(np.int64(640), np.int64(801), None)

>>> # Call with a buffer
>>> request = -16, 13
>>> buffer = 4, 7
>>> slice_domain(domain, request, buffer)
slice(640, 801, None)
slice(np.int64(640), np.int64(801), None)

>>> # Call with request as a single number
>>> request = -20
>>> slice_domain(domain, request)
slice(640, 641, None)
slice(np.int64(640), np.int64(641), None)

>>> request = -19.9
>>> slice_domain(domain, request)
slice(640, 642, None)
slice(np.int64(640), np.int64(642), None)

"""
# if the length of domain coordinates is <= 2, return the whole domain
Expand Down
8 changes: 4 additions & 4 deletions pycontrails/core/flight.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ def max_distance_gap(self) -> float:
... time=pd.date_range('2021-01-01T12', '2021-01-01T14', periods=200),
... )
>>> fl.max_distance_gap
7391.27...
np.float64(7391.27...)
"""
if self.attrs["crs"] != "EPSG:4326":
raise NotImplementedError("Only implemented for EPSG:4326 CRS.")
Expand Down Expand Up @@ -415,7 +415,7 @@ def length(self) -> float:
... time=pd.date_range('2021-01-01T12', '2021-01-01T14', periods=200),
... )
>>> fl.length
1436924.67...
np.float64(1436924.67...)
"""
if self.attrs["crs"] != "EPSG:4326":
raise NotImplementedError("Only implemented for EPSG:4326 CRS.")
Expand Down Expand Up @@ -1520,11 +1520,11 @@ def length_met(self, key: str, threshold: float = 1.0) -> float:

>>> # Length (in meters) of waypoints whose temperature exceeds 236K
>>> fl.length_met("air_temperature", threshold=236)
4132178.159...
np.float64(4132178.159...)

>>> # Proportion (with respect to distance) of waypoints whose temperature exceeds 236K
>>> fl.proportion_met("air_temperature", threshold=236)
0.663552...
np.float64(0.663552...)
"""
if key not in self.data:
raise KeyError(f"Column {key} does not exist in data.")
Expand Down
4 changes: 2 additions & 2 deletions pycontrails/core/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,11 +621,11 @@ class EmissionsProfileInterpolator:
>>> epi = EmissionsProfileInterpolator(xp, fp)
>>> # Interpolate a single value
>>> epi.interp(5)
0.150000...
np.float64(0.150000...)

>>> # Interpolate a single value on a logarithmic scale
>>> epi.log_interp(5)
1.105171...
np.float64(1.105171...)

>>> # Demonstrate speed up compared with xarray.DataArray interpolation
>>> import time, xarray as xr
Expand Down
2 changes: 1 addition & 1 deletion pycontrails/core/polygon.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,7 @@ def multipolygon_to_geojson(
else:
shape = len(ring.coords), 3
coords = np.empty(shape)
coords[:, :2] = ring.coords
coords[:, :2] = np.asarray(ring.coords)
coords[:, 2] = altitude

poly_coords.append(coords.tolist())
Expand Down
2 changes: 1 addition & 1 deletion pycontrails/core/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -2057,7 +2057,7 @@ def vector_to_lon_lat_grid(
[2.97, 0.12, 1.33, ..., 3.54, 0.74, 2.59]])

>>> da.sum().item() == vector["foo"].sum()
True
np.True_

"""
df = vector.select(("longitude", "latitude", *agg), copy=False).dataframe
Expand Down
4 changes: 0 additions & 4 deletions pycontrails/datalib/_met_utils/metsource.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,6 @@
#: Whether to open multi-file datasets in parallel
OPEN_IN_PARALLEL: bool = False

#: Whether to use file locking when opening multi-file datasets
OPEN_WITH_LOCK: bool = False


def parse_timesteps(time: TimeInput | None, freq: str | None = "1h") -> list[datetime]:
"""Parse time input into set of time steps.
Expand Down Expand Up @@ -741,5 +738,4 @@ def open_dataset(
xr_kwargs.setdefault("engine", NETCDF_ENGINE)
xr_kwargs.setdefault("chunks", DEFAULT_CHUNKS)
xr_kwargs.setdefault("parallel", OPEN_IN_PARALLEL)
xr_kwargs.setdefault("lock", OPEN_WITH_LOCK)
return xr.open_mfdataset(disk_paths, **xr_kwargs)
14 changes: 11 additions & 3 deletions pycontrails/datalib/goes.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import datetime
import enum
import io
import tempfile
from collections.abc import Iterable

import numpy as np
Expand Down Expand Up @@ -535,7 +535,7 @@ def _get_without_cache(self, time: datetime.datetime) -> xr.DataArray:
da_dict = {}
for rpath, init_bytes in data.items():
channel = _extract_channel_from_rpath(rpath)
ds = xr.open_dataset(io.BytesIO(init_bytes), engine="h5netcdf")
ds = _load_via_tempfile(init_bytes)

da = ds["CMI"]
da = da.expand_dims(band_id=ds["band_id"].values)
Expand All @@ -551,7 +551,7 @@ def _get_without_cache(self, time: datetime.datetime) -> xr.DataArray:
da = xr.concat(da_dict.values(), dim="band_id")

else:
ds = xr.open_dataset(io.BytesIO(data), engine="h5netcdf")
ds = _load_via_tempfile(data)
da = ds["CMI"]
da = da.expand_dims(band_id=ds["band_id"].values)

Expand All @@ -564,6 +564,14 @@ def _get_without_cache(self, time: datetime.datetime) -> xr.DataArray:
return da


def _load_via_tempfile(data: bytes) -> xr.Dataset:
"""Load xarray dataset via temporary file."""
with tempfile.NamedTemporaryFile(buffering=0) as tmp:
tmp.write(data)
ds = xr.load_dataset(tmp.name)
return ds
thabbott marked this conversation as resolved.
Show resolved Hide resolved


def _concat_c02(ds1: XArrayType, ds2: XArrayType) -> XArrayType:
"""Concatenate two datasets with C01 and C02 data."""
# Average the C02 data to the C01 resolution
Expand Down
18 changes: 9 additions & 9 deletions pycontrails/models/cocip/cocip_uncertainty.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,12 @@ class CocipUncertaintyParams(CocipParams):
>>> distr = scipy.stats.uniform(loc=0.4, scale=0.2)
>>> params = CocipUncertaintyParams(seed=123, initial_wake_vortex_depth_uncertainty=distr)
>>> params.initial_wake_vortex_depth
0.41076420
np.float64(0.41076420...)

>>> # Once seeded, calling the class again gives a new value
>>> params = CocipUncertaintyParams(initial_wake_vortex_depth=distr)
>>> params.initial_wake_vortex_depth
0.43526372
np.float64(0.43526372...)

>>> # To retain the default value, set the uncertainty to None
>>> params = CocipUncertaintyParams(rf_lw_enhancement_factor_uncertainty=None)
Expand Down Expand Up @@ -212,7 +212,7 @@ def uncertainty_params(self) -> dict[str, rv_frozen]:

return out

def rvs(self, size: None | int = None) -> dict[str, float | npt.NDArray[np.float64]]:
def rvs(self, size: None | int = None) -> dict[str, np.float64 | npt.NDArray[np.float64]]:
"""Call each distribution's `rvs` method to generate random parameters.

Seed calls to `rvs` with class variable `rng`.
Expand Down Expand Up @@ -247,12 +247,12 @@ def rvs(self, size: None | int = None) -> dict[str, float | npt.NDArray[np.float
[7.9063961e-04, 3.0336906e-03, 7.7571563e-04, 2.0577813e-02,
9.4205803e-01, 4.3379897e-03, 3.6786550e-03, 2.4747452e-02]],
dtype=float32),
'initial_wake_vortex_depth': 0.39805019708566847,
'nvpm_ei_n_enhancement_factor': 0.9371878437312526,
'rf_lw_enhancement_factor': 1.1017491252832377,
'rf_sw_enhancement_factor': 0.99721639115012,
'sedimentation_impact_factor': 0.5071779847244678,
'wind_shear_enhancement_exponent': 0.34100931239701004}
'initial_wake_vortex_depth': np.float64(0.39805019708566847),
'nvpm_ei_n_enhancement_factor': np.float64(0.9371878437312526),
'rf_lw_enhancement_factor': np.float64(1.1017491252832377),
'rf_sw_enhancement_factor': np.float64(0.99721639115012),
'sedimentation_impact_factor': np.float64(0.5071779847244678),
'wind_shear_enhancement_exponent': np.float64(0.34100931239701004)}
"""
return {
param: distr.rvs(size=size, random_state=self.rng)
Expand Down
2 changes: 1 addition & 1 deletion pycontrails/physics/geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -886,7 +886,7 @@ def spatial_bounding_box(
>>> lon = rng.uniform(-180, 180, size=30)
>>> lat = rng.uniform(-90, 90, size=30)
>>> spatial_bounding_box(lon, lat)
(-168.0, -77.0, 155.0, 82.0)
(np.float64(-168.0), np.float64(-77.0), np.float64(155.0), np.float64(82.0))
"""
lon_min = max(np.floor(np.min(longitude) - buffer), -180.0)
lon_max = min(np.ceil(np.max(longitude) + buffer), 179.99)
Expand Down
5 changes: 2 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ readme = { file = "README.md", content-type = "text/markdown" }

dependencies = [
"dask>=2022.3",
"numpy>=1.22,<2.0.0",
"numpy>=1.22",
"overrides>=6.1",
"pandas>=2.2",
"scipy>=1.10",
Expand Down Expand Up @@ -99,7 +99,7 @@ ecmwf = [
"cfgrib>=0.9",
"eccodes>=1.4",
"ecmwf-api-client>=1.6",
"netcdf4>=1.6.1,<1.7.0",
"netcdf4>=1.6.1",
"platformdirs>=3.0",
"requests>=2.25",
"lxml>=5.1.0",
Expand Down Expand Up @@ -128,7 +128,6 @@ sat = [
"geojson>=3.1",
"google-cloud-bigquery>=3.23",
"google-cloud-bigquery-storage>=2.25",
"h5netcdf>=1.2",
"pillow>=10.3",
"pyproj>=3.5",
"rasterio>=1.3",
Expand Down