Skip to content

Commit

Permalink
Add carve and test carve
Browse files Browse the repository at this point in the history
  • Loading branch information
Dobson committed Jan 22, 2024
1 parent d8432ee commit 2ab3e37
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 21 deletions.
44 changes: 43 additions & 1 deletion swmmanywhere/geospatial_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import pygeos
import pyproj
import rasterio as rst
from rasterio import features
from rasterio.warp import Resampling, calculate_default_transform, reproject
from scipy.interpolate import RegularGridInterpolator
from shapely import geometry as sgeom
Expand Down Expand Up @@ -270,4 +271,45 @@ def nearest_node_buffer(points1: dict[str, sgeom.Point],
matching[key] = labels2[nearest]

# Return the matching dictionary
return matching
return matching

def carve(geoms: list[sgeom.LineString],
depth: float,
raster_fid: str,
new_raster_fid: str):
"""Carve a raster along a list of shapely geometries.
Args:
geoms (list): List of Shapely geometries.
depth (float): Depth to carve.
raster_fid (str): Filepath to input raster.
new_raster_fid (str): Filepath to save the carved raster.
"""
with rst.open(raster_fid) as src:
# read data
data = src.read(1)
data = data.astype(float)
data_mask = data != src.nodata
bool_mask = np.zeros(data.shape, dtype=bool)
for geom in geoms:
# Create a mask for the line
mask = features.geometry_mask([sgeom.mapping(geom)],
out_shape=src.shape,
transform=src.transform,
invert=True)
# modify masked data
bool_mask[mask] = True # Adjust this multiplier as needed
#modify data
data[bool_mask & data_mask] -= depth
# Create a new GeoTIFF with modified values
with rst.open(new_raster_fid,
'w',
driver='GTiff',
height=src.height,
width=src.width,
count=1,
dtype=data.dtype,
crs=src.crs,
transform=src.transform,
nodata = src.nodata) as dest:
dest.write(data, 1)
74 changes: 54 additions & 20 deletions tests/test_geospatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
@author: Barney
"""

# import pytest
import os
from unittest.mock import MagicMock, patch

Expand Down Expand Up @@ -79,12 +78,9 @@ def test_get_utm():
crs = go.get_utm_epsg(-1, -51)
assert crs == 'EPSG:32730'


def test_reproject_raster():
"""Test the reproject_raster function."""
# Create a mock raster file
fid = 'test.tif'
data = np.random.randint(0, 255, (100, 100)).astype('uint8')
def create_raster(fid):
"""Define a function to create a mock raster file."""
data = np.ones((100, 100))
transform = rst.transform.from_origin(0, 0, 0.1, 0.1)
with rst.open(fid,
'w',
Expand All @@ -96,24 +92,33 @@ def test_reproject_raster():
crs='EPSG:4326',
transform=transform) as src:
src.write(data, 1)
def test_reproject_raster():
"""Test the reproject_raster function."""
# Create a mock raster file
fid = 'test.tif'
try:
create_raster(fid)

# Define the input parameters
target_crs = 'EPSG:32630'
new_fid = 'test_reprojected.tif'
# Define the input parameters
target_crs = 'EPSG:32630'
new_fid = 'test_reprojected.tif'

# Call the function
go.reproject_raster(target_crs, fid)
# Call the function
go.reproject_raster(target_crs, fid)

# Check if the reprojected file exists
assert os.path.exists(new_fid)
# Check if the reprojected file exists
assert os.path.exists(new_fid)

# Check if the reprojected file has the correct CRS
with rst.open(new_fid) as src:
assert src.crs.to_string() == target_crs
# Check if the reprojected file has the correct CRS
with rst.open(new_fid) as src:
assert src.crs.to_string() == target_crs
finally:
# Regardless of test outcome, delete the temp file
if os.path.exists(fid):
os.remove(fid)
if os.path.exists(new_fid):
os.remove(new_fid)

# Clean up the created files
os.remove(fid)
os.remove(new_fid)

def almost_equal(a, b, tol=1e-6):
"""Check if two numbers are almost equal."""
Expand Down Expand Up @@ -197,3 +202,32 @@ def test_nearest_node_buffer():
# Check if the function returns the correct matching nodes
assert matching == {'a': 'c', 'b': 'c'}

def test_carve_line():
"""Test the carve_line function."""
# Create a mock geometry
geoms = [sgeom.LineString([(0, 0), (1, 1)]),
sgeom.Polygon([(0, 0), (1, 0), (1, 1), (0, 1)])]

# Define the input parameters
depth = 1.0
raster_fid = 'input.tif'
new_raster_fid = 'output.tif'
try:
create_raster(raster_fid)

# Call the function
go.carve(geoms, depth, raster_fid, new_raster_fid)

with rst.open(raster_fid) as src:
data_ = src.read(1)

# Open the new GeoTIFF file and check if it has been correctly modified
with rst.open(new_raster_fid) as src:
data = src.read(1)
assert (data != data_).any()
finally:
# Regardless of test outcome, delete the temp file
if os.path.exists(raster_fid):
os.remove(raster_fid)
if os.path.exists(new_raster_fid):
os.remove(new_raster_fid)

0 comments on commit 2ab3e37

Please sign in to comment.