diff --git a/pyproject.toml b/pyproject.toml index be65c610..051b95fa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -79,7 +79,7 @@ module = "tests.*" disallow_untyped_defs = false [tool.pytest.ini_options] -# addopts = "-v --mypy -p no:warnings --cov=swmmanywhere --cov-report=html --doctest-modules --ignore=swmmanywhere/__main__.py" +addopts = "-v -p no:warnings --cov=swmmanywhere --cov-report=html --doctest-modules --ignore=swmmanywhere/__main__.py" [tool.ruff] select = ["D", "E", "F", "I"] # pydocstyle, pycodestyle, Pyflakes, isort diff --git a/swmmanywhere/geospatial_utilities.py b/swmmanywhere/geospatial_utilities.py index f5df9859..47af36d1 100644 --- a/swmmanywhere/geospatial_utilities.py +++ b/swmmanywhere/geospatial_utilities.py @@ -1,10 +1,7 @@ -# -*- coding: utf-8 -*- -"""Created 2024-01-20. +"""Geospatial utilities module for SWMManywhere. A module containing functions to perform a variety of geospatial operations, such as reprojecting coordinates and handling raster data. - -@author: Barnaby Dobson """ import itertools import json @@ -192,8 +189,9 @@ def get_transformer(source_crs: str, Example: >>> transformer = get_transformer('EPSG:4326', 'EPSG:32630') - >>> transformer.transform(-0.1276, 51.5074) - (699330.1106898375, 5710164.30300683) + >>> x, y = transformer.transform(-0.1276, 51.5074) + >>> print(f"{x:.6f}, {y:.6f}") + 699330.110690, 5710164.303007 """ return pyproj.Transformer.from_crs(source_crs, target_crs, @@ -419,7 +417,7 @@ def delineate_catchment(grid: pysheds.sgrid.sGrid, # Snap the node to the nearest grid cell x, y = data['x'], data['y'] grid_ = deepcopy(grid) - x_snap, y_snap = grid_.snap_to_mask(flow_acc > 5, (x, y)) + x_snap, y_snap = grid_.snap_to_mask(flow_acc >= 0, (x, y)) # Delineate the catchment catch = grid_.catchment(x=x_snap, diff --git a/swmmanywhere/graph_utilities.py b/swmmanywhere/graph_utilities.py index e3de2af2..a2262d71 100644 --- a/swmmanywhere/graph_utilities.py +++ b/swmmanywhere/graph_utilities.py @@ -1,7 +1,7 @@ -# -*- coding: utf-8 -*- -"""Created on 2024-01-26. +"""Graph utilities module for SWMManywhere. -@author: Barney +A module to contain graphfcns, the graphfcn registry object, and other graph +utilities (such as save/load functions). """ import json import os @@ -208,11 +208,15 @@ def __call__(self, G (nx.Graph): The same graph with an ID assigned to each edge """ edge_ids: set[str] = set() - for u, v, data in G.edges(data=True): + edges_to_remove = [] + for u, v, key, data in G.edges(data=True, keys = True): data['id'] = f'{u}-{v}' if data['id'] in edge_ids: logger.warning(f"Duplicate edge ID: {data['id']}") + edges_to_remove.append((u, v, key)) edge_ids.add(data['id']) + for u, v, key in edges_to_remove: + G.remove_edge(u, v, key) return G @register_graphfcn @@ -409,6 +413,39 @@ def create_new_edge_data(line, data, id_): return graph +@register_graphfcn +class fix_geometries(BaseGraphFunction, + required_edge_attributes = ['geometry'], + required_node_attributes = ['x', 'y']): + """fix_geometries class.""" + def __call__(self, G: nx.Graph, **kwargs) -> nx.Graph: + """Fix the geometries of the edges. + + This function fixes the geometries of the edges. The geometries are + recalculated from the node coordinates. + + Args: + G (nx.Graph): A graph + **kwargs: Additional keyword arguments are ignored. + + Returns: + G (nx.Graph): A graph + """ + G = G.copy() + for u, v, data in G.edges(data=True): + start_point_node = (G.nodes[u]['x'], G.nodes[u]['y']) + start_point_edge = data['geometry'].coords[0] + end_point_node = (G.nodes[v]['x'], G.nodes[v]['y']) + end_point_edge = data['geometry'].coords[-1] + if (start_point_edge == end_point_node) & \ + (end_point_edge == start_point_node): + data['geometry'] = data['geometry'].reverse() + elif (start_point_edge != start_point_node) | \ + (end_point_edge != end_point_node): + data['geometry'] = shapely.LineString([start_point_node, + end_point_node]) + return G + @register_graphfcn class calculate_contributing_area(BaseGraphFunction, required_edge_attributes = ['id', 'geometry', 'width'], diff --git a/swmmanywhere/logging.py b/swmmanywhere/logging.py index 4eff38e0..0ad1a655 100644 --- a/swmmanywhere/logging.py +++ b/swmmanywhere/logging.py @@ -1,7 +1,14 @@ -# -*- coding: utf-8 -*- -"""Created on 2024-03-04. - -@author: Barney +"""Logging module for SWMManywhere. + +Example: +>>> import os +>>> os.environ["SWMMANYWHERE_VERBOSE"] = "true" +>>> # logging is now enabled in any swmmanywhere module +>>> from swmmanywhere.logging import logger # You can now log yourself +>>> logger.info("This is an info message.") # Write to stdout # doctest: +SKIP +This is an info message. +>>> logger.add("file.log") # Add a log file # doctest: +SKIP +>>> os.environ["SWMMANYWHERE_VERBOSE"] = "false" # Disable logging """ import os import sys diff --git a/swmmanywhere/metric_utilities.py b/swmmanywhere/metric_utilities.py index 6a3d61b2..02ecde5d 100644 --- a/swmmanywhere/metric_utilities.py +++ b/swmmanywhere/metric_utilities.py @@ -1,10 +1,11 @@ -# -*- coding: utf-8 -*- -"""Created 2023-12-20. +"""Metric utilities module for SWMManywhere. -@author: Barnaby Dobson +A module for metrics, the metrics registry object and utilities for calculating +metrics (such as NSE or timeseries data alignment) used in SWMManywhere. """ from collections import defaultdict from inspect import signature +from itertools import product from typing import Callable, Optional import cytoolz.curried as tlz @@ -14,8 +15,11 @@ import networkx as nx import numpy as np import pandas as pd +import shapely from scipy import stats +from swmmanywhere.parameters import MetricEvaluation + class MetricRegistry(dict): """Registry object.""" @@ -30,7 +34,8 @@ def register(self, func: Callable) -> Callable: "synthetic_subs": gpd.GeoDataFrame, "real_subs": gpd.GeoDataFrame, "synthetic_G": nx.Graph, - "real_G": nx.Graph} + "real_G": nx.Graph, + "metric_evaluation": MetricEvaluation} sig = signature(func) for param, obj in sig.parameters.items(): @@ -61,7 +66,8 @@ def iterate_metrics(synthetic_results: pd.DataFrame, real_results: pd.DataFrame, real_subs: gpd.GeoDataFrame, real_G: nx.Graph, - metric_list: list[str]) -> dict[str, float]: + metric_list: list[str], + metric_evaluation: MetricEvaluation) -> dict[str, float]: """Iterate a list of metrics over a graph. Args: @@ -72,6 +78,7 @@ def iterate_metrics(synthetic_results: pd.DataFrame, real_subs (gpd.GeoDataFrame): The real subcatchments. real_G (nx.Graph): The real graph. metric_list (list[str]): A list of metrics to iterate. + metric_evaluation (MetricEvaluation): The metric evaluation parameters. Returns: dict[str, float]: The results of the metrics. @@ -87,6 +94,7 @@ def iterate_metrics(synthetic_results: pd.DataFrame, "real_results": real_results, "real_subs": real_subs, "real_G": real_G, + "metric_evaluation": metric_evaluation } return {m : metrics[m](**kwargs) for m in metric_list} @@ -313,18 +321,26 @@ def edge_betweenness_centrality(G: nx.Graph, bt_c[n] += v return bt_c -def align_by_subcatchment(var, +def align_by_shape(var, synthetic_results: pd.DataFrame, real_results: pd.DataFrame, - real_subs: gpd.GeoDataFrame, + shapes: gpd.GeoDataFrame, synthetic_G: nx.Graph, real_G: nx.Graph) -> pd.DataFrame: """Align by subcatchment. - Align synthetic and real results by subcatchment and return the results. + Align synthetic and real results by shape and return the results. + + Args: + var (str): The variable to align. + synthetic_results (pd.DataFrame): The synthetic results. + real_results (pd.DataFrame): The real results. + shapes (gpd.GeoDataFrame): The shapes to align by (e.g., grid or real_subs). + synthetic_G (nx.Graph): The synthetic graph. + real_G (nx.Graph): The real graph. """ - synthetic_joined = nodes_to_subs(synthetic_G, real_subs) - real_joined = nodes_to_subs(real_G, real_subs) + synthetic_joined = nodes_to_subs(synthetic_G, shapes) + real_joined = nodes_to_subs(real_G, shapes) # Extract data real_results = extract_var(real_results, var) @@ -347,6 +363,34 @@ def align_by_subcatchment(var, ) return results +def create_grid(bbox: tuple, + scale: float | tuple[float,float]) -> gpd.GeoDataFrame: + """Create a grid of polygons. + + Create a grid of polygons based on the bounding box and scale. + + Args: + bbox (tuple): The bounding box coordinates in the format (minx, miny, + maxx, maxy). + scale (float | tuple): The scale of the grid. If a tuple, the scale is + (dx, dy). Otherwise, the scale is dx = dy = scale. + + Returns: + gpd.GeoDataFrame: A geodataframe of the grid. + """ + minx, miny, maxx, maxy = bbox + + if isinstance(scale, tuple): + dx, dy = scale + else: + dx = dy = scale + xmins = np.arange(minx, maxx, dx) + ymins = np.arange(minx, maxy, dy) + grid = [{'geometry' : shapely.box(x, y, x + dx, y + dy), + 'sub_id' : i} for i, (x, y) in enumerate(product(xmins, ymins))] + + return gpd.GeoDataFrame(grid) + @metrics.register def nc_deltacon0(synthetic_G: nx.Graph, real_G: nx.Graph, @@ -517,6 +561,121 @@ def outlet_nse_flooding(synthetic_G: nx.Graph, list(sg_syn.nodes), list(sg_real.nodes)) +@metrics.register +def outlet_kstest_diameters(real_G: nx.Graph, + synthetic_G: nx.Graph, + real_results: pd.DataFrame, + real_subs: gpd.GeoDataFrame, + **kwargs) -> float: + """Outlet KStest diameters. + + Calculate the Kolmogorov-Smirnov statistic of the diameters in the subgraph + that drains to the dominant outlet node. The dominant outlet node of the + 'real' network is calculated by dominant_outlet, while the dominant outlet + node of the 'synthetic' network is calculated by best_outlet_match. + """ + # Identify synthetic and real outlet arcs + sg_syn, _ = best_outlet_match(synthetic_G, real_subs) + sg_real, _ = dominant_outlet(real_G, real_results) + + # Extract the diameters + syn_diameters = nx.get_edge_attributes(sg_syn, 'diameter') + real_diameters = nx.get_edge_attributes(sg_real, 'diameter') + return stats.ks_2samp(list(syn_diameters.values()), + list(real_diameters.values())).statistic + +@metrics.register +def outlet_pbias_length(real_G: nx.Graph, + synthetic_G: nx.Graph, + real_results: pd.DataFrame, + real_subs: gpd.GeoDataFrame, + **kwargs) -> float: + r"""Outlet PBIAS length. + + Calculate the percent bias of the total edge length in the subgraph that + drains to the dominant outlet node. The dominant outlet node of the 'real' + network is calculated by dominant_outlet, while the dominant outlet node of + the 'synthetic' network is calculated by best_outlet_match. + + The percentage bias is calculated as: + + .. math:: + + pbias = \\frac{{syn\_length - real\_length}}{{real\_length}} + + where: + - :math:`syn\_length` is the synthetic length, + - :math:`real\_length` is the real length. + """ + # Identify synthetic and real outlet arcs + sg_syn, _ = best_outlet_match(synthetic_G, real_subs) + sg_real, _ = dominant_outlet(real_G, real_results) + + # Calculate the percent bias + syn_length = sum([d['length'] for u,v,d in sg_syn.edges(data=True)]) + real_length = sum([d['length'] for u,v,d in sg_real.edges(data=True)]) + return (syn_length - real_length) / real_length + +@metrics.register +def outlet_pbias_nmanholes(real_G: nx.Graph, + synthetic_G: nx.Graph, + real_results: pd.DataFrame, + real_subs: gpd.GeoDataFrame, + **kwargs) -> float: + r"""Outlet PBIAS number of manholes (nodes). + + Calculate the percent bias of the total number of nodes in the subgraph + that drains to the dominant outlet node. The dominant outlet node of the + 'real' network is calculated by dominant_outlet, while the dominant outlet + node of the 'synthetic' network is calculated by best_outlet_match. + + The percentage bias is calculated as: + + .. math:: + + pbias = \\frac{{syn\_nnodes - real\_nnodes}}{{real\_nnodes}} + + where: + - :math:`syn\_nnodes` is the number of synthetic nodes, + - :math:`real\_nnodes` is the real number of nodes. + """ + # Identify synthetic and real outlet arcs + sg_syn, _ = best_outlet_match(synthetic_G, real_subs) + sg_real, _ = dominant_outlet(real_G, real_results) + + return (sg_syn.number_of_nodes() - sg_real.number_of_nodes()) \ + / sg_real.number_of_nodes() + +@metrics.register +def outlet_pbias_npipes(real_G: nx.Graph, + synthetic_G: nx.Graph, + real_results: pd.DataFrame, + real_subs: gpd.GeoDataFrame, + **kwargs) -> float: + r"""Outlet PBIAS number of pipes (edges). + + Calculate the percent bias of the total number of edges in the subgraph + that drains to the dominant outlet node. The dominant outlet node of the + 'real' network is calculated by dominant_outlet, while the dominant outlet + node of the 'synthetic' network is calculated by best_outlet_match. + + + The percentage bias is calculated as: + + .. math:: + + pbias = \\frac{{syn\_nedges - real\_nedges}}{{real\_nedges}} + + where: + - :math:`syn\_nedges` is the number of synthetic edges, + - :math:`real\_nedges` is the real number of edges. + """ + # Identify synthetic and real outlet arcs + sg_syn, _ = best_outlet_match(synthetic_G, real_subs) + sg_real, _ = dominant_outlet(real_G, real_results) + + return (sg_syn.number_of_edges() - sg_real.number_of_edges()) \ + / sg_real.number_of_edges() @metrics.register @@ -532,11 +691,39 @@ def subcatchment_nse_flooding(synthetic_G: nx.Graph, flooding over time for each subcatchment. The metric produced is the median NSE across all subcatchments. """ - results = align_by_subcatchment('flooding', + results = align_by_shape('flooding', + synthetic_results = synthetic_results, + real_results = real_results, + shapes = real_subs, + synthetic_G = synthetic_G, + real_G = real_G) + + return median_nse_by_group(results, 'sub_id') + +@metrics.register +def grid_nse_flooding(synthetic_G: nx.Graph, + real_G: nx.Graph, + synthetic_results: pd.DataFrame, + real_results: pd.DataFrame, + real_subs: gpd.GeoDataFrame, + metric_evaluation: MetricEvaluation, + **kwargs) -> float: + """Grid NSE flooding. + + Classify synthetic nodes to a grid and calculate the NSE of + flooding over time for each grid cell. The metric produced is the median + NSE across all grid cells. + """ + scale = metric_evaluation.grid_scale + grid = create_grid(real_subs.total_bounds, + scale) + grid.crs = real_subs.crs + + results = align_by_shape('flooding', synthetic_results = synthetic_results, real_results = real_results, - real_subs = real_subs, + shapes = grid, synthetic_G = synthetic_G, real_G = real_G) - return median_nse_by_group(results, 'sub_id') \ No newline at end of file + return median_nse_by_group(results, 'sub_id') diff --git a/swmmanywhere/parameters.py b/swmmanywhere/parameters.py index b0b90f57..b2d70c8d 100644 --- a/swmmanywhere/parameters.py +++ b/swmmanywhere/parameters.py @@ -1,8 +1,4 @@ -# -*- coding: utf-8 -*- -"""Created on 2024-01-26. - -@author: Barney -""" +"""Parameters and file paths module for SWMManywhere.""" from pathlib import Path @@ -16,7 +12,8 @@ def get_full_parameters(): "subcatchment_derivation": SubcatchmentDerivation(), "outlet_derivation": OutletDerivation(), "topology_derivation": TopologyDerivation(), - "hydraulic_design": HydraulicDesign() + "hydraulic_design": HydraulicDesign(), + "metric_evaluation": MetricEvaluation() } class SubcatchmentDerivation(BaseModel): @@ -172,12 +169,18 @@ class HydraulicDesign(BaseModel): description = "Depth of design storm in pipe by pipe method", unit = "m") -class FilePaths: - """Parameters for file path lookup. +class MetricEvaluation(BaseModel): + """Parameters for metric evaluation.""" + grid_scale: float = Field(default = 100, + le = 10, + ge = 5000, + unit = "m", + description = "Scale of the grid for metric evaluation") + - TODO: this doesn't validate file paths to allow for un-initialised data - (e.g., subcatchments are created by a graph and so cannot be validated). - """ + +class FilePaths: + """Parameters for file path lookup.""" def __init__(self, base_dir: Path, diff --git a/swmmanywhere/post_processing.py b/swmmanywhere/post_processing.py index bf75b89a..3f1ab9c5 100644 --- a/swmmanywhere/post_processing.py +++ b/swmmanywhere/post_processing.py @@ -1,15 +1,12 @@ -# -*- coding: utf-8 -*- -"""Created 2024-01-22. +"""Post processing module for SWMManywhere. -A module containing functions to format and write processed data into SWMM .inp +A module containing functions to format and write processed data into SWMM .inp files. - -@author: Barnaby Dobson """ import re import shutil from pathlib import Path -from typing import Literal +from typing import Any, Literal import geopandas as gpd import numpy as np @@ -259,7 +256,7 @@ def data_dict_to_inp(data_dict: dict[str, np.ndarray], # Set the flow routing change_flow_routing(routing, new_input_file) -def explode_polygon(row): +def explode_polygon(row: pd.Series): """Explode a polygon into a DataFrame of coordinates. Args: @@ -272,12 +269,12 @@ def explode_polygon(row): ... 'geometry' : Polygon([(0,0), (1,0), ... (1,1), (0,1)])}) >>> explode_polygon(df) - x y subcatchment - 0 0 0 1 - 1 1 0 1 - 2 1 1 1 - 3 0 1 1 - 4 0 0 1 + x y subcatchment + 0 0.0 0.0 1 + 1 1.0 0.0 1 + 2 1.0 1.0 1 + 3 0.0 1.0 1 + 4 0.0 0.0 1 """ # Get the vertices of the polygon vertices = list(row['geometry'].exterior.coords) @@ -288,12 +285,12 @@ def explode_polygon(row): df['subcatchment'] = row['subcatchment'] return df -def format_to_swmm_dict(nodes, - outfalls, - conduits, - subs, - event, - symbol): +def format_to_swmm_dict(nodes: pd.DataFrame, + outfalls: pd.DataFrame, + conduits: pd.DataFrame, + subs: gpd.GeoDataFrame, + event: dict[str, Any], + symbol: dict[str, Any]) -> dict[str, np.ndarray]: """Format data to a dictionary of data arrays with columns matching SWMM. These data are the parameters of all assets that are written to the SWMM @@ -318,8 +315,9 @@ def format_to_swmm_dict(nodes, 'x', 'y', 'name'. Example: + >>> import os >>> import geopandas as gpd - >>> from shapely.geometry import Point + >>> from shapely.geometry import Point, Polygon >>> nodes = gpd.GeoDataFrame({'id' : ['node1', 'node2'], ... 'x' : [0, 1], ... 'y' : [0, 1], @@ -347,8 +345,10 @@ def format_to_swmm_dict(nodes, ... 'rc' : [1], ... 'width' : [1], ... 'slope' : [0.001], - ... 'geometry' : [sgeom.Polygon([(0,0), (1,0), - ... (1,1), (0,1)])]}) + ... 'geometry' : [Polygon([(0.0,0.0), + ... (1.0,0.0), + ... (1.0,1.0), + ... (0.0,1.0)])]}) >>> rain_fid = os.path.join(os.path.dirname(os.path.abspath(__file__)), ... '..', ... 'swmmanywhere', @@ -361,7 +361,7 @@ def format_to_swmm_dict(nodes, >>> symbol = {'x' : 0, ... 'y' : 0, ... 'name' : 'name'} - >>> data_dict = stt.format_to_swmm_dict(nodes, + >>> data_dict = format_to_swmm_dict(nodes, ... outfalls, ... conduits, ... subs, diff --git a/swmmanywhere/prepare_data.py b/swmmanywhere/prepare_data.py index ab2cef1d..79439fc1 100644 --- a/swmmanywhere/prepare_data.py +++ b/swmmanywhere/prepare_data.py @@ -1,7 +1,6 @@ -# -*- coding: utf-8 -*- -"""Created 2023-12-20. +"""Prepare data module for SWMManywhere. -@author: Barnaby Dobson +A module to download data needed for SWMManywhere. """ import shutil diff --git a/swmmanywhere/preprocessing.py b/swmmanywhere/preprocessing.py index 995653dd..95de336d 100644 --- a/swmmanywhere/preprocessing.py +++ b/swmmanywhere/preprocessing.py @@ -1,7 +1,8 @@ -# -*- coding: utf-8 -*- -"""Created on 2024-01-26. +"""Preprocessing module for SWMManywhere. -@author: Barney +A module to call downloads, preprocess these downloads into formats suitable +for graphfcns, and some other utilities (such as creating a project folder +structure or create the starting graph from rivers/streets). """ import json diff --git a/swmmanywhere/swmmanywhere.py b/swmmanywhere/swmmanywhere.py index c0e089bd..5d9decb1 100644 --- a/swmmanywhere/swmmanywhere.py +++ b/swmmanywhere/swmmanywhere.py @@ -1,8 +1,4 @@ -# -*- coding: utf-8 -*- -"""Created on 2024-01-26. - -@author: Barney -""" +"""The main SWMManywhere module to generate and run a synthetic network.""" import os from pathlib import Path @@ -106,7 +102,8 @@ def swmmanywhere(config: dict): real_results, gpd.read_file(config['real']['subcatchments']), load_graph(config['real']['graph']), - config['metric_list']) + config['metric_list'], + parameters['metric_evaluation']) return addresses.inp, metrics diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..c523d703 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,4 @@ +def pytest_collection_modifyitems(config, items): + """Skip tests marked with downloads.""" + if not config.getoption('markexpr', 'False'): + config.option.markexpr = "not downloads" \ No newline at end of file diff --git a/tests/test_data/demo_config.yml b/tests/test_data/demo_config.yml index 16b7e5f4..10a5b698 100644 --- a/tests/test_data/demo_config.yml +++ b/tests/test_data/demo_config.yml @@ -16,6 +16,7 @@ graphfcn_list: - assign_id - format_osmnx_lanes - double_directed + - fix_geometries - split_long_edges - calculate_contributing_area - set_elevation diff --git a/tests/test_graph_utilities.py b/tests/test_graph_utilities.py index 7601fdba..002a30f6 100644 --- a/tests/test_graph_utilities.py +++ b/tests/test_graph_utilities.py @@ -4,6 +4,7 @@ @author: Barney """ import math +import os import tempfile from pathlib import Path @@ -250,6 +251,7 @@ def test_iterate_graphfcns(): project_name = None, bbox_number = None, model_number = None) + os.environ['SWMMANYWHERE_VERBOSE'] = "false" G = iterate_graphfcns(G, ['assign_id', 'format_osmnx_lanes'], @@ -258,3 +260,19 @@ def test_iterate_graphfcns(): for u, v, d in G.edges(data=True): assert 'id' in d.keys() assert 'width' in d.keys() + +def test_fix_geometries(): + """Test the fix_geometries function.""" + # Create a graph with edge geometry not matching node coordinates + G = load_graph(Path(__file__).parent / 'test_data' / 'graph_topo_derived.json') + + # Test doesn't work if this isn't true + assert G.get_edge_data(107733, 25472373,0)['geometry'].coords[0] != \ + (G.nodes[107733]['x'], G.nodes[107733]['y']) + + # Run the function + G_fixed = gu.fix_geometries(G) + + # Check that the edge geometry now matches the node coordinates + assert G_fixed.get_edge_data(107733, 25472373,0)['geometry'].coords[0] == \ + (G_fixed.nodes[107733]['x'], G_fixed.nodes[107733]['y']) \ No newline at end of file diff --git a/tests/test_logging.py b/tests/test_logging.py index f228ebfc..69695db9 100644 --- a/tests/test_logging.py +++ b/tests/test_logging.py @@ -29,6 +29,7 @@ def test_logger(): assert temp_file.read() != b"" logger.remove() fid.unlink() + os.environ["SWMMANYWHERE_VERBOSE"] = "false" def test_logger_disable(): """Test the disable function.""" @@ -67,4 +68,5 @@ def test_logger_again(): logger.test_logger() assert temp_file.read() != b"" logger.remove() - fid.unlink() \ No newline at end of file + fid.unlink() + os.environ["SWMMANYWHERE_VERBOSE"] = "false" \ No newline at end of file diff --git a/tests/test_metric_utilities.py b/tests/test_metric_utilities.py index eabd19f2..180fc7b3 100644 --- a/tests/test_metric_utilities.py +++ b/tests/test_metric_utilities.py @@ -8,6 +8,7 @@ from swmmanywhere import metric_utilities as mu from swmmanywhere.graph_utilities import load_graph +from swmmanywhere.parameters import MetricEvaluation def assert_close(a: float, b: float, rtol: float = 1e-3) -> None: @@ -279,6 +280,60 @@ def test_outlet_nse_flooding(): real_subs = subs) assert val == 0.0 +def test_design_params(): + """Test the design param related metrics.""" + G = load_graph(Path(__file__).parent / 'test_data' / 'graph_topo_derived.json') + nx.set_edge_attributes(G, 0.15, 'diameter') + subs = get_subs() + + # Mock results (only needed for dominant outlet) + results = pd.DataFrame([{'id' : 4253560, + 'variable' : 'flow', + 'value' : 10, + 'date' : pd.to_datetime('2021-01-01 00:00:00')}, + {'id' : 4253560, + 'variable' : 'flow', + 'value' : 5, + 'date' : pd.to_datetime('2021-01-01 00:00:05')}, + ]) + + # Target results + design_results = {'outlet_kstest_diameters' : 0.0625, + 'outlet_pbias_length' : -0.15088965, + 'outlet_pbias_nmanholes' : -0.05, + 'outlet_pbias_npipes' : -0.15789473} + + # Iterate for G = G, i.e., perfect results + metrics = mu.iterate_metrics(synthetic_G = G, + synthetic_subs = None, + synthetic_results = None, + real_G = G, + real_subs = subs, + real_results = results, + metric_list = design_results.keys(), + metric_evaluation = MetricEvaluation()) + for metric, val in metrics.items(): + assert metric in design_results + assert np.isclose(val, 0) + + # edit the graph for target results + G_ = G.copy() + G_.remove_node(list(G.nodes)[0]) + G_.edges[list(G_.edges)[0]]['diameter'] = 0.3 + + metrics = mu.iterate_metrics(synthetic_G = G_, + synthetic_subs = None, + synthetic_results = None, + real_G = G, + real_subs = subs, + real_results = results, + metric_list = design_results.keys(), + metric_evaluation = MetricEvaluation()) + + for metric, val in metrics.items(): + assert metric in design_results + assert np.isclose(val, design_results[metric]), metric + def test_netcomp_iterate(): """Test the netcomp metrics and iterate_metrics.""" netcomp_results = {'nc_deltacon0' : 0.00129408, @@ -294,7 +349,8 @@ def test_netcomp_iterate(): real_G = G, real_subs = None, real_results = None, - metric_list = netcomp_results.keys()) + metric_list = netcomp_results.keys(), + metric_evaluation = MetricEvaluation()) for metric, val in metrics.items(): assert metric in netcomp_results assert np.isclose(val, 0) @@ -306,7 +362,8 @@ def test_netcomp_iterate(): real_G = G, real_subs = None, real_results = None, - metric_list = netcomp_results.keys()) + metric_list = netcomp_results.keys(), + metric_evaluation = MetricEvaluation()) for metric, val in metrics.items(): assert metric in netcomp_results assert np.isclose(val, netcomp_results[metric]) @@ -394,3 +451,18 @@ def test_subcatchment_nse_flooding(): real_results = results, real_subs = subs) assert val == 1.0 + + # Test gridded + val = mu.metrics.grid_nse_flooding(synthetic_G = G_, + synthetic_results = results_, + real_G = G, + real_results = results, + real_subs = subs, + metric_evaluation = MetricEvaluation()) + assert val == 1.0 + +def test_create_grid(): + """Test the create_grid function.""" + grid = mu.create_grid((0,0,1,1), 1/3 - 0.001) + assert grid.shape[0] == 16 + assert set(grid.columns) == {'sub_id','geometry'} \ No newline at end of file diff --git a/tests/test_prepare_data.py b/tests/test_prepare_data.py index 3e4ab1e3..293999be 100644 --- a/tests/test_prepare_data.py +++ b/tests/test_prepare_data.py @@ -1,21 +1,31 @@ # -*- coding: utf-8 -*- -"""Created on Tue Oct 18 10:35:51 2022. +"""Test the prepare_data module. -@author: Barney +By default downloads themselves are mocked, but these can be enabled with the +following test command: + +pytest -m downloads """ -# import pytest +import io import tempfile from pathlib import Path +from unittest import mock import geopandas as gpd +import networkx as nx +import osmnx as ox +import pytest import rasterio +import yaml +from geopy.geocoders import Nominatim from swmmanywhere import prepare_data as downloaders # Test get_country -def test_get_uk(): +@pytest.mark.downloads +def test_get_uk_download(): """Check a UK point.""" # Coordinates for London, UK x = -0.1276 @@ -26,7 +36,8 @@ def test_get_uk(): assert result[2] == 'GB' assert result[3] == 'GBR' -def test_get_us(): +@pytest.mark.downloads +def test_get_us_download(): """Check a US point.""" x = -113.43318 y = 33.81869 @@ -36,7 +47,8 @@ def test_get_us(): assert result[2] == 'US' assert result[3] == 'USA' -def test_building_downloader(): +@pytest.mark.downloads +def test_building_downloader_download(): """Check buildings are downloaded.""" # Coordinates for small country (VAT) x = 7.41839 @@ -57,7 +69,8 @@ def test_building_downloader(): # Make sure has some rows assert gdf.shape[0] > 0 -def test_street_downloader(): +@pytest.mark.downloads +def test_street_downloader_download(): """Check streets are downloaded and a specific point in the graph.""" bbox = (-0.17929,51.49638, -0.17383,51.49846) G = downloaders.download_street(bbox) @@ -65,7 +78,8 @@ def test_street_downloader(): # Not sure if they they are likely to change the osmid assert 26389449 in G.nodes -def test_river_downloader(): +@pytest.mark.downloads +def test_river_downloader_download(): """Check rivers are downloaded and a specific point in the graph.""" bbox = (0.0402, 51.55759, 0.09825591114207548, 51.6205) G = downloaders.download_river(bbox) @@ -73,7 +87,8 @@ def test_river_downloader(): # Not sure if they they are likely to change the osmid assert 21473922 in G.nodes -def test_elevation_downloader(): +@pytest.mark.downloads +def test_elevation_downloader_download(): """Check elevation downloads, writes, contains data, and a known elevation.""" # Please do not reuse api_key test_api_key = 'b206e65629ac0e53d599e43438560d28' @@ -101,4 +116,105 @@ def test_elevation_downloader(): # Test some property of data (not sure if they may change this # data) - assert data.max().max() > 25, "Elevation data should be higher." \ No newline at end of file + assert data.max().max() > 25, "Elevation data should be higher." + +@pytest.fixture +def setup_mocks(): + """Set up get_country mock for the tests.""" + # Mock for geolocator.reverse + mock_location = mock.Mock() + mock_location.raw = {'address': {'country_code': 'gb'}} + + # Mock Nominatim + nominatim_patch = mock.patch.object(Nominatim, + 'reverse', + return_value=mock_location) + # Mock yaml.safe_load + yaml_patch = mock.patch.object(yaml, 'safe_load', return_value={'GB': 'GBR'}) + + with nominatim_patch, yaml_patch: + yield + +def test_get_uk(setup_mocks): + """Check a UK point.""" + # Coordinates for London, UK + x = -0.1276 + y = 51.5074 + + # Call get_country + result = downloaders.get_country(x, y) + + assert result[2] == 'GB' + assert result[3] == 'GBR' + +def test_building_downloader(setup_mocks): + """Check buildings are downloaded.""" + # Coordinates + x = -0.1276 + y = 51.5074 + + with tempfile.TemporaryDirectory() as temp_dir: + temp_fid = Path(temp_dir) / 'temp.parquet' + mock_response = mock.Mock() + mock_response.status_code = 200 + mock_response.content = b'{"features": []}' + with mock.patch('requests.get', + return_value=mock_response) as mock_get: + # Call your function that uses requests.get + response = downloaders.download_buildings(temp_fid, x, y) + + # Assert that requests.get was called with the right arguments + mock_get.assert_called_once_with('https://data.source.coop/vida/google-microsoft-open-buildings/geoparquet/by_country/country_iso=GBR/GBR.parquet') + + # Check response + assert response == 200 + +def test_street_downloader(): + """Check streets are downloaded and a specific point in the graph.""" + bbox = (-0.17929,51.49638, -0.17383,51.49846) + + mock_graph = nx.MultiDiGraph() + # Mock ox.graph_from_bbox + with mock.patch.object(ox, 'graph_from_bbox', return_value=mock_graph): + # Call download_street + G = downloaders.download_street(bbox) + assert G == mock_graph + +def test_river_downloader(): + """Check rivers are downloaded and a specific point in the graph.""" + bbox = (0.0402, 51.55759, 0.09825591114207548, 51.6205) + + mock_graph = nx.MultiDiGraph() + # Mock ox.graph_from_bbox + with mock.patch.object(ox, 'graph_from_bbox', return_value=mock_graph): + # Call download_street + G = downloaders.download_river(bbox) + assert G == mock_graph + +def test_elevation_downloader(): + """Check elevation downloads, writes, contains data, and a known elevation.""" + # Please do not reuse api_key + test_api_key = 'b206e65629ac0e53d599e43438560d28' + + bbox = (-0.17929,51.49638, -0.17383,51.49846) + + with tempfile.TemporaryDirectory() as temp_dir: + temp_fid = Path(temp_dir) / 'temp.tif' + + mock_response = mock.Mock() + mock_response.status_code = 200 + mock_response.raw = io.BytesIO(b'25') + with mock.patch('requests.get', + return_value=mock_response) as mock_get: + # Call your function that uses requests.get + response = downloaders.download_elevation(temp_fid, + bbox, + test_api_key) + # Assert that requests.get was called with the right arguments + assert 'https://portal.opentopography.org/API/globaldem?demtype=NASADEM&south=51.49638&north=51.49846&west=-0.17929&east=-0.17383&outputFormat=GTiff&API_Key' in mock_get.call_args[0][0] # noqa: E501 + + # Check response + assert response == 200 + + # Check response + assert temp_fid.exists(), "Elevation data file not found after download." \ No newline at end of file