From a3a1a52e7d7d22d26ee94aed2e5b025356c79bd7 Mon Sep 17 00:00:00 2001 From: Kevin Yamauchi Date: Mon, 20 Feb 2023 10:14:05 +0100 Subject: [PATCH 01/16] add sdata-> data dict transform --- spatialdata/_dl/transforms.py | 36 +++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) create mode 100644 spatialdata/_dl/transforms.py diff --git a/spatialdata/_dl/transforms.py b/spatialdata/_dl/transforms.py new file mode 100644 index 00000000..ed334597 --- /dev/null +++ b/spatialdata/_dl/transforms.py @@ -0,0 +1,36 @@ +from typing import Any + +from spatialdata import SpatialData + + +class SpatialDataToDataDict: + def __init__(self, data_mapping: dict[str, str]): + self.data_mapping = data_mapping + + @staticmethod + def _parse_spatial_data_path(sdata_path: str) -> tuple[str, str]: + """Convert a path in a SpatialData object to + the element type and element name.""" + path_components = sdata_path.split("/") + assert len(path_components) == 2 + + element_type = path_components[0] + element_name = path_components[1] + + return element_type, element_name + + def __call__(self, sdata: SpatialData) -> dict[str, Any]: + data_dict = {} + for sdata_path, data_dict_key in self.data_mapping.items(): + # get data item from the SpatialData object and add it to the data dictig + element_type, element_name = self._parse_spatial_data_path(sdata_path) + + if element_type == "table": + element = getattr(sdata, element_type) + else: + element_dict = getattr(sdata, element_type) + element = element_dict[element_name] + + data_dict[data_dict_key] = element + + return data_dict From b5fa7c639067948a7a2ab17b9beaccbccf82c434 Mon Sep 17 00:00:00 2001 From: Kevin Yamauchi Date: Mon, 20 Feb 2023 10:16:30 +0100 Subject: [PATCH 02/16] add initial dataset --- spatialdata/_dl/__init__.py | 0 spatialdata/_dl/datasets.py | 56 +++++++++++++++++++++++++++++++++++++ 2 files changed, 56 insertions(+) create mode 100644 spatialdata/_dl/__init__.py create mode 100644 spatialdata/_dl/datasets.py diff --git a/spatialdata/_dl/__init__.py b/spatialdata/_dl/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/spatialdata/_dl/datasets.py b/spatialdata/_dl/datasets.py new file mode 100644 index 00000000..9233c77e --- /dev/null +++ b/spatialdata/_dl/datasets.py @@ -0,0 +1,56 @@ +from typing import Callable, Optional + +from torch.utils.data import Dataset + +from spatialdata import SpatialData +from spatialdata._core._spatial_query import BoundingBoxRequest +from spatialdata._types import ArrayLike + + +class SpotCropDataset(Dataset): + def __init__( + self, + sdata: SpatialData, + coordinates: ArrayLike, + transform: Optional[Callable[[SpatialData], SpatialData]] = None, + ): + self.sdata = sdata + self.transform = transform + + self.coordinates = coordinates + self.radius = 5 + + def _get_bounding_box_coordinates(self, spot_index: int) -> tuple[ArrayLike, ArrayLike]: + """Get the coordinates for the corners of the bounding box of that encompasses a given spot. + + Parameters + ---------- + spot_index + The row index of the spot. + + Returns + ------- + min_coordinate + The minimum coordinate of the bounding box. + max_coordinate + The maximum coordinate of the bounding box. + """ + centroid = self.coordinates[spot_index] + min_coordinate = centroid - self.radius + max_coordinate = centroid + self.radius + + return min_coordinate, max_coordinate + + def __len___(self) -> int: + return len(self.coordiantes) + + def _getitem__(self, idx: int) -> SpatialData: + min_coordinate, max_coordinate = self._get_bounding_box_coordinates(spot_index=idx) + + request = BoundingBoxRequest(min_coordinate=min_coordinate, max_coordinate=max_coordinate, axes=("y", "x")) + sdata_item = self.sdata.query.bounding_box(request=request) + + if self.transform is not None: + sdata_item = self.transform(sdata_item) + + return sdata_item From 479d7e25fe3f27b51ce884d565fa659fba811278 Mon Sep 17 00:00:00 2001 From: Kevin Yamauchi Date: Mon, 20 Feb 2023 11:00:52 +0100 Subject: [PATCH 03/16] fix typos --- spatialdata/_dl/datasets.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/spatialdata/_dl/datasets.py b/spatialdata/_dl/datasets.py index 9233c77e..e05ebc89 100644 --- a/spatialdata/_dl/datasets.py +++ b/spatialdata/_dl/datasets.py @@ -12,13 +12,14 @@ def __init__( self, sdata: SpatialData, coordinates: ArrayLike, + radius: int = 5, transform: Optional[Callable[[SpatialData], SpatialData]] = None, ): self.sdata = sdata self.transform = transform self.coordinates = coordinates - self.radius = 5 + self.radius = radius def _get_bounding_box_coordinates(self, spot_index: int) -> tuple[ArrayLike, ArrayLike]: """Get the coordinates for the corners of the bounding box of that encompasses a given spot. @@ -41,10 +42,10 @@ def _get_bounding_box_coordinates(self, spot_index: int) -> tuple[ArrayLike, Arr return min_coordinate, max_coordinate - def __len___(self) -> int: - return len(self.coordiantes) + def __len__(self) -> int: + return len(self.coordinates) - def _getitem__(self, idx: int) -> SpatialData: + def __getitem__(self, idx: int) -> SpatialData: min_coordinate, max_coordinate = self._get_bounding_box_coordinates(spot_index=idx) request = BoundingBoxRequest(min_coordinate=min_coordinate, max_coordinate=max_coordinate, axes=("y", "x")) From 42706ffa974e13d208037c6c2b6fd96e6fcd9f8d Mon Sep 17 00:00:00 2001 From: Kevin Yamauchi Date: Mon, 20 Feb 2023 16:08:12 +0100 Subject: [PATCH 04/16] add shapes to dataset --- spatialdata/_dl/datasets.py | 40 +++++++++++++++++++++---------------- 1 file changed, 23 insertions(+), 17 deletions(-) diff --git a/spatialdata/_dl/datasets.py b/spatialdata/_dl/datasets.py index e05ebc89..6efe66d9 100644 --- a/spatialdata/_dl/datasets.py +++ b/spatialdata/_dl/datasets.py @@ -1,9 +1,11 @@ from typing import Callable, Optional +import numpy as np from torch.utils.data import Dataset from spatialdata import SpatialData from spatialdata._core._spatial_query import BoundingBoxRequest +from spatialdata._core.core_utils import get_dims from spatialdata._types import ArrayLike @@ -11,24 +13,19 @@ class SpotCropDataset(Dataset): def __init__( self, sdata: SpatialData, - coordinates: ArrayLike, - radius: int = 5, + spots_element_key: str, transform: Optional[Callable[[SpatialData], SpatialData]] = None, ): self.sdata = sdata + self.spots_element_key = spots_element_key self.transform = transform - self.coordinates = coordinates - self.radius = radius + self.min_coordinates, self.max_coordinates, self.spots_dims = self._get_bounding_box_coordinates() + self.n_spots = len(self.min_coordinates) - def _get_bounding_box_coordinates(self, spot_index: int) -> tuple[ArrayLike, ArrayLike]: + def _get_bounding_box_coordinates(self) -> tuple[ArrayLike, ArrayLike, tuple[str, ...]]: """Get the coordinates for the corners of the bounding box of that encompasses a given spot. - Parameters - ---------- - spot_index - The row index of the spot. - Returns ------- min_coordinate @@ -36,19 +33,28 @@ def _get_bounding_box_coordinates(self, spot_index: int) -> tuple[ArrayLike, Arr max_coordinate The maximum coordinate of the bounding box. """ - centroid = self.coordinates[spot_index] - min_coordinate = centroid - self.radius - max_coordinate = centroid + self.radius + spots_element = self.sdata.shapes[self.spots_element_key] + spots_dims = get_dims(spots_element) + + centroids = [] + for dim_name in spots_dims: + centroids.append(getattr(spots_element["geometry"], dim_name).to_numpy()) + centroids_array = np.column_stack(centroids) + radius = np.expand_dims(spots_element["radius"].to_numpy(), axis=1) + + min_coordinates = (centroids_array - radius).astype(int) + max_coordinates = (centroids_array + radius).astype(int) - return min_coordinate, max_coordinate + return min_coordinates, max_coordinates, spots_dims def __len__(self) -> int: - return len(self.coordinates) + return self.n_spots def __getitem__(self, idx: int) -> SpatialData: - min_coordinate, max_coordinate = self._get_bounding_box_coordinates(spot_index=idx) + min_coordinate = self.min_coordinates[idx] + max_coordinate = self.max_coordinates[idx] - request = BoundingBoxRequest(min_coordinate=min_coordinate, max_coordinate=max_coordinate, axes=("y", "x")) + request = BoundingBoxRequest(min_coordinate=min_coordinate, max_coordinate=max_coordinate, axes=self.spots_dims) sdata_item = self.sdata.query.bounding_box(request=request) if self.transform is not None: From e0bb5d874c0726d31dbd47b6911bb3733ca7bb62 Mon Sep 17 00:00:00 2001 From: Kevin Yamauchi Date: Wed, 22 Feb 2023 13:58:08 +0100 Subject: [PATCH 05/16] start multislide --- spatialdata/_dl/datasets.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/spatialdata/_dl/datasets.py b/spatialdata/_dl/datasets.py index 6efe66d9..96d3cb53 100644 --- a/spatialdata/_dl/datasets.py +++ b/spatialdata/_dl/datasets.py @@ -13,16 +13,20 @@ class SpotCropDataset(Dataset): def __init__( self, sdata: SpatialData, - spots_element_key: str, + spots_element_keys: list[str], transform: Optional[Callable[[SpatialData], SpatialData]] = None, ): self.sdata = sdata - self.spots_element_key = spots_element_key + self.spots_element_keys = spots_element_keys self.transform = transform self.min_coordinates, self.max_coordinates, self.spots_dims = self._get_bounding_box_coordinates() self.n_spots = len(self.min_coordinates) + def _get_centroids_and_metadata(self) -> None: + for key in self.spots_element_keys: + print(key) + def _get_bounding_box_coordinates(self) -> tuple[ArrayLike, ArrayLike, tuple[str, ...]]: """Get the coordinates for the corners of the bounding box of that encompasses a given spot. @@ -33,7 +37,7 @@ def _get_bounding_box_coordinates(self) -> tuple[ArrayLike, ArrayLike, tuple[str max_coordinate The maximum coordinate of the bounding box. """ - spots_element = self.sdata.shapes[self.spots_element_key] + spots_element = self.sdata.shapes[self.spots_element_keys[0]] spots_dims = get_dims(spots_element) centroids = [] From 9337549f1c5d26f8833cc1170435b8dc371836dd Mon Sep 17 00:00:00 2001 From: Luca Marconato <2664412+LucaMarconato@users.noreply.github.com> Date: Fri, 3 Mar 2023 16:31:00 +0100 Subject: [PATCH 06/16] wip, need to merge with rasterize branch --- examples/dev-examples/image_tiles.py | 107 +++++++++++++++++ spatialdata/_core/_spatialdata.py | 41 +++++++ spatialdata/_core/data_extent.py | 32 ++++++ spatialdata/_dl/datasets.py | 128 ++++++++++++++------- tests/_core/test_spatialdata_operations.py | 44 +++++++ 5 files changed, 311 insertions(+), 41 deletions(-) create mode 100644 examples/dev-examples/image_tiles.py create mode 100644 spatialdata/_core/data_extent.py diff --git a/examples/dev-examples/image_tiles.py b/examples/dev-examples/image_tiles.py new file mode 100644 index 00000000..d0f7c2ca --- /dev/null +++ b/examples/dev-examples/image_tiles.py @@ -0,0 +1,107 @@ +## +# from https://gist.github.com/kevinyamauchi/77f986889b7626db4ab3c1075a3a3e5e + +import numpy as np +from matplotlib import pyplot as plt +from skimage import draw + +import spatialdata as sd +from spatialdata._dl.datasets import ImageTilesDataset +from spatialdata._dl.transforms import SpatialDataToDataDict + +## +coordinates = np.array([[10, 10], [20, 20], [50, 30], [90, 70], [20, 80]]) + +radius = 5 + +colors = np.array( + [ + [102, 194, 165], + [252, 141, 98], + [141, 160, 203], + [231, 138, 195], + [166, 216, 84], + ] +) + +## +# make an image with spots +image = np.zeros((100, 100, 3), dtype=np.uint8) + +for spot_color, centroid in zip(colors, coordinates): + rr, cc = draw.disk(centroid, radius=radius) + + for color_index in range(3): + channel_dims = color_index * np.ones((len(rr),), dtype=int) + image[rr, cc, channel_dims] = spot_color[color_index] + +# plt.imshow(image) +# plt.show() + +## +sd_image = sd.Image2DModel.parse(image, dims=("y", "x", "c")) + +# circles coordinates are xy, so we flip them here. +circles = sd.ShapesModel.parse(coordinates[:, [1, 0]], radius=radius, geometry=0) +sdata = sd.SpatialData(images={"image": sd_image}, shapes={"spots": circles}) +sdata + +## +ds = ImageTilesDataset( + sdata=sdata, + regions_to_images={"/shapes/spots": "/images/image"}, + tile_dim_in_units=5, + tile_dim_in_pixels=32, + target_coordinate_system="global", + data_dict_transform=None, +) + +print(f"this dataset as {len(ds)} items") + +## +# we can use the __getitem__ interface to get one of the sample crops +print(ds[0]) + + +## +# now we plot all of the crops +def plot_sdata_dataset(ds): + n_samples = len(ds) + fig, axs = plt.subplots(1, n_samples) + + for index, sample in enumerate(ds): + if isinstance(sample, sd.SpatialData): + im = sample.images["image"] + else: + im = sample["image"] + axs[index].imshow(im.transpose("y", "x", "c")) + + +plot_sdata_dataset(ds) + +## +# we can also use transforms to automatically extract the relevant data +# into a datadictionary + +# map the SpatialData path to a data dict key +data_mapping = {"images/image": "image"} + +# make the transform +ds_transform = ImageTilesDataset( + sdata=sdata, + spots_element_key="spots", + transform=SpatialDataToDataDict(data_mapping=data_mapping), +) + +print(f"this dataset as {len(ds_transform)} items") + +## +# now the samples are a dictionary with key "image" and the item is the +# image array +# this is useful because it is the expected format for many of the +# +ds_transform[0] + +## +# plot of each sample in the dataset +plot_sdata_dataset(ds_transform) diff --git a/spatialdata/_core/_spatialdata.py b/spatialdata/_core/_spatialdata.py index 417a222e..85f07af0 100644 --- a/spatialdata/_core/_spatialdata.py +++ b/spatialdata/_core/_spatialdata.py @@ -2,6 +2,7 @@ import hashlib import os +import re from collections.abc import Generator from pathlib import Path from types import MappingProxyType @@ -1104,6 +1105,46 @@ def _gen_elements(self) -> Generator[tuple[str, str, SpatialElement], None, None for k, v in d.items(): yield element_type, k, v + def __getitem__(self, item: str) -> SpatialElement | AnnData: + # this regex match the following: + # /images/ehi + # /images/ehi + # /labels/ehi + # /table + # + # but not: + # /images/ + # images/ehi + # /iimages/ehi + # /images/ehi/ + # /images/ehi/ehi + # table + # /ttable + # /table/ + # /table/ehi + regex = r"^/(\bimages\b(?=/)|\blabels\b(?=/)|\bpoints\b(?=/)|\bshapes\b(?=/)|\btable\b(?=$))(/[a-zA-Z0-9_]+)?$" + match = re.match(regex, item) + if match: + element_type = match.group(1) + if element_type == "table": + element_name = None + else: + element_name = match.group(2)[1:] + else: + raise ValueError(f"{item} does not match any element in the SpatialData object") + if element_type == "table": + return self.table + elif element_type == "images": + return self.images[element_name] + elif element_type == "labels": + return self.labels[element_name] + elif element_type == "points": + return self.points[element_name] + elif element_type == "shapes": + return self.shapes[element_name] + else: + raise ValueError(f"Unknown element type {element_type}") + class QueryManager: """Perform queries on SpatialData objects""" diff --git a/spatialdata/_core/data_extent.py b/spatialdata/_core/data_extent.py new file mode 100644 index 00000000..15f44113 --- /dev/null +++ b/spatialdata/_core/data_extent.py @@ -0,0 +1,32 @@ +"""This file contains functions to compute the bounding box describing the extent of a spatial element, +or of a specific region in the SpatialElement object.""" +import numpy as np +from geopandas import GeoDataFrame + +from spatialdata._core.core_utils import get_dims +from spatialdata._types import ArrayLike + + +def _get_bounding_box_of_circle_elements(self, shapes: GeoDataFrame) -> tuple[ArrayLike, ArrayLike, tuple[str, ...]]: + """Get the coordinates for the corners of the bounding box of that encompasses a given spot. + + Returns + ------- + min_coordinate + The minimum coordinate of the bounding box. + max_coordinate + The maximum coordinate of the bounding box. + """ + spots_element = self.sdata.shapes[self.spots_element_keys[0]] + spots_dims = get_dims(spots_element) + + centroids = [] + for dim_name in spots_dims: + centroids.append(getattr(spots_element["geometry"], dim_name).to_numpy()) + centroids_array = np.column_stack(centroids) + radius = np.expand_dims(spots_element["radius"].to_numpy(), axis=1) + + min_coordinates = (centroids_array - radius).astype(int) + max_coordinates = (centroids_array + radius).astype(int) + + return min_coordinates, max_coordinates, spots_dims diff --git a/spatialdata/_dl/datasets.py b/spatialdata/_dl/datasets.py index 96d3cb53..4b6ab542 100644 --- a/spatialdata/_dl/datasets.py +++ b/spatialdata/_dl/datasets.py @@ -1,67 +1,113 @@ from typing import Callable, Optional import numpy as np +from geopandas import GeoDataFrame +from multiscale_spatial_image import MultiscaleSpatialImage +from spatial_image import SpatialImage from torch.utils.data import Dataset from spatialdata import SpatialData -from spatialdata._core._spatial_query import BoundingBoxRequest from spatialdata._core.core_utils import get_dims -from spatialdata._types import ArrayLike +from spatialdata._core.models import ( + Image2DModel, + Image3DModel, + Labels2DModel, + Labels3DModel, + ShapesModel, + get_schema, +) -class SpotCropDataset(Dataset): +class ImageTilesDataset(Dataset): def __init__( self, sdata: SpatialData, - spots_element_keys: list[str], - transform: Optional[Callable[[SpatialData], SpatialData]] = None, + regions_to_images: dict[str, str], + tile_dim_in_units: float, + tile_dim_in_pixels: int, + target_coordinate_system: str = "global", + data_dict_transform: Optional[Callable[[SpatialData], dict[str, SpatialImage]]] = None, ): self.sdata = sdata - self.spots_element_keys = spots_element_keys - self.transform = transform - - self.min_coordinates, self.max_coordinates, self.spots_dims = self._get_bounding_box_coordinates() - self.n_spots = len(self.min_coordinates) + self.regions_to_images = regions_to_images + self.tile_dim_in_units = tile_dim_in_units + self.tile_dim_in_pixels = tile_dim_in_pixels + self.data_dict_transform = data_dict_transform + self.target_coordinate_system = target_coordinate_system + + self.n_spots_dict = self._compute_n_spots_dict() + self.n_spots = sum(self.n_spots_dict.values()) + + def _validate_regions_to_images(self) -> None: + for region_key, image_key in self.regions_to_images.items(): + regions_element = self.sdata[region_key] + images_element = self.sdata[image_key] + # we could allow also for points + if not get_schema(regions_element) in [ShapesModel, Labels2DModel, Labels3DModel]: + raise ValueError(f"regions_element must be a shapes element or a labels element") + if not get_schema(images_element) in [Image2DModel, Image3DModel]: + raise ValueError(f"images_element must be an image element") + + def _compute_n_spots_dict(self) -> dict[str, int]: + n_spots_dict = {} + for region_key in self.regions_to_images.keys(): + element = self.sdata[region_key] + # we could allow also points + if isinstance(element, GeoDataFrame): + n_spots_dict[region_key] = len(element) + elif isinstance(element, SpatialImage): + raise NotImplementedError("labels not supported yet") + elif isinstance(element, MultiscaleSpatialImage): + raise NotImplementedError("labels not supported yet") + else: + raise ValueError(f"element must be a geodataframe or a spatial image") + return n_spots_dict def _get_centroids_and_metadata(self) -> None: for key in self.spots_element_keys: print(key) - def _get_bounding_box_coordinates(self) -> tuple[ArrayLike, ArrayLike, tuple[str, ...]]: - """Get the coordinates for the corners of the bounding box of that encompasses a given spot. - - Returns - ------- - min_coordinate - The minimum coordinate of the bounding box. - max_coordinate - The maximum coordinate of the bounding box. - """ - spots_element = self.sdata.shapes[self.spots_element_keys[0]] - spots_dims = get_dims(spots_element) - - centroids = [] - for dim_name in spots_dims: - centroids.append(getattr(spots_element["geometry"], dim_name).to_numpy()) - centroids_array = np.column_stack(centroids) - radius = np.expand_dims(spots_element["radius"].to_numpy(), axis=1) - - min_coordinates = (centroids_array - radius).astype(int) - max_coordinates = (centroids_array + radius).astype(int) - - return min_coordinates, max_coordinates, spots_dims + def _get_region_info_for_index(self, index: int) -> tuple[str, int]: + # TODO: this implmenetation can be improved + i = 0 + for region_key, n_spots in self.n_spots_dict.items(): + if index < i + n_spots: + return region_key, index - i + i += n_spots + raise ValueError(f"index {index} is out of range") def __len__(self) -> int: return self.n_spots def __getitem__(self, idx: int) -> SpatialData: - min_coordinate = self.min_coordinates[idx] - max_coordinate = self.max_coordinates[idx] - - request = BoundingBoxRequest(min_coordinate=min_coordinate, max_coordinate=max_coordinate, axes=self.spots_dims) - sdata_item = self.sdata.query.bounding_box(request=request) - - if self.transform is not None: - sdata_item = self.transform(sdata_item) + regions_name, region_index = self._get_region_info_for_index(idx) + regions = self.sdata[regions_name] + # TODO: here we just need to compute the centroids, we probably want to move this functionality to a different file + if isinstance(regions, GeoDataFrame): + get_dims(regions) + region = regions.iloc[region_index] + # the function coords.xy is just accessing _coords, and wrapping it with extra information, so we access + # it directly + centroid = region.geometry.coords._coords[0] + elif isinstance(regions, SpatialImage): + raise NotImplementedError("labels not supported yet") + elif isinstance(regions, MultiscaleSpatialImage): + raise NotImplementedError("labels not supported yet") + else: + raise ValueError(f"element must be shapes or labels") + np.array(centroid) - self.tile_dim_in_units / 2 + np.array(centroid) + self.tile_dim_in_units / 2 + + # tile = rasterize + # request = BoundingBoxRequest( + # target_coordinate_system=self.target_coordinate_system, + # axes=self.spots_dims, + # min_coordinate=min_coordinate, + # max_coordinate=max_coordinate, + # ) + # sdata_item = self.sdata.query.bounding_box(**request.to_dict()) + # + # if self.transform is not None: + # sdata_item = self.transform(sdata_item) return sdata_item diff --git a/tests/_core/test_spatialdata_operations.py b/tests/_core/test_spatialdata_operations.py index c21970ec..3a41cfa7 100644 --- a/tests/_core/test_spatialdata_operations.py +++ b/tests/_core/test_spatialdata_operations.py @@ -179,3 +179,47 @@ def test_locate_spatial_element(full_sdata): full_sdata.images["image2d_again"] = im with pytest.raises(ValueError): full_sdata._locate_spatial_element(im) + + +@pytest.mark.parametrize( + "input_string, expected_element_type", + [ + ("/images/image2d", "images"), + ("/images/image2d_multiscale", "images"), + ("/images/image2d_xarray", "images"), + ("/images/image2d_multiscale_xarray", "images"), + ("/images/image3d_numpy", "images"), + ("/images/image3d_multiscale_numpy", "images"), + ("/images/image3d_xarray", "images"), + ("/images/image3d_multiscale_xarray", "images"), + ("/labels/labels2d", "labels"), + ("/labels/labels2d_multiscale", "labels"), + ("/labels/labels2d_xarray", "labels"), + ("/labels/labels2d_multiscale_xarray", "labels"), + ("/labels/labels3d_numpy", "labels"), + ("/labels/labels3d_multiscale_numpy", "labels"), + ("/labels/labels3d_xarray", "labels"), + ("/labels/labels3d_multiscale_xarray", "labels"), + ("/points/points_0", "points"), + ("/points/points_0_1", "points"), + ("/shapes/poly", "shapes"), + ("/shapes/multipoly", "shapes"), + ("/shapes/circles", "shapes"), + ("/table", "table"), + ("images/image2d", None), + ("/iimages/image2d", None), + ("/images/image2d/", None), + ("/images/image2d/a", None), + ("/table/a", None), + ], +) +def test_get_item(full_sdata, input_string, expected_element_type): + if expected_element_type is None: + with pytest.raises(ValueError): + _ = full_sdata[input_string] + else: + element = full_sdata[input_string] + if expected_element_type == "table": + assert id(element) == id(full_sdata.table) + else: + assert id(element) == id(full_sdata.__getattribute__(expected_element_type)[input_string.split("/")[-1]]) From 89023905c4497545a032facc7c9ee7c47a0fa4df Mon Sep 17 00:00:00 2001 From: Luca Marconato <2664412+LucaMarconato@users.noreply.github.com> Date: Mon, 6 Mar 2023 12:28:53 +0100 Subject: [PATCH 07/16] wip tiling --- examples/dev-examples/image_tiles.py | 63 +++++++++++++--------------- spatialdata/_dl/datasets.py | 39 ++++++++++------- 2 files changed, 54 insertions(+), 48 deletions(-) diff --git a/examples/dev-examples/image_tiles.py b/examples/dev-examples/image_tiles.py index d0f7c2ca..b537bd28 100644 --- a/examples/dev-examples/image_tiles.py +++ b/examples/dev-examples/image_tiles.py @@ -7,7 +7,6 @@ import spatialdata as sd from spatialdata._dl.datasets import ImageTilesDataset -from spatialdata._dl.transforms import SpatialDataToDataDict ## coordinates = np.array([[10, 10], [20, 20], [50, 30], [90, 70], [20, 80]]) @@ -50,7 +49,7 @@ ds = ImageTilesDataset( sdata=sdata, regions_to_images={"/shapes/spots": "/images/image"}, - tile_dim_in_units=5, + tile_dim_in_units=10, tile_dim_in_pixels=32, target_coordinate_system="global", data_dict_transform=None, @@ -69,39 +68,37 @@ def plot_sdata_dataset(ds): n_samples = len(ds) fig, axs = plt.subplots(1, n_samples) - for index, sample in enumerate(ds): - if isinstance(sample, sd.SpatialData): - im = sample.images["image"] - else: - im = sample["image"] - axs[index].imshow(im.transpose("y", "x", "c")) + for i, (image, region, index) in enumerate(ds): + axs[i].imshow(image.transpose("y", "x", "c")) + plt.show() plot_sdata_dataset(ds) -## -# we can also use transforms to automatically extract the relevant data -# into a datadictionary - -# map the SpatialData path to a data dict key -data_mapping = {"images/image": "image"} - -# make the transform -ds_transform = ImageTilesDataset( - sdata=sdata, - spots_element_key="spots", - transform=SpatialDataToDataDict(data_mapping=data_mapping), -) - -print(f"this dataset as {len(ds_transform)} items") - -## -# now the samples are a dictionary with key "image" and the item is the -# image array -# this is useful because it is the expected format for many of the +# TODO: code to be restored when the transforms will use the bounding box query +# ## +# # we can also use transforms to automatically extract the relevant data +# # into a datadictionary # -ds_transform[0] - -## -# plot of each sample in the dataset -plot_sdata_dataset(ds_transform) +# # map the SpatialData path to a data dict key +# data_mapping = {"images/image": "image"} +# +# # make the transform +# ds_transform = ImageTilesDataset( +# sdata=sdata, +# spots_element_key="spots", +# transform=SpatialDataToDataDict(data_mapping=data_mapping), +# ) +# +# print(f"this dataset as {len(ds_transform)} items") +# +# ## +# # now the samples are a dictionary with key "image" and the item is the +# # image array +# # this is useful because it is the expected format for many of the +# # +# ds_transform[0] +# +# ## +# # plot of each sample in the dataset +# plot_sdata_dataset(ds_transform) diff --git a/spatialdata/_dl/datasets.py b/spatialdata/_dl/datasets.py index 4b6ab542..78e64d73 100644 --- a/spatialdata/_dl/datasets.py +++ b/spatialdata/_dl/datasets.py @@ -7,6 +7,7 @@ from torch.utils.data import Dataset from spatialdata import SpatialData +from spatialdata._core._rasterize import rasterize from spatialdata._core.core_utils import get_dims from spatialdata._core.models import ( Image2DModel, @@ -28,6 +29,9 @@ def __init__( target_coordinate_system: str = "global", data_dict_transform: Optional[Callable[[SpatialData], dict[str, SpatialImage]]] = None, ): + # TODO: we can extend this code to support: + # - automatic dermination of the tile_dim_in_pixels to match the image resolution (prevent down/upscaling) + # - use the bounding box query instead of the raster function if the user wants self.sdata = sdata self.regions_to_images = regions_to_images self.tile_dim_in_units = tile_dim_in_units @@ -63,10 +67,6 @@ def _compute_n_spots_dict(self) -> dict[str, int]: raise ValueError(f"element must be a geodataframe or a spatial image") return n_spots_dict - def _get_centroids_and_metadata(self) -> None: - for key in self.spots_element_keys: - print(key) - def _get_region_info_for_index(self, index: int) -> tuple[str, int]: # TODO: this implmenetation can be improved i = 0 @@ -79,12 +79,14 @@ def _get_region_info_for_index(self, index: int) -> tuple[str, int]: def __len__(self) -> int: return self.n_spots - def __getitem__(self, idx: int) -> SpatialData: + def __getitem__(self, idx: int) -> tuple[SpatialImage, str, int]: + if idx >= self.n_spots: + raise IndexError() regions_name, region_index = self._get_region_info_for_index(idx) regions = self.sdata[regions_name] # TODO: here we just need to compute the centroids, we probably want to move this functionality to a different file if isinstance(regions, GeoDataFrame): - get_dims(regions) + dims = get_dims(regions) region = regions.iloc[region_index] # the function coords.xy is just accessing _coords, and wrapping it with extra information, so we access # it directly @@ -95,19 +97,26 @@ def __getitem__(self, idx: int) -> SpatialData: raise NotImplementedError("labels not supported yet") else: raise ValueError(f"element must be shapes or labels") - np.array(centroid) - self.tile_dim_in_units / 2 - np.array(centroid) + self.tile_dim_in_units / 2 + min_coordinate = np.array(centroid) - self.tile_dim_in_units / 2 + max_coordinate = np.array(centroid) + self.tile_dim_in_units / 2 - # tile = rasterize + raster = self.sdata[self.regions_to_images[regions_name]] + tile = rasterize( + raster, + axes=dims, + min_coordinate=min_coordinate, + max_coordinate=max_coordinate, + target_coordinate_system=self.target_coordinate_system, + target_width=self.tile_dim_in_pixels, + ) + # TODO: as explained in the TODO in the __init__(), we want to let the user also use the bounding box query instaed of the rasterization + # the return function of this function would change, so we need to decide if instead having an extra Tile dataset class + # from spatialdata._core._spatial_query import BoundingBoxRequest # request = BoundingBoxRequest( # target_coordinate_system=self.target_coordinate_system, - # axes=self.spots_dims, + # axes=dims, # min_coordinate=min_coordinate, # max_coordinate=max_coordinate, # ) # sdata_item = self.sdata.query.bounding_box(**request.to_dict()) - # - # if self.transform is not None: - # sdata_item = self.transform(sdata_item) - - return sdata_item + return tile, regions_name, region_index From a307f1b50816cc33a036850d07640331dd6975e2 Mon Sep 17 00:00:00 2001 From: Luca Marconato <2664412+LucaMarconato@users.noreply.github.com> Date: Wed, 8 Mar 2023 23:45:45 +0100 Subject: [PATCH 08/16] tiling still wip, but usable --- examples/dev-examples/image_tiles.py | 4 +- .../spatial_query_and_rasterization.py | 73 ++++++++---- spatialdata/_core/_rasterize.py | 110 ++++++++++++------ spatialdata/_core/_spatial_query.py | 2 +- spatialdata/_core/_spatialdata.py | 2 +- spatialdata/_core/models.py | 2 +- spatialdata/_dl/datasets.py | 5 +- 7 files changed, 135 insertions(+), 63 deletions(-) diff --git a/examples/dev-examples/image_tiles.py b/examples/dev-examples/image_tiles.py index 61be1d74..ea9af907 100644 --- a/examples/dev-examples/image_tiles.py +++ b/examples/dev-examples/image_tiles.py @@ -48,7 +48,7 @@ ## ds = ImageTilesDataset( sdata=sdata, - regions_to_images={"/shapes/spots": "/images/image"}, + regions_to_images={"spots": "image"}, tile_dim_in_units=10, tile_dim_in_pixels=32, target_coordinate_system="global", @@ -70,7 +70,7 @@ def plot_sdata_dataset(ds: ImageTilesDataset) -> None: for i, (image, region, index) in enumerate(ds): axs[i].imshow(image.transpose("y", "x", "c")) - axs[i].set_title(f"region: {region}, index: {index}") + axs[i].set_title(f"{region}, {index}") plt.show() diff --git a/examples/dev-examples/spatial_query_and_rasterization.py b/examples/dev-examples/spatial_query_and_rasterization.py index d92a765c..b4358e7c 100644 --- a/examples/dev-examples/spatial_query_and_rasterization.py +++ b/examples/dev-examples/spatial_query_and_rasterization.py @@ -1,6 +1,4 @@ import numpy as np -from multiscale_spatial_image import MultiscaleSpatialImage -from spatial_image import SpatialImage from spatialdata import Labels2DModel from spatialdata._core._spatial_query import bounding_box_query @@ -9,10 +7,10 @@ remove_transformation, set_transformation, ) -from spatialdata._core.transformations import Affine +from spatialdata._core.transformations import Affine, Scale, Sequence -def _visualize_crop_affine_labels_2d() -> None: +def _visualize_crop_affine_labels_2d(): """ This examples show how the bounding box spatial query works for data that has been rotated. @@ -31,34 +29,41 @@ def _visualize_crop_affine_labels_2d() -> None: requested crop, as exaplained above) 5) then enable "3 cropped rotated processed", this shows the data that we wanted to query in the first place, in the target coordinate system ("rotated"). This is probaly the data you care about if for instance you want to - use tiles for deep learning. Note that for obtaning this answer there is also a better function (not available at - the time of this writing): rasterize(), which is faster and more accurate, so it should be used instead. The - function rasterize() transforms all the coordinates of the data into the target coordinate system, and it returns - only SpatialImage objects. So it has different use cases than the bounding box query. - 6) finally switch to the "global" coordinate_system. This is, for how we constructed the example, showing the + use tiles for deep learning. + 6) Note that for obtaning the previous answer there is also a better function rasterize(). + This is what "4 rasterized" shows, which is faster and more accurate, so it should be used instead. The function + rasterize() transforms all the coordinates of the data into the target coordinate system, and it returns only + SpatialImage objects. So it has different use cases than the bounding box query. BUG: Note that it is not pixel + perfect. I think this is due to the difference between considering the origin of a pixel its center or its corner. + 7) finally switch to the "global" coordinate_system. This is, for how we constructed the example, showing the original image as it would appear its intrinsic coordinate system (since the transformation that maps the original image to "global" is an identity. It then shows how the data showed at the point 5), localizes in the original image. """ ## # in this test let's try some affine transformations, we could do that also for the other tests + # image = scipy.misc.face()[:100, :100, :].copy() image = np.random.randint(low=10, high=100, size=(100, 100)) + multiscale_image = np.repeat(np.repeat(image, 4, axis=0), 4, axis=1) + # y: [5, 9], x: [0, 4] has value 1 image[50:, :50] = 2 + # labels_element = Image2DModel.parse(image, dims=('y', 'x', 'c')) labels_element = Labels2DModel.parse(image) + affine = Affine( + np.array( + [ + [np.cos(np.pi / 6), np.sin(-np.pi / 6), 0], + [np.sin(np.pi / 6), np.cos(np.pi / 6), 0], + [0, 0, 1], + ] + ), + input_axes=("x", "y"), + output_axes=("x", "y"), + ) set_transformation( labels_element, - Affine( - np.array( - [ - [np.cos(np.pi / 6), np.sin(-np.pi / 6), 20], - [np.sin(np.pi / 6), np.cos(np.pi / 6), 0], - [0, 0, 1], - ] - ), - input_axes=("x", "y"), - output_axes=("x", "y"), - ), + affine, "rotated", ) @@ -91,9 +96,6 @@ def _visualize_crop_affine_labels_2d() -> None: if labels_result_rotated is not None: d["2 cropped_rotated"] = labels_result_rotated - assert isinstance(labels_result_rotated, SpatialImage) or isinstance( - labels_result_rotated, MultiscaleSpatialImage - ) transform = labels_result_rotated.attrs["transform"]["rotated"] transform_rotated_processed = transform.transform(labels_result_rotated, maintain_positioning=True) transform_rotated_processed_recropped = bounding_box_query( @@ -106,7 +108,32 @@ def _visualize_crop_affine_labels_2d() -> None: d["3 cropped_rotated_processed_recropped"] = transform_rotated_processed_recropped remove_transformation(labels_result_rotated, "global") + multiscale_image[200:, :200] = 2 + # multiscale_labels = Labels2DModel.parse(multiscale_image) + multiscale_labels = Labels2DModel.parse(multiscale_image, scale_factors=[2, 2, 2, 2]) + sequence = Sequence([Scale([0.25, 0.25], axes=("x", "y")), affine]) + set_transformation(multiscale_labels, sequence, "rotated") + + from spatialdata._core._rasterize import rasterize + + rasterized = rasterize( + multiscale_labels, + axes=("y", "x"), + min_coordinate=np.array([25, 25]), + max_coordinate=np.array([75, 100]), + target_coordinate_system="rotated", + target_width=300, + ) + d["4 rasterized"] = rasterized + sdata = SpatialData(labels=d) + + # to see only what matters when debugging https://github.com/scverse/spatialdata/issues/165 + del sdata.labels["1 cropped_global"] + del sdata.labels["2 cropped_rotated"] + del sdata.labels["3 cropped_rotated_processed_recropped"] + del sdata.labels["0 original"].attrs["transform"]["global"] + Interactive(sdata) ## diff --git a/spatialdata/_core/_rasterize.py b/spatialdata/_core/_rasterize.py index e5165c83..d8541566 100644 --- a/spatialdata/_core/_rasterize.py +++ b/spatialdata/_core/_rasterize.py @@ -32,7 +32,6 @@ get_schema, ) from spatialdata._core.transformations import ( - Affine, BaseTransformation, Scale, Sequence, @@ -247,7 +246,7 @@ def _get_xarray_data_to_rasterize( min_coordinate: Union[list[Number], ArrayLike], max_coordinate: Union[list[Number], ArrayLike], target_sizes: dict[str, Optional[float]], - corrected_affine: Affine, + target_coordinate_system: str, ) -> tuple[DataArray, Optional[Scale]]: """ Returns the DataArray to rasterize along with its eventual scale factor (if from a pyramid level) from either a @@ -289,6 +288,13 @@ def _get_xarray_data_to_rasterize( xdata = next(iter(v)) assert set(get_spatial_axes(tuple(xdata.sizes.keys()))) == set(axes) + corrected_affine, _ = _get_corrected_affine_matrix( + data=SpatialImage(xdata), + axes=axes, + min_coordinate=min_coordinate, + max_coordinate=max_coordinate, + target_coordinate_system=target_coordinate_system, + ) m = corrected_affine.inverse().matrix # type: ignore[attr-defined] m_linear = m[:-1, :-1] m_translation = m[:-1, -1] @@ -300,7 +306,7 @@ def _get_xarray_data_to_rasterize( assert tuple(bb_corners.axis.data.tolist()) == axes bb_in_xdata = bb_corners.data @ m_linear + m_translation bb_in_xdata_sizes = { - ax: bb_in_xdata[axes.index(ax)].max() - bb_in_xdata[axes.index(ax)].min() for ax in axes + ax: bb_in_xdata[:, axes.index(ax)].max() - bb_in_xdata[:, axes.index(ax)].min() for ax in axes } for ax in axes: # TLDR; the sqrt selects a pyramid level in which the requested bounding box is a bit larger than the @@ -311,9 +317,10 @@ def _get_xarray_data_to_rasterize( # inverse-transformed bounding box. The sqrt comes from the ratio of the side of a square, # and the maximum diagonal of a square containing the original square, if the original square is # rotated. - if bb_in_xdata_sizes[ax] * np.sqrt(len(axes)) < target_sizes[ax]: + if bb_in_xdata_sizes[ax] < target_sizes[ax] * np.sqrt(len(axes)): break else: + # when this code is reached, latest_scale is selected break assert latest_scale is not None xdata = next(iter(data[latest_scale].values())) @@ -327,6 +334,40 @@ def _get_xarray_data_to_rasterize( return xdata, pyramid_scale +def _get_corrected_affine_matrix( + data: SpatialImage | MultiscaleSpatialImage, + axes: tuple[str, ...], + min_coordinate: ArrayLike, + max_coordinate: ArrayLike, + target_coordinate_system: str, +) -> tuple[ArrayLike, tuple[str, ...]]: + """ + TODO: docstring + """ + transformation = get_transformation(data, target_coordinate_system) + get_dims(data) + assert isinstance(transformation, BaseTransformation) + affine = _get_affine_for_element(data, transformation) + target_axes_unordered = affine.output_axes + assert set(target_axes_unordered) in [{"x", "y", "z"}, {"x", "y"}, {"c", "x", "y", "z"}, {"c", "x", "y"}] + target_axes: tuple[str, ...] + if "z" in target_axes_unordered: + if "c" in target_axes_unordered: + target_axes = ("c", "z", "y", "x") + else: + target_axes = ("z", "y", "x") + else: + if "c" in target_axes_unordered: + target_axes = ("c", "y", "x") + else: + target_axes = ("y", "x") + target_spatial_axes = get_spatial_axes(target_axes) + assert len(target_spatial_axes) == len(min_coordinate) + assert len(target_spatial_axes) == len(max_coordinate) + corrected_affine = affine.to_affine(input_axes=axes, output_axes=target_spatial_axes) + return corrected_affine, target_axes + + @rasterize.register(SpatialImage) @rasterize.register(MultiscaleSpatialImage) def _( @@ -359,29 +400,6 @@ def _( "z": target_depth, } - # get inverse transformation - transformation = get_transformation(data, target_coordinate_system) - dims = get_dims(data) - assert isinstance(transformation, BaseTransformation) - affine = _get_affine_for_element(data, transformation) - target_axes_unordered = affine.output_axes - assert set(target_axes_unordered) in [{"x", "y", "z"}, {"x", "y"}, {"c", "x", "y", "z"}, {"c", "x", "y"}] - target_axes: tuple[str, ...] - if "z" in target_axes_unordered: - if "c" in target_axes_unordered: - target_axes = ("c", "z", "y", "x") - else: - target_axes = ("z", "y", "x") - else: - if "c" in target_axes_unordered: - target_axes = ("c", "y", "x") - else: - target_axes = ("y", "x") - target_spatial_axes = get_spatial_axes(target_axes) - assert len(target_spatial_axes) == len(min_coordinate) - assert len(target_spatial_axes) == len(max_coordinate) - corrected_affine = affine.to_affine(input_axes=axes, output_axes=target_spatial_axes) - bb_sizes = {ax: max_coordinate[axes.index(ax)] - min_coordinate[axes.index(ax)] for ax in axes} scale_vector = [bb_sizes[ax] / target_sizes[ax] for ax in axes] scale = Scale(scale_vector, axes=axes) @@ -395,7 +413,8 @@ def _( min_coordinate=min_coordinate, max_coordinate=max_coordinate, target_sizes=target_sizes, - corrected_affine=corrected_affine, + target_coordinate_system=target_coordinate_system, + # corrected_affine=corrected_affine, ) if pyramid_scale is not None: @@ -403,17 +422,27 @@ def _( else: extra = [] + # get inverse transformation + corrected_affine, target_axes = _get_corrected_affine_matrix( + data=data, + axes=axes, + min_coordinate=min_coordinate, + max_coordinate=max_coordinate, + target_coordinate_system=target_coordinate_system, + ) + half_pixel_offset = Translation([0.5, 0.5, 0.5], axes=("z", "y", "x")) sequence = Sequence( [ - half_pixel_offset.inverse(), + # half_pixel_offset.inverse(), scale, translation, corrected_affine.inverse(), - half_pixel_offset, + # half_pixel_offset, ] + extra ) + dims = get_dims(data) matrix = sequence.to_affine_matrix(input_axes=target_axes, output_axes=dims) # get output shape @@ -433,12 +462,11 @@ def _( if schema == Labels2DModel or schema == Labels3DModel: kwargs = {"prefilter": False, "order": 0} elif schema == Image2DModel or schema == Image3DModel: + # kwargs = {"prefilter": True} kwargs = {} else: raise ValueError(f"Unsupported schema {schema}") - # TODO: adjust matrix - # TODO: add c # resample the image transformed_dask = dask_image.ndinterp.affine_transform( xdata.data, @@ -447,12 +475,28 @@ def _( # output_chunks=xdata.data.chunks, **kwargs, ) + # ## + # # debug code + # crop = xdata.sel( + # { + # "x": slice(min_coordinate[axes.index("x")], max_coordinate[axes.index("x")]), + # "y": slice(min_coordinate[axes.index("y")], max_coordinate[axes.index("y")]), + # } + # ) + # import matplotlib.pyplot as plt + # plt.figure(figsize=(20, 10)) + # plt.subplot(1, 2, 1) + # plt.imshow(crop.transpose("y", "x", "c").data) + # plt.subplot(1, 2, 2) + # plt.imshow(DataArray(transformed_dask, dims=xdata.dims).transpose("y", "x", "c").data) + # plt.show() + # ## assert isinstance(transformed_dask, DaskArray) transformed_data = schema.parse(transformed_dask, dims=xdata.dims) # type: ignore[call-arg,arg-type] if target_coordinate_system != "global": remove_transformation(transformed_data, "global") - sequence = Sequence([half_pixel_offset.inverse(), scale, translation]) + sequence = Sequence([half_pixel_offset.inverse(), scale, translation, half_pixel_offset]) set_transformation(transformed_data, sequence, target_coordinate_system) transformed_data = compute_coordinates(transformed_data) diff --git a/spatialdata/_core/_spatial_query.py b/spatialdata/_core/_spatial_query.py index c3d52553..a1b3f52b 100644 --- a/spatialdata/_core/_spatial_query.py +++ b/spatialdata/_core/_spatial_query.py @@ -301,7 +301,7 @@ def _( ) new_elements[element_type] = queried_elements - if filter_table: + if filter_table and sdata.table is not None: to_keep = np.array([False] * len(sdata.table)) region_key = sdata.table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY_KEY] instance_key = sdata.table.uns[TableModel.ATTRS_KEY][TableModel.INSTANCE_KEY] diff --git a/spatialdata/_core/_spatialdata.py b/spatialdata/_core/_spatialdata.py index 52f7e275..180f642f 100644 --- a/spatialdata/_core/_spatialdata.py +++ b/spatialdata/_core/_spatialdata.py @@ -741,7 +741,7 @@ def add_points( def add_shapes( self, name: str, - shapes: AnnData, + shapes: GeoDataFrame, overwrite: bool = False, ) -> None: """ diff --git a/spatialdata/_core/models.py b/spatialdata/_core/models.py index d6825e82..4c1971eb 100644 --- a/spatialdata/_core/models.py +++ b/spatialdata/_core/models.py @@ -413,7 +413,7 @@ def _( if not isinstance(gc, GeometryCollection): raise ValueError(f"`{data}` does not contain a `GeometryCollection`.") geo_df = GeoDataFrame({"geometry": gc.geoms}) - if isinstance(geo_df["geometry"][0], Point): + if isinstance(geo_df["geometry"].iloc[0], Point): if radius is None: raise ValueError("If `geometry` is `Circles`, `radius` must be provided.") geo_df[cls.RADIUS_KEY] = radius diff --git a/spatialdata/_dl/datasets.py b/spatialdata/_dl/datasets.py index 2b5ea287..d873b4d7 100644 --- a/spatialdata/_dl/datasets.py +++ b/spatialdata/_dl/datasets.py @@ -27,7 +27,7 @@ def __init__( tile_dim_in_units: float, tile_dim_in_pixels: int, target_coordinate_system: str = "global", - data_dict_transform: Optional[Callable[[SpatialData], dict[str, SpatialImage]]] = None, + transform: Optional[Callable[[SpatialData], dict[str, SpatialImage]]] = None, ): # TODO: we can extend this code to support: # - automatic dermination of the tile_dim_in_pixels to match the image resolution (prevent down/upscaling) @@ -36,7 +36,7 @@ def __init__( self.regions_to_images = regions_to_images self.tile_dim_in_units = tile_dim_in_units self.tile_dim_in_pixels = tile_dim_in_pixels - self.data_dict_transform = data_dict_transform + self.transform = transform self.target_coordinate_system = target_coordinate_system self.n_spots_dict = self._compute_n_spots_dict() @@ -109,6 +109,7 @@ def __getitem__(self, idx: int) -> tuple[SpatialImage, str, int]: target_coordinate_system=self.target_coordinate_system, target_width=self.tile_dim_in_pixels, ) + # TODO: as explained in the TODO in the __init__(), we want to let the user also use the bounding box query instaed of the rasterization # the return function of this function would change, so we need to decide if instead having an extra Tile dataset class # from spatialdata._core._spatial_query import BoundingBoxRequest From 2125bec3fc4acae19bd64e7654f9abb08de57e42 Mon Sep 17 00:00:00 2001 From: Luca Marconato <2664412+LucaMarconato@users.noreply.github.com> Date: Thu, 9 Mar 2023 00:23:14 +0100 Subject: [PATCH 09/16] fixed mypy --- examples/dev-examples/image_tiles.py | 1 - examples/dev-examples/spatial_query_and_rasterization.py | 4 +++- spatialdata/_core/_rasterize.py | 5 +++-- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/examples/dev-examples/image_tiles.py b/examples/dev-examples/image_tiles.py index ea9af907..16789915 100644 --- a/examples/dev-examples/image_tiles.py +++ b/examples/dev-examples/image_tiles.py @@ -52,7 +52,6 @@ tile_dim_in_units=10, tile_dim_in_pixels=32, target_coordinate_system="global", - data_dict_transform=None, ) print(f"this dataset as {len(ds)} items") diff --git a/examples/dev-examples/spatial_query_and_rasterization.py b/examples/dev-examples/spatial_query_and_rasterization.py index b4358e7c..0c932f4f 100644 --- a/examples/dev-examples/spatial_query_and_rasterization.py +++ b/examples/dev-examples/spatial_query_and_rasterization.py @@ -1,4 +1,5 @@ import numpy as np +from spatial_image import SpatialImage from spatialdata import Labels2DModel from spatialdata._core._spatial_query import bounding_box_query @@ -10,7 +11,7 @@ from spatialdata._core.transformations import Affine, Scale, Sequence -def _visualize_crop_affine_labels_2d(): +def _visualize_crop_affine_labels_2d() -> None: """ This examples show how the bounding box spatial query works for data that has been rotated. @@ -96,6 +97,7 @@ def _visualize_crop_affine_labels_2d(): if labels_result_rotated is not None: d["2 cropped_rotated"] = labels_result_rotated + assert isinstance(labels_result_rotated, SpatialImage) transform = labels_result_rotated.attrs["transform"]["rotated"] transform_rotated_processed = transform.transform(labels_result_rotated, maintain_positioning=True) transform_rotated_processed_recropped = bounding_box_query( diff --git a/spatialdata/_core/_rasterize.py b/spatialdata/_core/_rasterize.py index 58b36717..10bc2e56 100644 --- a/spatialdata/_core/_rasterize.py +++ b/spatialdata/_core/_rasterize.py @@ -32,6 +32,7 @@ get_schema, ) from spatialdata._core.transformations import ( + Affine, BaseTransformation, Scale, Sequence, @@ -335,12 +336,12 @@ def _get_xarray_data_to_rasterize( def _get_corrected_affine_matrix( - data: SpatialImage | MultiscaleSpatialImage, + data: Union[SpatialImage, MultiscaleSpatialImage], axes: tuple[str, ...], min_coordinate: ArrayLike, max_coordinate: ArrayLike, target_coordinate_system: str, -) -> tuple[ArrayLike, tuple[str, ...]]: +) -> tuple[Affine, tuple[str, ...]]: """ TODO: docstring """ From 9c2cf75b37a9e5a89d1cfa55b59f13a418842514 Mon Sep 17 00:00:00 2001 From: Luca Marconato <2664412+LucaMarconato@users.noreply.github.com> Date: Thu, 9 Mar 2023 17:22:16 +0100 Subject: [PATCH 10/16] type fix --- spatialdata/_core/_spatial_query.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spatialdata/_core/_spatial_query.py b/spatialdata/_core/_spatial_query.py index f4c974d5..9c3db662 100644 --- a/spatialdata/_core/_spatial_query.py +++ b/spatialdata/_core/_spatial_query.py @@ -333,7 +333,7 @@ def _( @bounding_box_query.register(SpatialImage) @bounding_box_query.register(MultiscaleSpatialImage) def _( - image: SpatialImage, + image: Union[SpatialImage, MultiscaleSpatialImage], axes: tuple[str, ...], min_coordinate: Union[list[Number], ArrayLike], max_coordinate: Union[list[Number], ArrayLike], From 656616be94c05448b1d62c317c7a9eeb04534f2b Mon Sep 17 00:00:00 2001 From: Luca Marconato <2664412+LucaMarconato@users.noreply.github.com> Date: Thu, 9 Mar 2023 20:33:52 +0100 Subject: [PATCH 11/16] fixed bug with xarray coordinates in multiscale, fixed wrong centroids when tiling --- spatialdata/_core/core_utils.py | 34 ++++++++------------------------- spatialdata/_dl/datasets.py | 13 ++++++++++--- spatialdata/utils.py | 10 ++++++++++ 3 files changed, 28 insertions(+), 29 deletions(-) diff --git a/spatialdata/_core/core_utils.py b/spatialdata/_core/core_utils.py index fa94e3ef..c1bc710f 100644 --- a/spatialdata/_core/core_utils.py +++ b/spatialdata/_core/core_utils.py @@ -385,36 +385,18 @@ def _get_scale(transforms: dict[str, Any]) -> Scale: @compute_coordinates.register(MultiscaleSpatialImage) def _(data: MultiscaleSpatialImage) -> MultiscaleSpatialImage: - def _compute_coords(n0: int, scale_f: float, n: int) -> ArrayLike: - scaled_max = n0 / scale_f - if n > 1: - offset = scaled_max / (2.0 * (n - 1)) - else: - offset = 0 - return np.linspace(0, scaled_max, n) + offset - spatial_coords = [ax for ax in get_dims(data) if ax in ["x", "y", "z"]] img_name = list(data["scale0"].data_vars.keys())[0] out = {} - for name, dt in data.items(): - if name == "scale0": - coords: dict[str, ArrayLike] = { - d: np.arange(data[name].sizes[d], dtype=np.float_) + 0.5 for d in spatial_coords - } - out[name] = dt[img_name].assign_coords(coords) - else: - scale = _get_scale(dt[img_name].attrs["transform"]) - scalef = scale.scale - assert len(spatial_coords) == len(scalef), "Mismatch between coordinates and scales." # type: ignore[arg-type] - new_coords = {} - for ax, s in zip(spatial_coords, scalef): - new_coords[ax] = _compute_coords( - n0=data["scale0"].sizes[ax], - scale_f=s, - n=data[name].sizes[ax], - ) - out[name] = dt[img_name].assign_coords(new_coords) + new_coords = {} + for ax in spatial_coords: + max_dim = data["scale0"].sizes[ax] + n = dt.sizes[ax] + offset = max_dim / n / 2 + coords = np.linspace(0, max_dim, n + 1)[:-1] + offset + new_coords[ax] = coords + out[name] = dt[img_name].assign_coords(new_coords) msi = MultiscaleSpatialImage.from_dict(d=out) # this is to trigger the validation of the dims _ = get_dims(msi) diff --git a/spatialdata/_dl/datasets.py b/spatialdata/_dl/datasets.py index d873b4d7..a95e97dd 100644 --- a/spatialdata/_dl/datasets.py +++ b/spatialdata/_dl/datasets.py @@ -8,6 +8,7 @@ from spatialdata import SpatialData from spatialdata._core._rasterize import rasterize +from spatialdata._core._spatialdata_ops import get_transformation from spatialdata._core.core_utils import get_dims from spatialdata._core.models import ( Image2DModel, @@ -17,6 +18,8 @@ ShapesModel, get_schema, ) +from spatialdata._core.transformations import BaseTransformation +from spatialdata.utils import affine_matrix_multiplication class ImageTilesDataset(Dataset): @@ -90,15 +93,19 @@ def __getitem__(self, idx: int) -> tuple[SpatialImage, str, int]: region = regions.iloc[region_index] # the function coords.xy is just accessing _coords, and wrapping it with extra information, so we access # it directly - centroid = region.geometry.coords._coords[0] + centroid = np.atleast_2d(region.geometry.coords._coords[0]) + t = get_transformation(regions, self.target_coordinate_system) + assert isinstance(t, BaseTransformation) + aff = t.to_affine_matrix(input_axes=dims, output_axes=dims) + transformed_centroid = np.squeeze(affine_matrix_multiplication(aff, centroid), 0) elif isinstance(regions, SpatialImage): raise NotImplementedError("labels not supported yet") elif isinstance(regions, MultiscaleSpatialImage): raise NotImplementedError("labels not supported yet") else: raise ValueError("element must be shapes or labels") - min_coordinate = np.array(centroid) - self.tile_dim_in_units / 2 - max_coordinate = np.array(centroid) + self.tile_dim_in_units / 2 + min_coordinate = np.array(transformed_centroid) - self.tile_dim_in_units / 2 + max_coordinate = np.array(transformed_centroid) + self.tile_dim_in_units / 2 raster = self.sdata[self.regions_to_images[regions_name]] tile = rasterize( diff --git a/spatialdata/utils.py b/spatialdata/utils.py index fe165d10..c0886f20 100644 --- a/spatialdata/utils.py +++ b/spatialdata/utils.py @@ -269,3 +269,13 @@ def natural_keys(text: str) -> list[Union[int, str]]: (See Toothy's implementation in the comments) """ return [atoi(c) for c in re.split(r"(\d+)", text)] + + +def affine_matrix_multiplication(matrix: ArrayLike, data: ArrayLike) -> ArrayLike: + assert len(data.shape) == 2 + assert matrix.shape[1] - 1 == data.shape[1] + vector_part = matrix[:-1, :-1] + offset_part = matrix[:-1, -1] + result = data @ vector_part.T + offset_part + assert result.shape[0] == data.shape[0] + return result # type: ignore[no-any-return] From db62b717a8d240479fe28b06f4b341625c09f36a Mon Sep 17 00:00:00 2001 From: LucaMarconato <2664412+LucaMarconato@users.noreply.github.com> Date: Tue, 14 Mar 2023 11:29:20 +0100 Subject: [PATCH 12/16] Apply suggestions from code review Co-authored-by: Giovanni Palla <25887487+giovp@users.noreply.github.com> --- spatialdata/_core/_rasterize.py | 2 -- spatialdata/_core/_spatialdata.py | 6 +----- 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/spatialdata/_core/_rasterize.py b/spatialdata/_core/_rasterize.py index 10bc2e56..f9f974b8 100644 --- a/spatialdata/_core/_rasterize.py +++ b/spatialdata/_core/_rasterize.py @@ -415,7 +415,6 @@ def _( max_coordinate=max_coordinate, target_sizes=target_sizes, target_coordinate_system=target_coordinate_system, - # corrected_affine=corrected_affine, ) if pyramid_scale is not None: @@ -463,7 +462,6 @@ def _( if schema == Labels2DModel or schema == Labels3DModel: kwargs = {"prefilter": False, "order": 0} elif schema == Image2DModel or schema == Image3DModel: - # kwargs = {"prefilter": True} kwargs = {} else: raise ValueError(f"Unsupported schema {schema}") diff --git a/spatialdata/_core/_spatialdata.py b/spatialdata/_core/_spatialdata.py index 180f642f..e409e3ee 100644 --- a/spatialdata/_core/_spatialdata.py +++ b/spatialdata/_core/_spatialdata.py @@ -1210,13 +1210,9 @@ def __getitem__(self, item: str) -> SpatialElement | AnnData: ------- The element. """ - found = [] for _, element_name, element in self._gen_elements(): if element_name == item: - found.append(element) - assert len(found) <= 1 - if len(found) == 1: - return found[0] + return element else: raise KeyError(f"Could not find element with name {item!r}") From edf71bd29067361986b234a1499634931758f72e Mon Sep 17 00:00:00 2001 From: Luca Marconato <2664412+LucaMarconato@users.noreply.github.com> Date: Tue, 14 Mar 2023 11:53:25 +0100 Subject: [PATCH 13/16] implemented suggestions from code review --- .pre-commit-config.yaml | 2 +- examples/dev-examples/README.md | 3 - examples/dev-examples/image_tiles.py | 104 ------------- .../spatial_query_and_rasterization.py | 144 ------------------ pyproject.toml | 3 + spatialdata/__init__.py | 8 + spatialdata/_core/_rasterize.py | 21 +-- spatialdata/{_dl => _dataloader}/__init__.py | 0 spatialdata/{_dl => _dataloader}/datasets.py | 26 +++- .../{_dl => _dataloader}/transforms.py | 0 10 files changed, 45 insertions(+), 266 deletions(-) delete mode 100644 examples/dev-examples/README.md delete mode 100644 examples/dev-examples/image_tiles.py delete mode 100644 examples/dev-examples/spatial_query_and_rasterization.py rename spatialdata/{_dl => _dataloader}/__init__.py (100%) rename spatialdata/{_dl => _dataloader}/datasets.py (82%) rename spatialdata/{_dl => _dataloader}/transforms.py (100%) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e0dfaea5..ae9ae76e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,7 +6,7 @@ default_stages: - push minimum_pre_commit_version: 2.16.0 ci: - skip: [mypy] + skip: [] repos: - repo: https://github.com/psf/black rev: 23.1.0 diff --git a/examples/dev-examples/README.md b/examples/dev-examples/README.md deleted file mode 100644 index 352615fe..00000000 --- a/examples/dev-examples/README.md +++ /dev/null @@ -1,3 +0,0 @@ -# dev-examples - -The examples are useful when developing parts of the codebase. They are not intended to be used as a reference for how to use the library. For that, please refer to the documentation, the example notebooks and the examples in the `example/` directory. diff --git a/examples/dev-examples/image_tiles.py b/examples/dev-examples/image_tiles.py deleted file mode 100644 index 16789915..00000000 --- a/examples/dev-examples/image_tiles.py +++ /dev/null @@ -1,104 +0,0 @@ -## -# from https://gist.github.com/kevinyamauchi/77f986889b7626db4ab3c1075a3a3e5e - -import numpy as np -from matplotlib import pyplot as plt -from skimage import draw - -import spatialdata as sd -from spatialdata._dl.datasets import ImageTilesDataset - -## -coordinates = np.array([[10, 10], [20, 20], [50, 30], [90, 70], [20, 80]]) - -radius = 5 - -colors = np.array( - [ - [102, 194, 165], - [252, 141, 98], - [141, 160, 203], - [231, 138, 195], - [166, 216, 84], - ] -) - -## -# make an image with spots -image = np.zeros((100, 100, 3), dtype=np.uint8) - -for spot_color, centroid in zip(colors, coordinates): - rr, cc = draw.disk(centroid, radius=radius) - - for color_index in range(3): - channel_dims = color_index * np.ones((len(rr),), dtype=int) - image[rr, cc, channel_dims] = spot_color[color_index] - -# plt.imshow(image) -# plt.show() - -## -sd_image = sd.Image2DModel.parse(image, dims=("y", "x", "c")) - -# circles coordinates are xy, so we flip them here. -circles = sd.ShapesModel.parse(coordinates[:, [1, 0]], radius=radius, geometry=0) -sdata = sd.SpatialData(images={"image": sd_image}, shapes={"spots": circles}) -sdata - -## -ds = ImageTilesDataset( - sdata=sdata, - regions_to_images={"spots": "image"}, - tile_dim_in_units=10, - tile_dim_in_pixels=32, - target_coordinate_system="global", -) - -print(f"this dataset as {len(ds)} items") - -## -# we can use the __getitem__ interface to get one of the sample crops -print(ds[0]) - - -## -# now we plot all of the crops -def plot_sdata_dataset(ds: ImageTilesDataset) -> None: - n_samples = len(ds) - fig, axs = plt.subplots(1, n_samples) - - for i, (image, region, index) in enumerate(ds): - axs[i].imshow(image.transpose("y", "x", "c")) - axs[i].set_title(f"{region}, {index}") - plt.show() - - -plot_sdata_dataset(ds) - -# TODO: code to be restored when the transforms will use the bounding box query -# ## -# # we can also use transforms to automatically extract the relevant data -# # into a datadictionary -# -# # map the SpatialData path to a data dict key -# data_mapping = {"images/image": "image"} -# -# # make the transform -# ds_transform = ImageTilesDataset( -# sdata=sdata, -# spots_element_key="spots", -# transform=SpatialDataToDataDict(data_mapping=data_mapping), -# ) -# -# print(f"this dataset as {len(ds_transform)} items") -# -# ## -# # now the samples are a dictionary with key "image" and the item is the -# # image array -# # this is useful because it is the expected format for many of the -# # -# ds_transform[0] -# -# ## -# # plot of each sample in the dataset -# plot_sdata_dataset(ds_transform) diff --git a/examples/dev-examples/spatial_query_and_rasterization.py b/examples/dev-examples/spatial_query_and_rasterization.py deleted file mode 100644 index 0c932f4f..00000000 --- a/examples/dev-examples/spatial_query_and_rasterization.py +++ /dev/null @@ -1,144 +0,0 @@ -import numpy as np -from spatial_image import SpatialImage - -from spatialdata import Labels2DModel -from spatialdata._core._spatial_query import bounding_box_query -from spatialdata._core._spatialdata_ops import ( - get_transformation, - remove_transformation, - set_transformation, -) -from spatialdata._core.transformations import Affine, Scale, Sequence - - -def _visualize_crop_affine_labels_2d() -> None: - """ - This examples show how the bounding box spatial query works for data that has been rotated. - - Notes - ----- - The bounding box query gives the data, from the intrinsic coordinate system, that is inside the bounding box of - the inverse-transformed query bounding box. - In this example I show this data, and I also show how to obtain the data back inside the original bounding box. - - To undertand the example I suggest to run it and then: - 1) select the "rotated" coordinate system from napari - 2) disable all the layers but "0 original" - 3) then enable "1 cropped global", this shows the data in the extrinsic coordinate system we care ("rotated"), - and the bounding box we want to query - 4) then enable "2 cropped rotated", this show the data that has been queries (this is a bounding box of the - requested crop, as exaplained above) - 5) then enable "3 cropped rotated processed", this shows the data that we wanted to query in the first place, - in the target coordinate system ("rotated"). This is probaly the data you care about if for instance you want to - use tiles for deep learning. - 6) Note that for obtaning the previous answer there is also a better function rasterize(). - This is what "4 rasterized" shows, which is faster and more accurate, so it should be used instead. The function - rasterize() transforms all the coordinates of the data into the target coordinate system, and it returns only - SpatialImage objects. So it has different use cases than the bounding box query. BUG: Note that it is not pixel - perfect. I think this is due to the difference between considering the origin of a pixel its center or its corner. - 7) finally switch to the "global" coordinate_system. This is, for how we constructed the example, showing the - original image as it would appear its intrinsic coordinate system (since the transformation that maps the - original image to "global" is an identity. It then shows how the data showed at the point 5), localizes in the - original image. - """ - ## - # in this test let's try some affine transformations, we could do that also for the other tests - # image = scipy.misc.face()[:100, :100, :].copy() - image = np.random.randint(low=10, high=100, size=(100, 100)) - multiscale_image = np.repeat(np.repeat(image, 4, axis=0), 4, axis=1) - - # y: [5, 9], x: [0, 4] has value 1 - image[50:, :50] = 2 - # labels_element = Image2DModel.parse(image, dims=('y', 'x', 'c')) - labels_element = Labels2DModel.parse(image) - affine = Affine( - np.array( - [ - [np.cos(np.pi / 6), np.sin(-np.pi / 6), 0], - [np.sin(np.pi / 6), np.cos(np.pi / 6), 0], - [0, 0, 1], - ] - ), - input_axes=("x", "y"), - output_axes=("x", "y"), - ) - set_transformation( - labels_element, - affine, - "rotated", - ) - - # bounding box: y: [5, 9], x: [0, 4] - labels_result_rotated = bounding_box_query( - labels_element, - axes=("y", "x"), - min_coordinate=np.array([25, 25]), - max_coordinate=np.array([75, 100]), - target_coordinate_system="rotated", - ) - labels_result_global = bounding_box_query( - labels_element, - axes=("y", "x"), - min_coordinate=np.array([25, 25]), - max_coordinate=np.array([75, 100]), - target_coordinate_system="global", - ) - from napari_spatialdata import Interactive - - from spatialdata import SpatialData - - old_transformation = get_transformation(labels_result_global, "global") - remove_transformation(labels_result_global, "global") - set_transformation(labels_result_global, old_transformation, "rotated") - d = { - "1 cropped_global": labels_result_global, - "0 original": labels_element, - } - if labels_result_rotated is not None: - d["2 cropped_rotated"] = labels_result_rotated - - assert isinstance(labels_result_rotated, SpatialImage) - transform = labels_result_rotated.attrs["transform"]["rotated"] - transform_rotated_processed = transform.transform(labels_result_rotated, maintain_positioning=True) - transform_rotated_processed_recropped = bounding_box_query( - transform_rotated_processed, - axes=("y", "x"), - min_coordinate=np.array([25, 25]), - max_coordinate=np.array([75, 100]), - target_coordinate_system="rotated", - ) - d["3 cropped_rotated_processed_recropped"] = transform_rotated_processed_recropped - remove_transformation(labels_result_rotated, "global") - - multiscale_image[200:, :200] = 2 - # multiscale_labels = Labels2DModel.parse(multiscale_image) - multiscale_labels = Labels2DModel.parse(multiscale_image, scale_factors=[2, 2, 2, 2]) - sequence = Sequence([Scale([0.25, 0.25], axes=("x", "y")), affine]) - set_transformation(multiscale_labels, sequence, "rotated") - - from spatialdata._core._rasterize import rasterize - - rasterized = rasterize( - multiscale_labels, - axes=("y", "x"), - min_coordinate=np.array([25, 25]), - max_coordinate=np.array([75, 100]), - target_coordinate_system="rotated", - target_width=300, - ) - d["4 rasterized"] = rasterized - - sdata = SpatialData(labels=d) - - # to see only what matters when debugging https://github.com/scverse/spatialdata/issues/165 - del sdata.labels["1 cropped_global"] - del sdata.labels["2 cropped_rotated"] - del sdata.labels["3 cropped_rotated_processed_recropped"] - del sdata.labels["0 original"].attrs["transform"]["global"] - - Interactive(sdata) - ## - - -if __name__ == "__main__": - _visualize_crop_affine_labels_2d() diff --git a/pyproject.toml b/pyproject.toml index c79389e2..bd3fe05c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,6 +58,9 @@ test = [ "pytest", "pytest-cov", ] +optional = [ + "torch" +] [tool.coverage.run] source = ["spatialdata"] diff --git a/spatialdata/__init__.py b/spatialdata/__init__.py index 668b34a6..26f9c761 100644 --- a/spatialdata/__init__.py +++ b/spatialdata/__init__.py @@ -1,4 +1,5 @@ from importlib.metadata import version +from typing import Union __version__ = version("spatialdata") @@ -40,3 +41,10 @@ TableModel, ) from spatialdata._io.read import read_zarr + +try: + from spatialdata._dataloader.datasets import ImageTilesDataset +except ImportError as e: + _error: Union[str, None] = str(e) +else: + _error = None diff --git a/spatialdata/_core/_rasterize.py b/spatialdata/_core/_rasterize.py index 10bc2e56..545412c5 100644 --- a/spatialdata/_core/_rasterize.py +++ b/spatialdata/_core/_rasterize.py @@ -292,8 +292,6 @@ def _get_xarray_data_to_rasterize( corrected_affine, _ = _get_corrected_affine_matrix( data=SpatialImage(xdata), axes=axes, - min_coordinate=min_coordinate, - max_coordinate=max_coordinate, target_coordinate_system=target_coordinate_system, ) m = corrected_affine.inverse().matrix # type: ignore[attr-defined] @@ -338,15 +336,20 @@ def _get_xarray_data_to_rasterize( def _get_corrected_affine_matrix( data: Union[SpatialImage, MultiscaleSpatialImage], axes: tuple[str, ...], - min_coordinate: ArrayLike, - max_coordinate: ArrayLike, target_coordinate_system: str, ) -> tuple[Affine, tuple[str, ...]]: """ - TODO: docstring + Get the affine matrix that maps the intrinsic coordinates of the data to the target_coordinate_system, + with in addition: + - restricting the domain to the axes specified in axes (i.e. the axes for which the bounding box is specified), in + particular axes never contains c; + - restricting the codomain to the spatial axes of the target coordinate system (i.e. excluding c). + + We do this because: + - we don't need to consider c + - when we create the target rasterized object, we need to have axes in the order that is requires by the schema """ transformation = get_transformation(data, target_coordinate_system) - get_dims(data) assert isinstance(transformation, BaseTransformation) affine = _get_affine_for_element(data, transformation) target_axes_unordered = affine.output_axes @@ -363,8 +366,8 @@ def _get_corrected_affine_matrix( else: target_axes = ("y", "x") target_spatial_axes = get_spatial_axes(target_axes) - assert len(target_spatial_axes) == len(min_coordinate) - assert len(target_spatial_axes) == len(max_coordinate) + assert len(target_spatial_axes) == len(axes) + assert len(target_spatial_axes) == len(axes) corrected_affine = affine.to_affine(input_axes=axes, output_axes=target_spatial_axes) return corrected_affine, target_axes @@ -427,8 +430,6 @@ def _( corrected_affine, target_axes = _get_corrected_affine_matrix( data=data, axes=axes, - min_coordinate=min_coordinate, - max_coordinate=max_coordinate, target_coordinate_system=target_coordinate_system, ) diff --git a/spatialdata/_dl/__init__.py b/spatialdata/_dataloader/__init__.py similarity index 100% rename from spatialdata/_dl/__init__.py rename to spatialdata/_dataloader/__init__.py diff --git a/spatialdata/_dl/datasets.py b/spatialdata/_dataloader/datasets.py similarity index 82% rename from spatialdata/_dl/datasets.py rename to spatialdata/_dataloader/datasets.py index a95e97dd..76b5fa93 100644 --- a/spatialdata/_dl/datasets.py +++ b/spatialdata/_dataloader/datasets.py @@ -1,5 +1,3 @@ -from typing import Callable, Optional - import numpy as np from geopandas import GeoDataFrame from multiscale_spatial_image import MultiscaleSpatialImage @@ -30,8 +28,28 @@ def __init__( tile_dim_in_units: float, tile_dim_in_pixels: int, target_coordinate_system: str = "global", - transform: Optional[Callable[[SpatialData], dict[str, SpatialImage]]] = None, + # unused at the moment, see + # transform: Optional[Callable[[SpatialData], dict[str, SpatialImage]]] = None, ): + """ + Torch Dataset that returns image tiles around regions from a SpatialData object. + + Parameters + ---------- + sdata + The SpatialData object containing the regions and images from which to extract the tiles from. + regions_to_images + A dictionary mapping the regions element key we want to extract the tiles around to the images element key + we want to get the image data from. + tile_dim_in_units + The dimension of the requested tile in the units of the target coordinate system. This specifies the extent + of the image each tile is querying. This is not related he size in pixel of each returned tile. + tile_dim_in_pixels + The dimension of the requested tile in pixels. This specifies the size of the output tiles that we will get, + independently of which extent of the image the tile is covering. + target_coordinate_system + The coordinate system in which the tile_dim_in_units is specified. + """ # TODO: we can extend this code to support: # - automatic dermination of the tile_dim_in_pixels to match the image resolution (prevent down/upscaling) # - use the bounding box query instead of the raster function if the user wants @@ -39,7 +57,7 @@ def __init__( self.regions_to_images = regions_to_images self.tile_dim_in_units = tile_dim_in_units self.tile_dim_in_pixels = tile_dim_in_pixels - self.transform = transform + # self.transform = transform self.target_coordinate_system = target_coordinate_system self.n_spots_dict = self._compute_n_spots_dict() diff --git a/spatialdata/_dl/transforms.py b/spatialdata/_dataloader/transforms.py similarity index 100% rename from spatialdata/_dl/transforms.py rename to spatialdata/_dataloader/transforms.py From 7a27ef072abc0d072a41c8df7761b30c0b52a717 Mon Sep 17 00:00:00 2001 From: Luca Marconato <2664412+LucaMarconato@users.noreply.github.com> Date: Tue, 14 Mar 2023 12:06:38 +0100 Subject: [PATCH 14/16] fixed test --- tests/_core/test_spatialdata_operations.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/_core/test_spatialdata_operations.py b/tests/_core/test_spatialdata_operations.py index dc6b8c8b..3ef2cd36 100644 --- a/tests/_core/test_spatialdata_operations.py +++ b/tests/_core/test_spatialdata_operations.py @@ -218,10 +218,12 @@ def test_locate_spatial_element(full_sdata): def test_get_item(points): assert id(points["points_0"]) == id(points.points["points_0"]) - # this should be illegal: https://github.com/scverse/spatialdata/issues/176 - points.images["points_0"] = Image2DModel.parse(np.array([[[1]]]), dims=("c", "y", "x")) - with pytest.raises(AssertionError): - _ = points["points_0"] + # removed this test after this change: https://github.com/scverse/spatialdata/pull/145#discussion_r1133122720 + # to be uncommented/removed/modified after this is closed: https://github.com/scverse/spatialdata/issues/186 + # # this should be illegal: https://github.com/scverse/spatialdata/issues/176 + # points.images["points_0"] = Image2DModel.parse(np.array([[[1]]]), dims=("c", "y", "x")) + # with pytest.raises(AssertionError): + # _ = points["points_0"] with pytest.raises(KeyError): _ = points["not_present"] From 6f4aa7c7d0ecea7fd3f53d92f7a0dd6359d6e10b Mon Sep 17 00:00:00 2001 From: Luca Marconato <2664412+LucaMarconato@users.noreply.github.com> Date: Tue, 14 Mar 2023 12:24:11 +0100 Subject: [PATCH 15/16] removed numpy=1.22 contraint for mypy --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c7eb7939..142a44b2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -28,7 +28,7 @@ repos: rev: v1.1.1 hooks: - id: mypy - additional_dependencies: [numpy==1.22.0, types-requests] + additional_dependencies: [numpy, types-requests] exclude: tests/|docs/|temp/|spatialdata/_core/reader.py - repo: https://github.com/asottile/yesqa rev: v1.4.0 From 41d818a9b9150c14a105be50adc7486ff61c2bb0 Mon Sep 17 00:00:00 2001 From: Luca Marconato <2664412+LucaMarconato@users.noreply.github.com> Date: Tue, 14 Mar 2023 12:27:26 +0100 Subject: [PATCH 16/16] mypy now using numpy==1.24 --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 142a44b2..ee1e4fa4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -28,7 +28,7 @@ repos: rev: v1.1.1 hooks: - id: mypy - additional_dependencies: [numpy, types-requests] + additional_dependencies: [numpy==1.24, types-requests] exclude: tests/|docs/|temp/|spatialdata/_core/reader.py - repo: https://github.com/asottile/yesqa rev: v1.4.0