Skip to content

Commit

Permalink
Fixes in get_centroids, spatial_query for multiscale raster, dataload…
Browse files Browse the repository at this point in the history
…er (#495)

fixes in get_centroids, spatial_query for multiscale raster, dataloader
  • Loading branch information
LucaMarconato authored Mar 21, 2024
1 parent 19c011d commit aa78ad6
Show file tree
Hide file tree
Showing 7 changed files with 115 additions and 62 deletions.
5 changes: 3 additions & 2 deletions src/spatialdata/_core/centroids.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,8 @@ def _(e: GeoDataFrame, coordinate_system: str = "global") -> DaskDataFrame:
f" Polygons/MultiPolygons. Found {type(first_geometry)} instead."
)
xy = e.centroid.get_coordinates().values
points = PointsModel.parse(xy, transformations={coordinate_system: t})
xy_df = pd.DataFrame(xy, columns=["x", "y"], index=e.index.copy())
points = PointsModel.parse(xy_df, transformations={coordinate_system: t})
return transform(points, to_coordinate_system=coordinate_system)


Expand All @@ -142,7 +143,7 @@ def _(e: DaskDataFrame, coordinate_system: str = "global") -> DaskDataFrame:
_validate_coordinate_system(e, coordinate_system)
axes = get_axes_names(e)
assert axes in [("x", "y"), ("x", "y", "z")]
coords = e[list(axes)].compute().values
coords = e[list(axes)].compute()
t = get_transformation(e, coordinate_system)
assert isinstance(t, BaseTransformation)
centroids = PointsModel.parse(coords, transformations={coordinate_system: t})
Expand Down
15 changes: 15 additions & 0 deletions src/spatialdata/_core/query/spatial_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,21 @@ def _(
return None
else:
d[k] = xdata
# the list of scales may not be contiguous when the data has small shape (for instance with yx = 22 and
# rotations we may end up having scale0 and scale2 but not scale1. Practically this may occur in torch tiler if
# the tiles are request to be too small).
# Here we remove scales after we found a scale missing
scales_to_keep = []
for i, scale_name in enumerate(d.keys()):
if scale_name == f"scale{i}":
scales_to_keep.append(scale_name)
else:
break
# case in which scale0 is not present but other scales are
if len(scales_to_keep) == 0:
return None
d = {k: d[k] for k in scales_to_keep}

query_result = MultiscaleSpatialImage.from_dict(d)
# rechunk the data to avoid irregular chunks
for scale in query_result:
Expand Down
25 changes: 7 additions & 18 deletions src/spatialdata/_core/spatialdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,24 +198,13 @@ def validate_table_in_spatialdata(self, table: AnnData) -> None:
dtype = element.scale0.ds.dtypes["image"]
else:
dtype = element.index.dtype
if dtype != table.obs[instance_key].dtype:
if dtype == str or table.obs[instance_key].dtype == str:
raise TypeError(
f"Table instance_key column ({instance_key}) has a dtype "
f"({table.obs[instance_key].dtype}) that does not match the dtype of the indices of "
f"the annotated element ({dtype})."
)

warnings.warn(
(
f"Table instance_key column ({instance_key}) has a dtype "
f"({table.obs[instance_key].dtype}) that does not match the dtype of the indices of "
f"the annotated element ({dtype}). Please note in the case of int16 vs int32 or "
"similar cases may be tolerated in downstream methods, but it is recommended to make "
"the dtypes match."
),
UserWarning,
stacklevel=2,
if dtype != table.obs[instance_key].dtype and (
dtype == str or table.obs[instance_key].dtype == str
):
raise TypeError(
f"Table instance_key column ({instance_key}) has a dtype "
f"({table.obs[instance_key].dtype}) that does not match the dtype of the indices of "
f"the annotated element ({dtype})."
)

@staticmethod
Expand Down
66 changes: 47 additions & 19 deletions src/spatialdata/dataloader/datasets.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import warnings
from collections.abc import Mapping
from functools import partial
from itertools import chain
Expand Down Expand Up @@ -27,6 +28,7 @@
Labels2DModel,
Labels3DModel,
PointsModel,
TableModel,
get_axes_names,
get_model,
get_table_keys,
Expand Down Expand Up @@ -128,6 +130,14 @@ def __init__(
self._validate(sdata, regions_to_images, regions_to_coordinate_systems, return_annotations, table_name)
self._preprocess(tile_scale, tile_dim_in_units, rasterize, table_name)

if rasterize_kwargs is not None and len(rasterize_kwargs) > 0 and rasterize is False:
warnings.warn(
"rasterize_kwargs are passed to the rasterize function, but rasterize is set to False. The arguments "
"will be ignored. If you want to use the rasterize function, please set rasterize to True.",
UserWarning,
stacklevel=2,
)

self._crop_image: Callable[..., Any] = (
partial(
rasterize_fn,
Expand Down Expand Up @@ -277,39 +287,57 @@ def _ensure_single_scale(data: SpatialImage | MultiscaleSpatialImage) -> Spatial
return SpatialImage(next(iter(data["scale0"].ds.values())))
raise ValueError(f"Expected a SpatialImage or MultiscaleSpatialImage, got {type(data)}.")

def _get_return(
self,
return_annot: str | list[str] | None,
@staticmethod
def _return_function(
idx: int,
tile: Any,
dataset_table: AnnData,
dataset_index: pd.DataFrame,
table_name: str | None,
) -> Callable[[int, Any], tuple[Any, Any] | SpatialData]:
"""Get function to return values from the table of the dataset."""
return_annot: str | list[str] | None,
) -> tuple[Any, Any] | SpatialData:
tile = ImageTilesDataset._ensure_single_scale(tile)
if return_annot is not None:
# table is always returned as array shape (1, len(return_annot))
# where return_table can be a single column or a list of columns
return_annot = [return_annot] if isinstance(return_annot, str) else return_annot
# return tuple of (tile, table)
if np.all([i in self.dataset_table.obs for i in return_annot]):
return lambda x, tile: (
self._ensure_single_scale(tile),
self.dataset_table.obs[return_annot].iloc[x].values.reshape(1, -1),
)
if np.all([i in self.dataset_table.var_names for i in return_annot]):
if issparse(self.dataset_table.X):
return lambda x, tile: (self._ensure_single_scale(tile), self.dataset_table[x, return_annot].X.A)
return lambda x, tile: (self._ensure_single_scale(tile), self.dataset_table[x, return_annot].X)
if np.all([i in dataset_table.obs for i in return_annot]):
return tile, dataset_table.obs[return_annot].iloc[idx].values.reshape(1, -1)
if np.all([i in dataset_table.var_names for i in return_annot]):
if issparse(dataset_table.X):
return tile, dataset_table[idx, return_annot].X.A
return tile, dataset_table[idx, return_annot].X
raise ValueError(
f"If `return_annot` is a `str`, it must be a column name in the table or a variable name in the table. "
f"If it is a `list` of `str`, each element should be as above, and they should all be entirely in obs "
f"or entirely in var. Got {return_annot}."
)
# return spatialdata consisting of the image tile and, if available, the associated table
if table_name:
return lambda x, tile: SpatialData(
images={self.dataset_index.iloc[x][self.IMAGE_KEY]: self._ensure_single_scale(tile)},
table=self.dataset_table[x],
# let's reset the target annotation metadata to avoid a warning when constructing the SpatialData object
table_row = dataset_table[idx].copy()
del table_row.uns[TableModel.ATTRS_KEY]
# TODO: add the shape used for constructing the tile; in the case of the label consider adding the circles
# or a crop of the label
return SpatialData(
images={dataset_index.iloc[idx][ImageTilesDataset.IMAGE_KEY]: tile},
table=table_row,
)
return lambda x, tile: SpatialData(
images={self.dataset_index.iloc[x][self.IMAGE_KEY]: self._ensure_single_scale(tile)}
return SpatialData(images={dataset_index.iloc[idx][ImageTilesDataset.IMAGE_KEY]: tile})

def _get_return(
self,
return_annot: str | list[str] | None,
table_name: str | None,
) -> Callable[[int, Any], tuple[Any, Any] | SpatialData]:
"""Get function to return values from the table of the dataset."""
return partial(
ImageTilesDataset._return_function,
dataset_table=self.dataset_table if table_name else None,
dataset_index=self.dataset_index,
table_name=table_name,
return_annot=return_annot,
)

def __len__(self) -> int:
Expand Down
2 changes: 2 additions & 0 deletions src/spatialdata/transformations/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ def _(e: MultiscaleSpatialImage, transformations: MappingToCoordinateSystem_t) -
old_shape: Optional[ArrayLike] = None
for i, (scale, node) in enumerate(dict(e).items()):
# this is to be sure that the pyramid levels are listed here in the correct order
if scale != f"scale{i}":
pass
assert scale == f"scale{i}"
assert len(dict(node)) == 1
xdata = list(node.values())[0]
Expand Down
10 changes: 0 additions & 10 deletions tests/core/operations/test_spatialdata_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,11 +420,6 @@ def test_validate_table_in_spatialdata(full_sdata):

full_sdata.validate_table_in_spatialdata(table)

# dtype mismatch
full_sdata.labels["labels2d"] = Labels2DModel.parse(full_sdata.labels["labels2d"].astype("int16"))
with pytest.warns(UserWarning, match="that does not match the dtype of the indices of the annotated element"):
full_sdata.validate_table_in_spatialdata(table)

# region not found
del full_sdata.labels["labels2d"]
with pytest.warns(UserWarning, match="in the SpatialData object"):
Expand All @@ -435,11 +430,6 @@ def test_validate_table_in_spatialdata(full_sdata):

full_sdata.validate_table_in_spatialdata(table)

# dtype mismatch
full_sdata.points["points_0"].index = full_sdata.points["points_0"].index.astype("int16")
with pytest.warns(UserWarning, match="that does not match the dtype of the indices of the annotated element"):
full_sdata.validate_table_in_spatialdata(table)

# region not found
del full_sdata.points["points_0"]
with pytest.warns(UserWarning, match="in the SpatialData object"):
Expand Down
54 changes: 41 additions & 13 deletions tests/core/test_centroids.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,41 @@
import math

import numpy as np
import pandas as pd
import pytest
from anndata import AnnData
from numpy.random import default_rng
from spatialdata._core.centroids import get_centroids
from spatialdata.models import Labels2DModel, Labels3DModel, TableModel, get_axes_names
from spatialdata.transformations import Identity, get_transformation, set_transformation

from tests.core.operations.test_transform import _get_affine
from spatialdata._core.query.relational_query import _get_unique_label_values_as_index
from spatialdata.models import Labels2DModel, Labels3DModel, PointsModel, TableModel, get_axes_names
from spatialdata.transformations import Affine, Identity, get_transformation, set_transformation

RNG = default_rng(42)


def _get_affine() -> Affine:
theta: float = math.pi / 18
k = 10.0
return Affine(
[
[2 * math.cos(theta), 2 * math.sin(-theta), -1000 / k],
[2 * math.sin(theta), 2 * math.cos(theta), 300 / k],
[0, 0, 1],
],
input_axes=("x", "y"),
output_axes=("x", "y"),
)


affine = _get_affine()


@pytest.mark.parametrize("coordinate_system", ["global", "aligned"])
@pytest.mark.parametrize("is_3d", [False, True])
def test_get_centroids_points(points, coordinate_system: str, is_3d: bool):
element = points["points_0"]
element = points["points_0"].compute()
element.index = np.arange(len(element)) + 10
element = PointsModel.parse(element)

# by default, the coordinate system is global and the points are 2D; let's modify the points as instructed by the
# test arguments
Expand All @@ -32,6 +50,9 @@ def test_get_centroids_points(points, coordinate_system: str, is_3d: bool):
# the axes of the centroids should be the same as the axes of the element
assert centroids.columns.tolist() == list(axes)

# check the index is preserved
assert np.array_equal(centroids.index.values, element.index.values)

# the centroids should not contain extra columns
assert "genes" in element.columns and "genes" not in centroids.columns

Expand All @@ -54,10 +75,14 @@ def test_get_centroids_points(points, coordinate_system: str, is_3d: bool):
@pytest.mark.parametrize("shapes_name", ["circles", "poly", "multipoly"])
def test_get_centroids_shapes(shapes, coordinate_system: str, shapes_name: str):
element = shapes[shapes_name]
element.index = np.arange(len(element)) + 10

if coordinate_system == "aligned":
set_transformation(element, transformation=affine, to_coordinate_system=coordinate_system)
centroids = get_centroids(element, coordinate_system=coordinate_system)

assert np.array_equal(centroids.index.values, element.index.values)

if shapes_name == "circles":
xy = element.geometry.get_coordinates().values
else:
Expand All @@ -82,12 +107,12 @@ def test_get_centroids_labels(labels, coordinate_system: str, is_multiscale: boo
array = np.array(
[
[
[0, 0, 1, 1],
[0, 0, 1, 1],
[0, 0, 10, 10],
[0, 0, 10, 10],
],
[
[2, 2, 1, 1],
[2, 2, 1, 1],
[20, 20, 10, 10],
[20, 20, 10, 10],
],
]
)
Expand All @@ -102,10 +127,10 @@ def test_get_centroids_labels(labels, coordinate_system: str, is_multiscale: boo
else:
array = np.array(
[
[1, 1, 1, 1],
[2, 2, 2, 2],
[2, 2, 2, 2],
[2, 2, 2, 2],
[10, 10, 10, 10],
[20, 20, 20, 20],
[20, 20, 20, 20],
[20, 20, 20, 20],
]
)
model = Labels2DModel
Expand All @@ -122,6 +147,9 @@ def test_get_centroids_labels(labels, coordinate_system: str, is_multiscale: boo
set_transformation(element, transformation=affine, to_coordinate_system=coordinate_system)
centroids = get_centroids(element, coordinate_system=coordinate_system)

labels_indices = _get_unique_label_values_as_index(element)
assert np.array_equal(centroids.index.values, labels_indices)

if coordinate_system == "global":
assert np.array_equal(centroids.compute().values, expected_centroids.values)
else:
Expand Down

0 comments on commit aa78ad6

Please sign in to comment.