diff --git a/src/spatialdata_io/readers/visium_hd.py b/src/spatialdata_io/readers/visium_hd.py index 2c1a19d..c910ccf 100644 --- a/src/spatialdata_io/readers/visium_hd.py +++ b/src/spatialdata_io/readers/visium_hd.py @@ -16,8 +16,9 @@ from geopandas import GeoDataFrame from imageio import imread as imread2 from multiscale_spatial_image import MultiscaleSpatialImage +from numpy.random import default_rng from spatial_image import SpatialImage -from spatialdata import SpatialData +from spatialdata import SpatialData, rasterize_bins from spatialdata.models import Image2DModel, ShapesModel, TableModel from spatialdata.transformations import ( Affine, @@ -31,6 +32,8 @@ from spatialdata_io._constants._constants import VisiumHDKeys from spatialdata_io._docs import inject_docs +RNG = default_rng(0) + @inject_docs(vx=VisiumHDKeys) def visium_hd( @@ -39,6 +42,7 @@ def visium_hd( filtered_counts_file: bool = True, bin_size: int | list[int] | None = None, bins_as_squares: bool = True, + annotate_table_by_labels: bool = False, fullres_image_file: str | Path | None = None, load_all_images: bool = False, imread_kwargs: Mapping[str, Any] = MappingProxyType({}), @@ -67,6 +71,9 @@ def visium_hd( bins_as_squares If `True`, the bins are represented as squares. If `False`, the bins are represented as circles. For a correct visualization one should use squares. + annotate_table_by_labels + If `True`, the tables will annotate labels layers representing the bins, if `False`, the tables will annotate + shapes layer. fullres_image_file Path to the full-resolution image. By default the image is searched in the ``{vx.MICROSCOPE_IMAGE!r}`` directory. @@ -89,6 +96,7 @@ def visium_hd( tables = {} shapes = {} images: dict[str, Any] = {} + labels: dict[str, Any] = {} if dataset_id is None: dataset_id = _infer_dataset_id(path) @@ -189,7 +197,14 @@ def _get_bins(path_bins: Path) -> list[str]: VisiumHDKeys.LOCATIONS_X, ] ) + assert isinstance(coords.index, pd.RangeIndex) + dtype = _get_uint_dtype(coords.index.stop) + + coords = coords.reset_index().rename(columns={"index": VisiumHDKeys.INSTANCE_KEY}) + coords[VisiumHDKeys.INSTANCE_KEY] = coords[VisiumHDKeys.INSTANCE_KEY].astype(dtype) + coords.set_index(VisiumHDKeys.BARCODE, inplace=True, drop=True) + coords_filtered = coords.loc[adata.obs.index] adata.obs = pd.merge(adata.obs, coords_filtered, how="left", left_index=True, right_index=True) # compatibility to legacy squidpy @@ -202,7 +217,6 @@ def _get_bins(path_bins: Path) -> list[str]: ], inplace=True, ) - adata.obs[VisiumHDKeys.INSTANCE_KEY] = np.arange(len(adata)) # scaling transform_original = Identity() @@ -247,7 +261,6 @@ def _get_bins(path_bins: Path) -> list[str]: GeoDataFrame(geometry=squares_series), transformations=transformations ) - # parse table adata.obs[VisiumHDKeys.REGION_KEY] = shapes_name adata.obs[VisiumHDKeys.REGION_KEY] = adata.obs[VisiumHDKeys.REGION_KEY].astype("category") @@ -347,7 +360,46 @@ def _get_bins(path_bins: Path) -> list[str]: affine1 = transform_matrices["spot_colrow_to_microscope_colrow"] set_transformation(image, Sequence([affine0, affine1]), "global") - return SpatialData(tables=tables, images=images, shapes=shapes) + sdata = SpatialData(tables=tables, images=images, shapes=shapes, labels=labels) + + if annotate_table_by_labels: + for bin_size_str in bin_sizes: + + shapes_name = dataset_id + "_" + bin_size_str + + # add labels layer (rasterized bins). + labels_name = f"{dataset_id}_{bin_size_str}_labels" + + labels_element = rasterize_bins( + sdata, + bins=shapes_name, + table_name=bin_size_str, + row_key=VisiumHDKeys.ARRAY_ROW, + col_key=VisiumHDKeys.ARRAY_COL, + value_key=None, + return_region_as_labels=True, + ) + + sdata[labels_name] = labels_element + + adata = sdata[bin_size_str] + + adata.obs[VisiumHDKeys.REGION_KEY] = labels_name + adata.obs[VisiumHDKeys.REGION_KEY] = adata.obs[VisiumHDKeys.REGION_KEY].astype("category") + + del adata.uns[TableModel.ATTRS_KEY] + + adata = TableModel.parse( + adata, + region=labels_name, + region_key=str(VisiumHDKeys.REGION_KEY), + instance_key=str(VisiumHDKeys.INSTANCE_KEY), + ) + + del sdata[bin_size_str] + sdata[bin_size_str] = adata + + return sdata def _infer_dataset_id(path: Path) -> str: @@ -422,3 +474,19 @@ def _get_transform_matrices(metadata: dict[str, Any], hd_layout: dict[str, Any]) transform_matrices[key.value] = _get_affine(data) return transform_matrices + + +def _get_uint_dtype(value: int) -> str: + max_uint64 = np.iinfo(np.uint64).max + max_uint32 = np.iinfo(np.uint32).max + max_uint16 = np.iinfo(np.uint16).max + + if max_uint16 >= value: + dtype = "uint16" + elif max_uint32 >= value: + dtype = "uint32" + elif max_uint64 >= value: + dtype = "uint64" + else: + raise ValueError(f"Maximum cell number is {value}. Values higher than {max_uint64} are not supported.") + return dtype