diff --git a/swmmanywhere/metric_utilities.py b/swmmanywhere/metric_utilities.py index aef68633..11d1e2ad 100644 --- a/swmmanywhere/metric_utilities.py +++ b/swmmanywhere/metric_utilities.py @@ -5,6 +5,7 @@ """ from collections import defaultdict from inspect import signature +from itertools import product from typing import Callable, Optional import cytoolz.curried as tlz @@ -363,19 +364,31 @@ def align_by_shape(var, return results def create_grid(bbox: tuple, - scale: float): - """Create a grid of polygons.""" - dx = scale - dy = scale + scale: float | tuple[float,float]) -> gpd.GeoDataFrame: + """Create a grid of polygons. + + Create a grid of polygons based on the bounding box and scale. + + Args: + bbox (tuple): The bounding box coordinates in the format (minx, miny, + maxx, maxy). + scale (float | tuple): The scale of the grid. If a tuple, the scale is + (dx, dy). Otherwise, the scale is dx = dy = scale. + + Returns: + gpd.GeoDataFrame: A geodataframe of the grid. + """ minx, miny, maxx, maxy = bbox - grid = [{'geometry': shapely.geometry.Polygon([(minx + i * dx, miny + j * dy), - (minx + (i + 1) * dx, miny + j * dy), - (minx + (i + 1) * dx, miny + (j + 1) * dy), - (minx + i * dx, miny + (j + 1) * dy)]), - 'sub_id': f'{i}_{j}'} - for i in range(int((maxx - minx) // dx + 1)) - for j in range(int((maxy - miny) // dy + 1))] + if isinstance(scale, tuple): + dx, dy = scale + else: + dx = dy = scale + xmins = np.arange(minx, maxx, dx) + ymins = np.arange(minx, maxy, dy) + grid = [{'geometry' : shapely.box(x, y, x + dx, y + dy), + 'sub_id' : i} for i, (x, y) in enumerate(product(xmins, ymins))] + return gpd.GeoDataFrame(grid) @metrics.register