diff --git a/src/spatialdata/dataloader/datasets.py b/src/spatialdata/dataloader/datasets.py index 271efbb4..9d85c79f 100644 --- a/src/spatialdata/dataloader/datasets.py +++ b/src/spatialdata/dataloader/datasets.py @@ -10,8 +10,10 @@ import pandas as pd from anndata import AnnData from geopandas import GeoDataFrame +from multiscale_spatial_image import MultiscaleSpatialImage from pandas import CategoricalDtype from scipy.sparse import issparse +from spatial_image import SpatialImage from torch.utils.data import Dataset from spatialdata._core.centroids import get_centroids @@ -267,6 +269,14 @@ def _preprocess( assert np.all([i in self.tiles_coords for i in dims_]) self.dims = list(dims_) + @staticmethod + def _ensure_single_scale(data: SpatialImage | MultiscaleSpatialImage) -> SpatialImage: + if isinstance(data, SpatialImage): + return data + if isinstance(data, MultiscaleSpatialImage): + return SpatialImage(next(iter(data["scale0"].ds.values()))) + raise ValueError(f"Expected a SpatialImage or MultiscaleSpatialImage, got {type(data)}.") + def _get_return( self, return_annot: str | list[str] | None, @@ -279,11 +289,14 @@ def _get_return( return_annot = [return_annot] if isinstance(return_annot, str) else return_annot # return tuple of (tile, table) if np.all([i in self.dataset_table.obs for i in return_annot]): - return lambda x, tile: (tile, self.dataset_table.obs[return_annot].iloc[x].values.reshape(1, -1)) + return lambda x, tile: ( + self._ensure_single_scale(tile), + self.dataset_table.obs[return_annot].iloc[x].values.reshape(1, -1), + ) if np.all([i in self.dataset_table.var_names for i in return_annot]): if issparse(self.dataset_table.X): - return lambda x, tile: (tile, self.dataset_table[x, return_annot].X.A) - return lambda x, tile: (tile, self.dataset_table[x, return_annot].X) + return lambda x, tile: (self._ensure_single_scale(tile), self.dataset_table[x, return_annot].X.A) + return lambda x, tile: (self._ensure_single_scale(tile), self.dataset_table[x, return_annot].X) raise ValueError( f"If `return_annot` is a `str`, it must be a column name in the table or a variable name in the table. " f"If it is a `list` of `str`, each element should be as above, and they should all be entirely in obs " @@ -292,10 +305,12 @@ def _get_return( # return spatialdata consisting of the image tile and, if available, the associated table if table_name: return lambda x, tile: SpatialData( - images={self.dataset_index.iloc[x][self.IMAGE_KEY]: tile}, + images={self.dataset_index.iloc[x][self.IMAGE_KEY]: self._ensure_single_scale(tile)}, table=self.dataset_table[x], ) - return lambda x, tile: SpatialData(images={self.dataset_index.iloc[x][self.IMAGE_KEY]: tile}) + return lambda x, tile: SpatialData( + images={self.dataset_index.iloc[x][self.IMAGE_KEY]: self._ensure_single_scale(tile)} + ) def __len__(self) -> int: return len(self.dataset_index) diff --git a/tests/dataloader/test_datasets.py b/tests/dataloader/test_datasets.py index 40fd1563..cfd16312 100644 --- a/tests/dataloader/test_datasets.py +++ b/tests/dataloader/test_datasets.py @@ -132,13 +132,7 @@ def test_multiscale_images(self, sdata_blobs, rasterize: bool, return_annot): ) if return_annot is None: sdata_tile = ds[0] - if rasterize: - # rasterize transforms teh multiscale image into a single scale image - tile = sdata_tile["blobs_multiscale_image"] - else: - tile = next(iter(sdata_tile["blobs_multiscale_image"]["scale0"].ds.values())) + tile = sdata_tile["blobs_multiscale_image"] else: tile, annot = ds[0] - if not rasterize: - tile = next(iter(tile["scale0"].ds.values())) assert tile.shape == (3, 10, 10)