Skip to content

Commit

Permalink
always returning a single-scale image
Browse files Browse the repository at this point in the history
  • Loading branch information
LucaMarconato committed Mar 20, 2024
1 parent a119850 commit e7aef86
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 12 deletions.
25 changes: 20 additions & 5 deletions src/spatialdata/dataloader/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
import pandas as pd
from anndata import AnnData
from geopandas import GeoDataFrame
from multiscale_spatial_image import MultiscaleSpatialImage
from pandas import CategoricalDtype
from scipy.sparse import issparse
from spatial_image import SpatialImage
from torch.utils.data import Dataset

from spatialdata._core.centroids import get_centroids
Expand Down Expand Up @@ -267,6 +269,14 @@ def _preprocess(
assert np.all([i in self.tiles_coords for i in dims_])
self.dims = list(dims_)

@staticmethod
def _ensure_single_scale(data: SpatialImage | MultiscaleSpatialImage) -> SpatialImage:
if isinstance(data, SpatialImage):
return data
if isinstance(data, MultiscaleSpatialImage):
return SpatialImage(next(iter(data["scale0"].ds.values())))
raise ValueError(f"Expected a SpatialImage or MultiscaleSpatialImage, got {type(data)}.")

def _get_return(
self,
return_annot: str | list[str] | None,
Expand All @@ -279,11 +289,14 @@ def _get_return(
return_annot = [return_annot] if isinstance(return_annot, str) else return_annot
# return tuple of (tile, table)
if np.all([i in self.dataset_table.obs for i in return_annot]):
return lambda x, tile: (tile, self.dataset_table.obs[return_annot].iloc[x].values.reshape(1, -1))
return lambda x, tile: (
self._ensure_single_scale(tile),
self.dataset_table.obs[return_annot].iloc[x].values.reshape(1, -1),
)
if np.all([i in self.dataset_table.var_names for i in return_annot]):
if issparse(self.dataset_table.X):
return lambda x, tile: (tile, self.dataset_table[x, return_annot].X.A)
return lambda x, tile: (tile, self.dataset_table[x, return_annot].X)
return lambda x, tile: (self._ensure_single_scale(tile), self.dataset_table[x, return_annot].X.A)
return lambda x, tile: (self._ensure_single_scale(tile), self.dataset_table[x, return_annot].X)
raise ValueError(
f"If `return_annot` is a `str`, it must be a column name in the table or a variable name in the table. "
f"If it is a `list` of `str`, each element should be as above, and they should all be entirely in obs "
Expand All @@ -292,10 +305,12 @@ def _get_return(
# return spatialdata consisting of the image tile and, if available, the associated table
if table_name:
return lambda x, tile: SpatialData(
images={self.dataset_index.iloc[x][self.IMAGE_KEY]: tile},
images={self.dataset_index.iloc[x][self.IMAGE_KEY]: self._ensure_single_scale(tile)},
table=self.dataset_table[x],
)
return lambda x, tile: SpatialData(images={self.dataset_index.iloc[x][self.IMAGE_KEY]: tile})
return lambda x, tile: SpatialData(
images={self.dataset_index.iloc[x][self.IMAGE_KEY]: self._ensure_single_scale(tile)}
)

def __len__(self) -> int:
return len(self.dataset_index)
Expand Down
8 changes: 1 addition & 7 deletions tests/dataloader/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,13 +132,7 @@ def test_multiscale_images(self, sdata_blobs, rasterize: bool, return_annot):
)
if return_annot is None:
sdata_tile = ds[0]
if rasterize:
# rasterize transforms teh multiscale image into a single scale image
tile = sdata_tile["blobs_multiscale_image"]
else:
tile = next(iter(sdata_tile["blobs_multiscale_image"]["scale0"].ds.values()))
tile = sdata_tile["blobs_multiscale_image"]
else:
tile, annot = ds[0]
if not rasterize:
tile = next(iter(tile["scale0"].ds.values()))
assert tile.shape == (3, 10, 10)

0 comments on commit e7aef86

Please sign in to comment.