Skip to content

Commit

Permalink
Merge branch 'graphfcns' into preprocessing
Browse files Browse the repository at this point in the history
  • Loading branch information
barneydobson committed Feb 6, 2024
2 parents f261a49 + 9207466 commit c8ce4c3
Show file tree
Hide file tree
Showing 7 changed files with 49 additions and 38 deletions.
6 changes: 3 additions & 3 deletions swmmanywhere/geospatial_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
def get_utm_epsg(x: float,
y: float,
crs: str | int | pyproj.CRS = 'EPSG:4326',
datum_name: str = "WGS 84"):
datum_name: str = "WGS 84") -> str:
"""Get the UTM CRS code for a given coordinate.
Note, this function is taken from GeoPandas and modified to use
Expand Down Expand Up @@ -154,7 +154,7 @@ def interpolate_points_on_raster(x: list[float],

def reproject_raster(target_crs: str,
fid: Path,
new_fid: Optional[Path] = None):
new_fid: Optional[Path] = None) -> None:
"""Reproject a raster to a new CRS.
Args:
Expand Down Expand Up @@ -302,7 +302,7 @@ def nearest_node_buffer(points1: dict[str, sgeom.Point],
def burn_shape_in_raster(geoms: list[sgeom.LineString],
depth: float,
raster_fid: Path,
new_raster_fid: Path):
new_raster_fid: Path) -> None:
"""Burn a depth into a raster along a list of shapely geometries.
Args:
Expand Down
40 changes: 21 additions & 19 deletions swmmanywhere/graph_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import osmnx as ox
import pandas as pd
from shapely import geometry as sgeom
from shapely import wkt
from tqdm import tqdm

from swmmanywhere import geospatial_utilities as go
Expand All @@ -40,12 +41,12 @@ def load_graph(fid: Path) -> nx.Graph:
for u, v, data in G.edges(data=True):
if 'geometry' in data:
geometry_coords = data['geometry']
line_string = sgeom.LineString(geometry_coords)
line_string = sgeom.LineString(wkt.loads(geometry_coords))
data['geometry'] = line_string
return G

def save_graph(G: nx.Graph,
fid: Path):
fid: Path) -> None:
"""Save a graph to a file.
Args:
Expand All @@ -55,7 +56,7 @@ def save_graph(G: nx.Graph,
json_data = nx.node_link_data(G)
def serialize_line_string(obj):
if isinstance(obj, sgeom.LineString):
return list(obj.coords)
return obj.wkt
else:
return obj
with open(fid, 'w') as file:
Expand Down Expand Up @@ -93,7 +94,7 @@ def __call__(self,

def validate_requirements(self,
edge_attributes: set,
node_attributes: set):
node_attributes: set) -> None:
"""Validate that the graph has the required attributes."""
for attribute in self.required_edge_attributes:
assert attribute in edge_attributes, "{0} not in attributes".format(
Expand All @@ -106,7 +107,7 @@ def validate_requirements(self,

def add_graphfcn(self,
edge_attributes: set,
node_attributes: set):
node_attributes: set) -> tuple[set, set]:
"""Add the attributes that the graph function adds."""
self.validate_requirements(edge_attributes, node_attributes)
edge_attributes = edge_attributes.union(self.adds_edge_attributes)
Expand All @@ -131,7 +132,7 @@ def register_graphfcn(cls) -> Callable:
setattr(graphfcns, cls.__name__, cls())
return cls

def get_osmid_id(data):
def get_osmid_id(data: dict) -> Hashable:
"""Get the ID of an edge."""
id_ = data.get('osmid', data.get('id'))
if isinstance(id_, list):
Expand All @@ -149,7 +150,7 @@ def __init__(self):

def __call__(self,
G: nx.Graph,
**kwargs):
**kwargs) -> nx.Graph:
"""Assign an ID to each edge.
This function takes a graph and assigns an ID to each edge. The ID is
Expand Down Expand Up @@ -183,7 +184,7 @@ def __init__(self):
def __call__(self,
G: nx.Graph,
subcatchment_derivation: parameters.SubcatchmentDerivation,
**kwargs):
**kwargs) -> nx.Graph:
"""Format the lanes attribute of each edge and calculates width.
Args:
Expand Down Expand Up @@ -212,7 +213,7 @@ def __init__(self):
"""Initialize the class."""
super().__init__()
self.required_edge_attributes = ['id']
def __call__(self, G: nx.Graph, **kwargs):
def __call__(self, G: nx.Graph, **kwargs) -> nx.Graph:
"""Create a 'double directed graph'.
This function duplicates a graph and adds reverse edges to the new graph.
Expand Down Expand Up @@ -251,7 +252,7 @@ def __init__(self):
def __call__(self,
G: nx.Graph,
subcatchment_derivation: parameters.SubcatchmentDerivation,
**kwargs):
**kwargs) -> nx.Graph:
"""Split long edges into shorter edges.
This function splits long edges into shorter edges. The edges are split
Expand Down Expand Up @@ -384,8 +385,8 @@ def __init__(self):

def __call__(self, G: nx.Graph,
subcatchment_derivation: parameters.SubcatchmentDerivation,
addresses: parameters.Addresses,
**kwargs):
addresses: parameters.FilePaths,
**kwargs) -> nx.Graph:
"""Calculate the contributing area for each edge.
This function calculates the contributing area for each edge. The
Expand All @@ -398,7 +399,7 @@ def __call__(self, G: nx.Graph,
G (nx.Graph): A graph
subcatchment_derivation (parameters.SubcatchmentDerivation): A
SubcatchmentDerivation parameter object
addresses (parameters.Addresses): An Addresses parameter object
addresses (parameters.FilePaths): An FilePaths parameter object
**kwargs: Additional keyword arguments are ignored.
Returns:
Expand Down Expand Up @@ -450,7 +451,7 @@ def __init__(self):
self.adds_node_attributes = ['elevation']

def __call__(self, G: nx.Graph,
addresses: parameters.Addresses,
addresses: parameters.FilePaths,
**kwargs) -> nx.Graph:
"""Set the elevation for each node.
Expand All @@ -459,7 +460,7 @@ def __call__(self, G: nx.Graph,
Args:
G (nx.Graph): A graph
addresses (parameters.Addresses): An Addresses parameter object
addresses (parameters.FilePaths): An FilePaths parameter object
**kwargs: Additional keyword arguments are ignored.
Returns:
Expand Down Expand Up @@ -799,7 +800,7 @@ def design_pipe(ds_elevation: float,
edge_length: float,
pipe_design: parameters.HydraulicDesign,
Q: float
):
) -> nx.Graph:
"""Design a pipe.
This function designs a pipe by iterating over a range of diameters and
Expand Down Expand Up @@ -897,12 +898,13 @@ def process_successors(G: nx.Graph,
chamber_floor: dict[Hashable, float],
edge_diams: dict[tuple[Hashable,Hashable,int], float],
pipe_design: parameters.HydraulicDesign
):
) -> None:
"""Process the successors of a node.
This function processes the successors of a node. It designs a pipe to the
downstream node and sets the diameter and downstream invert level of the
pipe. It also sets the downstream invert level of the downstream node.
pipe. It also sets the downstream invert level of the downstream node. It
returns None but modifies the edge_diams and chamber_floor dictionaries.
Args:
G (nx.Graph): A graph
Expand Down Expand Up @@ -955,7 +957,7 @@ def __call__(self,
G: nx.Graph,
pipe_design: parameters.HydraulicDesign,
**kwargs
):
)->nx.Graph:
"""Pipe by pipe hydraulic design.
Starting from the most upstream node, design a pipe to the downstream node
Expand Down
6 changes: 3 additions & 3 deletions swmmanywhere/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,10 +175,10 @@ class HydraulicDesign(BaseModel):
description = "Depth of design storm in pipe by pipe method",
unit = "m")

class Addresses:
"""Parameters for address lookup.
class FilePaths:
"""Parameters for file path lookup.
TODO: this doesn't validate addresses to allow for un-initialised data
TODO: this doesn't validate file paths to allow for un-initialised data
(e.g., subcatchments are created by a graph and so cannot be validated).
"""

