generated from ImperialCollegeLondon/pip-tools-template
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Initial attempt at graph function examples
-Create graph_utilities.py -Create tests -Create parameters
- Loading branch information
Dobson
committed
Jan 26, 2024
1 parent
6b6b629
commit 3ac4d52
Showing
3 changed files
with
195 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
# -*- coding: utf-8 -*- | ||
"""Created on 2024-01-26. | ||
@author: Barney | ||
""" | ||
|
||
from typing import Callable | ||
|
||
import networkx as nx | ||
|
||
from swmmanywhere import parameters | ||
|
||
graphfcns = {} | ||
|
||
def register_graphfcn(func: Callable[..., | ||
nx.Graph]) -> Callable[..., | ||
nx.Graph]: | ||
"""Register a graph function. | ||
Args: | ||
func (Callable): A function that takes a graph and other parameters | ||
Returns: | ||
func (Callable): Returns the same function | ||
""" | ||
# Add the function to the registry | ||
graphfcns[func.__name__] = func | ||
return func | ||
|
||
def get_osmid_id(data): | ||
"""Get the ID of an edge.""" | ||
id_ = data.get('osmid', data.get('id')) | ||
if isinstance(id_, list): | ||
id_ = id_[0] | ||
return id_ | ||
|
||
@register_graphfcn | ||
def assign_id(G: nx.Graph, | ||
**kwargs): | ||
"""Assign an ID to each edge. | ||
This function takes a graph and assigns an ID to each edge. The ID is | ||
assigned to the 'id' attribute of each edge. Needed because some edges | ||
have 'osmid', some have a list of 'osmid', others have 'id'. | ||
Requires a graph with edges that have: | ||
- 'osmid' or 'id' | ||
Adds the attributes: | ||
- 'id' | ||
Args: | ||
G (nx.Graph): A graph | ||
**kwargs: Additional keyword arguments are ignored. | ||
Returns: | ||
G (nx.Graph): The same graph with an ID assigned to each edge | ||
""" | ||
for u, v, data in G.edges(data=True): | ||
data['id'] = get_osmid_id(data) | ||
return G | ||
|
||
@register_graphfcn | ||
def format_osmnx_lanes(G: nx.Graph, | ||
subcatchment_derivation: parameters.SubcatchmentDerivation, | ||
**kwargs): | ||
"""Format the lanes attribute of each edge and calculates width. | ||
Requires a graph with edges that have: | ||
- 'lanes' (in osmnx format, i.e., empty for single lane, an int for a | ||
number of lanes or a list if the edge has multiple carriageways) | ||
Adds the attributes: | ||
- 'lanes' (float) | ||
- 'width' (float) | ||
Args: | ||
G (nx.Graph): A graph | ||
subcatchment_derivation (parameters.SubcatchmentDerivation): A | ||
SubcatchmentDerivation parameter object | ||
**kwargs: Additional keyword arguments are ignored. | ||
Returns: | ||
G (nx.Graph): A graph | ||
""" | ||
G = G.copy() | ||
for u, v, data in G.edges(data=True): | ||
lanes = data.get('lanes',1) | ||
if isinstance(lanes, list): | ||
lanes = sum([float(x) for x in lanes]) | ||
else: | ||
lanes = float(lanes) | ||
data['lanes'] = lanes | ||
data['width'] = lanes * subcatchment_derivation.lane_width | ||
return G | ||
|
||
@register_graphfcn | ||
def double_directed(G: nx.Graph, **kwargs): | ||
"""Create a 'double directed graph'. | ||
This function duplicates a graph and adds reverse edges to the new graph. | ||
These new edges share the same data as the 'forward' edges but have a new | ||
'id'. An undirected graph is not suitable because it removes data of one of | ||
the edges if there are edges in both directions between two nodes | ||
(necessary to preserve, e.g., consistent 'width'). | ||
Requires a graph with edges that have: | ||
- 'id' | ||
Args: | ||
G (nx.Graph): A graph | ||
**kwargs: Additional keyword arguments are ignored. | ||
Returns: | ||
G (nx.Graph): A graph | ||
""" | ||
G_new = G.copy() | ||
for u, v, data in G.edges(data=True): | ||
if (v, u) not in G.edges: | ||
reverse_data = data.copy() | ||
reverse_data['id'] = '{0}.reversed'.format(data['id']) | ||
G_new.add_edge(v, u, **reverse_data) | ||
return G_new |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
# -*- coding: utf-8 -*- | ||
"""Created on 2024-01-26. | ||
@author: Barney | ||
""" | ||
|
||
from pydantic import BaseModel, Field | ||
|
||
|
||
class SubcatchmentDerivation(BaseModel): | ||
"""Parameters for subcatchment derivation.""" | ||
lane_width: float = Field(default = 3.5, | ||
ge = 2.0, | ||
le = 5.0, | ||
unit = "m", | ||
description = "Width of a road lane.") | ||
|
||
carve_depth: float = Field(default = 2.0, | ||
ge = 1.0, | ||
le = 3.0, | ||
unit = "m", | ||
description = "Depth of road/river carve for flow accumulation.") | ||
|
||
max_street_length: float = Field(default = 60.0, | ||
ge = 20.0, | ||
le = 100.0, | ||
unit = "m", | ||
description = "Distance to split streets into segments.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# -*- coding: utf-8 -*- | ||
"""Created on 2024-01-26. | ||
@author: Barney | ||
""" | ||
|
||
|
||
from swmmanywhere import graph_utilities as gu | ||
from swmmanywhere import parameters | ||
from swmmanywhere.prepare_data import download_street | ||
|
||
|
||
def generate_street_graph(): | ||
"""Generate a street graph.""" | ||
bbox = (-0.11643,51.50309,-0.11169,51.50549) | ||
G = download_street(bbox) | ||
return G | ||
|
||
def test_assign_id(): | ||
"""Test the assign_id function.""" | ||
G = generate_street_graph() | ||
G = gu.assign_id(G) | ||
for u, v, data in G.edges(data=True): | ||
assert 'id' in data.keys() | ||
assert isinstance(data['id'], int) | ||
|
||
def test_double_directed(): | ||
"""Test the double_directed function.""" | ||
G = generate_street_graph() | ||
G = gu.assign_id(G) | ||
G = gu.double_directed(G) | ||
for u, v in G.edges(): | ||
assert (v,u) in G.edges | ||
|
||
def test_format_osmnx_lanes(): | ||
"""Test the format_osmnx_lanes function.""" | ||
G = generate_street_graph() | ||
params = parameters.SubcatchmentDerivation() | ||
G = gu.format_osmnx_lanes(G, params) | ||
for u, v, data in G.edges(data=True): | ||
assert 'lanes' in data.keys() | ||
assert isinstance(data['lanes'], float) | ||
assert 'width' in data.keys() | ||
assert isinstance(data['width'], float) |