diff --git a/swmmanywhere/geospatial_utilities.py b/swmmanywhere/geospatial_utilities.py index ce25a1cb..8e0557f3 100644 --- a/swmmanywhere/geospatial_utilities.py +++ b/swmmanywhere/geospatial_utilities.py @@ -134,16 +134,16 @@ def interpolate_points_on_raster(x: list[float], data[data == src.nodata] = None # Get the raster's coordinates - x = np.linspace(src.bounds.left, src.bounds.right, src.width) - y = np.linspace(src.bounds.bottom, src.bounds.top, src.height) + x_grid = np.linspace(src.bounds.left, src.bounds.right, src.width) + y_grid = np.linspace(src.bounds.bottom, src.bounds.top, src.height) # Define grid - xx, yy = np.meshgrid(x, y) + xx, yy = np.meshgrid(x_grid, y_grid) grid = np.vstack([xx.ravel(), yy.ravel()]).T values = data.ravel() # Define interpolator - interp = RegularGridInterpolator((y,x), + interp = RegularGridInterpolator((y_grid,x_grid), np.flipud(data), method='linear', bounds_error=False, diff --git a/swmmanywhere/graph_utilities.py b/swmmanywhere/graph_utilities.py index b42a1b15..83721cd1 100644 --- a/swmmanywhere/graph_utilities.py +++ b/swmmanywhere/graph_utilities.py @@ -93,7 +93,7 @@ def assign_id(G: nx.Graph, Requires a graph with edges that have: - 'osmid' or 'id' - Adds the attributes: + Adds the edge attributes: - 'id' Args: @@ -117,7 +117,7 @@ def format_osmnx_lanes(G: nx.Graph, - 'lanes' (in osmnx format, i.e., empty for single lane, an int for a number of lanes or a list if the edge has multiple carriageways) - Adds the attributes: + Adds the edge attributes: - 'lanes' (float) - 'width' (float) @@ -311,7 +311,7 @@ def calculate_contributing_area(G: nx.Graph, - 'id' (str) - 'width' (float) - Adds the attributes: + Adds the edge attributes: - 'contributing_area' (float) Args: @@ -356,3 +356,35 @@ def calculate_contributing_area(G: nx.Graph, d['contributing_area'] = 0.0 return G +def set_elevation(G: nx.Graph, + addresses: parameters.Addresses, + **kwargs) -> nx.Graph: + """Set the elevation for each node. + + This function sets the elevation for each node. The elevation is + calculated from the elevation data. + + Requires a graph with nodes that have: + - 'x' (float) + - 'y' (float) + + Adds the node attributes: + - 'elevation' (float) + + Args: + G (nx.Graph): A graph + addresses (parameters.Addresses): An Addresses parameter object + **kwargs: Additional keyword arguments are ignored. + + Returns: + G (nx.Graph): A graph + """ + G = G.copy() + x = [d['x'] for x, d in G.nodes(data=True)] + y = [d['y'] for x, d in G.nodes(data=True)] + elevations = go.interpolate_points_on_raster(x, + y, + addresses.elevation) + elevations_dict = {id_: elev for id_, elev in zip(G.nodes, elevations)} + nx.set_node_attributes(G, elevations_dict, 'elevation') + return G \ No newline at end of file diff --git a/tests/test_geospatial_utilities.py b/tests/test_geospatial_utilities.py index 72956252..082dd2fc 100644 --- a/tests/test_geospatial_utilities.py +++ b/tests/test_geospatial_utilities.py @@ -63,7 +63,7 @@ def test_interpolate_points_on_raster(mock_rst_open): mock_src.height = 2 mock_src.nodata = None mock_rst_open.return_value.__enter__.return_value = mock_src - + # Define the x and y coordinates x = [0.25, 0.75] y = [0.25, 0.75] @@ -73,8 +73,8 @@ def test_interpolate_points_on_raster(mock_rst_open): y, Path('fake_path')) - # [3,2] feels unintuitive but it's because rasters measure from the top - assert result == [3.0, 2.0] + # [2.75, 2.25] feels unintuitive but it's because rasters measure from the top + assert result == [2.75, 2.25] def test_get_utm(): """Test the get_utm_epsg function.""" diff --git a/tests/test_graph_utilities.py b/tests/test_graph_utilities.py index b916c6b9..f9e1167b 100644 --- a/tests/test_graph_utilities.py +++ b/tests/test_graph_utilities.py @@ -84,4 +84,21 @@ def test_derive_subcatchments(): G = gu.calculate_contributing_area(G, params, addresses) for u, v, data in G.edges(data=True): assert 'contributing_area' in data.keys() - assert isinstance(data['contributing_area'], float) \ No newline at end of file + assert isinstance(data['contributing_area'], float) + +def test_set_elevation(): + """Test the set_elevation function.""" + G, _ = load_street_network() + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = Path(temp_dir) + addresses = parameters.Addresses(base_dir = temp_path, + project_name = 'test', + bbox_number = 1, + extension = 'json', + model_number = 1) + addresses.elevation = Path(__file__).parent / 'test_data' / 'elevation.tif' + G = gu.set_elevation(G, addresses) + for id_, data in G.nodes(data=True): + assert 'elevation' in data.keys() + assert isinstance(data['elevation'], float) + assert data['elevation'] > 0 \ No newline at end of file