Skip to content

Commit

Permalink
polygon_query() support for images (#358)
Browse files Browse the repository at this point in the history
  • Loading branch information
LucaMarconato authored Sep 24, 2023
1 parent 4c6c1e9 commit 572136f
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 13 deletions.
53 changes: 46 additions & 7 deletions src/spatialdata/_core/query/spatial_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,14 @@ def _(


def _polygon_query(
sdata: SpatialData, polygon: Polygon, target_coordinate_system: str, filter_table: bool, shapes: bool, points: bool
sdata: SpatialData,
polygon: Polygon,
target_coordinate_system: str,
filter_table: bool,
shapes: bool,
points: bool,
images: bool,
labels: bool,
) -> SpatialData:
from spatialdata._core.query._utils import circles_to_polygons
from spatialdata._core.query.relational_query import _filter_table_by_elements
Expand Down Expand Up @@ -669,11 +676,32 @@ def _polygon_query(
set_transformation(ddf, transformation, target_coordinate_system)
new_points[points_name] = ddf

if filter_table:
new_images = {}
if images:
for images_name, im in sdata.images.items():
min_x, min_y, max_x, max_y = polygon.bounds
cropped = bounding_box_query(
im,
min_coordinate=[min_x, min_y],
max_coordinate=[max_x, max_y],
axes=("x", "y"),
target_coordinate_system=target_coordinate_system,
)
new_images[images_name] = cropped
if labels:
for labels_name, l in sdata.labels.items():
_ = labels_name
_ = l
raise NotImplementedError(
"labels=True is not implemented yet. If you encounter this error please open an "
"issue and we will prioritize the implementation."
)

if filter_table and sdata.table is not None:
table = _filter_table_by_elements(sdata.table, {"shapes": new_shapes, "points": new_points})
else:
table = sdata.table
return SpatialData(shapes=new_shapes, points=new_points, table=table)
return SpatialData(shapes=new_shapes, points=new_points, images=new_images, table=table)


# this function is currently excluded from the API documentation. TODO: add it after the refactoring
Expand All @@ -684,6 +712,8 @@ def polygon_query(
filter_table: bool = True,
shapes: bool = True,
points: bool = True,
images: bool = True,
labels: bool = True,
) -> SpatialData:
"""
Query a spatial data object by a polygon, filtering shapes and points.
Expand Down Expand Up @@ -725,14 +755,21 @@ def polygon_query(
filter_table=filter_table,
shapes=shapes,
points=points,
images=images,
labels=labels,
)
# TODO: the performance for this case can be greatly improved by using the geopandas queries only once, and not
# in a loop as done preliminarily here
if points:
raise NotImplementedError(
"points=True is not implemented when querying by multiple polygons. If you encounter this error, please"
" open an issue on GitHub and we will prioritize the implementation."
if points or images or labels:
logger.warning(
"Spatial querying of images, points and labels is not implemented when querying by multiple polygons "
'simultaneously. You can silence this warning by setting "points=False, images=False, labels=False". If '
"you need this implementation please open an issue on GitHub."
)
points = False
images = False
labels = False

sdatas = []
for polygon in tqdm(polygons):
try:
Expand All @@ -744,6 +781,8 @@ def polygon_query(
filter_table=False,
shapes=shapes,
points=points,
images=images,
labels=labels,
)
sdatas.append(queried_sdata)
except ValueError as e:
Expand Down
40 changes: 34 additions & 6 deletions tests/core/query/test_spatial_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest
from anndata import AnnData
from multiscale_spatial_image import MultiscaleSpatialImage
from shapely import Polygon
from spatial_image import SpatialImage
from spatialdata import SpatialData
from spatialdata._core.query.spatial_query import (
Expand Down Expand Up @@ -379,7 +380,11 @@ def test_polygon_query_shapes(sdata_query_aggregation):
circle_pol = circle.buffer(sdata["by_circles"].radius.iloc[0])

queried = polygon_query(
values_sdata, polygons=polygon, target_coordinate_system="global", shapes=True, points=False
values_sdata,
polygons=polygon,
target_coordinate_system="global",
shapes=True,
points=False,
)
assert len(queried["values_polygons"]) == 4
assert len(queried["values_circles"]) == 4
Expand Down Expand Up @@ -432,11 +437,34 @@ def test_polygon_query_spatial_data(sdata_query_aggregation):
assert len(queried.table) == 8


@pytest.mark.skip
def test_polygon_query_image2d():
# single image case
# multiscale case
pass
@pytest.mark.parametrize("n_channels", [1, 2, 3])
def test_polygon_query_image2d(n_channels: int):
original_image = np.zeros((n_channels, 10, 10))
# y: [5, 9], x: [0, 4] has value 1
original_image[:, 5::, 0:5] = 1
image_element = Image2DModel.parse(original_image)
image_element_multiscale = Image2DModel.parse(original_image, scale_factors=[2, 2])

polygon = Polygon([(3, 3), (3, 7), (5, 3)])
for image in [image_element, image_element_multiscale]:
# bounding box: y: [5, 10[, x: [0, 5[
image_result = polygon_query(
SpatialData(images={"my_image": image}),
polygons=polygon,
target_coordinate_system="global",
)["my_image"]
expected_image = original_image[:, 3:7, 3:5] # c dimension is preserved
if isinstance(image, SpatialImage):
assert isinstance(image, SpatialImage)
np.testing.assert_allclose(image_result, expected_image)
elif isinstance(image, MultiscaleSpatialImage):
assert isinstance(image_result, MultiscaleSpatialImage)
v = image_result["scale0"].values()
assert len(v) == 1
xdata = v.__iter__().__next__()
np.testing.assert_allclose(xdata, expected_image)
else:
raise ValueError("Unexpected type")


@pytest.mark.skip
Expand Down

0 comments on commit 572136f

Please sign in to comment.