diff --git a/swmmanywhere/graph_utilities.py b/swmmanywhere/graph_utilities.py index 315b19cb..0c9c5deb 100644 --- a/swmmanywhere/graph_utilities.py +++ b/swmmanywhere/graph_utilities.py @@ -7,6 +7,8 @@ from typing import Callable import networkx as nx +import osmnx as ox +from shapely import geometry as sgeom from swmmanywhere import parameters @@ -121,3 +123,126 @@ def double_directed(G: nx.Graph, **kwargs): reverse_data['id'] = '{0}.reversed'.format(data['id']) G_new.add_edge(v, u, **reverse_data) return G_new + +@register_graphfcn +def split_long_edges(graph: nx.Graph, + subcatchment_derivation: parameters.SubcatchmentDerivation, + **kwargs): + """Split long edges into shorter edges. + + This function splits long edges into shorter edges. The edges are split + into segments of length 'max_street_length'. The first and last segment + are connected to the original nodes. Intermediate segments are connected + to newly created nodes. + + Requires a graph with edges that have: + - 'geometry' (shapely LineString) + - 'length' (float) + - 'id' (str) + + Args: + graph (nx.Graph): A graph + subcatchment_derivation (parameters.SubcatchmentDerivation): A + SubcatchmentDerivation parameter object + **kwargs: Additional keyword arguments are ignored. + + Returns: + graph (nx.Graph): A graph + """ + max_length = subcatchment_derivation.max_street_length + graph = graph.copy() + edges_to_remove = [] + edges_to_add = [] + nodes_to_add = [] + maxlabel = max(graph.nodes) + 1 + ll = 0 + + def create_new_edge_data(line, data, id_): + new_line = sgeom.LineString(line) + new_data = data.copy() + new_data['id'] = id_ + new_data['length'] = new_line.length + new_data['geometry'] = sgeom.LineString([(x[0], x[1]) + for x in new_line.coords]) + return new_data + + for u, v, data in graph.edges(data=True): + line = data['geometry'] + length = data['length'] + if ((u, v) not in edges_to_remove) & ((v, u) not in edges_to_remove): + if length > max_length: + new_points = [sgeom.Point(x) + for x in ox.utils_geo.interpolate_points(line, + max_length)] + if len(new_points) > 2: + for ix, (start, end) in enumerate(zip(new_points[:-1], + new_points[1:])): + new_data = create_new_edge_data([start, + end], + data, + '{0}.{1}'.format( + data['id'],ix)) + if (v,u) in graph.edges: + # Create reversed data + data_r = graph.get_edge_data(v, u).copy()[0] + id_ = '{0}.{1}'.format(data_r['id'],ix) + new_data_r = create_new_edge_data([end, start], + data_r.copy(), + id_) + if ix == 0: + # Create start to first intermediate + edges_to_add.append((u, maxlabel + ll, new_data.copy())) + nodes_to_add.append((maxlabel + ll, + {'x': + new_data['geometry'].coords[-1][0], + 'y': + new_data['geometry'].coords[-1][1]})) + + if (v, u) in graph.edges: + # Create first intermediate to start + edges_to_add.append((maxlabel + ll, + u, + new_data_r.copy())) + + ll += 1 + elif ix == len(new_points) - 2: + # Create last intermediate to end + edges_to_add.append((maxlabel + ll - 1, + v, + new_data.copy())) + if (v, u) in graph.edges: + # Create end to last intermediate + edges_to_add.append((v, + maxlabel + ll - 1, + new_data_r.copy())) + else: + nodes_to_add.append((maxlabel + ll, + {'x': + new_data['geometry'].coords[-1][0], + 'y': + new_data['geometry'].coords[-1][1]})) + # Create N-1 intermediate to N intermediate + edges_to_add.append((maxlabel + ll - 1, + maxlabel + ll, + new_data.copy())) + if (v, u) in graph.edges: + # Create N intermediate to N-1 intermediate + edges_to_add.append((maxlabel + ll, + maxlabel + ll - 1, + new_data_r.copy())) + ll += 1 + edges_to_remove.append((u, v)) + if (v, u) in graph.edges: + edges_to_remove.append((v, u)) + + for u, v in edges_to_remove: + if (u, v) in graph.edges: + graph.remove_edge(u, v) + + for node in nodes_to_add: + graph.add_node(node[0], **node[1]) + + for edge in edges_to_add: + graph.add_edge(edge[0], edge[1], **edge[2]) + + return graph diff --git a/tests/test_graph_utilities.py b/tests/test_graph_utilities.py index b3a83f5a..9dcaf987 100644 --- a/tests/test_graph_utilities.py +++ b/tests/test_graph_utilities.py @@ -5,6 +5,7 @@ """ +from swmmanywhere import geospatial_utilities as go from swmmanywhere import graph_utilities as gu from swmmanywhere import parameters from swmmanywhere.prepare_data import download_street @@ -41,4 +42,20 @@ def test_format_osmnx_lanes(): assert 'lanes' in data.keys() assert isinstance(data['lanes'], float) assert 'width' in data.keys() - assert isinstance(data['width'], float) \ No newline at end of file + assert isinstance(data['width'], float) + +def test_split_long_edges(): + """Test the split_long_edges function.""" + G = generate_street_graph() + G = gu.assign_id(G) + id_ = list(G.nodes)[0] + G = go.reproject_graph(G, + 'EPSG:4326', + go.get_utm_epsg(G.nodes[id_]['x'], + G.nodes[id_]['y']) + ) + max_length = 20 + params = parameters.SubcatchmentDerivation(max_street_length = max_length) + G = gu.split_long_edges(G, params) + for u, v, data in G.edges(data=True): + assert data['length'] <= (max_length * 2) \ No newline at end of file