diff --git a/dev-requirements.txt b/dev-requirements.txt index c42a01d3..e73b3a4e 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -48,6 +48,7 @@ click==8.1.7 # cligj # fiona # pip-tools + # planetary-computer # rasterio click-plugins==1.1.1 # via @@ -109,7 +110,9 @@ iniconfig==2.0.0 joblib==1.4.2 # via swmmanywhere (pyproject.toml) jsonschema==4.22.0 - # via swmmanywhere (pyproject.toml) + # via + # pystac + # swmmanywhere (pyproject.toml) jsonschema-specifications==2023.12.1 # via jsonschema julian==0.14 @@ -120,8 +123,6 @@ llvmlite==0.42.0 # via numba loguru==0.7.2 # via swmmanywhere (pyproject.toml) -looseversion==1.3.0 - # via pysheds multiurl==0.3.1 # via cads-api-client mypy==1.10.0 @@ -177,6 +178,7 @@ packaging==24.1 # fastparquet # geopandas # lazy-loader + # planetary-computer # pyswmm # pytest # rioxarray @@ -196,6 +198,8 @@ pillow==10.3.0 # scikit-image pip-tools==7.4.1 # via swmmanywhere (pyproject.toml) +planetary-computer==1.0.0 + # via swmmanywhere (pyproject.toml) platformdirs==4.2.2 # via virtualenv pluggy==1.5.0 @@ -205,7 +209,9 @@ pre-commit==3.7.1 pyarrow==16.1.0 # via swmmanywhere (pyproject.toml) pydantic==2.7.3 - # via swmmanywhere (pyproject.toml) + # via + # planetary-computer + # swmmanywhere (pyproject.toml) pydantic-core==2.18.4 # via pydantic pyflwdir==0.5.8 @@ -223,6 +229,14 @@ pyproject-hooks==1.1.0 # pip-tools pysheds==0.3.5 # via swmmanywhere (pyproject.toml) +pystac[validation]==1.10.1 + # via + # planetary-computer + # pystac-client +pystac-client==0.8.2 + # via + # planetary-computer + # swmmanywhere (pyproject.toml) pyswmm==2.0.1 # via swmmanywhere (pyproject.toml) pytest==8.2.2 @@ -241,10 +255,15 @@ python-dateutil==2.9.0.post0 # via # multiurl # pandas + # pystac + # pystac-client +python-dotenv==1.0.1 + # via planetary-computer pytz==2024.1 # via # multiurl # pandas + # planetary-computer pyyaml==6.0.1 # via # pre-commit @@ -264,6 +283,8 @@ requests==2.32.3 # cdsapi # multiurl # osmnx + # planetary-computer + # pystac-client rioxarray==0.15.5 # via swmmanywhere (pyproject.toml) rpds-py==0.18.1 diff --git a/pyproject.toml b/pyproject.toml index 18684dda..3d93acec 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,10 +39,12 @@ dependencies = [ "numpy", "osmnx", "pandas", + "planetary_computer", "pyarrow", "pydantic", "pyflwdir", "pysheds", + "pystac_client", "pyswmm", "PyYAML", "rasterio", diff --git a/requirements.txt b/requirements.txt index b8ac3813..8d5374bc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -39,6 +39,7 @@ click==8.1.7 # click-plugins # cligj # fiona + # planetary-computer # rasterio click-plugins==1.1.1 # via @@ -86,7 +87,9 @@ imageio==2.33.1 joblib==1.3.2 # via swmmanywhere (pyproject.toml) jsonschema==4.21.1 - # via swmmanywhere (pyproject.toml) + # via + # pystac + # swmmanywhere (pyproject.toml) jsonschema-specifications==2023.12.1 # via jsonschema julian==0.14 @@ -140,6 +143,7 @@ packaging==23.2 # fastparquet # geopandas # lazy-loader + # planetary-computer # pyswmm # rioxarray # scikit-image @@ -156,10 +160,14 @@ pillow==10.2.0 # via # imageio # scikit-image +planetary-computer==1.0.0 + # via swmmanywhere (pyproject.toml) pyarrow==14.0.2 # via swmmanywhere (pyproject.toml) pydantic==2.5.3 - # via swmmanywhere (pyproject.toml) + # via + # planetary-computer + # swmmanywhere (pyproject.toml) pydantic-core==2.14.6 # via pydantic pyflwdir==0.5.8 @@ -173,12 +181,27 @@ pyproj==3.6.1 # rioxarray pysheds==0.3.5 # via swmmanywhere (pyproject.toml) +pystac[validation]==1.10.1 + # via + # planetary-computer + # pystac-client +pystac-client==0.8.2 + # via + # planetary-computer + # swmmanywhere (pyproject.toml) pyswmm==1.5.1 # via swmmanywhere (pyproject.toml) python-dateutil==2.9.0.post0 - # via pandas + # via + # pandas + # pystac + # pystac-client +python-dotenv==1.0.1 + # via planetary-computer pytz==2024.1 - # via pandas + # via + # pandas + # planetary-computer pyyaml==6.0.1 # via swmmanywhere (pyproject.toml) rasterio==1.3.9 @@ -194,6 +217,8 @@ requests==2.31.0 # via # cdsapi # osmnx + # planetary-computer + # pystac-client rioxarray==0.15.1 # via swmmanywhere (pyproject.toml) rpds-py==0.18.0 diff --git a/swmmanywhere/defs/schema.yml b/swmmanywhere/defs/schema.yml index 0acf51db..abc7b6a0 100644 --- a/swmmanywhere/defs/schema.yml +++ b/swmmanywhere/defs/schema.yml @@ -3,7 +3,6 @@ properties: base_dir: {type: string} project: {type: string} bbox: {type: array, items: {type: number}, minItems: 4, maxItems: 4} - api_keys: {type: string} model_number: {type: integer} run_settings: type: object @@ -31,4 +30,4 @@ properties: metric_list: {type: array, items: {type: string}} address_overrides: {type: ['object', 'null']} parameter_overrides: {type: ['object', 'null']} -required: [base_dir, project, bbox, api_keys, graphfcn_list] \ No newline at end of file +required: [base_dir, project, bbox, graphfcn_list] \ No newline at end of file diff --git a/swmmanywhere/prepare_data.py b/swmmanywhere/prepare_data.py index ced87610..f4b29f51 100644 --- a/swmmanywhere/prepare_data.py +++ b/swmmanywhere/prepare_data.py @@ -4,7 +4,6 @@ """ from __future__ import annotations -import shutil from pathlib import Path from typing import cast @@ -12,7 +11,11 @@ import networkx as nx import osmnx as ox import pandas as pd +import planetary_computer +import pystac_client import requests +import rioxarray +import rioxarray.merge as rxr_merge import xarray as xr from geopy.geocoders import Nominatim @@ -138,56 +141,41 @@ def download_river(bbox: tuple[float, float, float, float]) -> nx.MultiDiGraph: return cast("nx.MultiDiGraph", graph) def download_elevation(fid: Path, - bbox: tuple[float, float, float, float], - api_key: str ='') -> int: - """Download NASADEM elevation data from OpenTopography API. + bbox: tuple[float, float, float, float]) -> None: + """Download NASADEM elevation data from Microsoft Planetary computer. - Downloads elevation data in GeoTIFF format from OpenTopography API based on - the specified bounding box. + Downloads elevation data in GeoTIFF format from Microsoft Planetary computer + based on the specified bounding box. Args: fid (Path): File path to save the downloaded elevation data. - bbox (tuple): Bounding box coordinates in the format - (minx, miny, maxx, maxy). - api_key (str, optional): Your OpenTopography API key. - Defaults to ''. - - Returns: - status_code (int): Response status code - - Raises: - requests.exceptions.RequestException: If there is an error in the API - request. + bbox (tuple[float, float, float, float]): Bounding box as tuple in form + of (west, south, east, north) at EPSG:4326. Example: ``` bbox = (-120, 35, -118, 37) # Example bounding box coordinates download_elevation('elevation_data.tif', - bbox, - api_key='your_actual_api_key') + bbox) ``` - Note: - To obtain an API key, you need to sign up on the OpenTopography - website. - + Author: + cheginit """ - minx, miny, maxx, maxy = bbox - url = f'https://portal.opentopography.org/API/globaldem?demtype=NASADEM&south={miny}&north={maxy}&west={minx}&east={maxx}&outputFormat=GTiff&API_Key={api_key}' - - try: - r = requests.get(url, stream=True) - r.raise_for_status() - - with fid.open('wb') as rast_file: - shutil.copyfileobj(r.raw, rast_file) - - logger.info('Elevation data downloaded successfully.') - - except requests.exceptions.RequestException as e: - logger.error(f'Error downloading elevation data: {e}') - - return r.status_code + catalog = pystac_client.Client.open( + "https://planetarycomputer.microsoft.com/api/stac/v1", + modifier=planetary_computer.sign_inplace, + ) + search = catalog.search( + collections=["nasadem"], + bbox=bbox, + ) + signed_asset = (planetary_computer.sign(item.assets["elevation"]).href + for item in search.items()) + dem = rxr_merge.merge_arrays([rioxarray.open_rasterio(href).squeeze(drop=True) + for href in signed_asset]) + dem = dem.rio.clip_box(*bbox) + dem.rio.to_raster(fid) def download_precipitation(bbox: tuple[float, float, float, float], start_date: str = '2015-01-01', diff --git a/swmmanywhere/preprocessing.py b/swmmanywhere/preprocessing.py index f0dcdc7f..b69b2c83 100644 --- a/swmmanywhere/preprocessing.py +++ b/swmmanywhere/preprocessing.py @@ -182,7 +182,6 @@ def prepare_precipitation(bbox: tuple[float, float, float, float], def prepare_elevation(bbox: tuple[float, float, float, float], addresses: parameters.FilePaths, - api_keys: dict[str, str], target_crs: str): """Download and reproject elevation data.""" if addresses.elevation.exists(): @@ -192,7 +191,6 @@ def prepare_elevation(bbox: tuple[float, float, float, float], fid = Path(temp_dir) / 'elevation.tif' prepare_data.download_elevation(fid, bbox, - api_keys['nasadem_key'] ) go.reproject_raster(target_crs, fid, @@ -280,7 +278,6 @@ def prepare_river(bbox: tuple[float, float, float, float], def run_downloads(bbox: tuple[float, float, float, float], addresses: parameters.FilePaths, - api_keys: dict[str, str], network_types = ['drive']): """Run the data downloads. @@ -292,16 +289,16 @@ def run_downloads(bbox: tuple[float, float, float, float], bbox (tuple[float, float, float, float]): Bounding box coordinates in the format (minx, miny, maxx, maxy) in EPSG:4326. addresses (FilePaths): Class containing the addresses of the directories. - api_keys (dict): Dictionary containing the API keys. network_types (list): List of network types to download. """ target_crs = go.get_utm_epsg(bbox[0], bbox[1]) # Download precipitation data - prepare_precipitation(bbox, addresses, api_keys, target_crs) + # Currently commented because it doesn't work + # prepare_precipitation(bbox, addresses, api_keys, target_crs) # Download elevation data - prepare_elevation(bbox, addresses, api_keys, target_crs) + prepare_elevation(bbox, addresses, target_crs) # Download building data prepare_building(bbox, addresses, target_crs) diff --git a/swmmanywhere/swmmanywhere.py b/swmmanywhere/swmmanywhere.py index 32a3d7d0..26333cc6 100644 --- a/swmmanywhere/swmmanywhere.py +++ b/swmmanywhere/swmmanywhere.py @@ -65,10 +65,8 @@ def swmmanywhere(config: dict) -> tuple[Path, dict | None]: # Run downloads logger.info("Running downloads.") - api_keys = yaml_load(config['api_keys'].read_text()) preprocessing.run_downloads(config['bbox'], addresses, - api_keys, network_types = params['topology_derivation'].allowable_networks ) @@ -144,7 +142,7 @@ def swmmanywhere(config: dict) -> tuple[Path, dict | None]: return addresses.inp, metrics def check_top_level_paths(config: dict): - """Check the top level paths in the config. + """Check the top level paths (`base_dir`) in the config. Args: config (dict): The configuration. @@ -152,10 +150,10 @@ def check_top_level_paths(config: dict): Raises: FileNotFoundError: If a top level path does not exist. """ - for key in ['base_dir', 'api_keys']: - if not Path(config[key]).exists(): - raise FileNotFoundError(f"{key} not found at {config[key]}") - config[key] = Path(config[key]) + key = 'base_dir' + if not Path(config[key]).exists(): + raise FileNotFoundError(f"{key} not found at {config[key]}") + config[key] = Path(config[key]) return config def check_address_overrides(config: dict): diff --git a/tests/test_data/demo_config.yml b/tests/test_data/demo_config.yml index f1c2182d..43aebd55 100644 --- a/tests/test_data/demo_config.yml +++ b/tests/test_data/demo_config.yml @@ -1,7 +1,6 @@ base_dir: /path/to/base/directory project: demo bbox: [0.04020, 51.55759, 0.09826, 51.62050] -api_keys: /path/to/api/keys.yml run_settings: reporting_iters: 100 duration: 86400 diff --git a/tests/test_prepare_data.py b/tests/test_prepare_data.py index 109d8310..634415b0 100644 --- a/tests/test_prepare_data.py +++ b/tests/test_prepare_data.py @@ -8,7 +8,6 @@ """ from __future__ import annotations -import io import tempfile from pathlib import Path from unittest import mock @@ -91,20 +90,14 @@ def test_river_downloader_download(): @pytest.mark.downloads def test_elevation_downloader_download(): """Check elevation downloads, writes, contains data, and a known elevation.""" - # Please do not reuse api_key - test_api_key = 'b206e65629ac0e53d599e43438560d28' - bbox = (-0.17929,51.49638, -0.17383,51.49846) with tempfile.TemporaryDirectory() as temp_dir: temp_fid = Path(temp_dir) / 'temp.tif' # Download - response = downloaders.download_elevation(temp_fid, bbox, test_api_key) + downloaders.download_elevation(temp_fid, bbox) - # Check response - assert response == 200 - # Check response assert temp_fid.exists(), "Elevation data file not found after download." @@ -202,30 +195,62 @@ def test_river_downloader(): assert G.size() == 0 -def test_elevation_downloader(): +def test_download_elevation(): """Check elevation downloads, writes, contains data, and a known elevation.""" - # Please do not reuse api_key - test_api_key = 'b206e65629ac0e53d599e43438560d28' - - bbox = (-0.17929,51.49638, -0.17383,51.49846) + bbox = (-0.17929, 51.49638, -0.17383, 51.49846) with tempfile.TemporaryDirectory() as temp_dir: temp_fid = Path(temp_dir) / 'temp.tif' - - mock_response = mock.Mock() - mock_response.status_code = 200 - mock_response.raw = io.BytesIO(b'25') - with mock.patch('requests.get', - return_value=mock_response) as mock_get: - # Call your function that uses requests.get - response = downloaders.download_elevation(temp_fid, - bbox, - test_api_key) - # Assert that requests.get was called with the right arguments - assert 'https://portal.opentopography.org/API/globaldem?demtype=NASADEM&south=51.49638&north=51.49846&west=-0.17929&east=-0.17383&outputFormat=GTiff&API_Key' in mock_get.call_args[0][0] # noqa: E501 - # Check response - assert response == 200 - - # Check response - assert temp_fid.exists(), "Elevation data file not found after download." \ No newline at end of file + # Mock the external dependencies + module_base = 'swmmanywhere.prepare_data.' + with mock.patch(f'{module_base}pystac_client.Client.open') as mock_open, \ + mock.patch(f'{module_base}planetary_computer.sign_inplace'), \ + mock.patch(f'{module_base}planetary_computer.sign') as mock_sign, \ + mock.patch(f'{module_base}rioxarray.open_rasterio') as mock_open_rasterio, \ + mock.patch(f'{module_base}rxr_merge.merge_arrays') as mock_merge_arrays: + + # Mock the behavior of the catalog search and items + mock_catalog = mock.MagicMock() + mock_open.return_value = mock_catalog + mock_search = mock.MagicMock() + mock_catalog.search.return_value = mock_search + mock_items = [mock.MagicMock(), mock.MagicMock()] + for item in mock_items: + item.assets = {"elevation": mock.MagicMock()} + mock_search.items.return_value = mock_items + + # Mock the signed URLs + mock_sign.side_effect = lambda x: mock.MagicMock(href=f"signed_{x}") + + # Mock the raster data + mock_raster = mock.MagicMock() + mock_open_rasterio.return_value = mock_raster + + # Mock the merged array + mock_merged_array = mock.MagicMock() + mock_merge_arrays.return_value = mock_merged_array + + # Mock the `rio` attribute on the merged array + mock_merged_array.rio = mock.MagicMock() + mock_merged_array.rio.clip_box.return_value = mock_merged_array + mock_merged_array.rio.to_raster.return_value = None + + # Call the function + downloaders.download_elevation(temp_fid, bbox) + + # Assertions + mock_open.assert_called_once_with( + "https://planetarycomputer.microsoft.com/api/stac/v1", + modifier=mock.ANY, + ) + mock_catalog.search.assert_called_once_with( + collections=["nasadem"], + bbox=bbox, + ) + assert len(mock_items) == 2 + assert mock_sign.call_count == len(mock_items) + mock_open_rasterio.assert_called() + mock_merge_arrays.assert_called_once() + mock_merged_array.rio.clip_box.assert_called_once_with(*bbox) + mock_merged_array.rio.to_raster.assert_called_once_with(temp_fid) diff --git a/tests/test_swmmanywhere.py b/tests/test_swmmanywhere.py index dec87ae1..fb8899e4 100644 --- a/tests/test_swmmanywhere.py +++ b/tests/test_swmmanywhere.py @@ -62,10 +62,6 @@ def test_swmmanywhere(): config['parameter_overrides'] = {'subcatchment_derivation' : {'subbasin_streamorder' : 5}} config['run_settings']['duration'] = 1000 - api_keys = {'nasadem_key' : 'b206e65629ac0e53d599e43438560d28'} - with open(base_dir / 'api_keys.yml', 'w') as f: - yaml.dump(api_keys, f) - config['api_keys'] = str(base_dir / 'api_keys.yml') # Fill the real dict with unused paths to avoid filevalidation errors config['real']['subcatchments'] = str(defs_dir / 'storm.dat') @@ -124,7 +120,6 @@ def test_load_config_file_validation(): # Fill with unused paths to avoid filevalidation errors config['base_dir'] = str(defs_dir / 'storm.dat') - config['api_keys'] = str(defs_dir / 'storm.dat') with open(base_dir / 'test_config.yml', 'w') as f: yaml.dump(config, f) @@ -168,7 +163,6 @@ def test_save_config(): # Fill with unused paths to avoid filevalidation errors config['base_dir'] = str(defs_dir / 'storm.dat') - config['api_keys'] = str(defs_dir / 'storm.dat') swmmanywhere.save_config(config, temp_dir / 'test.yml')