diff --git a/src/spatialdata/dataloader/datasets.py b/src/spatialdata/dataloader/datasets.py index 64db66df..a202aa07 100644 --- a/src/spatialdata/dataloader/datasets.py +++ b/src/spatialdata/dataloader/datasets.py @@ -14,19 +14,22 @@ from scipy.sparse import issparse from torch.utils.data import Dataset +from spatialdata._core.centroids import get_centroids +from spatialdata._core.operations.transform import transform +from spatialdata._core.operations.vectorize import to_circles +from spatialdata._core.query.relational_query import _get_unique_label_values_as_index, join_sdata_spatialelement_table from spatialdata._core.spatialdata import SpatialData -from spatialdata._utils import _affine_matrix_multiplication from spatialdata.models import ( Image2DModel, Image3DModel, + Labels2DModel, + Labels3DModel, PointsModel, - ShapesModel, - TableModel, get_axes_names, get_model, + get_table_keys, ) -from spatialdata.transformations import get_transformation -from spatialdata.transformations.transformations import BaseTransformation +from spatialdata.transformations import BaseTransformation, get_transformation, set_transformation __all__ = ["ImageTilesDataset"] @@ -52,11 +55,15 @@ class ImageTilesDataset(Dataset): A mapping between regions and coordinate systems. The coordinate systems are used to transform both the centroid coordinates of the regions and the images. tile_scale - It is a 1D scaling factor applied to the regions. - For example: + This parameter is used to determine the size (width and height) of the tiles. + Each tile will have size in units equal to tile_scale times the diameter of the circle that approximates (=same + area) the region that defines the tile. - - if `shapes` are circles, the radius is scaled by `tile_scale`. - - if `shapes` are polygons/multipolygon, the perimeter of the polygon is scaled by `tile_scale`. + For example, suppose the regions to be multiscale labels; this is how the tiles are created: + + 1) for each tile, each label region is approximated with a circle with the same area of the label region. + 2) The tile is then created as having the width/height equal to the diameter of the circle, + multiplied by `tile_scale`. If `tile_dim_in_units` is passed, `tile_scale` is ignored. tile_dim_in_units @@ -116,8 +123,8 @@ def __init__( from spatialdata import bounding_box_query from spatialdata._core.operations.rasterize import rasterize as rasterize_fn - self._validate(sdata, regions_to_images, regions_to_coordinate_systems, table_name) - self._preprocess(tile_scale, tile_dim_in_units) + self._validate(sdata, regions_to_images, regions_to_coordinate_systems, return_annotations, table_name) + self._preprocess(tile_scale, tile_dim_in_units, rasterize, table_name) self._crop_image: Callable[..., Any] = ( partial( @@ -127,7 +134,7 @@ def __init__( if rasterize else bounding_box_query # type: ignore[assignment] ) - self._return = self._get_return(return_annotations) + self._return = self._get_return(return_annotations, table_name) self.transform = transform def _validate( @@ -135,10 +142,13 @@ def _validate( sdata: SpatialData, regions_to_images: dict[str, str], regions_to_coordinate_systems: dict[str, str], - table_name: str | None = None, + return_annotations: str | list[str] | None, + table_name: str | None, ) -> None: """Validate input parameters.""" self.sdata = sdata + if return_annotations is not None and table_name is None: + raise ValueError("`table_name` must be provided if `return_annotations` is not `None`.") # check that the regions specified in the two dicts are the same assert set(regions_to_images.keys()) == set( @@ -146,16 +156,16 @@ def _validate( ), "The keys in `regions_to_images` and `regions_to_coordinate_systems` must be the same." self.regions = list(regions_to_coordinate_systems.keys()) # all regions for the dataloader - cs_region_image = [] # list of tuples (coordinate_system, region, image) - for region_key in self.regions: - image_key = regions_to_images[region_key] + cs_region_image: list[tuple[str, str, str]] = [] # list of tuples (coordinate_system, region, image) + for region_name in self.regions: + image_name = regions_to_images[region_name] # get elements - region_elem = sdata[region_key] - image_elem = sdata[image_key] + region_elem = sdata[region_name] + image_elem = sdata[image_name] # check that the elements are supported - if get_model(region_elem) not in [PointsModel]: + if get_model(region_elem) == PointsModel: raise ValueError( "`regions_element` must be a shapes or labels element, points are currently not supported." ) @@ -163,82 +173,148 @@ def _validate( raise ValueError("`images_element` must be an image element.") # check that the coordinate systems are valid for the elements - cs = regions_to_coordinate_systems[region_key] - region_trans = get_transformation(region_elem, cs, get_all=True) - image_trans = get_transformation(image_elem, cs, get_all=True) + cs = regions_to_coordinate_systems[region_name] + region_trans = get_transformation(region_elem, get_all=True) + image_trans = get_transformation(image_elem, get_all=True) + assert isinstance(region_trans, dict) + assert isinstance(image_trans, dict) if cs in region_trans and cs in image_trans: - cs_region_image.append((cs, region_key, image_key)) + cs_region_image.append((cs, region_name, image_name)) else: raise ValueError( - f"The coordinate system `{cs}` is not valid for the region `{region_key}` and image `{image_key}`." + f"The coordinate system `{cs}` is not valid for the region `{region_name}` and image `{image_name}`" + "." ) - # TODOOOOOOOOOOOOOOOOOOOO: join table - - self._region_key = sdata.tables["table"].uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY_KEY] - self._instance_key = sdata.tables["table"].uns[TableModel.ATTRS_KEY][TableModel.INSTANCE_KEY] - if not isinstance(sdata.tables["table"].obs[self._region_key].dtype, CategoricalDtype): - raise TypeError( - f"The `regions_element` column `{self._region_key}` in the table must be a categorical dtype. " - f"Please convert it." - ) - # available_regions = sdata.tables["table"].obs[self._region_key].cat.categories - self.dataset_table = self.sdata.tables["table"][ - self.sdata.tables["table"].obs[self._region_key].isin(self.regions) - ] # filtered table for the data loader - self._cs_region_image = tuple(cs_region_image) # tuple of tuples (coordinate_system, region_key, image_key) + if table_name is not None: + _, region_key, instance_key = get_table_keys(sdata.tables[table_name]) + if get_model(region_elem) in [Labels2DModel, Labels3DModel]: + indices = _get_unique_label_values_as_index(region_elem).tolist() + else: + indices = region_elem.index.tolist() + table = sdata.tables[table_name] + if not isinstance(sdata.tables["table"].obs[region_key].dtype, CategoricalDtype): + raise TypeError( + f"The `regions_element` column `{region_key}` in the table must be a categorical dtype. " + f"Please convert it." + ) + instance_column = table.obs[table.obs[region_key] == region_name][instance_key].tolist() + if not set(indices).issubset(instance_column): + raise RuntimeError( + "Some of the instances in the region element are not annotated by the table. Instances of the " + f"regions element: {indices}. Instances in the table: {instance_column}. You can remove some " + "instances from the region element, add them to the table, or set `table_name` to `None` (if " + "the table annotation can be excluded from the dataset)." + ) + self._cs_region_image = tuple(cs_region_image) # tuple of tuples (coordinate_system, region_name, image_name) def _preprocess( self, - tile_scale: float = 1.0, - tile_dim_in_units: float | None = None, + tile_scale: float, + tile_dim_in_units: float | None, + rasterize: bool, + table_name: str | None, ) -> None: """Preprocess the dataset.""" + if table_name is not None: + _, region_key, instance_key = get_table_keys(self.sdata.tables[table_name]) + filtered_table = self.sdata.tables["table"][ + self.sdata.tables["table"].obs[region_key].isin(self.regions) + ] # filtered table for the data loader + index_df = [] tile_coords_df = [] dims_l = [] - shapes_l = [] - table = self.sdata.tables["table"] - for cs, region, image in self._cs_region_image: - # get dims and transformations for the region element - dims = get_axes_names(self.sdata[region]) - dims_l.append(dims) - t = get_transformation(self.sdata[region], cs) - assert isinstance(t, BaseTransformation) - - # get instances from region - inst = table.obs[table.obs[self._region_key] == region][self._instance_key].values - - # subset the regions by instances - subset_region = self.sdata[region].iloc[inst] - # get coordinates of centroids and extent for tiles - tile_coords = _get_tile_coords(subset_region, t, dims, tile_scale, tile_dim_in_units) + tables_l = [] + for cs, region_name, image_name in self._cs_region_image: + circles = to_circles(self.sdata[region_name]) + dims_l.append(get_axes_names(circles)) + + tile_coords = _get_tile_coords( + circles=circles, + cs=cs, + rasterize=rasterize, + tile_scale=tile_scale, + tile_dim_in_units=tile_dim_in_units, + ) tile_coords_df.append(tile_coords) - # get shapes - shapes_l.append(self.sdata[region]) - - # get index dictionary, with `instance_id`, `cs`, `region`, and `image` + inst = circles.index.values df = pd.DataFrame({self.INSTANCE_KEY: inst}) df[self.CS_KEY] = cs - df[self.REGION_KEY] = region - df[self.IMAGE_KEY] = image + df[self.REGION_KEY] = region_name + df[self.IMAGE_KEY] = image_name index_df.append(df) + if table_name is not None: + table_subset = filtered_table[filtered_table.obs[region_key] == region_name] + circles_sdata = SpatialData.init_from_elements({region_name: circles}, tables=table_subset.copy()) + _, table = join_sdata_spatialelement_table( + circles_sdata, region_name, table_name, how="left", match_rows="left" + ) + # get index dictionary, with `instance_id`, `cs`, `region`, and `image` + tables_l.append(table) + # concatenate and assign to self - self.dataset_index = pd.concat(index_df).reset_index(drop=True) self.tiles_coords = pd.concat(tile_coords_df).reset_index(drop=True) - # get table filtered by regions - self.filtered_table = table.obs[table.obs[self._region_key].isin(self.regions)] - + self.dataset_index = pd.concat(index_df).reset_index(drop=True) assert len(self.tiles_coords) == len(self.dataset_index) + if table_name: + self.dataset_table = AnnData.concatenate(*tables_l) + assert len(self.tiles_coords) == len(self.dataset_table) + dims_ = set(chain(*dims_l)) assert np.all([i in self.tiles_coords for i in dims_]) self.dims = list(dims_) + # get table filtered by regions + # self.filtered_table = table.obs[table.obs[self._region_key].isin(self.regions)] + + # index_df = [] + # tile_coords_df = [] + # dims_l = [] + # shapes_l = [] + # for cs, region, image in self._cs_region_image: + # # get dims and transformations for the region element + # dims = get_axes_names(self.sdata[region]) + # dims_l.append(dims) + # t = get_transformation(self.sdata[region], cs) + # assert isinstance(t, BaseTransformation) + # + # # get instances from region + # inst = table.obs[table.obs[self._region_key] == region][self._instance_key].values + # + # # subset the regions by instances + # subset_region = self.sdata[region].iloc[inst] + # # get coordinates of centroids and extent for tiles + # tile_coords = _get_tile_coords(subset_region, t, dims, tile_scale, tile_dim_in_units) + # tile_coords_df.append(tile_coords) + # + # # get shapes + # shapes_l.append(self.sdata[region]) + # + # # get index dictionary, with `instance_id`, `cs`, `region`, and `image` + # df = pd.DataFrame({self.INSTANCE_KEY: inst}) + # df[self.CS_KEY] = cs + # df[self.REGION_KEY] = region + # df[self.IMAGE_KEY] = image + # index_df.append(df) + # + # # concatenate and assign to self + # self.dataset_index = pd.concat(index_df).reset_index(drop=True) + # self.tiles_coords = pd.concat(tile_coords_df).reset_index(drop=True) + # # get table filtered by regions + # self.filtered_table = table.obs[table.obs[self._region_key].isin(self.regions)] + # + # assert len(self.tiles_coords) == len(self.dataset_index) + # dims_ = set(chain(*dims_l)) + # assert np.all([i in self.tiles_coords for i in dims_]) + # self.dims = list(dims_) + def _get_return( self, return_annot: str | list[str] | None, + table_name: str | None, ) -> Callable[[int, Any], tuple[Any, Any] | SpatialData]: """Get function to return values from the table of the dataset.""" if return_annot is not None: @@ -253,14 +329,17 @@ def _get_return( return lambda x, tile: (tile, self.dataset_table[x, return_annot].X.A) return lambda x, tile: (tile, self.dataset_table[x, return_annot].X) raise ValueError( - f"`return_annot` must be a column name in the table or a variable name in the table. " - f"Got {return_annot}." + f"If `return_annot` is a `str`, it must be a column name in the table or a variable name in the table. " + f"If it is a `list` of `str`, each element should be as above, and they should all be entirely in obs " + f"or entirely in var. Got {return_annot}." ) - # return spatialdata consisting of the image tile and the associated table - return lambda x, tile: SpatialData( - images={self.dataset_index.iloc[x][self.IMAGE_KEY]: tile}, - table=self.dataset_table[x], - ) + # return spatialdata consisting of the image tile and, if available, the associated table + if table_name: + return lambda x, tile: SpatialData( + images={self.dataset_index.iloc[x][self.IMAGE_KEY]: tile}, + table=self.dataset_table[x], + ) + return lambda x, tile: SpatialData(images={self.dataset_index.iloc[x][self.IMAGE_KEY]: tile}) def __len__(self) -> int: return len(self.dataset_index) @@ -390,51 +469,72 @@ def dims(self, dims: list[str]) -> None: def _get_tile_coords( - elem: GeoDataFrame, - transformation: BaseTransformation, - dims: tuple[str, ...], + circles: GeoDataFrame, + cs: str, + rasterize: bool, + # elem: GeoDataFrame, + # transformation: BaseTransformation, + # dims: tuple[str, ...], tile_scale: float | None = None, tile_dim_in_units: float | None = None, ) -> pd.DataFrame: """Get the (transformed) centroid of the region and the extent.""" - # get centroids and transform them - centroids = elem.centroid.get_coordinates().values - aff = transformation.to_affine_matrix(input_axes=dims, output_axes=dims) - centroids = _affine_matrix_multiplication(aff, centroids) - - # get extent, first by checking shape defaults, then by using the `tile_dim_in_units` - if tile_dim_in_units is None: - if elem.iloc[0, 0].geom_type == "Point": - extent = elem[ShapesModel.RADIUS_KEY].values * tile_scale - elif elem.iloc[0, 0].geom_type in ["Polygon", "MultiPolygon"]: - extent = elem[ShapesModel.GEOMETRY_KEY].length * tile_scale - else: - raise ValueError("Only point and polygon shapes are supported.") + transform(circles, to_coordinate_system=cs) if tile_dim_in_units is not None: - if isinstance(tile_dim_in_units, (float, int)): - extent = np.repeat(tile_dim_in_units, len(centroids)) - else: - raise TypeError( - f"`tile_dim_in_units` must be a `float`, `int`, `list`, `tuple` or `np.ndarray`, " - f"not {type(tile_dim_in_units)}." - ) - if len(extent) != len(centroids): - raise ValueError( - f"the number of elements in the region ({len(extent)}) does not match" - f" the number of instances ({len(centroids)})." - ) - - # transform extent - # TODO: review this, what is being dropped by the transformation? - aff = transformation.to_affine_matrix(input_axes=tuple(dims[0]), output_axes=tuple(dims[0])) - extent = _affine_matrix_multiplication(aff, np.array(extent)[:, np.newaxis]) - - # get min and max coordinates - min_coordinates = np.array(centroids) - extent / 2 - max_coordinates = np.array(centroids) + extent / 2 - + circles.radius = tile_dim_in_units / 2 + else: + circles.radius *= tile_scale + # if rasterize is True, the tile dim is determined from the diameter of the circles in cs; else we need to + # transform the circles to the intrinsic coordinate system of the element + if not rasterize: + transformation = get_transformation(circles, to_coordinate_system=cs) + assert isinstance(transformation, BaseTransformation) + back_transformation = transformation.inverse() + set_transformation(circles, back_transformation, to_coordinate_system="intrinsic_of_element") + transform(circles, to_coordinate_system="intrinsic_of_element") + # extent, aka the tile size + extent = (circles.radius * 2).values.reshape(-1, 1) + # get centroids and transform them + # centroids = elem.centroid.get_coordinates().values + # aff = transformation.to_affine_matrix(input_axes=dims, output_axes=dims) + # centroids = _affine_matrix_multiplication(aff, centroids) + # + # # get extent, first by checking shape defaults, then by using the `tile_dim_in_units` + # if tile_dim_in_units is None: + # if elem.iloc[0, 0].geom_type == "Point": + # extent = elem[ShapesModel.RADIUS_KEY].values * tile_scale + # elif elem.iloc[0, 0].geom_type in ["Polygon", "MultiPolygon"]: + # extent = elem[ShapesModel.GEOMETRY_KEY].length * tile_scale + # else: + # raise ValueError("Only point and polygon shapes are supported.") + # if tile_dim_in_units is not None: + # if isinstance(tile_dim_in_units, (float, int)): + # extent = np.repeat(tile_dim_in_units, len(centroids)) + # else: + # raise TypeError( + # f"`tile_dim_in_units` must be a `float`, `int`, `list`, `tuple` or `np.ndarray`, " + # f"not {type(tile_dim_in_units)}." + # ) + # if len(extent) != len(centroids): + # raise ValueError( + # f"the number of elements in the region ({len(extent)}) does not match" + # f" the number of instances ({len(centroids)})." + # ) + # + # # transform extent + # # TODO: review this, what is being dropped by the transformation? + # aff = transformation.to_affine_matrix(input_axes=tuple(dims[0]), output_axes=tuple(dims[0])) + centroids_points = get_centroids(circles) + axes = get_axes_names(centroids_points) + centroids_numpy = centroids_points.compute().values + # extent = _affine_matrix_multiplication(aff, np.array(extent)[:, np.newaxis]) + # + # # get min and max coordinates + min_coordinates = np.array(centroids_numpy) - extent / 2 + max_coordinates = np.array(centroids_numpy) + extent / 2 + # # return a dataframe with columns e.g. ["x", "y", "extent", "minx", "miny", "maxx", "maxy"] return pd.DataFrame( - np.hstack([centroids, extent, min_coordinates, max_coordinates]), - columns=list(dims) + ["extent"] + ["min" + dim for dim in dims] + ["max" + dim for dim in dims], + np.hstack([centroids_numpy, extent, min_coordinates, max_coordinates]), + columns=list(axes) + ["extent"] + ["min" + ax for ax in axes] + ["max" + ax for ax in axes], ) diff --git a/src/spatialdata/datasets.py b/src/spatialdata/datasets.py index 546a0927..72e7ac0d 100644 --- a/src/spatialdata/datasets.py +++ b/src/spatialdata/datasets.py @@ -1,10 +1,11 @@ """SpatialData datasets.""" -from typing import Any, Optional, Union +from typing import Any, Literal, Optional, Union import numpy as np import pandas as pd import scipy +from anndata import AnnData from dask.dataframe.core import DataFrame as DaskDataFrame from geopandas import GeoDataFrame from multiscale_spatial_image import MultiscaleSpatialImage @@ -15,6 +16,7 @@ from spatial_image import SpatialImage from spatialdata._core.operations.aggregate import aggregate +from spatialdata._core.query.relational_query import _get_unique_label_values_as_index from spatialdata._core.spatialdata import SpatialData from spatialdata._logging import logger from spatialdata._types import ArrayLike @@ -342,3 +344,22 @@ def _generate_random_points(self, n: int, bbox: tuple[int, int]) -> list[Point]: point = Point(x, y) points.append(point) return points + + +BlobsTypes = Literal[ + "blobs_labels", "blobs_multiscale_labels", "blobs_circles", "blobs_polygons", "blobs_multipolygons" +] + + +def blobs_annotating_element(name: BlobsTypes) -> SpatialData: + sdata = blobs(length=50) + if name in ["blobs_labels", "blobs_multiscale_labels"]: + instance_id = _get_unique_label_values_as_index(sdata[name]).tolist() + else: + instance_id = sdata[name].index.tolist() + n = len(instance_id) + new_table = AnnData(shape=(n, 0), obs={"region": [name for _ in range(n)], "instance_id": instance_id}) + new_table = TableModel.parse(new_table, region=name, region_key="region", instance_key="instance_id") + del sdata.tables["table"] + sdata["table"] = new_table + return sdata diff --git a/tests/dataloader/test_datasets.py b/tests/dataloader/test_datasets.py index 9c8e7c60..4829842a 100644 --- a/tests/dataloader/test_datasets.py +++ b/tests/dataloader/test_datasets.py @@ -1,12 +1,7 @@ -import contextlib - import numpy as np -import pandas as pd import pytest -from anndata import AnnData -from spatialdata._core.spatialdata import SpatialData from spatialdata.dataloader import ImageTilesDataset -from spatialdata.models import TableModel +from spatialdata.datasets import blobs_annotating_element class TestImageTilesDataset: @@ -15,101 +10,109 @@ class TestImageTilesDataset: "regions_element", ["blobs_labels", "blobs_multiscale_labels", "blobs_circles", "blobs_polygons", "blobs_multipolygons"], ) - def test_validation(self, sdata_blobs, image_element, regions_element): - if regions_element in ["blobs_labels", "blobs_multiscale_labels"] or image_element == "blobs_multiscale_image": - cm = pytest.raises(NotImplementedError) - elif regions_element in ["blobs_circles", "blobs_polygons", "blobs_multipolygons"]: - cm = pytest.raises(ValueError) + @pytest.mark.parametrize("table", [True, False]) + def test_validation(self, sdata_blobs, image_element: str, regions_element: str, table: bool): + if table: + sdata = blobs_annotating_element(regions_element) else: - cm = contextlib.nullcontext() - with cm: - _ = ImageTilesDataset( - sdata=sdata_blobs, - regions_to_images={regions_element: image_element}, - regions_to_coordinate_systems={regions_element: "global"}, - ) + sdata = sdata_blobs + del sdata_blobs.tables["table"] + _ = ImageTilesDataset( + sdata=sdata, + regions_to_images={regions_element: image_element}, + regions_to_coordinate_systems={regions_element: "global"}, + table_name="table" if table else None, + return_annotations="instance_id" if table else None, + ) - @pytest.mark.parametrize("regions_element", ["blobs_circles", "blobs_polygons", "blobs_multipolygons"]) - @pytest.mark.parametrize("raster", [True, False]) - def test_default(self, sdata_blobs, regions_element, raster): - raster_kwargs = {"target_unit_to_pixels": 2} if raster else {} + @pytest.mark.parametrize( + "regions_element", + ["blobs_circles", "blobs_polygons", "blobs_multipolygons", "blobs_labels", "blobs_multiscale_labels"], + ) + @pytest.mark.parametrize("rasterize", [True, False]) + def test_default(self, sdata_blobs, regions_element, rasterize): + rasterize_kwargs = {"target_unit_to_pixels": 2} if rasterize else {} - sdata = self._annotate_shapes(sdata_blobs, regions_element) + sdata = blobs_annotating_element(regions_element) ds = ImageTilesDataset( sdata=sdata, - rasterize=raster, + rasterize=rasterize, regions_to_images={regions_element: "blobs_image"}, regions_to_coordinate_systems={regions_element: "global"}, - rasterize_kwargs=raster_kwargs, + rasterize_kwargs=rasterize_kwargs, + table_name="table", ) sdata_tile = ds[0] tile = sdata_tile.images.values().__iter__().__next__() if regions_element == "blobs_circles": - if raster: - assert tile.shape == (3, 50, 50) + if rasterize: + assert tile.shape == (3, 20, 20) else: - assert tile.shape == (3, 25, 25) + assert tile.shape == (3, 10, 10) elif regions_element == "blobs_polygons": - if raster: - assert tile.shape == (3, 164, 164) + if rasterize: + assert tile.shape == (3, 6, 6) else: - assert tile.shape == (3, 82, 82) + assert tile.shape == (3, 3, 3) elif regions_element == "blobs_multipolygons": - if raster: - assert tile.shape == (3, 329, 329) + if rasterize: + assert tile.shape == (3, 9, 9) else: - assert tile.shape == (3, 165, 164) + assert tile.shape == (3, 5, 4) + elif regions_element == "blobs_labels" or regions_element == "blobs_multiscale_labels": + if rasterize: + assert tile.shape == (3, 16, 16) + else: + assert tile.shape == (3, 8, 8) else: raise ValueError(f"Unexpected regions_element: {regions_element}") # extent has units in pixel so should be the same as tile shape - if raster: + if rasterize: assert round(ds.tiles_coords.extent.unique()[0] * 2) == tile.shape[1] else: - if regions_element != "blobs_multipolygons": - assert int(ds.tiles_coords.extent.unique()[0]) == tile.shape[1] - else: - assert int(ds.tiles_coords.extent.unique()[0]) + 1 == tile.shape[1] - assert np.all(sdata_tile.table.obs.columns == ds.sdata.table.obs.columns) + # here we have a tolerance of 1 pixel because the size of the tile depends on the values of the centroids + # and of the extenta and here we keep the test simple. + # For example, if the centroid is 0.5 and the extent is 0.1, the tile will be 1 pixel since the extent will + # span 0.4 to 0.6; but if the centroid is 0.95 now the tile will be 2 pixels + assert np.ceil(ds.tiles_coords.extent.unique()[0]) in [tile.shape[1], tile.shape[1] + 1] + assert np.all(sdata_tile["table"].obs.columns == ds.sdata["table"].obs.columns) assert list(sdata_tile.images.keys())[0] == "blobs_image" - @pytest.mark.parametrize("regions_element", ["blobs_circles", "blobs_polygons", "blobs_multipolygons"]) - @pytest.mark.parametrize("return_annot", ["region", ["region", "instance_id"]]) + @pytest.mark.parametrize( + "regions_element", + ["blobs_circles", "blobs_polygons", "blobs_multipolygons", "blobs_labels", "blobs_multiscale_labels"], + ) + @pytest.mark.parametrize("return_annot", [None, "region", ["region", "instance_id"]]) def test_return_annot(self, sdata_blobs, regions_element, return_annot): - sdata = self._annotate_shapes(sdata_blobs, regions_element) + sdata = blobs_annotating_element(regions_element) ds = ImageTilesDataset( sdata=sdata, regions_to_images={regions_element: "blobs_image"}, regions_to_coordinate_systems={regions_element: "global"}, return_annotations=return_annot, + table_name="table", ) - - tile, annot = ds[0] + if return_annot is None: + sdata_tile = ds[0] + tile = sdata_tile["blobs_image"] + else: + tile, annot = ds[0] if regions_element == "blobs_circles": - assert tile.shape == (3, 25, 25) + assert tile.shape == (3, 10, 10) elif regions_element == "blobs_polygons": - assert tile.shape == (3, 82, 82) + assert tile.shape == (3, 3, 3) elif regions_element == "blobs_multipolygons": - assert tile.shape == (3, 165, 164) + assert tile.shape == (3, 5, 4) + elif regions_element == "blobs_labels" or regions_element == "blobs_multiscale_labels": + assert tile.shape == (3, 8, 8) else: raise ValueError(f"Unexpected regions_element: {regions_element}") # extent has units in pixel so should be the same as tile shape - if regions_element != "blobs_multipolygons": - assert int(ds.tiles_coords.extent.unique()[0]) == tile.shape[1] - else: - assert round(ds.tiles_coords.extent.unique()[0]) + 1 == tile.shape[1] - return_annot = [return_annot] if isinstance(return_annot, str) else return_annot - assert annot.shape[1] == len(return_annot) - - # TODO: consider adding this logic to blobs, to generate blobs with arbitrary table annotation - def _annotate_shapes(self, sdata: SpatialData, shape: str) -> SpatialData: - new_table = AnnData( - X=np.random.default_rng(0).random((len(sdata[shape]), 10)), - obs=pd.DataFrame({"region": shape, "instance_id": sdata[shape].index.values}), - ) - new_table = TableModel.parse(new_table, region=shape, region_key="region", instance_key="instance_id") - del sdata.table - sdata.table = new_table - return sdata + # see comment in the test above explaining why we have a tolerance of 1 pixel + assert np.ceil(ds.tiles_coords.extent.unique()[0]) in [tile.shape[1], tile.shape[1] + 1] + if return_annot is not None: + return_annot = [return_annot] if isinstance(return_annot, str) else return_annot + assert annot.shape[1] == len(return_annot)