Skip to content

Commit

Permalink
Update Raster IO (#679)
Browse files Browse the repository at this point in the history
  • Loading branch information
czaloom authored Jul 23, 2024
1 parent a46b4d4 commit 01b92a8
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 15 deletions.
59 changes: 49 additions & 10 deletions api/valor_api/backend/core/geometry.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import io
import struct
from base64 import b64encode

from geoalchemy2 import Geometry
from geoalchemy2.functions import ST_AsPNG
import numpy as np
from geoalchemy2 import Geometry, RasterElement
from geoalchemy2.types import CompositeType
from PIL import Image
from sqlalchemy import (
Expand Down Expand Up @@ -313,7 +314,7 @@ def convert_geometry(

def _raster_to_png_b64(
db: Session,
raster: Image.Image,
raster: RasterElement,
) -> str:
"""
Convert a raster to a png.
Expand All @@ -330,16 +331,54 @@ def _raster_to_png_b64(
str
The encoded raster.
"""
raster = Image.open(io.BytesIO(db.scalar(ST_AsPNG((raster))).tobytes()))
if raster.mode != "L":
raise RuntimeError
# Ensure raster_wkb is a bytes-like object
raster_wkb = bytes.fromhex(raster.data)

# Unpack header to get width and height
# reference: https://postgis.net/docs/manual-dev/RT_reference.html
header_format = "<BHHddddddiHH"
header_size = struct.calcsize(header_format)
(
ndr,
version,
num_bands,
scale_x,
scale_y,
ip_x,
ip_y,
skew_x,
skew_y,
srid,
width,
height,
) = struct.unpack(header_format, raster_wkb[:header_size])

# Check if the raster has a single band
if num_bands != 1:
raise ValueError("This function only supports single-band rasters.")

# Calculate the number of bytes needed for the pixel data
# Each byte represents 1 pixel
num_pixels = width * height
num_bytes = num_pixels

# Convert the byte data to a binary array
pixel_format = "B"
pixel_data = struct.unpack(
f"{width * height}{pixel_format}",
raster_wkb[header_size + 2 : header_size + 2 + num_bytes],
)

# Convert pixel data to numpy array
raster_numpy = np.array(pixel_data, dtype=bool)
raster_numpy = raster_numpy.reshape((height, width))

# mask is greyscale with values 0 and 1. to convert to binary
# we first need to map 1 to 255
raster = raster.point(lambda x: 255 if x == 1 else 0).convert("1")
# Convert to Pillow Image
raster_image = Image.fromarray(raster_numpy)

# b64 encode PNG to mask str
f = io.BytesIO()
raster.save(f, format="PNG")
raster_image.save(f, format="PNG")
f.seek(0)
mask_bytes = f.read()
return b64encode(mask_bytes).decode()
8 changes: 4 additions & 4 deletions api/valor_api/schemas/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -915,7 +915,7 @@ def from_numpy(cls, mask: np.ndarray) -> "Raster":
f"Expecting a binary mask (i.e. of dtype bool) but got dtype {mask.dtype}"
)
f = io.BytesIO()
PIL.Image.fromarray(mask).save(f, format="PNG")
PIL.Image.fromarray(mask).save(f, format="PNG", mode="1")
f.seek(0)
mask_bytes = f.read()
f.close()
Expand Down Expand Up @@ -1038,7 +1038,7 @@ def to_psql(self) -> ScalarSelect | bytes:
0, # skewy
0, # srid
),
"8BUI",
"1BB",
)
geom_raster = ST_AsRaster(
ST_SnapToGrid(
Expand All @@ -1047,7 +1047,7 @@ def to_psql(self) -> ScalarSelect | bytes:
),
1.0, # scalex
1.0, # scaley
"8BUI", # pixeltype
"1BB", # pixeltype
1, # value
0, # nodataval
)
Expand All @@ -1056,7 +1056,7 @@ def to_psql(self) -> ScalarSelect | bytes:
empty_raster,
geom_raster,
"[rast2]",
"8BUI",
"1BB",
"UNION",
)
).scalar_subquery()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def run_detailed_pr_curve_evaluation(dset: Dataset, model: Model):


def run_benchmarking_analysis(
limits_to_test: list[int] = [3, 3],
limits_to_test: list[int] = [6, 6],
results_file: str = "results.json",
data_file: str = "data.json",
):
Expand Down

0 comments on commit 01b92a8

Please sign in to comment.