Skip to content

Commit

Permalink
Merge pull request #110 from ImperialCollegeLondon/7-downloader-mock
Browse files Browse the repository at this point in the history
Mock downloads with option to turn on
  • Loading branch information
barneydobson authored Mar 25, 2024
2 parents f58d64c + 5613991 commit 3c9d0fa
Show file tree
Hide file tree
Showing 5 changed files with 145 additions and 31 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ module = "tests.*"
disallow_untyped_defs = false

[tool.pytest.ini_options]
# addopts = "-v --mypy -p no:warnings --cov=swmmanywhere --cov-report=html --doctest-modules --ignore=swmmanywhere/__main__.py"
addopts = "-v -p no:warnings --cov=swmmanywhere --cov-report=html --doctest-modules --ignore=swmmanywhere/__main__.py"

[tool.ruff]
select = ["D", "E", "F", "I"] # pydocstyle, pycodestyle, Pyflakes, isort
Expand Down
5 changes: 3 additions & 2 deletions swmmanywhere/geospatial_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,9 @@ def get_transformer(source_crs: str,
Example:
>>> transformer = get_transformer('EPSG:4326', 'EPSG:32630')
>>> transformer.transform(-0.1276, 51.5074)
(699330.1106898375, 5710164.30300683)
>>> x, y = transformer.transform(-0.1276, 51.5074)
>>> print(f"{x:.6f}, {y:.6f}")
699330.110690, 5710164.303007
"""
return pyproj.Transformer.from_crs(source_crs,
target_crs,
Expand Down
39 changes: 21 additions & 18 deletions swmmanywhere/post_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import re
import shutil
from pathlib import Path
from typing import Literal
from typing import Any, Literal

import geopandas as gpd
import numpy as np
Expand Down Expand Up @@ -256,7 +256,7 @@ def data_dict_to_inp(data_dict: dict[str, np.ndarray],
# Set the flow routing
change_flow_routing(routing, new_input_file)

def explode_polygon(row):
def explode_polygon(row: pd.Series):
"""Explode a polygon into a DataFrame of coordinates.
Args:
Expand All @@ -269,12 +269,12 @@ def explode_polygon(row):
... 'geometry' : Polygon([(0,0), (1,0),
... (1,1), (0,1)])})
>>> explode_polygon(df)
x y subcatchment
0 0 0 1
1 1 0 1
2 1 1 1
3 0 1 1
4 0 0 1
x y subcatchment
0 0.0 0.0 1
1 1.0 0.0 1
2 1.0 1.0 1
3 0.0 1.0 1
4 0.0 0.0 1
"""
# Get the vertices of the polygon
vertices = list(row['geometry'].exterior.coords)
Expand All @@ -285,12 +285,12 @@ def explode_polygon(row):
df['subcatchment'] = row['subcatchment']
return df

def format_to_swmm_dict(nodes,
outfalls,
conduits,
subs,
event,
symbol):
def format_to_swmm_dict(nodes: pd.DataFrame,
outfalls: pd.DataFrame,
conduits: pd.DataFrame,
subs: gpd.GeoDataFrame,
event: dict[str, Any],
symbol: dict[str, Any]) -> dict[str, np.ndarray]:
"""Format data to a dictionary of data arrays with columns matching SWMM.
These data are the parameters of all assets that are written to the SWMM
Expand All @@ -315,8 +315,9 @@ def format_to_swmm_dict(nodes,
'x', 'y', 'name'.
Example:
>>> import os
>>> import geopandas as gpd
>>> from shapely.geometry import Point
>>> from shapely.geometry import Point, Polygon
>>> nodes = gpd.GeoDataFrame({'id' : ['node1', 'node2'],
... 'x' : [0, 1],
... 'y' : [0, 1],
Expand Down Expand Up @@ -344,8 +345,10 @@ def format_to_swmm_dict(nodes,
... 'rc' : [1],
... 'width' : [1],
... 'slope' : [0.001],
... 'geometry' : [sgeom.Polygon([(0,0), (1,0),
... (1,1), (0,1)])]})
... 'geometry' : [Polygon([(0.0,0.0),
... (1.0,0.0),
... (1.0,1.0),
... (0.0,1.0)])]})
>>> rain_fid = os.path.join(os.path.dirname(os.path.abspath(__file__)),
... '..',
... 'swmmanywhere',
Expand All @@ -358,7 +361,7 @@ def format_to_swmm_dict(nodes,
>>> symbol = {'x' : 0,
... 'y' : 0,
... 'name' : 'name'}
>>> data_dict = stt.format_to_swmm_dict(nodes,
>>> data_dict = format_to_swmm_dict(nodes,
... outfalls,
... conduits,
... subs,
Expand Down
4 changes: 4 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
def pytest_collection_modifyitems(config, items):
"""Skip tests marked with downloads."""
if not config.getoption('markexpr', 'False'):
config.option.markexpr = "not downloads"
126 changes: 116 additions & 10 deletions tests/test_prepare_data.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,31 @@
# -*- coding: utf-8 -*-
"""Created on Tue Oct 18 10:35:51 2022.
"""Test the prepare_data module.
@author: Barney
By default downloads themselves are mocked, but these can be enabled with the
following test command:
pytest -m downloads
"""

# import pytest
import io
import tempfile
from pathlib import Path
from unittest import mock

import geopandas as gpd
import networkx as nx
import osmnx as ox
import pytest
import rasterio
import yaml
from geopy.geocoders import Nominatim

from swmmanywhere import prepare_data as downloaders


# Test get_country
def test_get_uk():
@pytest.mark.downloads
def test_get_uk_download():
"""Check a UK point."""
# Coordinates for London, UK
x = -0.1276
Expand All @@ -26,7 +36,8 @@ def test_get_uk():
assert result[2] == 'GB'
assert result[3] == 'GBR'

def test_get_us():
@pytest.mark.downloads
def test_get_us_download():
"""Check a US point."""
x = -113.43318
y = 33.81869
Expand All @@ -36,7 +47,8 @@ def test_get_us():
assert result[2] == 'US'
assert result[3] == 'USA'

def test_building_downloader():
@pytest.mark.downloads
def test_building_downloader_download():
"""Check buildings are downloaded."""
# Coordinates for small country (VAT)
x = 7.41839
Expand All @@ -57,23 +69,26 @@ def test_building_downloader():
# Make sure has some rows
assert gdf.shape[0] > 0

def test_street_downloader():
@pytest.mark.downloads
def test_street_downloader_download():
"""Check streets are downloaded and a specific point in the graph."""
bbox = (-0.17929,51.49638, -0.17383,51.49846)
G = downloaders.download_street(bbox)

# Not sure if they they are likely to change the osmid
assert 26389449 in G.nodes

def test_river_downloader():
@pytest.mark.downloads
def test_river_downloader_download():
"""Check rivers are downloaded and a specific point in the graph."""
bbox = (0.0402, 51.55759, 0.09825591114207548, 51.6205)
G = downloaders.download_river(bbox)

# Not sure if they they are likely to change the osmid
assert 21473922 in G.nodes

def test_elevation_downloader():
@pytest.mark.downloads
def test_elevation_downloader_download():
"""Check elevation downloads, writes, contains data, and a known elevation."""
# Please do not reuse api_key
test_api_key = 'b206e65629ac0e53d599e43438560d28'
Expand Down Expand Up @@ -101,4 +116,95 @@ def test_elevation_downloader():

# Test some property of data (not sure if they may change this
# data)
assert data.max().max() > 25, "Elevation data should be higher."
assert data.max().max() > 25, "Elevation data should be higher."

def test_get_uk():
"""Check a UK point."""
# Coordinates for London, UK
x = -0.1276
y = 51.5074

# Create a mock response for geolocator.reverse
mock_location = mock.Mock()
mock_location.raw = {'address': {'country_code': 'gb'}}

# Mock Nominatim
with mock.patch.object(Nominatim, 'reverse', return_value=mock_location):
# Mock yaml.safe_load
with mock.patch.object(yaml, 'safe_load', return_value={'GB': 'GBR'}):
# Call get_country
result = downloaders.get_country(x, y)

assert result[2] == 'GB'
assert result[3] == 'GBR'

def test_building_downloader():
"""Check buildings are downloaded."""
# Coordinates for small country (VAT)
x = 7.41839
y = 43.73205
with tempfile.TemporaryDirectory() as temp_dir:
temp_fid = Path(temp_dir) / 'temp.parquet'
mock_response = mock.Mock()
mock_response.status_code = 200
mock_response.content = b'{"features": []}'
with mock.patch('requests.get',
return_value=mock_response) as mock_get:
# Call your function that uses requests.get
response = downloaders.download_buildings(temp_fid, x, y)

# Assert that requests.get was called with the right arguments
mock_get.assert_called_once_with('https://data.source.coop/vida/google-microsoft-open-buildings/geoparquet/by_country/country_iso=MCO/MCO.parquet')

# Check response
assert response == 200

def test_street_downloader():
"""Check streets are downloaded and a specific point in the graph."""
bbox = (-0.17929,51.49638, -0.17383,51.49846)

mock_graph = nx.MultiDiGraph()
# Mock ox.graph_from_bbox
with mock.patch.object(ox, 'graph_from_bbox', return_value=mock_graph):
# Call download_street
G = downloaders.download_street(bbox)
assert G == mock_graph

def test_river_downloader():
"""Check rivers are downloaded and a specific point in the graph."""
bbox = (0.0402, 51.55759, 0.09825591114207548, 51.6205)

mock_graph = nx.MultiDiGraph()
# Mock ox.graph_from_bbox
with mock.patch.object(ox, 'graph_from_bbox', return_value=mock_graph):
# Call download_street
G = downloaders.download_river(bbox)
assert G == mock_graph

def test_elevation_downloader():
"""Check elevation downloads, writes, contains data, and a known elevation."""
# Please do not reuse api_key
test_api_key = 'b206e65629ac0e53d599e43438560d28'

bbox = (-0.17929,51.49638, -0.17383,51.49846)

with tempfile.TemporaryDirectory() as temp_dir:
temp_fid = Path(temp_dir) / 'temp.tif'

mock_response = mock.Mock()
mock_response.status_code = 200
mock_response.raw = io.BytesIO(b'25')
with mock.patch('requests.get',
return_value=mock_response) as mock_get:
# Call your function that uses requests.get
response = downloaders.download_elevation(temp_fid,
bbox,
test_api_key)
# Assert that requests.get was called with the right arguments
assert 'https://portal.opentopography.org/API/globaldem?demtype=NASADEM&south=51.49638&north=51.49846&west=-0.17929&east=-0.17383&outputFormat=GTiff&API_Key' in mock_get.call_args[0][0] # noqa: E501

# Check response
assert response == 200

# Check response
assert temp_fid.exists(), "Elevation data file not found after download."

0 comments on commit 3c9d0fa

Please sign in to comment.