Skip to content

Commit

Permalink
issue consistent ModuleNotFoundErrors for missing dependencies
Browse files Browse the repository at this point in the history
  • Loading branch information
zebengberg committed Sep 22, 2023
1 parent 023a172 commit 6d95528
Show file tree
Hide file tree
Showing 15 changed files with 231 additions and 95 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@

- Use the experimental version number parameter `E` in `pycontrails.ecmwf.hres.get_forecast_filename`. Update the logic involved in setting the dissemination data stream indicator `S`.

### Internals

- Provide consistent `ModuleNotFoundError` messages when optional dependencies are not installed.
- Move the `synthetic_flight` module into the `pycontrails.ext` namespace.

## v0.47.1

### Fixes
Expand Down
54 changes: 35 additions & 19 deletions pycontrails/core/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import functools
import logging
import os
import pathlib
Expand All @@ -14,20 +15,23 @@

from overrides import overrides

from pycontrails.utils import dependencies

# optional imports
if TYPE_CHECKING:
import google


@functools.cache
def _get_user_cache_dir() -> str:
try:
import platformdirs
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
"Using the pycontrails CacheStore requires the 'platformdirs' package. "
"This can be installed with 'pip install pycontrails[ecmwf]' or "
"'pip install platformdirs'."
) from e
dependencies.raise_module_not_found_error(
name="cache module",
package_name="platformdirs",
module_not_found_error=e,
)
return platformdirs.user_cache_dir("pycontrails")


Expand Down Expand Up @@ -468,10 +472,11 @@ def __init__(
try:
from google.cloud import storage
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
"GCPCache requires the `google-cloud-storage` module, which can be installed "
"using `pip install pycontrails[gcp]`"
) from e
dependencies.raise_module_not_found_error(
name="GCPCacheStore class",
package_name="google-cloud-storage",
module_not_found_error=e,
)

if "https://" in cache_dir:
raise ValueError(
Expand Down Expand Up @@ -830,11 +835,12 @@ def _upload_with_progress(blob: Any, disk_path: str, timeout: int, chunk_size: i
try:
from tqdm.auto import tqdm
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
"Method `put` requires the `tqdm` module, which can be installed using "
"`pip install pycontrails[gcp]`. "
"Alternatively, set instance attribute `show_progress=False`."
) from e
dependencies.raise_module_not_found_error(
name="_upload_with_progress function",
package_name="tqdm",
module_not_found_error=e,
pycontrails_optional_package="gcp",
)

# minimal possible chunk_size to allow nice progress bar
blob.chunk_size = chunk_size
Expand All @@ -852,13 +858,23 @@ def _download_with_progress(

try:
from google.resumable_media.requests import ChunkedDownload
except ModuleNotFoundError as e:
dependencies.raise_module_not_found_error(
name="_download_with_progress function",
package_name="google-cloud-storage",
module_not_found_error=e,
pycontrails_optional_package="gcp",
)

try:
from tqdm.auto import tqdm
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
"Method `get` requires the `tqdm` and `google-cloud-storage` modules, "
"which can be installed using `pip install pycontrails[gcp]`. "
"Alternatively, set instance attribute `show_progress=False`."
) from e
dependencies.raise_module_not_found_error(
name="_download_with_progress function",
package_name="tqdm",
module_not_found_error=e,
pycontrails_optional_package="gcp",
)

blob = gcp_cache._bucket.get_blob(gcp_path)
url = blob._get_download_url(gcp_cache._client)
Expand Down
42 changes: 25 additions & 17 deletions pycontrails/core/flight.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from pycontrails.core.fuel import Fuel, JetA
from pycontrails.core.vector import AttrDict, GeoVectorDataset, VectorDataDict, VectorDataset
from pycontrails.physics import constants, geo, units
from pycontrails.utils import dependencies

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -533,10 +534,12 @@ def segment_azimuth(self) -> npt.NDArray[np.float_]:
try:
import pyproj
except ModuleNotFoundError as exc:
raise ModuleNotFoundError(
"The 'segment_azimuth' method requires the 'pyproj' package. "
"Install with 'pip install pyproj'."
) from exc
dependencies.raise_module_not_found_error(
name="Flight.segment_azimuth method",
package_name="pyproj",
module_not_found_error=exc,
pycontrails_optional_package="pyproj",
)

geod = pyproj.Geod(a=constants.radius_earth)
az, *_ = geod.inv(lons1, lats1, lons2, lats2)
Expand Down Expand Up @@ -1001,10 +1004,12 @@ def _geodesic_interpolation(self, geodesic_threshold: float) -> pd.DataFrame | N
try:
import pyproj
except ModuleNotFoundError as exc:
raise ModuleNotFoundError(
"The '_geodesic_interpolation' method requires the 'pyproj' package. "
"Install with 'pip install pyproj'."
) from exc
dependencies.raise_module_not_found_error(
name="Flight._geodesic_interpolation method",
package_name="pyproj",
module_not_found_error=exc,
pycontrails_optional_package="pyproj",
)

geod = pyproj.Geod(ellps="WGS84")
longitudes: list[float] = []
Expand Down Expand Up @@ -1138,11 +1143,11 @@ def to_traffic(self) -> "traffic.core.Flight":
try:
import traffic.core
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
"This requires the 'traffic' module, which can be installed using "
"'pip install traffic'. See the installation documentation at "
"https://traffic-viz.github.io/installation.html for more information."
) from e
dependencies.raise_module_not_found_error(
name="Flight.to_traffic method",
package_name="traffic",
module_not_found_error=e,
)

