diff --git a/api/valor_api/backend/core/geometry.py b/api/valor_api/backend/core/geometry.py index b8966163f..505e74b94 100644 --- a/api/valor_api/backend/core/geometry.py +++ b/api/valor_api/backend/core/geometry.py @@ -1,8 +1,9 @@ import io +import struct from base64 import b64encode -from geoalchemy2 import Geometry -from geoalchemy2.functions import ST_AsPNG +import numpy as np +from geoalchemy2 import Geometry, RasterElement from geoalchemy2.types import CompositeType from PIL import Image from sqlalchemy import ( @@ -313,7 +314,7 @@ def convert_geometry( def _raster_to_png_b64( db: Session, - raster: Image.Image, + raster: RasterElement, ) -> str: """ Convert a raster to a png. @@ -330,16 +331,54 @@ def _raster_to_png_b64( str The encoded raster. """ - raster = Image.open(io.BytesIO(db.scalar(ST_AsPNG((raster))).tobytes())) - if raster.mode != "L": - raise RuntimeError + # Ensure raster_wkb is a bytes-like object + raster_wkb = bytes.fromhex(raster.data) + + # Unpack header to get width and height + # reference: https://postgis.net/docs/manual-dev/RT_reference.html + header_format = " "Raster": f"Expecting a binary mask (i.e. of dtype bool) but got dtype {mask.dtype}" ) f = io.BytesIO() - PIL.Image.fromarray(mask).save(f, format="PNG") + PIL.Image.fromarray(mask).save(f, format="PNG", mode="1") f.seek(0) mask_bytes = f.read() f.close() @@ -1038,7 +1038,7 @@ def to_psql(self) -> ScalarSelect | bytes: 0, # skewy 0, # srid ), - "8BUI", + "1BB", ) geom_raster = ST_AsRaster( ST_SnapToGrid( @@ -1047,7 +1047,7 @@ def to_psql(self) -> ScalarSelect | bytes: ), 1.0, # scalex 1.0, # scaley - "8BUI", # pixeltype + "1BB", # pixeltype 1, # value 0, # nodataval ) @@ -1056,7 +1056,7 @@ def to_psql(self) -> ScalarSelect | bytes: empty_raster, geom_raster, "[rast2]", - "8BUI", + "1BB", "UNION", ) ).scalar_subquery() diff --git a/integration_tests/benchmarks/object-detection/benchmark_script.py b/integration_tests/benchmarks/object-detection/benchmark_script.py index 52a64682a..5dcc6a4ea 100644 --- a/integration_tests/benchmarks/object-detection/benchmark_script.py +++ b/integration_tests/benchmarks/object-detection/benchmark_script.py @@ -233,7 +233,7 @@ def run_detailed_pr_curve_evaluation(dset: Dataset, model: Model): def run_benchmarking_analysis( - limits_to_test: list[int] = [3, 3], + limits_to_test: list[int] = [6, 6], results_file: str = "results.json", data_file: str = "data.json", ):