Skip to content

Commit

Permalink
chore: update geospatial with main manually to deal with conflict
Browse files Browse the repository at this point in the history
  • Loading branch information
dbhart committed Nov 18, 2024
1 parent a054ad6 commit 6dced54
Showing 1 changed file with 61 additions and 7 deletions.
68 changes: 61 additions & 7 deletions wntr/gis/geospatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pandas as pd
import numpy as np


try:
from shapely.geometry import MultiPoint, LineString, Point, shape
has_shapely = True
Expand All @@ -18,6 +19,13 @@
except ModuleNotFoundError:
gpd = None
has_geopandas = False

try:
import rasterio
has_rasterio = True
except ModuleNotFoundError:
rasterio = None
has_rasterio = False


def snap(A, B, tolerance):
Expand Down Expand Up @@ -57,9 +65,9 @@ def snap(A, B, tolerance):
if not has_shapely or not has_geopandas:
raise ModuleNotFoundError('shapley and geopandas are required')

isinstance(A, gpd.GeoDataFrame)
assert isinstance(A, gpd.GeoDataFrame)
assert(A['geometry'].geom_type).isin(['Point']).all()
isinstance(B, gpd.GeoDataFrame)
assert isinstance(B, gpd.GeoDataFrame)
assert (B['geometry'].geom_type).isin(['Point', 'LineString', 'MultiLineString']).all()
assert A.crs == B.crs

Expand Down Expand Up @@ -110,7 +118,6 @@ def snap(A, B, tolerance):
B.index.name = None

# snap to points
snapped_points = None
if B['geometry'].geom_type.isin(['Point']).all():
snapped_points = closest.rename(columns={"indexB":"node"})
snapped_points = snapped_points[["node", "snap_distance", "geometry"]]
Expand Down Expand Up @@ -198,18 +205,18 @@ def intersect(A, B, B_value=None, include_background=False, background_value=0):
if not has_shapely or not has_geopandas:
raise ModuleNotFoundError('shapley and geopandas are required')

isinstance(A, gpd.GeoDataFrame)
assert isinstance(A, gpd.GeoDataFrame)
assert (A['geometry'].geom_type).isin(['Point', 'LineString',
'MultiLineString', 'Polygon',
'MultiPolygon']).all()
isinstance(B, gpd.GeoDataFrame)
assert isinstance(B, gpd.GeoDataFrame)
assert (B['geometry'].geom_type).isin(['Point', 'LineString',
'MultiLineString', 'Polygon',
'MultiPolygon']).all()
if isinstance(B_value, str):
assert B_value in B.columns
isinstance(include_background, bool)
isinstance(background_value, (int, float))
assert isinstance(include_background, bool)
assert isinstance(background_value, (int, float))
assert A.crs == B.crs, "A and B must have the same crs."

if include_background:
Expand Down Expand Up @@ -283,6 +290,53 @@ def intersect(A, B, B_value=None, include_background=False, background_value=0):

return stats


def sample_raster(A, filepath, bands=1):
"""Sample a raster (e.g., GeoTIFF file) using Points in GeoDataFrame A.
This function can take either a filepath to a raster or a virtual raster
(VRT), which combines multiple raster tiles into a single object. The
function opens the raster(s) and samples it at the coordinates of the point
geometries in A. This function assigns nan to values that match the
raster's `nodata` attribute. These sampled values are returned as a Series
which has an index matching A.
Parameters
----------
A : GeoDataFrame
GeoDataFrame containing Point geometries
filepath : str
Path to raster or alternatively a virtual raster (VRT)
bands : int or list[int] (optional, default = 1)
Index or indices of the bands to sample (using 1-based indexing)
Returns
-------
Series
Pandas Series containing the sampled values for each geometry in gdf
"""
# further functionality could include other geometries (Line, Polygon),
# and use of multiprocessing to speed up querying.
if not has_rasterio:
raise ModuleNotFoundError('rasterio is required')

assert (A['geometry'].geom_type == "Point").all()
assert isinstance(filepath, str)
assert isinstance(bands, (int, list))

with rasterio.open(filepath) as raster:
xys = zip(A.geometry.x, A.geometry.y)

values = np.array(
tuple(raster.sample(xys, bands)), dtype=float # force to float to allow for conversion of nodata to nan
).squeeze()

values[values == raster.nodata] = np.nan
values = pd.Series(values, index=A.index)

return values


def connect_lines(lines, threshold):
"""
Connect lines by identifying start and end nodes that are within a
Expand Down

0 comments on commit 6dced54

Please sign in to comment.