return traffic.core.Flight(
self.to_dataframe(copy=True).rename(columns={"time": "timestamp"})
Expand Down Expand Up @@ -1767,11 +1772,14 @@ def fit_altitude(
"""
try:
import pwlf
except ModuleNotFoundError:
raise ModuleNotFoundError(
"The 'fit_altitude' function requires the 'pwlf' package."
"This can be installed with 'pip install pwlf'."
except ModuleNotFoundError as e:
dependencies.raise_module_not_found_error(
name="fit_altitude function",
package_name="pwlf",
module_not_found_error=e,
pycontrails_optional_package="pwlf",
)

for i in range(1, max_segments):
m2 = pwlf.PiecewiseLinFit(elapsed_time, altitude_ft)
r = m2.fitfast(i, pop)
Expand Down
21 changes: 13 additions & 8 deletions pycontrails/core/met.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from pycontrails.core.cache import CacheStore, DiskCacheStore
from pycontrails.core.met_var import AirPressure, Altitude, MetVariable
from pycontrails.physics import units
from pycontrails.utils import dependencies
from pycontrails.utils import temp as temp_module

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -1942,18 +1943,22 @@ def to_polyhedra(
try:
from skimage import measure
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
"This method requires the `skimage` module from scikit-learn, which can be "
"installed using `pip install pycontrails[vis]`"
) from e
dependencies.raise_module_not_found_error(
name="MetDataArray.to_polyhedra method",
package_name="scikit-image",
pycontrails_optional_package="vis",
module_not_found_error=e,
)

try:
import open3d as o3d
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
"This method requires the `open3d` module, which can be installed "
"with `pip install pycontrails[open3d]` or `pip install open3d`."
) from e
dependencies.raise_module_not_found_error(
name="MetDataArray.to_polyhedra method",
package_name="open3d",
pycontrails_optional_package="open3d",
module_not_found_error=e,
)

if len(self.data["level"]) == 1:
raise ValueError(
Expand Down
21 changes: 17 additions & 4 deletions pycontrails/core/polygon.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,29 @@
import numpy as np
import numpy.typing as npt

from pycontrails.utils import dependencies

try:
import cv2
except ModuleNotFoundError as exc:
dependencies.raise_module_not_found_error(
name="polygon module",
package_name="opencv-python",
module_not_found_error=exc,
pycontrails_optional_package="vis",
)

try:
import shapely
import shapely.geometry
import shapely.validation
except ModuleNotFoundError as exc:
raise ModuleNotFoundError(
"This module requires the 'opencv-python' and 'shapely' packages. "
"These can be installed with 'pip install pycontrails[vis]'."
) from exc
dependencies.raise_module_not_found_error(
name="polygon module",
package_name="shapely",
module_not_found_error=exc,
pycontrails_optional_package="vis",
)


def buffer_and_clean(
Expand Down
10 changes: 7 additions & 3 deletions pycontrails/core/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from pycontrails.core import coordinates, interpolation
from pycontrails.core import met as met_module
from pycontrails.physics import units
from pycontrails.utils import dependencies
from pycontrails.utils import json as json_module

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -1386,9 +1387,12 @@ def transform_crs(
try:
import pyproj
except ModuleNotFoundError as exc:
raise ModuleNotFoundError(
"Transforming CRS requires the 'pyproj' module. Install with 'pip install pyproj'."
) from exc
dependencies.raise_module_not_found_error(
name="GeoVectorDataset.transform_crs method",
package_name="pyproj",
module_not_found_error=exc,
pycontrails_optional_package="pyproj",
)

transformer = pyproj.Transformer.from_crs(self.attrs["crs"], crs, always_xy=True)
lon, lat = transformer.transform(self["longitude"], self["latitude"])
Expand Down
11 changes: 7 additions & 4 deletions pycontrails/datalib/ecmwf/era5.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
TopNetSolarRadiation,
TopNetThermalRadiation,
)
from pycontrails.utils import dependencies
from pycontrails.utils.temp import temp_file

if TYPE_CHECKING:
Expand Down Expand Up @@ -497,10 +498,12 @@ def _download_file(self, times: list[datetime]) -> None:
try:
import cdsapi
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
"Some `ecmwf` module dependencies are missing. "
"Please install all required dependencies using `pip install -e .[ecmwf]`"
) from e
dependencies.raise_module_not_found_error(
name="ERA5._download_file method",
package_name="cdsapi",
module_not_found_error=e,
pycontrails_optional_package="ecmwf",
)

try:
self.cds = cdsapi.Client(url=self.url, key=self.key)
Expand Down
14 changes: 8 additions & 6 deletions pycontrails/datalib/ecmwf/hres.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
TopNetSolarRadiation,
TopNetThermalRadiation,
)
from pycontrails.utils.iteration import chunk_list
from pycontrails.utils import dependencies, iteration
from pycontrails.utils.temp import temp_file
from pycontrails.utils.types import DatetimeLike

Expand Down Expand Up @@ -271,10 +271,12 @@ def __init__(
try:
from ecmwfapi import ECMWFService
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
"Some `ecmwf` module dependencies are missing. "
"Please install all required dependencies using `pip install -e .[ecmwf]`"
) from e
dependencies.raise_module_not_found_error(
name="HRES class",
package_name="ecmwf-api-client",
module_not_found_error=e,
pycontrails_optional_package="ecmwf",
)

# constants
# ERA5 now delays creating the server attribute until it is needed to download
Expand Down Expand Up @@ -596,7 +598,7 @@ def download_dataset(self, times: list[datetime]) -> None:

# download in sets of 24
if len(steps) > 24:
for _steps in chunk_list(steps, 24):
for _steps in iteration.chunk_list(steps, 24):
self._download_file(_steps)
elif len(steps) > 0:
self._download_file(steps)
Expand Down
Loading

0 comments on commit 6d95528

Please sign in to comment.