diff --git a/swmmanywhere/geospatial_operations.py b/swmmanywhere/geospatial_operations.py index d8c1bb52..0068afea 100644 --- a/swmmanywhere/geospatial_operations.py +++ b/swmmanywhere/geospatial_operations.py @@ -5,12 +5,14 @@ """ from typing import Optional +import networkx as nx import numpy as np import pandas as pd import pyproj import rasterio as rst from rasterio.warp import Resampling, calculate_default_transform, reproject from scipy.interpolate import RegularGridInterpolator +from shapely.geometry import LineString def get_utm_epsg(lon: float, lat: float) -> str: @@ -185,4 +187,43 @@ def f(row): df['x'], df['y'] = zip(*df.apply(f,axis=1)) - return df \ No newline at end of file + return df + +def reproject_graph(G: nx.Graph, + source_crs: str, + target_crs: str) -> nx.Graph: + """Reproject the coordinates in a graph. + + Args: + G (nx.Graph): Graph with nodes containing 'x' and 'y' properties. + source_crs (str): Source CRS in EPSG format (e.g., EPSG:4326). + target_crs (str): Target CRS in EPSG format (e.g., EPSG:32630). + + Returns: + nx.Graph: Graph with nodes containing 'x' and 'y' properties. + """ + # Create a PyProj transformer for CRS conversion + transformer = get_transformer(source_crs, target_crs) + + # Create a new graph with the converted nodes and edges + G_new = G.copy() + + # Convert and add nodes with 'x', 'y' properties + for node, data in G_new.nodes(data=True): + x, y = transformer.transform(data['x'], data['y']) + data['x'] = x + data['y'] = y + + # Convert and add edges with 'geometry' property + for u, v, data in G_new.edges(data=True): + if 'geometry' in data.keys(): + geometry = data['geometry'] + new_geometry = LineString(transformer.transform(x, y) + for x, y in geometry.coords) + else: + new_geometry = LineString([[G_new.nodes[u]['x'], + G_new.nodes[u]['y']], + [G_new.nodes[v]['x'], + G_new.nodes[v]['y']]]) + data['geometry'] = new_geometry + return G_new \ No newline at end of file diff --git a/tests/test_geospatial.py b/tests/test_geospatial.py index d30e2999..c2072057 100644 --- a/tests/test_geospatial.py +++ b/tests/test_geospatial.py @@ -8,10 +8,12 @@ import os from unittest.mock import MagicMock, patch +import networkx as nx import numpy as np import pandas as pd import rasterio as rst from scipy.interpolate import RegularGridInterpolator +from shapely.geometry import LineString from swmmanywhere import geospatial_operations as go @@ -135,8 +137,37 @@ def test_reproject_df(): target_crs = 'EPSG:32630' # Call the function - transformed_df = go.transform_df(df, source_crs, target_crs) + transformed_df = go.reproject_df(df, source_crs, target_crs) # Check the output assert transformed_df['x'].values[0] == 699330.1106898375 assert transformed_df['y'].values[0] == 5710164.30300683 + +def test_reproject_graph(): + """Test the reproject_graph function.""" + # Create a mock graph + G = nx.Graph() + G.add_node(1, x=0, y=0) + G.add_node(2, x=1, y=1) + G.add_edge(1, 2) + G.add_node(3, x=1, y=2) + G.add_edge(2, 3, geometry=LineString([(1, 1), (1, 2)])) + + # Define the input parameters + source_crs = 'EPSG:4326' + target_crs = 'EPSG:32630' + + # Call the function + G_new = go.reproject_graph(G, source_crs, target_crs) + + # Test node coordinates + assert G_new.nodes[1]['x'] == 833978.5569194595 + assert G_new.nodes[1]['y'] == 0 + assert G_new.nodes[2]['x'] == 945396.6839773951 + assert G_new.nodes[2]['y'] == 110801.83254625657 + assert G_new.nodes[3]['x'] == 945193.8596723974 + assert G_new.nodes[3]['y'] == 221604.0105092727 + + # Test edge geometry + assert list(G_new[1][2]['geometry'].coords)[0][0] == 833978.5569194595 + assert list(G_new[2][3]['geometry'].coords)[0][0] == 945396.6839773951 \ No newline at end of file