Skip to content

Commit

Permalink
Update geospatial_analysis based on review
Browse files Browse the repository at this point in the history
  • Loading branch information
Dobson committed Jan 23, 2024
1 parent 2ab3e37 commit 8140fd2
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 75 deletions.
13 changes: 9 additions & 4 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,9 @@ numpy==1.26.3
# osmnx
# pandas
# pyarrow
# pygeos
# pysheds
# rasterio
# rioxarray
# salib
# scikit-image
# scipy
Expand All @@ -163,6 +163,7 @@ packaging==23.2
# geopandas
# matplotlib
# pytest
# rioxarray
# scikit-image
# xarray
pandas==2.1.4
Expand All @@ -189,8 +190,6 @@ pre-commit==3.6.0
# via swmmanywhere (pyproject.toml)
pyarrow==14.0.2
# via swmmanywhere (pyproject.toml)
pygeos==0.14
# via swmmanywhere (pyproject.toml)
pyparsing==3.1.1
# via
# matplotlib
Expand All @@ -199,6 +198,7 @@ pyproj==3.6.1
# via
# geopandas
# pysheds
# rioxarray
pyproject-hooks==1.0.0
# via build
pysheds==0.3.5
Expand Down Expand Up @@ -228,11 +228,14 @@ pyyaml==6.0.1
rasterio==1.3.9
# via
# pysheds
# rioxarray
# swmmanywhere (pyproject.toml)
requests==2.31.0
# via
# cdsapi
# osmnx
rioxarray==0.15.1
# via swmmanywhere (pyproject.toml)
ruff==0.1.11
# via swmmanywhere (pyproject.toml)
salib==1.4.7
Expand Down Expand Up @@ -278,7 +281,9 @@ virtualenv==20.24.5
wheel==0.41.3
# via pip-tools
xarray==2023.12.0
# via swmmanywhere (pyproject.toml)
# via
# rioxarray
# swmmanywhere (pyproject.toml)

