From 6d9552850976bc253a54847ca11813e3cd78c712 Mon Sep 17 00:00:00 2001 From: Zeb Engberg Date: Fri, 22 Sep 2023 08:25:32 -0600 Subject: [PATCH] issue consistent ModuleNotFoundErrors for missing dependencies --- CHANGELOG.md | 5 ++ pycontrails/core/cache.py | 54 ++++++++++------ pycontrails/core/flight.py | 42 ++++++++----- pycontrails/core/met.py | 21 ++++--- pycontrails/core/polygon.py | 21 +++++-- pycontrails/core/vector.py | 10 ++- pycontrails/datalib/ecmwf/era5.py | 11 ++-- pycontrails/datalib/ecmwf/hres.py | 14 +++-- pycontrails/datalib/gfs/gfs.py | 31 +++++++--- pycontrails/ext/bada.py | 4 +- pycontrails/ext/cirium.py | 4 +- pycontrails/models/accf.py | 20 +++--- pycontrails/models/cocip/output_formats.py | 12 ++-- .../models/cocipgrid/cocip_time_handling.py | 15 ++--- pycontrails/utils/dependencies.py | 62 +++++++++++++++++++ 15 files changed, 231 insertions(+), 95 deletions(-) create mode 100644 pycontrails/utils/dependencies.py diff --git a/CHANGELOG.md b/CHANGELOG.md index a988233ba..883761627 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,11 @@ - Use the experimental version number parameter `E` in `pycontrails.ecmwf.hres.get_forecast_filename`. Update the logic involved in setting the dissemination data stream indicator `S`. +### Internals + +- Provide consistent `ModuleNotFoundError` messages when optional dependencies are not installed. +- Move the `synthetic_flight` module into the `pycontrails.ext` namespace. + ## v0.47.1 ### Fixes diff --git a/pycontrails/core/cache.py b/pycontrails/core/cache.py index 7c9237ee5..44c354bda 100644 --- a/pycontrails/core/cache.py +++ b/pycontrails/core/cache.py @@ -2,6 +2,7 @@ from __future__ import annotations +import functools import logging import os import pathlib @@ -14,20 +15,23 @@ from overrides import overrides +from pycontrails.utils import dependencies + # optional imports if TYPE_CHECKING: import google +@functools.cache def _get_user_cache_dir() -> str: try: import platformdirs except ModuleNotFoundError as e: - raise ModuleNotFoundError( - "Using the pycontrails CacheStore requires the 'platformdirs' package. " - "This can be installed with 'pip install pycontrails[ecmwf]' or " - "'pip install platformdirs'." - ) from e + dependencies.raise_module_not_found_error( + name="cache module", + package_name="platformdirs", + module_not_found_error=e, + ) return platformdirs.user_cache_dir("pycontrails") @@ -468,10 +472,11 @@ def __init__( try: from google.cloud import storage except ModuleNotFoundError as e: - raise ModuleNotFoundError( - "GCPCache requires the `google-cloud-storage` module, which can be installed " - "using `pip install pycontrails[gcp]`" - ) from e + dependencies.raise_module_not_found_error( + name="GCPCacheStore class", + package_name="google-cloud-storage", + module_not_found_error=e, + ) if "https://" in cache_dir: raise ValueError( @@ -830,11 +835,12 @@ def _upload_with_progress(blob: Any, disk_path: str, timeout: int, chunk_size: i try: from tqdm.auto import tqdm except ModuleNotFoundError as e: - raise ModuleNotFoundError( - "Method `put` requires the `tqdm` module, which can be installed using " - "`pip install pycontrails[gcp]`. " - "Alternatively, set instance attribute `show_progress=False`." - ) from e + dependencies.raise_module_not_found_error( + name="_upload_with_progress function", + package_name="tqdm", + module_not_found_error=e, + pycontrails_optional_package="gcp", + ) # minimal possible chunk_size to allow nice progress bar blob.chunk_size = chunk_size @@ -852,13 +858,23 @@ def _download_with_progress( try: from google.resumable_media.requests import ChunkedDownload + except ModuleNotFoundError as e: + dependencies.raise_module_not_found_error( + name="_download_with_progress function", + package_name="google-cloud-storage", + module_not_found_error=e, + pycontrails_optional_package="gcp", + ) + + try: from tqdm.auto import tqdm except ModuleNotFoundError as e: - raise ModuleNotFoundError( - "Method `get` requires the `tqdm` and `google-cloud-storage` modules, " - "which can be installed using `pip install pycontrails[gcp]`. " - "Alternatively, set instance attribute `show_progress=False`." - ) from e + dependencies.raise_module_not_found_error( + name="_download_with_progress function", + package_name="tqdm", + module_not_found_error=e, + pycontrails_optional_package="gcp", + ) blob = gcp_cache._bucket.get_blob(gcp_path) url = blob._get_download_url(gcp_cache._client) diff --git a/pycontrails/core/flight.py b/pycontrails/core/flight.py index a78389587..126344aaf 100644 --- a/pycontrails/core/flight.py +++ b/pycontrails/core/flight.py @@ -16,6 +16,7 @@ from pycontrails.core.fuel import Fuel, JetA from pycontrails.core.vector import AttrDict, GeoVectorDataset, VectorDataDict, VectorDataset from pycontrails.physics import constants, geo, units +from pycontrails.utils import dependencies logger = logging.getLogger(__name__) @@ -533,10 +534,12 @@ def segment_azimuth(self) -> npt.NDArray[np.float_]: try: import pyproj except ModuleNotFoundError as exc: - raise ModuleNotFoundError( - "The 'segment_azimuth' method requires the 'pyproj' package. " - "Install with 'pip install pyproj'." - ) from exc + dependencies.raise_module_not_found_error( + name="Flight.segment_azimuth method", + package_name="pyproj", + module_not_found_error=exc, + pycontrails_optional_package="pyproj", + ) geod = pyproj.Geod(a=constants.radius_earth) az, *_ = geod.inv(lons1, lats1, lons2, lats2) @@ -1001,10 +1004,12 @@ def _geodesic_interpolation(self, geodesic_threshold: float) -> pd.DataFrame | N try: import pyproj except ModuleNotFoundError as exc: - raise ModuleNotFoundError( - "The '_geodesic_interpolation' method requires the 'pyproj' package. " - "Install with 'pip install pyproj'." - ) from exc + dependencies.raise_module_not_found_error( + name="Flight._geodesic_interpolation method", + package_name="pyproj", + module_not_found_error=exc, + pycontrails_optional_package="pyproj", + ) geod = pyproj.Geod(ellps="WGS84") longitudes: list[float] = [] @@ -1138,11 +1143,11 @@ def to_traffic(self) -> "traffic.core.Flight": try: import traffic.core except ModuleNotFoundError as e: - raise ModuleNotFoundError( - "This requires the 'traffic' module, which can be installed using " - "'pip install traffic'. See the installation documentation at " - "https://traffic-viz.github.io/installation.html for more information." - ) from e + dependencies.raise_module_not_found_error( + name="Flight.to_traffic method", + package_name="traffic", + module_not_found_error=e, + ) return traffic.core.Flight( self.to_dataframe(copy=True).rename(columns={"time": "timestamp"}) @@ -1767,11 +1772,14 @@ def fit_altitude( """ try: import pwlf - except ModuleNotFoundError: - raise ModuleNotFoundError( - "The 'fit_altitude' function requires the 'pwlf' package." - "This can be installed with 'pip install pwlf'." + except ModuleNotFoundError as e: + dependencies.raise_module_not_found_error( + name="fit_altitude function", + package_name="pwlf", + module_not_found_error=e, + pycontrails_optional_package="pwlf", ) + for i in range(1, max_segments): m2 = pwlf.PiecewiseLinFit(elapsed_time, altitude_ft) r = m2.fitfast(i, pop) diff --git a/pycontrails/core/met.py b/pycontrails/core/met.py index 7ee0d9ef8..180861548 100644 --- a/pycontrails/core/met.py +++ b/pycontrails/core/met.py @@ -36,6 +36,7 @@ from pycontrails.core.cache import CacheStore, DiskCacheStore from pycontrails.core.met_var import AirPressure, Altitude, MetVariable from pycontrails.physics import units +from pycontrails.utils import dependencies from pycontrails.utils import temp as temp_module logger = logging.getLogger(__name__) @@ -1942,18 +1943,22 @@ def to_polyhedra( try: from skimage import measure except ModuleNotFoundError as e: - raise ModuleNotFoundError( - "This method requires the `skimage` module from scikit-learn, which can be " - "installed using `pip install pycontrails[vis]`" - ) from e + dependencies.raise_module_not_found_error( + name="MetDataArray.to_polyhedra method", + package_name="scikit-image", + pycontrails_optional_package="vis", + module_not_found_error=e, + ) try: import open3d as o3d except ModuleNotFoundError as e: - raise ModuleNotFoundError( - "This method requires the `open3d` module, which can be installed " - "with `pip install pycontrails[open3d]` or `pip install open3d`." - ) from e + dependencies.raise_module_not_found_error( + name="MetDataArray.to_polyhedra method", + package_name="open3d", + pycontrails_optional_package="open3d", + module_not_found_error=e, + ) if len(self.data["level"]) == 1: raise ValueError( diff --git a/pycontrails/core/polygon.py b/pycontrails/core/polygon.py index c507fa7f8..d0da7c6ba 100644 --- a/pycontrails/core/polygon.py +++ b/pycontrails/core/polygon.py @@ -14,16 +14,29 @@ import numpy as np import numpy.typing as npt +from pycontrails.utils import dependencies + try: import cv2 +except ModuleNotFoundError as exc: + dependencies.raise_module_not_found_error( + name="polygon module", + package_name="opencv-python", + module_not_found_error=exc, + pycontrails_optional_package="vis", + ) + +try: import shapely import shapely.geometry import shapely.validation except ModuleNotFoundError as exc: - raise ModuleNotFoundError( - "This module requires the 'opencv-python' and 'shapely' packages. " - "These can be installed with 'pip install pycontrails[vis]'." - ) from exc + dependencies.raise_module_not_found_error( + name="polygon module", + package_name="shapely", + module_not_found_error=exc, + pycontrails_optional_package="vis", + ) def buffer_and_clean( diff --git a/pycontrails/core/vector.py b/pycontrails/core/vector.py index 5bf7aed80..72f3dec15 100644 --- a/pycontrails/core/vector.py +++ b/pycontrails/core/vector.py @@ -17,6 +17,7 @@ from pycontrails.core import coordinates, interpolation from pycontrails.core import met as met_module from pycontrails.physics import units +from pycontrails.utils import dependencies from pycontrails.utils import json as json_module logger = logging.getLogger(__name__) @@ -1386,9 +1387,12 @@ def transform_crs( try: import pyproj except ModuleNotFoundError as exc: - raise ModuleNotFoundError( - "Transforming CRS requires the 'pyproj' module. Install with 'pip install pyproj'." - ) from exc + dependencies.raise_module_not_found_error( + name="GeoVectorDataset.transform_crs method", + package_name="pyproj", + module_not_found_error=exc, + pycontrails_optional_package="pyproj", + ) transformer = pyproj.Transformer.from_crs(self.attrs["crs"], crs, always_xy=True) lon, lat = transformer.transform(self["longitude"], self["latitude"]) diff --git a/pycontrails/datalib/ecmwf/era5.py b/pycontrails/datalib/ecmwf/era5.py index 04efe41fa..3d03f8507 100644 --- a/pycontrails/datalib/ecmwf/era5.py +++ b/pycontrails/datalib/ecmwf/era5.py @@ -28,6 +28,7 @@ TopNetSolarRadiation, TopNetThermalRadiation, ) +from pycontrails.utils import dependencies from pycontrails.utils.temp import temp_file if TYPE_CHECKING: @@ -497,10 +498,12 @@ def _download_file(self, times: list[datetime]) -> None: try: import cdsapi except ModuleNotFoundError as e: - raise ModuleNotFoundError( - "Some `ecmwf` module dependencies are missing. " - "Please install all required dependencies using `pip install -e .[ecmwf]`" - ) from e + dependencies.raise_module_not_found_error( + name="ERA5._download_file method", + package_name="cdsapi", + module_not_found_error=e, + pycontrails_optional_package="ecmwf", + ) try: self.cds = cdsapi.Client(url=self.url, key=self.key) diff --git a/pycontrails/datalib/ecmwf/hres.py b/pycontrails/datalib/ecmwf/hres.py index 86eb64697..6245f770e 100644 --- a/pycontrails/datalib/ecmwf/hres.py +++ b/pycontrails/datalib/ecmwf/hres.py @@ -27,7 +27,7 @@ TopNetSolarRadiation, TopNetThermalRadiation, ) -from pycontrails.utils.iteration import chunk_list +from pycontrails.utils import dependencies, iteration from pycontrails.utils.temp import temp_file from pycontrails.utils.types import DatetimeLike @@ -271,10 +271,12 @@ def __init__( try: from ecmwfapi import ECMWFService except ModuleNotFoundError as e: - raise ModuleNotFoundError( - "Some `ecmwf` module dependencies are missing. " - "Please install all required dependencies using `pip install -e .[ecmwf]`" - ) from e + dependencies.raise_module_not_found_error( + name="HRES class", + package_name="ecmwf-api-client", + module_not_found_error=e, + pycontrails_optional_package="ecmwf", + ) # constants # ERA5 now delays creating the server attribute until it is needed to download @@ -596,7 +598,7 @@ def download_dataset(self, times: list[datetime]) -> None: # download in sets of 24 if len(steps) > 24: - for _steps in chunk_list(steps, 24): + for _steps in iteration.chunk_list(steps, 24): self._download_file(_steps) elif len(steps) > 0: self._download_file(steps) diff --git a/pycontrails/datalib/gfs/gfs.py b/pycontrails/datalib/gfs/gfs.py index 3c7fe8daf..b35caf531 100644 --- a/pycontrails/datalib/gfs/gfs.py +++ b/pycontrails/datalib/gfs/gfs.py @@ -30,6 +30,7 @@ TOAUpwardShortwaveRadiation, Visibility, ) +from pycontrails.utils import dependencies from pycontrails.utils.temp import temp_file from pycontrails.utils.types import DatetimeLike @@ -135,12 +136,23 @@ def __init__( ): try: import boto3 + except ModuleNotFoundError as e: + dependencies.raise_module_not_found_error( + name="GFSForecast class", + package_name="boto3", + module_not_found_error=e, + pycontrails_optional_package="gfs", + ) + + try: import botocore except ModuleNotFoundError as e: - raise ModuleNotFoundError( - "`gfs` module dependencies are missing. " - "Please install all required dependencies using `pip install -e .[gfs]`" - ) from e + dependencies.raise_module_not_found_error( + name="GFSForecast class", + package_name="botocore", + module_not_found_error=e, + pycontrails_optional_package="gfs", + ) # inputs self.paths = paths @@ -650,11 +662,12 @@ def _download_with_progress( try: from tqdm import tqdm except ModuleNotFoundError as e: - raise ModuleNotFoundError( - "Download with progress requires the `tqdm` module, " - "which can be installed using `pip install pycontrails[gfs]`. " - "Alternatively, set instance attribute `show_progress=False`." - ) from e + dependencies.raise_module_not_found_error( + name="_download_with_progress function", + package_name="tqdm", + module_not_found_error=e, + pycontrails_optional_package="gfs", + ) meta = client.head_object(Bucket=bucket, Key=key) filesize = meta["ContentLength"] diff --git a/pycontrails/ext/bada.py b/pycontrails/ext/bada.py index 4248fcb45..6f0426da4 100644 --- a/pycontrails/ext/bada.py +++ b/pycontrails/ext/bada.py @@ -22,8 +22,8 @@ except ImportError as e: raise ImportError( - 'Failed to import `pycontrails-bada` extension. Install with `pip install "pycontrails-bada' - ' @ git+ssh://git@github.com/contrailcirrus/pycontrails-bada.git"`' + "Failed to import the 'pycontrails-bada' package. Install with 'pip install" + ' "pycontrails-bada @ git+ssh://git@github.com/contrailcirrus/pycontrails-bada.git"\'.' ) from e else: __all__ = [ diff --git a/pycontrails/ext/cirium.py b/pycontrails/ext/cirium.py index 3342684c4..448f5f4e7 100644 --- a/pycontrails/ext/cirium.py +++ b/pycontrails/ext/cirium.py @@ -7,8 +7,8 @@ except ImportError as e: raise ImportError( - "Failed to import `pycontrails-cirium` extension. Install with `pip install" - ' "pycontrails-cirium @ git+ssh://git@github.com/contrailcirrus/pycontrails-cirium.git"`' + "Failed to import the 'pycontrails-cirium' package. Install with 'pip install" + ' "pycontrails-cirium @ git+ssh://git@github.com/contrailcirrus/pycontrails-cirium.git"\'.' ) from e else: __all__ = ["Cirium"] diff --git a/pycontrails/models/accf.py b/pycontrails/models/accf.py index 392a1cc0e..041e3dffd 100644 --- a/pycontrails/models/accf.py +++ b/pycontrails/models/accf.py @@ -21,6 +21,7 @@ from pycontrails.core.models import Model, ModelParams from pycontrails.core.vector import GeoVectorDataset from pycontrails.datalib import ecmwf +from pycontrails.utils import dependencies WideBodyJets = { "A332", @@ -143,14 +144,6 @@ def __init__( params: dict[str, Any] = {}, **params_kwargs: Any, ) -> None: - try: - import climaccf # noqa: F401 - except ModuleNotFoundError as e: - raise ModuleNotFoundError( - "Requires the climaccf package which can be installed" - "using pip install pycontrails[accf]" - ) from e - # Normalize ECMWF variables met = standardize_variables(met, self.met_variables) @@ -213,8 +206,15 @@ def eval( NotImplementedError Raises if input ``source`` is not supported. """ - - from climaccf.accf import GeTaCCFs + try: + from climaccf.accf import GeTaCCFs + except ModuleNotFoundError as e: + dependencies.raise_module_not_found_error( + name="ACCF.eval method", + package_name="climaccf", + module_not_found_error=e, + pycontrails_optional_package="accf", + ) self.update_params(params) self.set_source(source) diff --git a/pycontrails/models/cocip/output_formats.py b/pycontrails/models/cocip/output_formats.py index 119d088e9..039dd26f2 100644 --- a/pycontrails/models/cocip/output_formats.py +++ b/pycontrails/models/cocip/output_formats.py @@ -31,8 +31,8 @@ from pycontrails.models.cocip.radiative_forcing import albedo from pycontrails.models.humidity_scaling import HumidityScaling from pycontrails.models.tau_cirrus import tau_cirrus -from pycontrails.physics import geo, units -from pycontrails.physics.thermo import rho_d +from pycontrails.physics import geo, thermo, units +from pycontrails.utils import dependencies # ----------------------- # Flight waypoint outputs @@ -920,7 +920,7 @@ def time_slice_statistics( * flight_waypoints["segment_length"] ) contrails["pressure"] = units.m_to_pl(contrails["altitude"]) - contrails["rho_air"] = rho_d(contrails["air_temperature"], contrails["pressure"]) + contrails["rho_air"] = thermo.rho_d(contrails["air_temperature"], contrails["pressure"]) contrails["plume_mass_per_m"] = plume_mass_per_distance( contrails["area_eff"], contrails["rho_air"] ) @@ -1586,7 +1586,11 @@ def contrails_to_hi_res_grid( try: from tqdm.auto import tqdm except ModuleNotFoundError as exc: - raise ModuleNotFoundError("Install the 'tqdm' package") from exc + dependencies.raise_module_not_found_error( + name="contrails_to_hi_res_grid function", + package_name="tqdm", + module_not_found_error=exc, + ) for i in tqdm(heads_t.index[:2000]): contrail_segment = GeoVectorDataset( diff --git a/pycontrails/models/cocipgrid/cocip_time_handling.py b/pycontrails/models/cocipgrid/cocip_time_handling.py index b4b385156..ad3c14c99 100644 --- a/pycontrails/models/cocipgrid/cocip_time_handling.py +++ b/pycontrails/models/cocipgrid/cocip_time_handling.py @@ -15,6 +15,7 @@ from pycontrails.core.met import MetDataset from pycontrails.core.vector import GeoVectorDataset from pycontrails.models.cocip import cocip +from pycontrails.utils import dependencies if TYPE_CHECKING: import tqdm @@ -159,13 +160,13 @@ def init_pbar(self) -> "tqdm.tqdm" | None: try: from tqdm.auto import tqdm - except ModuleNotFoundError as e: - raise ModuleNotFoundError( - f"Running model {type(self).__name__} with parameter " - "show_progress=True requires the 'tqdm' module, which can be " - "installed with 'pip install tqdm'. " - "Alternatively, set model parameter 'show_progress=False'." - ) from e + except ModuleNotFoundError as exc: + dependencies.raise_module_not_found_error( + name="CocipGrid.init_pbar method", + package_name="tqdm", + module_not_found_error=exc, + extra="Alternatively, set model parameter 'show_progress=False'.", + ) estimate = self._estimate_runtime() diff --git a/pycontrails/utils/dependencies.py b/pycontrails/utils/dependencies.py new file mode 100644 index 000000000..8c918194a --- /dev/null +++ b/pycontrails/utils/dependencies.py @@ -0,0 +1,62 @@ +"""Raise ``ModuleNotFoundError`` when dependencies are not met.""" + +from typing import NoReturn + + +def raise_module_not_found_error( + name: str, + package_name: str, + module_not_found_error: ModuleNotFoundError, + pycontrails_optional_package: str | None = None, + extra: str | None = None, +) -> NoReturn: + """Raise ``ModuleNotFoundError`` with a helpful message. + + Parameters + ---------- + name : str + The name describing the context of the ``ModuleNotFoundError``. For example, + if the module is required for a specific function, the name could be + "my_function function". If the module is required for a specific method, + the name could be "MyClass.my_method method". If the module is required + for an entire pycontrails module, the name could be "my_module module". + package_name : str + The name of the package that is required. This should be the full name of + the python package, which may be different from the name of the module + that is actually imported. For example, if ``import sklearn`` triggers + the ``ModuleNotFoundError``, the ``package_name`` should be "scikit-learn". + module_not_found_error : ModuleNotFoundError + The ``ModuleNotFoundError`` that was raised. This is simply passed to the + ``from`` clause of the ``raise`` statement below. + pycontrails_optional_package : str, optional + The name of the optional pycontrails package that can be used to + install the required package. See the ``pyproject.toml`` file. + extra : str, optional + Any extra information that should be included in the error message. + This is appended to the end of the error message. + """ + # Put the function or method or module name in quotes if the full name + # contains a space. + try: + n1, n2 = name.split(" ") + except ValueError: + if "'" not in name: + name = f"'{name}'" + else: + if "'" not in n1: + n1 = f"'{n1}'" + name = f"{n1} {n2}" + + msg = ( + f"The {name} requires the '{package_name}' package. " + f"This can be installed with 'pip install {package_name}'" + ) + if pycontrails_optional_package: + msg = f"{msg} or 'pip install pycontrails[{pycontrails_optional_package}]'." + else: + msg = f"{msg}." + + if extra: + msg = f"{msg} {extra}" + + raise ModuleNotFoundError(msg) from module_not_found_error