Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Visium hd #211

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
76 changes: 72 additions & 4 deletions src/spatialdata_io/readers/visium_hd.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
from geopandas import GeoDataFrame
from imageio import imread as imread2
from multiscale_spatial_image import MultiscaleSpatialImage
from numpy.random import default_rng
from spatial_image import SpatialImage
from spatialdata import SpatialData
from spatialdata import SpatialData, rasterize_bins
from spatialdata.models import Image2DModel, ShapesModel, TableModel
from spatialdata.transformations import (
Affine,
Expand All @@ -31,6 +32,8 @@
from spatialdata_io._constants._constants import VisiumHDKeys
from spatialdata_io._docs import inject_docs

RNG = default_rng(0)


@inject_docs(vx=VisiumHDKeys)
def visium_hd(
Expand All @@ -39,6 +42,7 @@ def visium_hd(
filtered_counts_file: bool = True,
bin_size: int | list[int] | None = None,
bins_as_squares: bool = True,
annotate_table_by_labels: bool = False,
fullres_image_file: str | Path | None = None,
load_all_images: bool = False,
imread_kwargs: Mapping[str, Any] = MappingProxyType({}),
Expand Down Expand Up @@ -67,6 +71,9 @@ def visium_hd(
bins_as_squares
If `True`, the bins are represented as squares. If `False`, the bins are represented as circles. For a correct
visualization one should use squares.
annotate_table_by_labels
If `True`, the tables will annotate labels layers representing the bins, if `False`, the tables will annotate
shapes layer.
fullres_image_file
Path to the full-resolution image. By default the image is searched in the ``{vx.MICROSCOPE_IMAGE!r}``
directory.
Expand All @@ -89,6 +96,7 @@ def visium_hd(
tables = {}
shapes = {}
images: dict[str, Any] = {}
labels: dict[str, Any] = {}

if dataset_id is None:
dataset_id = _infer_dataset_id(path)
Expand Down Expand Up @@ -189,7 +197,14 @@ def _get_bins(path_bins: Path) -> list[str]:
VisiumHDKeys.LOCATIONS_X,
]
)
assert isinstance(coords.index, pd.RangeIndex)
dtype = _get_uint_dtype(coords.index.stop)

coords = coords.reset_index().rename(columns={"index": VisiumHDKeys.INSTANCE_KEY})
coords[VisiumHDKeys.INSTANCE_KEY] = coords[VisiumHDKeys.INSTANCE_KEY].astype(dtype)

coords.set_index(VisiumHDKeys.BARCODE, inplace=True, drop=True)

coords_filtered = coords.loc[adata.obs.index]
adata.obs = pd.merge(adata.obs, coords_filtered, how="left", left_index=True, right_index=True)
# compatibility to legacy squidpy
Expand All @@ -202,7 +217,6 @@ def _get_bins(path_bins: Path) -> list[str]:
],
inplace=True,
)
adata.obs[VisiumHDKeys.INSTANCE_KEY] = np.arange(len(adata))

# scaling
transform_original = Identity()
Expand Down Expand Up @@ -247,7 +261,6 @@ def _get_bins(path_bins: Path) -> list[str]:
GeoDataFrame(geometry=squares_series), transformations=transformations
)

# parse table
adata.obs[VisiumHDKeys.REGION_KEY] = shapes_name
adata.obs[VisiumHDKeys.REGION_KEY] = adata.obs[VisiumHDKeys.REGION_KEY].astype("category")

Expand Down Expand Up @@ -347,7 +360,46 @@ def _get_bins(path_bins: Path) -> list[str]:
affine1 = transform_matrices["spot_colrow_to_microscope_colrow"]
set_transformation(image, Sequence([affine0, affine1]), "global")

return SpatialData(tables=tables, images=images, shapes=shapes)
sdata = SpatialData(tables=tables, images=images, shapes=shapes, labels=labels)

if annotate_table_by_labels:
for bin_size_str in bin_sizes:

shapes_name = dataset_id + "_" + bin_size_str

# add labels layer (rasterized bins).
labels_name = f"{dataset_id}_{bin_size_str}_labels"

labels_element = rasterize_bins(
sdata,
bins=shapes_name,
table_name=bin_size_str,
row_key=VisiumHDKeys.ARRAY_ROW,
col_key=VisiumHDKeys.ARRAY_COL,
value_key=None,
return_region_as_labels=True,
)

sdata[labels_name] = labels_element

adata = sdata[bin_size_str]

adata.obs[VisiumHDKeys.REGION_KEY] = labels_name
adata.obs[VisiumHDKeys.REGION_KEY] = adata.obs[VisiumHDKeys.REGION_KEY].astype("category")

del adata.uns[TableModel.ATTRS_KEY]

adata = TableModel.parse(
adata,
region=labels_name,
region_key=str(VisiumHDKeys.REGION_KEY),
instance_key=str(VisiumHDKeys.INSTANCE_KEY),
)

del sdata[bin_size_str]
sdata[bin_size_str] = adata

return sdata


def _infer_dataset_id(path: Path) -> str:
Expand Down Expand Up @@ -422,3 +474,19 @@ def _get_transform_matrices(metadata: dict[str, Any], hd_layout: dict[str, Any])
transform_matrices[key.value] = _get_affine(data)

return transform_matrices


def _get_uint_dtype(value: int) -> str:
max_uint64 = np.iinfo(np.uint64).max
max_uint32 = np.iinfo(np.uint32).max
max_uint16 = np.iinfo(np.uint16).max

if max_uint16 >= value:
dtype = "uint16"
elif max_uint32 >= value:
dtype = "uint32"
elif max_uint64 >= value:
dtype = "uint64"
else:
raise ValueError(f"Maximum cell number is {value}. Values higher than {max_uint64} are not supported.")
return dtype
Loading