Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable file locks in MetDataSource.open_dataset #213

Merged
merged 11 commits into from
Jul 5, 2024
14 changes: 14 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,19 @@
# Changelog

## v0.52.1

### Breaking changes

- Remove `lock=False` as a default keyword argument to `xr.open_mfdataset` in the `MetDataSource.open_dataset` method. This reverts a change from [v0.44.1](https://github.com/contrailcirrus/pycontrails/releases/tag/v0.44.1) and prevents segmentation faults when using recent versions of [netCDF4](https://pypi.org/project/netCDF4/) (1.7.0 and above).
- GOES imagery is now loaded from a temporary file on disk rather than directly from memory when using `GOES.get` without a cachestore.

### Internals

- Remove upper limits on netCDF4 and numpy versions.
- Remove h5netcdf dependency.
- Update doctests with numpy 2 scalar representation (see [NEP 51](https://numpy.org/neps/nep-0051-scalar-representation.html)). Doctests will now fail when run with numpy 1.
- Run certain tests in `test_ecmwf.py` and `test_met.py` using the single-threaded dask scheduler to prevent tests from hanging while waiting for a lock that is never released. (This issue was [encountered previously](https://github.com/contrailcirrus/pycontrails/pull/68), and removing `lock=False` in `MetDataSrouce.open_dataset` reverts the fix.)

## v0.52.0

### Breaking changes
Expand Down
9 changes: 8 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,13 @@ main-test-status:
DOCS_DIR = docs
DOCS_BUILD_DIR = docs/_build

# Check that numpy 2 is being used
ensure-numpy-2:
[[ $$(python -c 'import numpy as np; print(np.__version__)') == 2.* ]] \
|| ( \
echo -e "$(COLOR_YELLOW)NumPy 2 required for doctests$(END_COLOR)"; \
exit 1)

# Check for GCP credentials
ensure-gcp-credentials:
python -c 'from google.cloud import storage; storage.Client()' \
Expand Down Expand Up @@ -163,7 +170,7 @@ ensure-era5-cached:
cache-era5-gcp: ensure-era5-cached
gcloud storage cp -r -n .doc-test-cache/* gs://contrails-301217-unit-test/doc-test-cache/

doctest: ensure-era5-cached ensure-gcp-credentials
doctest: ensure-numpy-2 ensure-era5-cached ensure-gcp-credentials
pytest --doctest-modules \
--ignore-glob=pycontrails/ext/* \
pycontrails -vv
Expand Down
10 changes: 5 additions & 5 deletions pycontrails/core/coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,27 +59,27 @@ def slice_domain(
>>> # Call with request as np.array
>>> request = np.linspace(-20, 20, 100)
>>> slice_domain(domain, request)
slice(640, 801, None)
slice(np.int64(640), np.int64(801), None)

>>> # Call with request as tuple
>>> request = -20, 20
>>> slice_domain(domain, request)
slice(640, 801, None)
slice(np.int64(640), np.int64(801), None)

>>> # Call with a buffer
>>> request = -16, 13
>>> buffer = 4, 7
>>> slice_domain(domain, request, buffer)
slice(640, 801, None)
slice(np.int64(640), np.int64(801), None)

>>> # Call with request as a single number
>>> request = -20
>>> slice_domain(domain, request)
slice(640, 641, None)
slice(np.int64(640), np.int64(641), None)

>>> request = -19.9
>>> slice_domain(domain, request)
slice(640, 642, None)
slice(np.int64(640), np.int64(642), None)

"""
# if the length of domain coordinates is <= 2, return the whole domain
Expand Down
8 changes: 4 additions & 4 deletions pycontrails/core/flight.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ def max_distance_gap(self) -> float:
... time=pd.date_range('2021-01-01T12', '2021-01-01T14', periods=200),
... )
>>> fl.max_distance_gap
7391.27...
np.float64(7391.27...)
"""
if self.attrs["crs"] != "EPSG:4326":
raise NotImplementedError("Only implemented for EPSG:4326 CRS.")
Expand Down Expand Up @@ -415,7 +415,7 @@ def length(self) -> float:
... time=pd.date_range('2021-01-01T12', '2021-01-01T14', periods=200),
... )
>>> fl.length
1436924.67...
np.float64(1436924.67...)
"""
if self.attrs["crs"] != "EPSG:4326":
raise NotImplementedError("Only implemented for EPSG:4326 CRS.")
Expand Down Expand Up @@ -1520,11 +1520,11 @@ def length_met(self, key: str, threshold: float = 1.0) -> float:

>>> # Length (in meters) of waypoints whose temperature exceeds 236K
>>> fl.length_met("air_temperature", threshold=236)
4132178.159...
np.float64(4132178.159...)

>>> # Proportion (with respect to distance) of waypoints whose temperature exceeds 236K
>>> fl.proportion_met("air_temperature", threshold=236)
0.663552...
np.float64(0.663552...)
"""
if key not in self.data:
raise KeyError(f"Column {key} does not exist in data.")
Expand Down
4 changes: 2 additions & 2 deletions pycontrails/core/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,11 +621,11 @@ class EmissionsProfileInterpolator:
>>> epi = EmissionsProfileInterpolator(xp, fp)
>>> # Interpolate a single value
>>> epi.interp(5)
0.150000...
np.float64(0.150000...)

>>> # Interpolate a single value on a logarithmic scale
>>> epi.log_interp(5)
1.105171...
np.float64(1.105171...)

>>> # Demonstrate speed up compared with xarray.DataArray interpolation
>>> import time, xarray as xr
Expand Down
2 changes: 1 addition & 1 deletion pycontrails/core/polygon.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,7 @@ def multipolygon_to_geojson(
else:
shape = len(ring.coords), 3
coords = np.empty(shape)
coords[:, :2] = ring.coords
coords[:, :2] = np.asarray(ring.coords)
coords[:, 2] = altitude

poly_coords.append(coords.tolist())
Expand Down
2 changes: 1 addition & 1 deletion pycontrails/core/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -2057,7 +2057,7 @@ def vector_to_lon_lat_grid(
[2.97, 0.12, 1.33, ..., 3.54, 0.74, 2.59]])

>>> da.sum().item() == vector["foo"].sum()
True
np.True_

"""
df = vector.select(("longitude", "latitude", *agg), copy=False).dataframe
Expand Down
4 changes: 0 additions & 4 deletions pycontrails/datalib/_met_utils/metsource.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,6 @@
#: Whether to open multi-file datasets in parallel
OPEN_IN_PARALLEL: bool = False

#: Whether to use file locking when opening multi-file datasets
OPEN_WITH_LOCK: bool = False


def parse_timesteps(time: TimeInput | None, freq: str | None = "1h") -> list[datetime]:
"""Parse time input into set of time steps.
Expand Down Expand Up @@ -741,5 +738,4 @@ def open_dataset(
xr_kwargs.setdefault("engine", NETCDF_ENGINE)
xr_kwargs.setdefault("chunks", DEFAULT_CHUNKS)
xr_kwargs.setdefault("parallel", OPEN_IN_PARALLEL)
xr_kwargs.setdefault("lock", OPEN_WITH_LOCK)
return xr.open_mfdataset(disk_paths, **xr_kwargs)
13 changes: 10 additions & 3 deletions pycontrails/datalib/goes.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import datetime
import enum
import io
import tempfile
from collections.abc import Iterable

import numpy as np
Expand Down Expand Up @@ -535,7 +535,7 @@ def _get_without_cache(self, time: datetime.datetime) -> xr.DataArray:
da_dict = {}
for rpath, init_bytes in data.items():
channel = _extract_channel_from_rpath(rpath)
ds = xr.open_dataset(io.BytesIO(init_bytes), engine="h5netcdf")
ds = _load_via_tempfile(init_bytes)

da = ds["CMI"]
da = da.expand_dims(band_id=ds["band_id"].values)
Expand All @@ -551,7 +551,7 @@ def _get_without_cache(self, time: datetime.datetime) -> xr.DataArray:
da = xr.concat(da_dict.values(), dim="band_id")

else:
ds = xr.open_dataset(io.BytesIO(data), engine="h5netcdf")
ds = _load_via_tempfile(data)
da = ds["CMI"]
da = da.expand_dims(band_id=ds["band_id"].values)

Expand All @@ -564,6 +564,13 @@ def _get_without_cache(self, time: datetime.datetime) -> xr.DataArray:
return da


def _load_via_tempfile(data: bytes) -> xr.Dataset:
"""Load xarray dataset via temporary file."""
with tempfile.NamedTemporaryFile(buffering=0) as tmp:
tmp.write(data)
return xr.load_dataset(tmp.name)


def _concat_c02(ds1: XArrayType, ds2: XArrayType) -> XArrayType:
"""Concatenate two datasets with C01 and C02 data."""
# Average the C02 data to the C01 resolution
Expand Down
18 changes: 9 additions & 9 deletions pycontrails/models/cocip/cocip_uncertainty.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,12 @@ class CocipUncertaintyParams(CocipParams):
>>> distr = scipy.stats.uniform(loc=0.4, scale=0.2)
>>> params = CocipUncertaintyParams(seed=123, initial_wake_vortex_depth_uncertainty=distr)
>>> params.initial_wake_vortex_depth
0.41076420
np.float64(0.41076420...)

>>> # Once seeded, calling the class again gives a new value
>>> params = CocipUncertaintyParams(initial_wake_vortex_depth=distr)
>>> params.initial_wake_vortex_depth
0.43526372
np.float64(0.43526372...)

>>> # To retain the default value, set the uncertainty to None
>>> params = CocipUncertaintyParams(rf_lw_enhancement_factor_uncertainty=None)
Expand Down Expand Up @@ -212,7 +212,7 @@ def uncertainty_params(self) -> dict[str, rv_frozen]:

return out

def rvs(self, size: None | int = None) -> dict[str, float | npt.NDArray[np.float64]]:
def rvs(self, size: None | int = None) -> dict[str, np.float64 | npt.NDArray[np.float64]]:
"""Call each distribution's `rvs` method to generate random parameters.

Seed calls to `rvs` with class variable `rng`.
Expand Down Expand Up @@ -247,12 +247,12 @@ def rvs(self, size: None | int = None) -> dict[str, float | npt.NDArray[np.float
[7.9063961e-04, 3.0336906e-03, 7.7571563e-04, 2.0577813e-02,
9.4205803e-01, 4.3379897e-03, 3.6786550e-03, 2.4747452e-02]],
dtype=float32),
'initial_wake_vortex_depth': 0.39805019708566847,
'nvpm_ei_n_enhancement_factor': 0.9371878437312526,
'rf_lw_enhancement_factor': 1.1017491252832377,
'rf_sw_enhancement_factor': 0.99721639115012,
'sedimentation_impact_factor': 0.5071779847244678,
'wind_shear_enhancement_exponent': 0.34100931239701004}
'initial_wake_vortex_depth': np.float64(0.39805019708566847),
'nvpm_ei_n_enhancement_factor': np.float64(0.9371878437312526),
'rf_lw_enhancement_factor': np.float64(1.1017491252832377),
'rf_sw_enhancement_factor': np.float64(0.99721639115012),
'sedimentation_impact_factor': np.float64(0.5071779847244678),
'wind_shear_enhancement_exponent': np.float64(0.34100931239701004)}
"""
return {
param: distr.rvs(size=size, random_state=self.rng)
Expand Down
3 changes: 2 additions & 1 deletion pycontrails/models/cocip/output_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@

from pycontrails.core.met import MetDataArray, MetDataset
from pycontrails.core.vector import GeoVectorDataset, vector_to_lon_lat_grid
from pycontrails.datalib.goes import GOES, extract_goes_visualization
from pycontrails.models.cocip.contrail_properties import contrail_edges, plume_mass_per_distance
from pycontrails.models.cocip.radiative_forcing import albedo
from pycontrails.models.humidity_scaling import HumidityScaling
Expand Down Expand Up @@ -2125,6 +2124,8 @@ def compare_cocip_with_goes(
File path of saved CoCiP-GOES image if ``path_write_img`` is provided.
"""

from pycontrails.datalib.goes import GOES, extract_goes_visualization

try:
import cartopy.crs as ccrs
from cartopy.mpl.ticker import LatitudeFormatter, LongitudeFormatter
Expand Down
2 changes: 1 addition & 1 deletion pycontrails/physics/geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -886,7 +886,7 @@ def spatial_bounding_box(
>>> lon = rng.uniform(-180, 180, size=30)
>>> lat = rng.uniform(-90, 90, size=30)
>>> spatial_bounding_box(lon, lat)
(-168.0, -77.0, 155.0, 82.0)
(np.float64(-168.0), np.float64(-77.0), np.float64(155.0), np.float64(82.0))
"""
lon_min = max(np.floor(np.min(longitude) - buffer), -180.0)
lon_max = min(np.ceil(np.max(longitude) + buffer), 179.99)
Expand Down
5 changes: 2 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ readme = { file = "README.md", content-type = "text/markdown" }

dependencies = [
"dask>=2022.3",
"numpy>=1.22,<2.0.0",
"numpy>=1.22",
"overrides>=6.1",
"pandas>=2.2",
"scipy>=1.10",
Expand Down Expand Up @@ -99,7 +99,7 @@ ecmwf = [
"cfgrib>=0.9",
"eccodes>=1.4",
"ecmwf-api-client>=1.6",
"netcdf4>=1.6.1,<1.7.0",
"netcdf4>=1.6.1",
"platformdirs>=3.0",
"requests>=2.25",
"lxml>=5.1.0",
Expand Down Expand Up @@ -128,7 +128,6 @@ sat = [
"geojson>=3.1",
"google-cloud-bigquery>=3.23",
"google-cloud-bigquery-storage>=2.25",
"h5netcdf>=1.2",
"pillow>=10.3",
"pyproj>=3.5",
"rasterio>=1.3",
Expand Down
13 changes: 13 additions & 0 deletions tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from datetime import datetime
from typing import Any

import dask
import numpy as np
import pandas as pd
import pytest
Expand Down Expand Up @@ -442,3 +443,15 @@ def polygon_bug() -> MetDataArray:
path = get_static_path("polygon-bug.nc")
da = xr.open_dataarray(path)
return MetDataArray(da)


@pytest.fixture()
def _dask_single_threaded():
"""Run test using single-threaded dask scheduler.

As of v0.52.1, using the default multi-threaded scheduler can cause
some tests to hang while waiting to acquire a lock that is never released.
This fixture can be used to run those tests using a single-threaded scheduler.
"""
with dask.config.set(scheduler="single-threaded"):
yield
10 changes: 8 additions & 2 deletions tests/unit/test_ecmwf.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ def test_ERA5_repr() -> None:
assert "Dataset:" in out


@pytest.mark.usefixtures("_dask_single_threaded")
def test_ERA5_cachestore(met_ecmwf_pl_path: str, override_cache: DiskCacheStore) -> None:
"""Test ERA5 cachestore input."""

Expand Down Expand Up @@ -342,6 +343,7 @@ def test_ERA5_cachestore(met_ecmwf_pl_path: str, override_cache: DiskCacheStore)
assert pre_init_size == post_init_size


@pytest.mark.usefixtures("_dask_single_threaded")
def test_ERA5_pressure_levels(met_ecmwf_pl_path: str, override_cache: DiskCacheStore) -> None:
"""Test ERA5 pressure_level parsing."""
times = [datetime(2019, 5, 31, 5, 0, 0)]
Expand Down Expand Up @@ -432,6 +434,7 @@ def test_ERA5_hash() -> None:
assert era5.hash != era52.hash


@pytest.mark.usefixtures("_dask_single_threaded")
def test_ERA5_paths_with_time(
met_ecmwf_pl_path: str,
met_ecmwf_sl_path: str,
Expand Down Expand Up @@ -465,6 +468,7 @@ def test_ERA5_paths_with_time(
assert override_cache.size > 0


@pytest.mark.usefixtures("_dask_single_threaded")
def test_ERA5_paths_without_time(
met_ecmwf_pl_path: str,
override_cache: DiskCacheStore,
Expand Down Expand Up @@ -498,6 +502,7 @@ def test_ERA5_paths_without_time(
assert pre_init_size == post_init_size


@pytest.mark.usefixtures("_dask_single_threaded")
def test_ERA5_paths_with_error(
met_ecmwf_pl_path: str,
override_cache: DiskCacheStore,
Expand Down Expand Up @@ -614,6 +619,7 @@ def test_ERA5_set_met_source_metadata(product_type: str, variables: str) -> None
assert ds.attrs["product"] == product_type.split("_")[0]


@pytest.mark.usefixtures("_dask_single_threaded")
def test_ERA5_met_source_open_metdataset(met_ecmwf_pl_path: str) -> None:
"""Test the met_source attribute on the MetDataset arising from ERA5."""
era5 = ERA5(
Expand Down Expand Up @@ -807,7 +813,7 @@ def test_model_level_era5_ensemble_mars_request() -> None:
}


def test_model_level_era5_set_metadata(met_ecmwf_pl_path: str) -> None:
def test_model_level_era5_set_metadata() -> None:
"""Test metadata setting."""
era5 = ERA5ModelLevel(time=(datetime(2000, 1, 1), datetime(2000, 1, 2)), variables=["t", "q"])
ds = xr.Dataset()
Expand Down Expand Up @@ -1134,7 +1140,7 @@ def test_model_level_hres_mars_request() -> None:
)


def test_model_level_hres_set_metadata(met_ecmwf_pl_path: str) -> None:
def test_model_level_hres_set_metadata() -> None:
"""Test metadata setting."""
hres = HRESModelLevel(time=datetime(2000, 1, 1), variables=["t", "q"])
ds = xr.Dataset()
Expand Down
1 change: 1 addition & 0 deletions tests/unit/test_met.py
Original file line number Diff line number Diff line change
Expand Up @@ -1013,6 +1013,7 @@ def test_met_copy(met_ecmwf_pl_path: str) -> None:
assert round(mds2["t"].data[5][5][2][0].item()) == 235


@pytest.mark.usefixtures("_dask_single_threaded")
def test_met_wrap_longitude_chunks(met_ecmwf_pl_path: str, override_cache: DiskCacheStore) -> None:
"""Check that the wrap_longitude method increments longitudinal chunks."""
# Using the xr.open_dataset method
Expand Down