# The following packages are considered to be unsafe in a requirements file:
# pip
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ dependencies = [ # TODO definitely don't need all of these
"osmnx",
"pandas",
"pyarrow",
"pygeos",
"pysheds",
"PyYAML",
"rasterio",
"rioxarray",
"SALib",
"SciPy",
"shapely",
Expand Down
13 changes: 9 additions & 4 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,9 @@ numpy==1.26.3
# osmnx
# pandas
# pyarrow
# pygeos
# pysheds
# rasterio
# rioxarray
# salib
# scikit-image
# scipy
Expand All @@ -131,6 +131,7 @@ packaging==23.2
# fastparquet
# geopandas
# matplotlib
# rioxarray
# scikit-image
# xarray
pandas==2.1.4
Expand All @@ -149,8 +150,6 @@ pillow==10.2.0
# scikit-image
pyarrow==14.0.2
# via swmmanywhere (pyproject.toml)
pygeos==0.14
# via swmmanywhere (pyproject.toml)
pyparsing==3.1.1
# via
# matplotlib
Expand All @@ -159,6 +158,7 @@ pyproj==3.6.1
# via
# geopandas
# pysheds
# rioxarray
pysheds==0.3.5
# via swmmanywhere (pyproject.toml)
python-dateutil==2.8.2
Expand All @@ -172,11 +172,14 @@ pyyaml==6.0.1
rasterio==1.3.9
# via
# pysheds
# rioxarray
# swmmanywhere (pyproject.toml)
requests==2.31.0
# via
# cdsapi
# osmnx
rioxarray==0.15.1
# via swmmanywhere (pyproject.toml)
salib==1.4.7
# via swmmanywhere (pyproject.toml)
scikit-image==0.22.0
Expand Down Expand Up @@ -214,7 +217,9 @@ tzdata==2023.4
urllib3==2.1.0
# via requests
xarray==2023.12.0
# via swmmanywhere (pyproject.toml)
# via
# rioxarray
# swmmanywhere (pyproject.toml)

# The following packages are considered to be unsafe in a requirements file:
# setuptools
72 changes: 29 additions & 43 deletions swmmanywhere/geospatial_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,17 @@
"""
from typing import Optional

import geopandas as gpd
import networkx as nx
import numpy as np
import pandas as pd
import pygeos
import pyproj
import rasterio as rst
import rioxarray
from rasterio import features
from rasterio.warp import Resampling, calculate_default_transform, reproject
from scipy.interpolate import RegularGridInterpolator
from shapely import geometry as sgeom
from shapely.strtree import STRtree


def get_utm_epsg(lon: float, lat: float) -> str:
Expand Down Expand Up @@ -120,33 +121,18 @@ def reproject_raster(target_crs: str,
new_fid (str, optional): Filepath to save the reprojected raster.
Defaults to None, which will just use fid with '_reprojected'.
"""
with rst.open(fid) as src:
# Define the transformation parameters for reprojection
transform, width, height = calculate_default_transform(
src.crs, target_crs, src.width, src.height, *src.bounds)

# Create the output raster file
kwargs = src.meta.copy()
kwargs.update({
'crs': target_crs,
'transform': transform,
'width': width,
'height': height
})
# Open the raster
with rioxarray.open_rasterio(fid) as raster:

# Reproject the raster
reprojected = raster.rio.reproject(target_crs)

# Define the output filepath
if new_fid is None:
new_fid = fid.replace('.tif','_reprojected.tif')

with rst.open(new_fid, 'w', **kwargs) as dst:
# Reproject the data
reproject(
source=rst.band(src, 1),
destination=rst.band(dst, 1),
src_transform=src.transform,
src_crs=src.crs,
dst_transform=transform,
dst_crs=target_crs,
resampling=Resampling.bilinear
)
# Save the reprojected raster
reprojected.rio.to_raster(new_fid)

def get_transformer(source_crs: str,
target_crs: str) -> pyproj.Transformer:
Expand Down Expand Up @@ -179,23 +165,23 @@ def reproject_df(df: pd.DataFrame,
target_crs (str): Target CRS in EPSG format (e.g., EPSG:32630).
"""
# Function to transform coordinates
pts = gpd.points_from_xy(df["longitude"],
df["latitude"],
crs=source_crs).to_crs(target_crs)
df = df.copy()
transformer = get_transformer(source_crs, target_crs)

# Reproject the coordinates in the DataFrame
def f(row):
return transformer.transform(row['longitude'],
row['latitude'])

df['x'], df['y'] = zip(*df.apply(f,axis=1))

df['x'] = pts.x
df['y'] = pts.y
return df

def reproject_graph(G: nx.Graph,
source_crs: str,
target_crs: str) -> nx.Graph:
"""Reproject the coordinates in a graph.
osmnx.projection.project_graph might be suitable if some other behaviour
needs to be captured, but it currently fails the tests so I will ignore for
now.
Args:
G (nx.Graph): Graph with nodes containing 'x' and 'y' properties.
source_crs (str): Source CRS in EPSG format (e.g., EPSG:4326).
Expand Down Expand Up @@ -228,8 +214,10 @@ def reproject_graph(G: nx.Graph,
[G_new.nodes[v]['x'],
G_new.nodes[v]['y']]])
data['geometry'] = new_geometry

return G_new


def nearest_node_buffer(points1: dict[str, sgeom.Point],
points2: dict[str, sgeom.Point],
threshold: float) -> dict:
Expand All @@ -251,33 +239,31 @@ def nearest_node_buffer(points1: dict[str, sgeom.Point],
# Convert the keys of points2 to a list
labels2 = list(points2.keys())

# Convert the values of points2 to PyGEOS geometries
# and create a spatial index
pygeos_nodes = pygeos.from_shapely(list(points2.values()))
tree = pygeos.STRtree(pygeos_nodes)
# Create a spatial index
tree = STRtree(list(points2.values()))

# Initialize an empty dictionary to store the matching nodes
matching = {}

# Iterate over points1
for key, geom in points1.items():
# Find the nearest node in the spatial index to the current geometry
nearest = tree.nearest(pygeos.from_shapely(geom))[1][0]
nearest = tree.nearest(geom)
nearest_geom = points2[labels2[nearest]]

# If the nearest node is within the threshold, add it to the
# matching dictionary
if geom.buffer(threshold).intersection(nearest_geom):
if geom.buffer(threshold).intersects(nearest_geom):
matching[key] = labels2[nearest]

# Return the matching dictionary
return matching

def carve(geoms: list[sgeom.LineString],
def burn_shape_in_raster(geoms: list[sgeom.LineString],
depth: float,
raster_fid: str,
new_raster_fid: str):
"""Carve a raster along a list of shapely geometries.
"""Burn a depth into a raster along a list of shapely geometries.
Args:
geoms (list): List of Shapely geometries.
Expand Down
26 changes: 3 additions & 23 deletions tests/test_geospatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import networkx as nx
import numpy as np
import pandas as pd
import rasterio as rst
from scipy.interpolate import RegularGridInterpolator
from shapely import geometry as sgeom
Expand Down Expand Up @@ -137,25 +136,6 @@ def test_get_transformer():
assert almost_equal(new_point[1],
expected_point[1])

def test_reproject_df():
"""Test the reproject_df function."""
# Create a mock DataFrame
df = pd.DataFrame({
'longitude': [-0.1276],
'latitude': [51.5074]
})

# Define the input parameters
source_crs = 'EPSG:4326'
target_crs = 'EPSG:32630'

# Call the function
transformed_df = go.reproject_df(df, source_crs, target_crs)

# Check the output
assert almost_equal(transformed_df['x'].values[0], 699330.1106898375)
assert almost_equal(transformed_df['y'].values[0], 5710164.30300683)

def test_reproject_graph():
"""Test the reproject_graph function."""
# Create a mock graph
Expand Down Expand Up @@ -202,8 +182,8 @@ def test_nearest_node_buffer():
# Check if the function returns the correct matching nodes
assert matching == {'a': 'c', 'b': 'c'}

def test_carve_line():
"""Test the carve_line function."""
def test_burn_shape_in_raster():
"""Test the burn_shape_in_raster function."""
# Create a mock geometry
geoms = [sgeom.LineString([(0, 0), (1, 1)]),
sgeom.Polygon([(0, 0), (1, 0), (1, 1), (0, 1)])]
Expand All @@ -216,7 +196,7 @@ def test_carve_line():
create_raster(raster_fid)

# Call the function
go.carve(geoms, depth, raster_fid, new_raster_fid)
go.burn_shape_in_raster(geoms, depth, raster_fid, new_raster_fid)

with rst.open(raster_fid) as src:
data_ = src.read(1)
Expand Down

0 comments on commit 8140fd2

Please sign in to comment.