diff --git a/swmmanywhere/geospatial_utilities.py b/swmmanywhere/geospatial_utilities.py index 8d7f08a8..47af36d1 100644 --- a/swmmanywhere/geospatial_utilities.py +++ b/swmmanywhere/geospatial_utilities.py @@ -417,7 +417,7 @@ def delineate_catchment(grid: pysheds.sgrid.sGrid, # Snap the node to the nearest grid cell x, y = data['x'], data['y'] grid_ = deepcopy(grid) - x_snap, y_snap = grid_.snap_to_mask(flow_acc > 5, (x, y)) + x_snap, y_snap = grid_.snap_to_mask(flow_acc >= 0, (x, y)) # Delineate the catchment catch = grid_.catchment(x=x_snap, diff --git a/swmmanywhere/graph_utilities.py b/swmmanywhere/graph_utilities.py index f069100f..a2262d71 100644 --- a/swmmanywhere/graph_utilities.py +++ b/swmmanywhere/graph_utilities.py @@ -208,11 +208,15 @@ def __call__(self, G (nx.Graph): The same graph with an ID assigned to each edge """ edge_ids: set[str] = set() - for u, v, data in G.edges(data=True): + edges_to_remove = [] + for u, v, key, data in G.edges(data=True, keys = True): data['id'] = f'{u}-{v}' if data['id'] in edge_ids: logger.warning(f"Duplicate edge ID: {data['id']}") + edges_to_remove.append((u, v, key)) edge_ids.add(data['id']) + for u, v, key in edges_to_remove: + G.remove_edge(u, v, key) return G @register_graphfcn @@ -409,6 +413,39 @@ def create_new_edge_data(line, data, id_): return graph +@register_graphfcn +class fix_geometries(BaseGraphFunction, + required_edge_attributes = ['geometry'], + required_node_attributes = ['x', 'y']): + """fix_geometries class.""" + def __call__(self, G: nx.Graph, **kwargs) -> nx.Graph: + """Fix the geometries of the edges. + + This function fixes the geometries of the edges. The geometries are + recalculated from the node coordinates. + + Args: + G (nx.Graph): A graph + **kwargs: Additional keyword arguments are ignored. + + Returns: + G (nx.Graph): A graph + """ + G = G.copy() + for u, v, data in G.edges(data=True): + start_point_node = (G.nodes[u]['x'], G.nodes[u]['y']) + start_point_edge = data['geometry'].coords[0] + end_point_node = (G.nodes[v]['x'], G.nodes[v]['y']) + end_point_edge = data['geometry'].coords[-1] + if (start_point_edge == end_point_node) & \ + (end_point_edge == start_point_node): + data['geometry'] = data['geometry'].reverse() + elif (start_point_edge != start_point_node) | \ + (end_point_edge != end_point_node): + data['geometry'] = shapely.LineString([start_point_node, + end_point_node]) + return G + @register_graphfcn class calculate_contributing_area(BaseGraphFunction, required_edge_attributes = ['id', 'geometry', 'width'], diff --git a/tests/test_data/demo_config.yml b/tests/test_data/demo_config.yml index 16b7e5f4..10a5b698 100644 --- a/tests/test_data/demo_config.yml +++ b/tests/test_data/demo_config.yml @@ -16,6 +16,7 @@ graphfcn_list: - assign_id - format_osmnx_lanes - double_directed + - fix_geometries - split_long_edges - calculate_contributing_area - set_elevation diff --git a/tests/test_graph_utilities.py b/tests/test_graph_utilities.py index 45251599..ccaa3a72 100644 --- a/tests/test_graph_utilities.py +++ b/tests/test_graph_utilities.py @@ -260,3 +260,19 @@ def test_iterate_graphfcns(): for u, v, d in G.edges(data=True): assert 'id' in d.keys() assert 'width' in d.keys() + +def test_fix_geometries(): + """Test the fix_geometries function.""" + # Create a graph with edge geometry not matching node coordinates + G = load_graph(Path(__file__).parent / 'test_data' / 'graph_topo_derived.json') + + # Test doesn't work if this isn't true + assert G.get_edge_data(107733, 25472373,0)['geometry'].coords[0] != \ + (G.nodes[107733]['x'], G.nodes[107733]['y']) + + # Run the function + G_fixed = gu.fix_geometries(G) + + # Check that the edge geometry now matches the node coordinates + assert G_fixed.get_edge_data(107733, 25472373,0)['geometry'].coords[0] == \ + (G_fixed.nodes[107733]['x'], G_fixed.nodes[107733]['y']) \ No newline at end of file