From e3e7e89ae3b454ccedadb0151c6119369723b0eb Mon Sep 17 00:00:00 2001 From: Dobson Date: Fri, 19 Jan 2024 14:41:56 +0000 Subject: [PATCH 01/15] Start of geospatial analysis -Add raster interpolation and tests --- swmmanywhere/geospatial_operations.py | 81 +++++++++++++++++++++++++++ tests/test_geospatial.py | 65 +++++++++++++++++++++ 2 files changed, 146 insertions(+) create mode 100644 swmmanywhere/geospatial_operations.py create mode 100644 tests/test_geospatial.py diff --git a/swmmanywhere/geospatial_operations.py b/swmmanywhere/geospatial_operations.py new file mode 100644 index 00000000..bde8de0e --- /dev/null +++ b/swmmanywhere/geospatial_operations.py @@ -0,0 +1,81 @@ +# -*- coding: utf-8 -*- +"""Created 2024-01-20. + +@author: Barnaby Dobson +""" +import numpy as np +import rasterio as rst +from scipy.interpolate import RegularGridInterpolator + + +def interp_wrap(xy: tuple[float,float], + interp: RegularGridInterpolator, + grid: np.ndarray, + values: list[float]) -> float: + """Wrap the interpolation function to handle NaNs. + + Picks the nearest non NaN grid point if the interpolated value is NaN, + otherwise returns the interpolated value. + + Args: + xy (tuple): Coordinate of interest + interp (RegularGridInterpolator): The interpolator object. + grid (np.ndarray): List of xy coordinates of the grid points. + values (list): The list of values at each point in the grid. + + Returns: + float: The interpolated value. + """ + # Call the interpolator + val = float(interp(xy)) + # If the value is NaN, we need to pick nearest non nan grid point + if np.isnan(val): + # Get the distances to all grid points + distances = np.linalg.norm(grid - xy, axis=1) + # Get the indices of the grid points sorted by distance + indices = np.argsort(distances) + # Iterate over the grid points in order of increasing distance + for index in indices: + # If the value at this grid point is not NaN, return it + if not np.isnan(values[index]): + return values[index] + else: + return val + + raise ValueError("No non NaN values found in grid.") + +def interpolate_points_on_raster(x: list[float], + y: list[float], + elevation_fid: str) -> list[float ]: + """Interpolate points on a raster. + + Args: + x (list): X coordinates. + y (list): Y coordinates. + elevation_fid (str): Filepath to elevation raster. + + Returns: + elevation (float): Elevation at point. + """ + with rst.open(elevation_fid) as src: + # Read the raster data + data = src.read(1).astype(float) # Assuming it's a single-band raster + data[data == src.nodata] = None + + # Get the raster's coordinates + x = np.linspace(src.bounds.left, src.bounds.right, src.width) + y = np.linspace(src.bounds.bottom, src.bounds.top, src.height) + + # Define grid + xx, yy = np.meshgrid(x, y) + grid = np.vstack([xx.ravel(), yy.ravel()]).T + values = data.ravel() + + # Define interpolator + interp = RegularGridInterpolator((y,x), + np.flipud(data), + method='linear', + bounds_error=False, + fill_value=None) + # Interpolate for x,y + return [interp_wrap((y_, x_), interp, grid, values) for x_, y_ in zip(x,y)] \ No newline at end of file diff --git a/tests/test_geospatial.py b/tests/test_geospatial.py new file mode 100644 index 00000000..1dade0ce --- /dev/null +++ b/tests/test_geospatial.py @@ -0,0 +1,65 @@ +# -*- coding: utf-8 -*- +"""Created on Tue Oct 18 10:35:51 2022. + +@author: Barney +""" + +# import pytest +from unittest.mock import MagicMock, patch + +import numpy as np +from scipy.interpolate import RegularGridInterpolator + +from swmmanywhere import geospatial_operations as go + + +def test_interp_wrap(): + """Test the interp_wrap function.""" + # Define a simple grid and values + x = np.linspace(0, 1, 5) + y = np.linspace(0, 1, 5) + xx, yy = np.meshgrid(x, y) + grid = np.vstack([xx.ravel(), yy.ravel()]).T + values = np.linspace(0, 1, 25) + values_grid = values.reshape(5, 5) + + # Define an interpolator + interp = RegularGridInterpolator((x,y), + values_grid) + + # Test the function at a point inside the grid + yx = (0.875, 0.875) + result = go.interp_wrap(yx, interp, grid, values) + assert result == 0.875 + + # Test the function on a nan point + values_grid[1][1] = np.nan + yx = (0.251, 0.25) + result = go.interp_wrap(yx, interp, grid, values) + assert result == values_grid[1][2] + +@patch('rasterio.open') +def test_interpolate_points_on_raster(mock_rst_open): + """Test the interpolate_points_on_raster function.""" + # Mock the raster file + mock_src = MagicMock() + mock_src.read.return_value = np.array([[1, 2], [3, 4]]) + mock_src.bounds = MagicMock() + mock_src.bounds.left = 0 + mock_src.bounds.right = 1 + mock_src.bounds.bottom = 0 + mock_src.bounds.top = 1 + mock_src.width = 2 + mock_src.height = 2 + mock_src.nodata = None + mock_rst_open.return_value.__enter__.return_value = mock_src + + # Define the x and y coordinates + x = [0.25, 0.75] + y = [0.25, 0.75] + + # Call the function + result = go.interpolate_points_on_raster(x, y, 'fake_path') + + # [3,2] feels unintuitive but it's because rasters measure from the top + assert result == [3.0, 2.0] \ No newline at end of file From fd1d8804111db0c9d27cb4a23cd12d70751fae16 Mon Sep 17 00:00:00 2001 From: Dobson Date: Fri, 19 Jan 2024 15:42:18 +0000 Subject: [PATCH 02/15] Update geospatial -Add reproject raster and tests -Add get_utm_epsg and tests --- swmmanywhere/geospatial_operations.py | 64 ++++++++++++++++++++++++++- tests/test_geospatial.py | 50 ++++++++++++++++++++- 2 files changed, 112 insertions(+), 2 deletions(-) diff --git a/swmmanywhere/geospatial_operations.py b/swmmanywhere/geospatial_operations.py index bde8de0e..9f095a75 100644 --- a/swmmanywhere/geospatial_operations.py +++ b/swmmanywhere/geospatial_operations.py @@ -3,11 +3,34 @@ @author: Barnaby Dobson """ +from typing import Optional + import numpy as np import rasterio as rst +from rasterio.warp import Resampling, calculate_default_transform, reproject from scipy.interpolate import RegularGridInterpolator +def get_utm_epsg(lon: float, lat: float) -> str: + """Get the formatted UTM EPSG code for a given longitude and latitude. + + Args: + lon (float): Longitude in EPSG:4326 (x) + lat (float): Latitude in EPSG:4326 (y) + + Returns: + str: Formatted EPSG code for the UTM zone. + + Example: + >>> get_utm_epsg(-0.1276, 51.5074) + 'EPSG:32630' + """ + # Determine the UTM zone number + zone_number = int((lon + 180) / 6) + 1 + # Determine the UTM EPSG code based on the hemisphere + utm_epsg = 32600 + zone_number if lat >= 0 else 32700 + zone_number + return 'EPSG:{0}'.format(utm_epsg) + def interp_wrap(xy: tuple[float,float], interp: RegularGridInterpolator, grid: np.ndarray, @@ -78,4 +101,43 @@ def interpolate_points_on_raster(x: list[float], bounds_error=False, fill_value=None) # Interpolate for x,y - return [interp_wrap((y_, x_), interp, grid, values) for x_, y_ in zip(x,y)] \ No newline at end of file + return [interp_wrap((y_, x_), interp, grid, values) for x_, y_ in zip(x,y)] + +def reproject_raster(target_crs: str, + fid: str, + new_fid: Optional[str] = None): + """Reproject a raster to a new CRS. + + Args: + target_crs (str): Target CRS in EPSG format (e.g., EPSG:32630). + fid (str): Filepath to the raster to reproject. + new_fid (str, optional): Filepath to save the reprojected raster. + Defaults to None, which will just use fid with '_reprojected'. + """ + with rst.open(fid) as src: + # Define the transformation parameters for reprojection + transform, width, height = calculate_default_transform( + src.crs, target_crs, src.width, src.height, *src.bounds) + + # Create the output raster file + kwargs = src.meta.copy() + kwargs.update({ + 'crs': target_crs, + 'transform': transform, + 'width': width, + 'height': height + }) + if new_fid is None: + new_fid = fid.replace('.tif','_reprojected.tif') + + with rst.open(new_fid, 'w', **kwargs) as dst: + # Reproject the data + reproject( + source=rst.band(src, 1), + destination=rst.band(dst, 1), + src_transform=src.transform, + src_crs=src.crs, + dst_transform=transform, + dst_crs=target_crs, + resampling=Resampling.bilinear + ) \ No newline at end of file diff --git a/tests/test_geospatial.py b/tests/test_geospatial.py index 1dade0ce..76859a7d 100644 --- a/tests/test_geospatial.py +++ b/tests/test_geospatial.py @@ -5,9 +5,11 @@ """ # import pytest +import os from unittest.mock import MagicMock, patch import numpy as np +import rasterio as rst from scipy.interpolate import RegularGridInterpolator from swmmanywhere import geospatial_operations as go @@ -62,4 +64,50 @@ def test_interpolate_points_on_raster(mock_rst_open): result = go.interpolate_points_on_raster(x, y, 'fake_path') # [3,2] feels unintuitive but it's because rasters measure from the top - assert result == [3.0, 2.0] \ No newline at end of file + assert result == [3.0, 2.0] + +def test_get_utm(): + """Test the get_utm_epsg function.""" + # Test a northern hemisphere point + crs = go.get_utm_epsg(-1, 51) + assert crs == 'EPSG:32630' + + # Test a southern hemisphere point + crs = go.get_utm_epsg(-1, -51) + assert crs == 'EPSG:32730' + + +def test_reproject_raster(): + """Test the reproject_raster function.""" + # Create a mock raster file + fid = 'test.tif' + data = np.random.randint(0, 255, (100, 100)).astype('uint8') + transform = rst.transform.from_origin(0, 0, 0.1, 0.1) + with rst.open(fid, + 'w', + driver='GTiff', + height=100, + width=100, + count=1, + dtype='uint8', + crs='EPSG:4326', + transform=transform) as src: + src.write(data, 1) + + # Define the input parameters + target_crs = 'EPSG:32630' + new_fid = 'test_reprojected.tif' + + # Call the function + go.reproject_raster(target_crs, fid) + + # Check if the reprojected file exists + assert os.path.exists(new_fid) + + # Check if the reprojected file has the correct CRS + with rst.open(new_fid) as src: + assert src.crs.to_string() == target_crs + + # Clean up the created files + os.remove(fid) + os.remove(new_fid) \ No newline at end of file From a2c411d7e5e197bb0f8df672e1bdf346d0e12fc9 Mon Sep 17 00:00:00 2001 From: Dobson Date: Fri, 19 Jan 2024 15:51:43 +0000 Subject: [PATCH 03/15] Add pyproj get_transformer --- swmmanywhere/geospatial_operations.py | 23 ++++++++++++++++++++++- tests/test_geospatial.py | 13 ++++++++++++- 2 files changed, 34 insertions(+), 2 deletions(-) diff --git a/swmmanywhere/geospatial_operations.py b/swmmanywhere/geospatial_operations.py index 9f095a75..032baa59 100644 --- a/swmmanywhere/geospatial_operations.py +++ b/swmmanywhere/geospatial_operations.py @@ -6,6 +6,7 @@ from typing import Optional import numpy as np +import pyproj import rasterio as rst from rasterio.warp import Resampling, calculate_default_transform, reproject from scipy.interpolate import RegularGridInterpolator @@ -140,4 +141,24 @@ def reproject_raster(target_crs: str, dst_transform=transform, dst_crs=target_crs, resampling=Resampling.bilinear - ) \ No newline at end of file + ) + +def get_transformer(source_crs: str, + target_crs: str) -> pyproj.Transformer: + """Get a transformer object for reprojection. + + Args: + source_crs (str): Source CRS in EPSG format (e.g., EPSG:32630). + target_crs (str): Target CRS in EPSG format (e.g., EPSG:32630). + + Returns: + pyproj.Transformer: Transformer object for reprojection. + + Example: + >>> transformer = get_transformer('EPSG:4326', 'EPSG:32630') + >>> transformer.transform(-0.1276, 51.5074) + (699330.1106898375, 5710164.30300683) + """ + return pyproj.Transformer.from_crs(source_crs, + target_crs, + always_xy=True) \ No newline at end of file diff --git a/tests/test_geospatial.py b/tests/test_geospatial.py index 76859a7d..d7489c6e 100644 --- a/tests/test_geospatial.py +++ b/tests/test_geospatial.py @@ -110,4 +110,15 @@ def test_reproject_raster(): # Clean up the created files os.remove(fid) - os.remove(new_fid) \ No newline at end of file + os.remove(new_fid) + +def test_get_transformer(): + """Test the get_transformer function.""" + # Test a northern hemisphere point + transformer = go.get_transformer('EPSG:4326', 'EPSG:32630') + + initial_point = (-0.1276, 51.5074) + expected_point = (699330.1106898375, 5710164.30300683) + assert transformer.transform(*initial_point) == expected_point + + \ No newline at end of file From 48b2eb0ae0b8e25d4235b5b40706c490059e48a8 Mon Sep 17 00:00:00 2001 From: Dobson Date: Fri, 19 Jan 2024 16:02:50 +0000 Subject: [PATCH 04/15] Add pyproj get_transformer --- swmmanywhere/geospatial_operations.py | 26 +++++++++++++++++++++++++- tests/test_geospatial.py | 20 +++++++++++++++++++- 2 files changed, 44 insertions(+), 2 deletions(-) diff --git a/swmmanywhere/geospatial_operations.py b/swmmanywhere/geospatial_operations.py index 032baa59..d8c1bb52 100644 --- a/swmmanywhere/geospatial_operations.py +++ b/swmmanywhere/geospatial_operations.py @@ -6,6 +6,7 @@ from typing import Optional import numpy as np +import pandas as pd import pyproj import rasterio as rst from rasterio.warp import Resampling, calculate_default_transform, reproject @@ -161,4 +162,27 @@ def get_transformer(source_crs: str, """ return pyproj.Transformer.from_crs(source_crs, target_crs, - always_xy=True) \ No newline at end of file + always_xy=True) + +def reproject_df(df: pd.DataFrame, + source_crs: str, + target_crs: str) -> pd.DataFrame: + """Reproject the coordinates in a DataFrame. + + Args: + df (pd.DataFrame): DataFrame with columns 'longitude' and 'latitude'. + source_crs (str): Source CRS in EPSG format (e.g., EPSG:4326). + target_crs (str): Target CRS in EPSG format (e.g., EPSG:32630). + """ + # Function to transform coordinates + df = df.copy() + transformer = get_transformer(source_crs, target_crs) + + # Reproject the coordinates in the DataFrame + def f(row): + return transformer.transform(row['longitude'], + row['latitude']) + + df['x'], df['y'] = zip(*df.apply(f,axis=1)) + + return df \ No newline at end of file diff --git a/tests/test_geospatial.py b/tests/test_geospatial.py index d7489c6e..d30e2999 100644 --- a/tests/test_geospatial.py +++ b/tests/test_geospatial.py @@ -9,6 +9,7 @@ from unittest.mock import MagicMock, patch import numpy as np +import pandas as pd import rasterio as rst from scipy.interpolate import RegularGridInterpolator @@ -121,4 +122,21 @@ def test_get_transformer(): expected_point = (699330.1106898375, 5710164.30300683) assert transformer.transform(*initial_point) == expected_point - \ No newline at end of file +def test_reproject_df(): + """Test the reproject_df function.""" + # Create a mock DataFrame + df = pd.DataFrame({ + 'longitude': [-0.1276], + 'latitude': [51.5074] + }) + + # Define the input parameters + source_crs = 'EPSG:4326' + target_crs = 'EPSG:32630' + + # Call the function + transformed_df = go.transform_df(df, source_crs, target_crs) + + # Check the output + assert transformed_df['x'].values[0] == 699330.1106898375 + assert transformed_df['y'].values[0] == 5710164.30300683 From 3e9c5b30919960199bb2cc3102cd27a94fb45670 Mon Sep 17 00:00:00 2001 From: Dobson Date: Fri, 19 Jan 2024 16:21:36 +0000 Subject: [PATCH 05/15] Add/test reproject df/graph --- swmmanywhere/geospatial_operations.py | 43 ++++++++++++++++++++++++++- tests/test_geospatial.py | 33 +++++++++++++++++++- 2 files changed, 74 insertions(+), 2 deletions(-) diff --git a/swmmanywhere/geospatial_operations.py b/swmmanywhere/geospatial_operations.py index d8c1bb52..0068afea 100644 --- a/swmmanywhere/geospatial_operations.py +++ b/swmmanywhere/geospatial_operations.py @@ -5,12 +5,14 @@ """ from typing import Optional +import networkx as nx import numpy as np import pandas as pd import pyproj import rasterio as rst from rasterio.warp import Resampling, calculate_default_transform, reproject from scipy.interpolate import RegularGridInterpolator +from shapely.geometry import LineString def get_utm_epsg(lon: float, lat: float) -> str: @@ -185,4 +187,43 @@ def f(row): df['x'], df['y'] = zip(*df.apply(f,axis=1)) - return df \ No newline at end of file + return df + +def reproject_graph(G: nx.Graph, + source_crs: str, + target_crs: str) -> nx.Graph: + """Reproject the coordinates in a graph. + + Args: + G (nx.Graph): Graph with nodes containing 'x' and 'y' properties. + source_crs (str): Source CRS in EPSG format (e.g., EPSG:4326). + target_crs (str): Target CRS in EPSG format (e.g., EPSG:32630). + + Returns: + nx.Graph: Graph with nodes containing 'x' and 'y' properties. + """ + # Create a PyProj transformer for CRS conversion + transformer = get_transformer(source_crs, target_crs) + + # Create a new graph with the converted nodes and edges + G_new = G.copy() + + # Convert and add nodes with 'x', 'y' properties + for node, data in G_new.nodes(data=True): + x, y = transformer.transform(data['x'], data['y']) + data['x'] = x + data['y'] = y + + # Convert and add edges with 'geometry' property + for u, v, data in G_new.edges(data=True): + if 'geometry' in data.keys(): + geometry = data['geometry'] + new_geometry = LineString(transformer.transform(x, y) + for x, y in geometry.coords) + else: + new_geometry = LineString([[G_new.nodes[u]['x'], + G_new.nodes[u]['y']], + [G_new.nodes[v]['x'], + G_new.nodes[v]['y']]]) + data['geometry'] = new_geometry + return G_new \ No newline at end of file diff --git a/tests/test_geospatial.py b/tests/test_geospatial.py index d30e2999..c2072057 100644 --- a/tests/test_geospatial.py +++ b/tests/test_geospatial.py @@ -8,10 +8,12 @@ import os from unittest.mock import MagicMock, patch +import networkx as nx import numpy as np import pandas as pd import rasterio as rst from scipy.interpolate import RegularGridInterpolator +from shapely.geometry import LineString from swmmanywhere import geospatial_operations as go @@ -135,8 +137,37 @@ def test_reproject_df(): target_crs = 'EPSG:32630' # Call the function - transformed_df = go.transform_df(df, source_crs, target_crs) + transformed_df = go.reproject_df(df, source_crs, target_crs) # Check the output assert transformed_df['x'].values[0] == 699330.1106898375 assert transformed_df['y'].values[0] == 5710164.30300683 + +def test_reproject_graph(): + """Test the reproject_graph function.""" + # Create a mock graph + G = nx.Graph() + G.add_node(1, x=0, y=0) + G.add_node(2, x=1, y=1) + G.add_edge(1, 2) + G.add_node(3, x=1, y=2) + G.add_edge(2, 3, geometry=LineString([(1, 1), (1, 2)])) + + # Define the input parameters + source_crs = 'EPSG:4326' + target_crs = 'EPSG:32630' + + # Call the function + G_new = go.reproject_graph(G, source_crs, target_crs) + + # Test node coordinates + assert G_new.nodes[1]['x'] == 833978.5569194595 + assert G_new.nodes[1]['y'] == 0 + assert G_new.nodes[2]['x'] == 945396.6839773951 + assert G_new.nodes[2]['y'] == 110801.83254625657 + assert G_new.nodes[3]['x'] == 945193.8596723974 + assert G_new.nodes[3]['y'] == 221604.0105092727 + + # Test edge geometry + assert list(G_new[1][2]['geometry'].coords)[0][0] == 833978.5569194595 + assert list(G_new[2][3]['geometry'].coords)[0][0] == 945396.6839773951 \ No newline at end of file From ba1686e3a553867a5071e560e96b9397709cc7d8 Mon Sep 17 00:00:00 2001 From: Dobson Date: Fri, 19 Jan 2024 16:30:21 +0000 Subject: [PATCH 06/15] Update reprojection test accuracy --- tests/test_geospatial.py | 32 +++++++++++++++++++++----------- 1 file changed, 21 insertions(+), 11 deletions(-) diff --git a/tests/test_geospatial.py b/tests/test_geospatial.py index c2072057..305896d6 100644 --- a/tests/test_geospatial.py +++ b/tests/test_geospatial.py @@ -115,6 +115,10 @@ def test_reproject_raster(): os.remove(fid) os.remove(new_fid) +def almost_equal(a, b, tol=1e-6): + """Check if two numbers are almost equal.""" + return abs(a-b) < tol + def test_get_transformer(): """Test the get_transformer function.""" # Test a northern hemisphere point @@ -122,7 +126,11 @@ def test_get_transformer(): initial_point = (-0.1276, 51.5074) expected_point = (699330.1106898375, 5710164.30300683) - assert transformer.transform(*initial_point) == expected_point + new_point = transformer.transform(*initial_point) + assert almost_equal(new_point[0], + expected_point[0]) + assert almost_equal(new_point[1], + expected_point[1]) def test_reproject_df(): """Test the reproject_df function.""" @@ -140,8 +148,8 @@ def test_reproject_df(): transformed_df = go.reproject_df(df, source_crs, target_crs) # Check the output - assert transformed_df['x'].values[0] == 699330.1106898375 - assert transformed_df['y'].values[0] == 5710164.30300683 + assert almost_equal(transformed_df['x'].values[0], 699330.1106898375) + assert almost_equal(transformed_df['y'].values[0], 5710164.30300683) def test_reproject_graph(): """Test the reproject_graph function.""" @@ -161,13 +169,15 @@ def test_reproject_graph(): G_new = go.reproject_graph(G, source_crs, target_crs) # Test node coordinates - assert G_new.nodes[1]['x'] == 833978.5569194595 - assert G_new.nodes[1]['y'] == 0 - assert G_new.nodes[2]['x'] == 945396.6839773951 - assert G_new.nodes[2]['y'] == 110801.83254625657 - assert G_new.nodes[3]['x'] == 945193.8596723974 - assert G_new.nodes[3]['y'] == 221604.0105092727 + assert almost_equal(G_new.nodes[1]['x'], 833978.5569194595) + assert almost_equal(G_new.nodes[1]['y'], 0) + assert almost_equal(G_new.nodes[2]['x'], 945396.6839773951) + assert almost_equal(G_new.nodes[2]['y'], 110801.83254625657) + assert almost_equal(G_new.nodes[3]['x'], 945193.8596723974) + assert almost_equal(G_new.nodes[3]['y'], 221604.0105092727) # Test edge geometry - assert list(G_new[1][2]['geometry'].coords)[0][0] == 833978.5569194595 - assert list(G_new[2][3]['geometry'].coords)[0][0] == 945396.6839773951 \ No newline at end of file + assert almost_equal(list(G_new[1][2]['geometry'].coords)[0][0], + 833978.5569194595) + assert almost_equal(list(G_new[2][3]['geometry'].coords)[0][0], + 945396.6839773951) \ No newline at end of file From d8432ee1bf38743ddd329b95ecabeadcb094e86b Mon Sep 17 00:00:00 2001 From: Dobson Date: Fri, 19 Jan 2024 16:49:10 +0000 Subject: [PATCH 07/15] Add nearest_node_buffer --- swmmanywhere/geospatial_operations.py | 52 ++++++++++++++++++++++++--- tests/test_geospatial.py | 22 ++++++++++-- 2 files changed, 67 insertions(+), 7 deletions(-) diff --git a/swmmanywhere/geospatial_operations.py b/swmmanywhere/geospatial_operations.py index 0068afea..3ef3c07d 100644 --- a/swmmanywhere/geospatial_operations.py +++ b/swmmanywhere/geospatial_operations.py @@ -8,11 +8,12 @@ import networkx as nx import numpy as np import pandas as pd +import pygeos import pyproj import rasterio as rst from rasterio.warp import Resampling, calculate_default_transform, reproject from scipy.interpolate import RegularGridInterpolator -from shapely.geometry import LineString +from shapely import geometry as sgeom def get_utm_epsg(lon: float, lat: float) -> str: @@ -218,12 +219,55 @@ def reproject_graph(G: nx.Graph, for u, v, data in G_new.edges(data=True): if 'geometry' in data.keys(): geometry = data['geometry'] - new_geometry = LineString(transformer.transform(x, y) + new_geometry = sgeom.LineString(transformer.transform(x, y) for x, y in geometry.coords) else: - new_geometry = LineString([[G_new.nodes[u]['x'], + new_geometry = sgeom.LineString([[G_new.nodes[u]['x'], G_new.nodes[u]['y']], [G_new.nodes[v]['x'], G_new.nodes[v]['y']]]) data['geometry'] = new_geometry - return G_new \ No newline at end of file + return G_new + +def nearest_node_buffer(points1: dict[str, sgeom.Point], + points2: dict[str, sgeom.Point], + threshold: float) -> dict: + """Find the nearest node within a given buffer threshold. + + Args: + points1 (dict): A dictionary where keys are labels and values are + Shapely points geometries. + points2 (dict): A dictionary where keys are labels and values are + Shapely points geometries. + threshold (float): The maximum distance for a node to be considered + 'nearest'. If no nodes are within this distance, the node is not + included in the output. + + Returns: + dict: A dictionary where keys are labels from points1 and values are + labels from points2 of the nearest nodes within the threshold. + """ + # Convert the keys of points2 to a list + labels2 = list(points2.keys()) + + # Convert the values of points2 to PyGEOS geometries + # and create a spatial index + pygeos_nodes = pygeos.from_shapely(list(points2.values())) + tree = pygeos.STRtree(pygeos_nodes) + + # Initialize an empty dictionary to store the matching nodes + matching = {} + + # Iterate over points1 + for key, geom in points1.items(): + # Find the nearest node in the spatial index to the current geometry + nearest = tree.nearest(pygeos.from_shapely(geom))[1][0] + nearest_geom = points2[labels2[nearest]] + + # If the nearest node is within the threshold, add it to the + # matching dictionary + if geom.buffer(threshold).intersection(nearest_geom): + matching[key] = labels2[nearest] + + # Return the matching dictionary + return matching \ No newline at end of file diff --git a/tests/test_geospatial.py b/tests/test_geospatial.py index 305896d6..b63b3720 100644 --- a/tests/test_geospatial.py +++ b/tests/test_geospatial.py @@ -13,7 +13,7 @@ import pandas as pd import rasterio as rst from scipy.interpolate import RegularGridInterpolator -from shapely.geometry import LineString +from shapely import geometry as sgeom from swmmanywhere import geospatial_operations as go @@ -159,7 +159,7 @@ def test_reproject_graph(): G.add_node(2, x=1, y=1) G.add_edge(1, 2) G.add_node(3, x=1, y=2) - G.add_edge(2, 3, geometry=LineString([(1, 1), (1, 2)])) + G.add_edge(2, 3, geometry=sgeom.LineString([(1, 1), (1, 2)])) # Define the input parameters source_crs = 'EPSG:4326' @@ -180,4 +180,20 @@ def test_reproject_graph(): assert almost_equal(list(G_new[1][2]['geometry'].coords)[0][0], 833978.5569194595) assert almost_equal(list(G_new[2][3]['geometry'].coords)[0][0], - 945396.6839773951) \ No newline at end of file + 945396.6839773951) + +def test_nearest_node_buffer(): + """Test the nearest_node_buffer function.""" + # Create mock dictionaries of points + points1 = {'a': sgeom.Point(0, 0), 'b': sgeom.Point(1, 1)} + points2 = {'c': sgeom.Point(0.5, 0.5), 'd': sgeom.Point(2, 2)} + + # Define the input threshold + threshold = 1.0 + + # Call the function + matching = go.nearest_node_buffer(points1, points2, threshold) + + # Check if the function returns the correct matching nodes + assert matching == {'a': 'c', 'b': 'c'} + From 2ab3e37db6c11effa60d7cafd796dab49f140b73 Mon Sep 17 00:00:00 2001 From: Dobson Date: Mon, 22 Jan 2024 09:59:50 +0000 Subject: [PATCH 08/15] Add carve and test carve --- swmmanywhere/geospatial_operations.py | 44 +++++++++++++++- tests/test_geospatial.py | 74 +++++++++++++++++++-------- 2 files changed, 97 insertions(+), 21 deletions(-) diff --git a/swmmanywhere/geospatial_operations.py b/swmmanywhere/geospatial_operations.py index 3ef3c07d..e07e5ff0 100644 --- a/swmmanywhere/geospatial_operations.py +++ b/swmmanywhere/geospatial_operations.py @@ -11,6 +11,7 @@ import pygeos import pyproj import rasterio as rst +from rasterio import features from rasterio.warp import Resampling, calculate_default_transform, reproject from scipy.interpolate import RegularGridInterpolator from shapely import geometry as sgeom @@ -270,4 +271,45 @@ def nearest_node_buffer(points1: dict[str, sgeom.Point], matching[key] = labels2[nearest] # Return the matching dictionary - return matching \ No newline at end of file + return matching + +def carve(geoms: list[sgeom.LineString], + depth: float, + raster_fid: str, + new_raster_fid: str): + """Carve a raster along a list of shapely geometries. + + Args: + geoms (list): List of Shapely geometries. + depth (float): Depth to carve. + raster_fid (str): Filepath to input raster. + new_raster_fid (str): Filepath to save the carved raster. + """ + with rst.open(raster_fid) as src: + # read data + data = src.read(1) + data = data.astype(float) + data_mask = data != src.nodata + bool_mask = np.zeros(data.shape, dtype=bool) + for geom in geoms: + # Create a mask for the line + mask = features.geometry_mask([sgeom.mapping(geom)], + out_shape=src.shape, + transform=src.transform, + invert=True) + # modify masked data + bool_mask[mask] = True # Adjust this multiplier as needed + #modify data + data[bool_mask & data_mask] -= depth + # Create a new GeoTIFF with modified values + with rst.open(new_raster_fid, + 'w', + driver='GTiff', + height=src.height, + width=src.width, + count=1, + dtype=data.dtype, + crs=src.crs, + transform=src.transform, + nodata = src.nodata) as dest: + dest.write(data, 1) \ No newline at end of file diff --git a/tests/test_geospatial.py b/tests/test_geospatial.py index b63b3720..eb374651 100644 --- a/tests/test_geospatial.py +++ b/tests/test_geospatial.py @@ -4,7 +4,6 @@ @author: Barney """ -# import pytest import os from unittest.mock import MagicMock, patch @@ -79,12 +78,9 @@ def test_get_utm(): crs = go.get_utm_epsg(-1, -51) assert crs == 'EPSG:32730' - -def test_reproject_raster(): - """Test the reproject_raster function.""" - # Create a mock raster file - fid = 'test.tif' - data = np.random.randint(0, 255, (100, 100)).astype('uint8') +def create_raster(fid): + """Define a function to create a mock raster file.""" + data = np.ones((100, 100)) transform = rst.transform.from_origin(0, 0, 0.1, 0.1) with rst.open(fid, 'w', @@ -96,24 +92,33 @@ def test_reproject_raster(): crs='EPSG:4326', transform=transform) as src: src.write(data, 1) +def test_reproject_raster(): + """Test the reproject_raster function.""" + # Create a mock raster file + fid = 'test.tif' + try: + create_raster(fid) - # Define the input parameters - target_crs = 'EPSG:32630' - new_fid = 'test_reprojected.tif' + # Define the input parameters + target_crs = 'EPSG:32630' + new_fid = 'test_reprojected.tif' - # Call the function - go.reproject_raster(target_crs, fid) + # Call the function + go.reproject_raster(target_crs, fid) - # Check if the reprojected file exists - assert os.path.exists(new_fid) + # Check if the reprojected file exists + assert os.path.exists(new_fid) - # Check if the reprojected file has the correct CRS - with rst.open(new_fid) as src: - assert src.crs.to_string() == target_crs + # Check if the reprojected file has the correct CRS + with rst.open(new_fid) as src: + assert src.crs.to_string() == target_crs + finally: + # Regardless of test outcome, delete the temp file + if os.path.exists(fid): + os.remove(fid) + if os.path.exists(new_fid): + os.remove(new_fid) - # Clean up the created files - os.remove(fid) - os.remove(new_fid) def almost_equal(a, b, tol=1e-6): """Check if two numbers are almost equal.""" @@ -197,3 +202,32 @@ def test_nearest_node_buffer(): # Check if the function returns the correct matching nodes assert matching == {'a': 'c', 'b': 'c'} +def test_carve_line(): + """Test the carve_line function.""" + # Create a mock geometry + geoms = [sgeom.LineString([(0, 0), (1, 1)]), + sgeom.Polygon([(0, 0), (1, 0), (1, 1), (0, 1)])] + + # Define the input parameters + depth = 1.0 + raster_fid = 'input.tif' + new_raster_fid = 'output.tif' + try: + create_raster(raster_fid) + + # Call the function + go.carve(geoms, depth, raster_fid, new_raster_fid) + + with rst.open(raster_fid) as src: + data_ = src.read(1) + + # Open the new GeoTIFF file and check if it has been correctly modified + with rst.open(new_raster_fid) as src: + data = src.read(1) + assert (data != data_).any() + finally: + # Regardless of test outcome, delete the temp file + if os.path.exists(raster_fid): + os.remove(raster_fid) + if os.path.exists(new_raster_fid): + os.remove(new_raster_fid) \ No newline at end of file From 8140fd23de34af63a0bc896dbb563d2cae00cb0f Mon Sep 17 00:00:00 2001 From: Dobson Date: Tue, 23 Jan 2024 16:25:45 +0000 Subject: [PATCH 09/15] Update geospatial_analysis based on review --- dev-requirements.txt | 13 +++-- pyproject.toml | 2 +- requirements.txt | 13 +++-- swmmanywhere/geospatial_operations.py | 72 +++++++++++---------------- tests/test_geospatial.py | 26 ++-------- 5 files changed, 51 insertions(+), 75 deletions(-) diff --git a/dev-requirements.txt b/dev-requirements.txt index 7556ccb0..218d771f 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -143,9 +143,9 @@ numpy==1.26.3 # osmnx # pandas # pyarrow - # pygeos # pysheds # rasterio + # rioxarray # salib # scikit-image # scipy @@ -163,6 +163,7 @@ packaging==23.2 # geopandas # matplotlib # pytest + # rioxarray # scikit-image # xarray pandas==2.1.4 @@ -189,8 +190,6 @@ pre-commit==3.6.0 # via swmmanywhere (pyproject.toml) pyarrow==14.0.2 # via swmmanywhere (pyproject.toml) -pygeos==0.14 - # via swmmanywhere (pyproject.toml) pyparsing==3.1.1 # via # matplotlib @@ -199,6 +198,7 @@ pyproj==3.6.1 # via # geopandas # pysheds + # rioxarray pyproject-hooks==1.0.0 # via build pysheds==0.3.5 @@ -228,11 +228,14 @@ pyyaml==6.0.1 rasterio==1.3.9 # via # pysheds + # rioxarray # swmmanywhere (pyproject.toml) requests==2.31.0 # via # cdsapi # osmnx +rioxarray==0.15.1 + # via swmmanywhere (pyproject.toml) ruff==0.1.11 # via swmmanywhere (pyproject.toml) salib==1.4.7 @@ -278,7 +281,9 @@ virtualenv==20.24.5 wheel==0.41.3 # via pip-tools xarray==2023.12.0 - # via swmmanywhere (pyproject.toml) + # via + # rioxarray + # swmmanywhere (pyproject.toml) # The following packages are considered to be unsafe in a requirements file: # pip diff --git a/pyproject.toml b/pyproject.toml index 19ce1bdc..e317ac9c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,10 +27,10 @@ dependencies = [ # TODO definitely don't need all of these "osmnx", "pandas", "pyarrow", - "pygeos", "pysheds", "PyYAML", "rasterio", + "rioxarray", "SALib", "SciPy", "shapely", diff --git a/requirements.txt b/requirements.txt index 8d88cf6d..120ed9d3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -113,9 +113,9 @@ numpy==1.26.3 # osmnx # pandas # pyarrow - # pygeos # pysheds # rasterio + # rioxarray # salib # scikit-image # scipy @@ -131,6 +131,7 @@ packaging==23.2 # fastparquet # geopandas # matplotlib + # rioxarray # scikit-image # xarray pandas==2.1.4 @@ -149,8 +150,6 @@ pillow==10.2.0 # scikit-image pyarrow==14.0.2 # via swmmanywhere (pyproject.toml) -pygeos==0.14 - # via swmmanywhere (pyproject.toml) pyparsing==3.1.1 # via # matplotlib @@ -159,6 +158,7 @@ pyproj==3.6.1 # via # geopandas # pysheds + # rioxarray pysheds==0.3.5 # via swmmanywhere (pyproject.toml) python-dateutil==2.8.2 @@ -172,11 +172,14 @@ pyyaml==6.0.1 rasterio==1.3.9 # via # pysheds + # rioxarray # swmmanywhere (pyproject.toml) requests==2.31.0 # via # cdsapi # osmnx +rioxarray==0.15.1 + # via swmmanywhere (pyproject.toml) salib==1.4.7 # via swmmanywhere (pyproject.toml) scikit-image==0.22.0 @@ -214,7 +217,9 @@ tzdata==2023.4 urllib3==2.1.0 # via requests xarray==2023.12.0 - # via swmmanywhere (pyproject.toml) + # via + # rioxarray + # swmmanywhere (pyproject.toml) # The following packages are considered to be unsafe in a requirements file: # setuptools diff --git a/swmmanywhere/geospatial_operations.py b/swmmanywhere/geospatial_operations.py index e07e5ff0..7001f048 100644 --- a/swmmanywhere/geospatial_operations.py +++ b/swmmanywhere/geospatial_operations.py @@ -5,16 +5,17 @@ """ from typing import Optional +import geopandas as gpd import networkx as nx import numpy as np import pandas as pd -import pygeos import pyproj import rasterio as rst +import rioxarray from rasterio import features -from rasterio.warp import Resampling, calculate_default_transform, reproject from scipy.interpolate import RegularGridInterpolator from shapely import geometry as sgeom +from shapely.strtree import STRtree def get_utm_epsg(lon: float, lat: float) -> str: @@ -120,33 +121,18 @@ def reproject_raster(target_crs: str, new_fid (str, optional): Filepath to save the reprojected raster. Defaults to None, which will just use fid with '_reprojected'. """ - with rst.open(fid) as src: - # Define the transformation parameters for reprojection - transform, width, height = calculate_default_transform( - src.crs, target_crs, src.width, src.height, *src.bounds) - - # Create the output raster file - kwargs = src.meta.copy() - kwargs.update({ - 'crs': target_crs, - 'transform': transform, - 'width': width, - 'height': height - }) + # Open the raster + with rioxarray.open_rasterio(fid) as raster: + + # Reproject the raster + reprojected = raster.rio.reproject(target_crs) + + # Define the output filepath if new_fid is None: new_fid = fid.replace('.tif','_reprojected.tif') - with rst.open(new_fid, 'w', **kwargs) as dst: - # Reproject the data - reproject( - source=rst.band(src, 1), - destination=rst.band(dst, 1), - src_transform=src.transform, - src_crs=src.crs, - dst_transform=transform, - dst_crs=target_crs, - resampling=Resampling.bilinear - ) + # Save the reprojected raster + reprojected.rio.to_raster(new_fid) def get_transformer(source_crs: str, target_crs: str) -> pyproj.Transformer: @@ -179,16 +165,12 @@ def reproject_df(df: pd.DataFrame, target_crs (str): Target CRS in EPSG format (e.g., EPSG:32630). """ # Function to transform coordinates + pts = gpd.points_from_xy(df["longitude"], + df["latitude"], + crs=source_crs).to_crs(target_crs) df = df.copy() - transformer = get_transformer(source_crs, target_crs) - - # Reproject the coordinates in the DataFrame - def f(row): - return transformer.transform(row['longitude'], - row['latitude']) - - df['x'], df['y'] = zip(*df.apply(f,axis=1)) - + df['x'] = pts.x + df['y'] = pts.y return df def reproject_graph(G: nx.Graph, @@ -196,6 +178,10 @@ def reproject_graph(G: nx.Graph, target_crs: str) -> nx.Graph: """Reproject the coordinates in a graph. + osmnx.projection.project_graph might be suitable if some other behaviour + needs to be captured, but it currently fails the tests so I will ignore for + now. + Args: G (nx.Graph): Graph with nodes containing 'x' and 'y' properties. source_crs (str): Source CRS in EPSG format (e.g., EPSG:4326). @@ -228,8 +214,10 @@ def reproject_graph(G: nx.Graph, [G_new.nodes[v]['x'], G_new.nodes[v]['y']]]) data['geometry'] = new_geometry + return G_new + def nearest_node_buffer(points1: dict[str, sgeom.Point], points2: dict[str, sgeom.Point], threshold: float) -> dict: @@ -251,10 +239,8 @@ def nearest_node_buffer(points1: dict[str, sgeom.Point], # Convert the keys of points2 to a list labels2 = list(points2.keys()) - # Convert the values of points2 to PyGEOS geometries - # and create a spatial index - pygeos_nodes = pygeos.from_shapely(list(points2.values())) - tree = pygeos.STRtree(pygeos_nodes) + # Create a spatial index + tree = STRtree(list(points2.values())) # Initialize an empty dictionary to store the matching nodes matching = {} @@ -262,22 +248,22 @@ def nearest_node_buffer(points1: dict[str, sgeom.Point], # Iterate over points1 for key, geom in points1.items(): # Find the nearest node in the spatial index to the current geometry - nearest = tree.nearest(pygeos.from_shapely(geom))[1][0] + nearest = tree.nearest(geom) nearest_geom = points2[labels2[nearest]] # If the nearest node is within the threshold, add it to the # matching dictionary - if geom.buffer(threshold).intersection(nearest_geom): + if geom.buffer(threshold).intersects(nearest_geom): matching[key] = labels2[nearest] # Return the matching dictionary return matching -def carve(geoms: list[sgeom.LineString], +def burn_shape_in_raster(geoms: list[sgeom.LineString], depth: float, raster_fid: str, new_raster_fid: str): - """Carve a raster along a list of shapely geometries. + """Burn a depth into a raster along a list of shapely geometries. Args: geoms (list): List of Shapely geometries. diff --git a/tests/test_geospatial.py b/tests/test_geospatial.py index eb374651..cf320c62 100644 --- a/tests/test_geospatial.py +++ b/tests/test_geospatial.py @@ -9,7 +9,6 @@ import networkx as nx import numpy as np -import pandas as pd import rasterio as rst from scipy.interpolate import RegularGridInterpolator from shapely import geometry as sgeom @@ -137,25 +136,6 @@ def test_get_transformer(): assert almost_equal(new_point[1], expected_point[1]) -def test_reproject_df(): - """Test the reproject_df function.""" - # Create a mock DataFrame - df = pd.DataFrame({ - 'longitude': [-0.1276], - 'latitude': [51.5074] - }) - - # Define the input parameters - source_crs = 'EPSG:4326' - target_crs = 'EPSG:32630' - - # Call the function - transformed_df = go.reproject_df(df, source_crs, target_crs) - - # Check the output - assert almost_equal(transformed_df['x'].values[0], 699330.1106898375) - assert almost_equal(transformed_df['y'].values[0], 5710164.30300683) - def test_reproject_graph(): """Test the reproject_graph function.""" # Create a mock graph @@ -202,8 +182,8 @@ def test_nearest_node_buffer(): # Check if the function returns the correct matching nodes assert matching == {'a': 'c', 'b': 'c'} -def test_carve_line(): - """Test the carve_line function.""" +def test_burn_shape_in_raster(): + """Test the burn_shape_in_raster function.""" # Create a mock geometry geoms = [sgeom.LineString([(0, 0), (1, 1)]), sgeom.Polygon([(0, 0), (1, 0), (1, 1), (0, 1)])] @@ -216,7 +196,7 @@ def test_carve_line(): create_raster(raster_fid) # Call the function - go.carve(geoms, depth, raster_fid, new_raster_fid) + go.burn_shape_in_raster(geoms, depth, raster_fid, new_raster_fid) with rst.open(raster_fid) as src: data_ = src.read(1) From 00c740b4417e151b91774a3899404cb738de72b7 Mon Sep 17 00:00:00 2001 From: Dobson Date: Tue, 23 Jan 2024 16:49:59 +0000 Subject: [PATCH 10/15] Update to use more general get_utm --- swmmanywhere/geospatial_operations.py | 52 +++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/swmmanywhere/geospatial_operations.py b/swmmanywhere/geospatial_operations.py index 7001f048..ec8c1dc9 100644 --- a/swmmanywhere/geospatial_operations.py +++ b/swmmanywhere/geospatial_operations.py @@ -3,6 +3,7 @@ @author: Barnaby Dobson """ +from functools import lru_cache from typing import Optional import geopandas as gpd @@ -17,6 +18,57 @@ from shapely import geometry as sgeom from shapely.strtree import STRtree +TransformerFromCRS = lru_cache(pyproj.transformer.Transformer.from_crs) + + +def get_utm_crs(x: float, + y: float, + crs: str | int | pyproj.CRS = 'EPSG:4326', + datum_name: str = "WGS 84"): + """Get the UTM CRS code for a given coordinate. + + Note, this function is taken from GeoPandas and modified to use + for getting the UTM CRS code for a given coordinate. + + Args: + x (float): Longitude in crs + y (float): Latitude in crs + crs (str | int | pyproj.CRS, optional): The CRS of the input + coordinates. Defaults to 'EPSG:4326'. + datum_name (str, optional): The datum name to use for the UTM CRS + + Returns: + str: Formatted EPSG code for the UTM zone. + + Example: + >>> get_utm_epsg(-0.1276, 51.5074) + 'EPSG:32630' + """ + if not isinstance(x, float) or not isinstance(y, float): + raise TypeError("x and y must be floats") + + try: + crs = pyproj.CRS(crs) + except pyproj.exceptions.CRSError: + raise ValueError("Invalid CRS") + + # ensure using geographic coordinates + if pyproj.CRS(crs).is_geographic: + lon = x + lat = y + else: + transformer = TransformerFromCRS(crs, "EPSG:4326", always_xy=True) + lon, lat = transformer.transform(x, y) + utm_crs_list = pyproj.database.query_utm_crs_info( + datum_name=datum_name, + area_of_interest=pyproj.aoi.AreaOfInterest( + west_lon_degree=lon, + south_lat_degree=lat, + east_lon_degree=lon, + north_lat_degree=lat, + ), + ) + return f"{utm_crs_list[0].auth_name}:{utm_crs_list[0].code}" def get_utm_epsg(lon: float, lat: float) -> str: """Get the formatted UTM EPSG code for a given longitude and latitude. From 4798c019c6645badb1a8bb4414dfb6835ac1c253 Mon Sep 17 00:00:00 2001 From: barneydobson Date: Thu, 25 Jan 2024 10:00:00 +0000 Subject: [PATCH 11/15] Update swmmanywhere/geospatial_operations.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Diego Alonso Álvarez <6095790+dalonsoa@users.noreply.github.com> --- swmmanywhere/geospatial_operations.py | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/swmmanywhere/geospatial_operations.py b/swmmanywhere/geospatial_operations.py index ec8c1dc9..4e588b7b 100644 --- a/swmmanywhere/geospatial_operations.py +++ b/swmmanywhere/geospatial_operations.py @@ -70,25 +70,6 @@ def get_utm_crs(x: float, ) return f"{utm_crs_list[0].auth_name}:{utm_crs_list[0].code}" -def get_utm_epsg(lon: float, lat: float) -> str: - """Get the formatted UTM EPSG code for a given longitude and latitude. - - Args: - lon (float): Longitude in EPSG:4326 (x) - lat (float): Latitude in EPSG:4326 (y) - - Returns: - str: Formatted EPSG code for the UTM zone. - - Example: - >>> get_utm_epsg(-0.1276, 51.5074) - 'EPSG:32630' - """ - # Determine the UTM zone number - zone_number = int((lon + 180) / 6) + 1 - # Determine the UTM EPSG code based on the hemisphere - utm_epsg = 32600 + zone_number if lat >= 0 else 32700 + zone_number - return 'EPSG:{0}'.format(utm_epsg) def interp_wrap(xy: tuple[float,float], interp: RegularGridInterpolator, From 561c1650442ffde9bdadb0c374c251460bca43a0 Mon Sep 17 00:00:00 2001 From: barneydobson Date: Thu, 25 Jan 2024 10:00:28 +0000 Subject: [PATCH 12/15] Update swmmanywhere/geospatial_operations.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Diego Alonso Álvarez <6095790+dalonsoa@users.noreply.github.com> --- swmmanywhere/geospatial_operations.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/swmmanywhere/geospatial_operations.py b/swmmanywhere/geospatial_operations.py index 4e588b7b..010d822c 100644 --- a/swmmanywhere/geospatial_operations.py +++ b/swmmanywhere/geospatial_operations.py @@ -238,15 +238,13 @@ def reproject_graph(G: nx.Graph, # Convert and add edges with 'geometry' property for u, v, data in G_new.edges(data=True): if 'geometry' in data.keys(): - geometry = data['geometry'] - new_geometry = sgeom.LineString(transformer.transform(x, y) - for x, y in geometry.coords) + data['geometry'] = sgeom.LineString(transformer.transform(x, y) + for x, y in data['geometry'].coords) else: - new_geometry = sgeom.LineString([[G_new.nodes[u]['x'], + data['geometry'] = sgeom.LineString([[G_new.nodes[u]['x'], G_new.nodes[u]['y']], [G_new.nodes[v]['x'], G_new.nodes[v]['y']]]) - data['geometry'] = new_geometry return G_new From ce195867e35abf1efce2e1d0199aafe4bdc787ea Mon Sep 17 00:00:00 2001 From: barneydobson Date: Thu, 25 Jan 2024 10:00:57 +0000 Subject: [PATCH 13/15] Update swmmanywhere/geospatial_operations.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Diego Alonso Álvarez <6095790+dalonsoa@users.noreply.github.com> --- swmmanywhere/geospatial_operations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/swmmanywhere/geospatial_operations.py b/swmmanywhere/geospatial_operations.py index 010d822c..8576e41b 100644 --- a/swmmanywhere/geospatial_operations.py +++ b/swmmanywhere/geospatial_operations.py @@ -71,7 +71,7 @@ def get_utm_crs(x: float, return f"{utm_crs_list[0].auth_name}:{utm_crs_list[0].code}" -def interp_wrap(xy: tuple[float,float], +def interp_with_nans(xy: tuple[float,float], interp: RegularGridInterpolator, grid: np.ndarray, values: list[float]) -> float: From 264dce3ad36566cde036fc0c3836a665221b1ee3 Mon Sep 17 00:00:00 2001 From: Dobson Date: Thu, 25 Jan 2024 10:07:39 +0000 Subject: [PATCH 14/15] "Fix tests to accommodate Diego's suggestions" --- swmmanywhere/geospatial_operations.py | 7 +++++-- tests/test_geospatial.py | 12 ++++++------ 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/swmmanywhere/geospatial_operations.py b/swmmanywhere/geospatial_operations.py index 8576e41b..6749f4c7 100644 --- a/swmmanywhere/geospatial_operations.py +++ b/swmmanywhere/geospatial_operations.py @@ -1,6 +1,9 @@ # -*- coding: utf-8 -*- """Created 2024-01-20. +A module containing functions to perform a variety of geospatial operations, +such as reprojecting coordinates and handling raster data. + @author: Barnaby Dobson """ from functools import lru_cache @@ -21,7 +24,7 @@ TransformerFromCRS = lru_cache(pyproj.transformer.Transformer.from_crs) -def get_utm_crs(x: float, +def get_utm_epsg(x: float, y: float, crs: str | int | pyproj.CRS = 'EPSG:4326', datum_name: str = "WGS 84"): @@ -141,7 +144,7 @@ def interpolate_points_on_raster(x: list[float], bounds_error=False, fill_value=None) # Interpolate for x,y - return [interp_wrap((y_, x_), interp, grid, values) for x_, y_ in zip(x,y)] + return [interp_with_nans((y_, x_), interp, grid, values) for x_, y_ in zip(x,y)] def reproject_raster(target_crs: str, fid: str, diff --git a/tests/test_geospatial.py b/tests/test_geospatial.py index cf320c62..fd70f724 100644 --- a/tests/test_geospatial.py +++ b/tests/test_geospatial.py @@ -16,8 +16,8 @@ from swmmanywhere import geospatial_operations as go -def test_interp_wrap(): - """Test the interp_wrap function.""" +def test_interp_with_nans(): + """Test the interp_interp_with_nans function.""" # Define a simple grid and values x = np.linspace(0, 1, 5) y = np.linspace(0, 1, 5) @@ -32,13 +32,13 @@ def test_interp_wrap(): # Test the function at a point inside the grid yx = (0.875, 0.875) - result = go.interp_wrap(yx, interp, grid, values) + result = go.interp_with_nans(yx, interp, grid, values) assert result == 0.875 # Test the function on a nan point values_grid[1][1] = np.nan yx = (0.251, 0.25) - result = go.interp_wrap(yx, interp, grid, values) + result = go.interp_with_nans(yx, interp, grid, values) assert result == values_grid[1][2] @patch('rasterio.open') @@ -70,11 +70,11 @@ def test_interpolate_points_on_raster(mock_rst_open): def test_get_utm(): """Test the get_utm_epsg function.""" # Test a northern hemisphere point - crs = go.get_utm_epsg(-1, 51) + crs = go.get_utm_epsg(-1.0, 51.0) assert crs == 'EPSG:32630' # Test a southern hemisphere point - crs = go.get_utm_epsg(-1, -51) + crs = go.get_utm_epsg(-1.0, -51.0) assert crs == 'EPSG:32730' def create_raster(fid): From 557b60dc283856a8d41f203a7a83f02d72a52b66 Mon Sep 17 00:00:00 2001 From: Dobson Date: Thu, 25 Jan 2024 10:14:31 +0000 Subject: [PATCH 15/15] Disable link checking in readme.md --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 0bc89d95..1e707092 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,7 @@ # SWMManywhere - + [![Test and build](https://github.com/ImperialCollegeLondon/SWMManywhere/actions/workflows/ci.yml/badge.svg)](https://github.com/ImperialCollegeLondon/SWMManywhere/actions/workflows/ci.yml) + ## High level workflow overview