diff --git a/src/swmmanywhere/prepare_data.py b/src/swmmanywhere/prepare_data.py index 31d33887..e98359a0 100644 --- a/src/swmmanywhere/prepare_data.py +++ b/src/swmmanywhere/prepare_data.py @@ -19,6 +19,11 @@ import rioxarray.merge as rxr_merge import xarray as xr from geopy.geocoders import Nominatim +from pyarrow import RecordBatchReader +from pyarrow.compute import field +from pyarrow.dataset import dataset +from pyarrow.fs import S3FileSystem +from pyarrow.parquet import ParquetWriter from swmmanywhere.logging import logger from swmmanywhere.utilities import yaml_load @@ -59,6 +64,50 @@ def get_country(x: float, y: float) -> dict[int, str]: return {2: iso_country_code, 3: data.get(iso_country_code, "")} +def _record_batch_reader(bbox: tuple[float, float, float, float]) -> RecordBatchReader: + """Get a pyarrow batch reader this for bounding box and s3 path.""" + s3_region = "us-west-2" + version = "2024-07-22.0" + path = f"overturemaps-{s3_region}/release/{version}/theme=buildings/type=building/" + xmin, ymin, xmax, ymax = bbox + ds_filter = ( + (field("bbox", "xmin") < xmax) + & (field("bbox", "xmax") > xmin) + & (field("bbox", "ymin") < ymax) + & (field("bbox", "ymax") > ymin) + ) + + ds = dataset(path, filesystem=S3FileSystem(anonymous=True, region=s3_region)) + batches = ds.to_batches(filter=ds_filter) + non_empty_batches = (b for b in batches if b.num_rows > 0) + + geoarrow_schema = ds.schema.set( + ds.schema.get_field_index("geometry"), + ds.schema.field("geometry").with_metadata( + {b"ARROW:extension:name": b"geoarrow.wkb"} + ), + ) + return RecordBatchReader.from_batches(geoarrow_schema, non_empty_batches) + + +def download_buildings_bbox( + file_address: Path, bbox: tuple[float, float, float, float] +) -> None: + """Retrieve building data in bbox from Overture Maps to file. + + This function is based on + `overturemaps-py `__. + + Args: + bbox (tuple): Bounding box coordinates (xmin, ymin, xmax, ymax) + file_address (Path): File address to save the downloaded data. + """ + reader = _record_batch_reader(bbox) + with ParquetWriter(file_address, reader.schema) as writer: + for batch in reader: + writer.write_batch(batch) + + def download_buildings(file_address: Path, x: float, y: float) -> int: """Download buildings data based on coordinates and save to a file. diff --git a/src/swmmanywhere/preprocessing.py b/src/swmmanywhere/preprocessing.py index 97af0a64..72766569 100644 --- a/src/swmmanywhere/preprocessing.py +++ b/src/swmmanywhere/preprocessing.py @@ -78,23 +78,14 @@ def prepare_elevation( def prepare_building( bbox: tuple[float, float, float, float], addresses: FilePaths, target_crs: str ): - """Download, trim and reproject building data.""" + """Download and reproject building data.""" if addresses.bbox_paths.building.exists(): return - if not addresses.project_paths.national_building.exists(): - logger.info( - f"""downloading buildings to - {addresses.project_paths.national_building}""" - ) - prepare_data.download_buildings( - addresses.project_paths.national_building, bbox[0], bbox[1] - ) - - logger.info(f"trimming buildings to {addresses.bbox_paths.building}") - national_buildings = gpd.read_parquet(addresses.project_paths.national_building) - buildings = national_buildings.cx[bbox[0] : bbox[2], bbox[1] : bbox[3]] # type: ignore + logger.info(f"downloading buildings to {addresses.bbox_paths.building}") + prepare_data.download_buildings_bbox(addresses.bbox_paths.building, bbox) + buildings = gpd.read_parquet(addresses.bbox_paths.building) buildings = buildings.to_crs(target_crs) write_df(buildings, addresses.bbox_paths.building) diff --git a/tests/test_prepare_data.py b/tests/test_prepare_data.py index d51ab15b..08429e56 100644 --- a/tests/test_prepare_data.py +++ b/tests/test_prepare_data.py @@ -73,6 +73,26 @@ def test_building_downloader_download(): assert gdf.shape[0] > 0 +@pytest.mark.downloads +def test_building_bbox_downloader_download(): + """Check buildings are downloaded.""" + # Coordinates for small country (VAT) + bbox = (-0.17929, 51.49638, -0.17383, 51.49846) + with tempfile.TemporaryDirectory() as temp_dir: + temp_fid = Path(temp_dir) / "temp.parquet" + # Download + downloaders.download_buildings_bbox(temp_fid, bbox) + + # Check file exists + assert temp_fid.exists(), "Buildings data file not found after download." + + # Load data + gdf = gpd.read_parquet(temp_fid) + + # Make sure has some rows + assert gdf.shape[0] > 0 + + @pytest.mark.downloads def test_street_downloader_download(): """Check streets are downloaded and a specific point in the graph."""