Expand Down
2 changes: 1 addition & 1 deletion tests/test_data/graph_topo_derived.json

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion tests/test_data/street_graph.json

Large diffs are not rendered by default.

12 changes: 4 additions & 8 deletions tests/test_geospatial_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,8 @@ def test_reproject_raster():
assert src.crs.to_string() == target_crs
finally:
# Regardless of test outcome, delete the temp file
if fid.exists():
fid.unlink()
if new_fid.exists():
new_fid.unlink()
fid.unlink(missing_ok=True)
new_fid.unlink(missing_ok=True)


def almost_equal(a, b, tol=1e-6):
Expand Down Expand Up @@ -216,10 +214,8 @@ def test_burn_shape_in_raster():
assert (data != data_).any()
finally:
# Regardless of test outcome, delete the temp file
if raster_fid.exists():
raster_fid.unlink()
if new_raster_fid.exists():
new_raster_fid.unlink()
raster_fid.unlink(missing_ok=True)
new_raster_fid.unlink(missing_ok=True)

def test_derive_subcatchments():
"""Test the derive_subcatchments function."""
Expand Down
19 changes: 16 additions & 3 deletions tests/test_graph_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@
from pathlib import Path

import geopandas as gpd
import networkx as nx
from shapely import geometry as sgeom

from swmmanywhere import parameters
from swmmanywhere.graph_utilities import graphfcns as gu
from swmmanywhere.graph_utilities import load_graph
from swmmanywhere.graph_utilities import load_graph, save_graph


def load_street_network():
Expand All @@ -21,6 +22,18 @@ def load_street_network():
G = load_graph(Path(__file__).parent / 'test_data' / 'street_graph.json')
return G, bbox

def test_save_load():
"""Test the save_graph and load_graph functions."""
# Load a street network
G,_ = load_street_network()
with tempfile.TemporaryDirectory() as temp_dir:
# Save the graph
save_graph(G, Path(temp_dir) / 'test_graph.json')
# Load the graph
G_new = load_graph(Path(temp_dir) / 'test_graph.json')
# Check if the loaded graph is the same as the original graph
assert nx.is_isomorphic(G, G_new)

def test_assign_id():
"""Test the assign_id function."""
G, _ = load_street_network()
Expand Down Expand Up @@ -60,7 +73,7 @@ def test_derive_subcatchments():
"""Test the derive_subcatchments function."""
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir)
addresses = parameters.Addresses(base_dir = temp_path,
addresses = parameters.FilePaths(base_dir = temp_path,
project_name = 'test',
bbox_number = 1,
extension = 'json',
Expand Down Expand Up @@ -95,7 +108,7 @@ def test_set_elevation_and_slope():
G, _ = load_street_network()
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir)
addresses = parameters.Addresses(base_dir = temp_path,
addresses = parameters.FilePaths(base_dir = temp_path,
project_name = 'test',
bbox_number = 1,
extension = 'json',
Expand Down

0 comments on commit c8ce4c3

Please sign in to comment.