Skip to content

Commit

Permalink
rasterize bins labels
Browse files Browse the repository at this point in the history
  • Loading branch information
ArneDefauw committed Dec 17, 2024
1 parent 7d2be3e commit 5c3edee
Show file tree
Hide file tree
Showing 2 changed files with 200 additions and 62 deletions.
111 changes: 69 additions & 42 deletions src/spatialdata/_core/operations/rasterize_bins.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from xarray import DataArray

from spatialdata._core.query.relational_query import get_values
from spatialdata._logging import logger
from spatialdata._types import ArrayLike
from spatialdata.models import Image2DModel, Labels2DModel, get_table_keys
from spatialdata.transformations import Affine, Sequence, get_transformation
Expand Down Expand Up @@ -54,10 +55,10 @@ def rasterize_bins(
If `None`, all the var names will be used, and the returned object will be lazily constructed.
Ignored if `return_region_as_labels` is `True`.
return_regions_as_labels
If `True` this function returns a lazy spatial image of shape `(c, y, x)` with dimension of `c` equal to
If `False` this function returns a lazy spatial image of shape `(c, y, x)` with dimension of `c` equal to
the number of key(s) specified in `value_key`,
or the number of var names in `table_name` if `value_key` is None.
If `False`, will return labels of shape `(y,x)`,
or the number of var names in `table_name` if `value_key` is `None`.
If `True`, will return labels of shape `(y,x)`,
which will be the raster equivalent of bins specified in `bins`.
Returns
Expand All @@ -81,64 +82,90 @@ def rasterize_bins(
"""
element = sdata[bins]
table = sdata.tables[table_name]
if not isinstance(element, GeoDataFrame | DaskDataFrame):
raise ValueError("The bins should be a GeoDataFrame or a DaskDataFrame.")
if not isinstance(element, GeoDataFrame | DaskDataFrame | DataArray):
raise ValueError("The bins should be a GeoDataFrame, a DaskDataFrame or a DataArray.")
if isinstance(element, DataArray):
if "c" in element.dims:
raise ValueError(
"If bins is a DataArray, it should hold labels. "
f"But found associated dimension containing 'c': {element.dims}."
)
if not np.issubdtype(element.dtype, np.integer):
raise ValueError(f"If bins is a DataArray, it should hold integers. Found dtype {element.dtype}.")

_, region_key, instance_key = get_table_keys(table)
if not table.obs[region_key].dtype == "category":
raise ValueError(f"Please convert `table.obs['{region_key}']` to a category series to improve performances")
unique_regions = table.obs[region_key].cat.categories
if len(unique_regions) > 1 or unique_regions[0] != bins:
if len(unique_regions) > 1:
raise ValueError(f"Found multiple regions annotated by the table: {', '.join(list(unique_regions))}.")
if unique_regions[0] != bins:
raise ValueError("The table should be associated with the specified bins.")

if isinstance(element, DataArray) and return_region_as_labels:
raise ValueError(
"The table should be associated with the specified bins. "
f"Found multiple regions annotated by the table: {', '.join(list(unique_regions))}."
f"bins is already a labels layer that annotates the table '{table_name}'. "
"Consider setting 'return_region_as_labels' to 'False' to create a lazy spatial image."
)

min_row, min_col = table.obs[row_key].min(), table.obs[col_key].min()
n_rows, n_cols = table.obs[row_key].max() - min_row + 1, table.obs[col_key].max() - min_col + 1
y = (table.obs[row_key] - min_row).values
x = (table.obs[col_key] - min_col).values

# get the transformation
if table.n_obs < 6:
raise ValueError("At least 6 bins are needed to estimate the transformation.")

random_indices = RNG.choice(table.n_obs, min(20, table.n_obs), replace=True)
location_ids = table.obs[instance_key].iloc[random_indices].values
sub_df, sub_table = element.loc[location_ids], table[random_indices]

src = np.stack([sub_table.obs[col_key] - min_col, sub_table.obs[row_key] - min_row], axis=1)
if isinstance(sub_df, GeoDataFrame):
if isinstance(sub_df.iloc[0].geometry, Point):
sub_x = sub_df.geometry.x.values
sub_y = sub_df.geometry.y.values
else:
assert isinstance(sub_df.iloc[0].geometry, Polygon | MultiPolygon)
sub_x = sub_df.centroid.x
sub_y = sub_df.centroid.y
if isinstance(element, DataArray):
transformations = get_transformation(element, get_all=True)
else:
assert isinstance(sub_df, DaskDataFrame)
sub_x = sub_df.x.compute().values
sub_y = sub_df.y.compute().values
dst = np.stack([sub_x, sub_y], axis=1)

to_bins = Sequence(
[
Affine(
estimate_transform(ttype="affine", src=src, dst=dst).params,
input_axes=("x", "y"),
output_axes=("x", "y"),
)
]
)
bins_transformations = get_transformation(element, get_all=True)
# get the transformation
if table.n_obs < 6:
raise ValueError("At least 6 bins are needed to estimate the transformation.")

random_indices = RNG.choice(table.n_obs, min(20, table.n_obs), replace=True)
location_ids = table.obs[instance_key].iloc[random_indices].values
sub_df, sub_table = element.loc[location_ids], table[random_indices]

src = np.stack([sub_table.obs[col_key] - min_col, sub_table.obs[row_key] - min_row], axis=1)
if isinstance(sub_df, GeoDataFrame):
if isinstance(sub_df.iloc[0].geometry, Point):
sub_x = sub_df.geometry.x.values
sub_y = sub_df.geometry.y.values
else:
assert isinstance(sub_df.iloc[0].geometry, Polygon | MultiPolygon)
sub_x = sub_df.centroid.x
sub_y = sub_df.centroid.y
else:
assert isinstance(sub_df, DaskDataFrame)
sub_x = sub_df.x.compute().values
sub_y = sub_df.y.compute().values
dst = np.stack([sub_x, sub_y], axis=1)

to_bins = Sequence(
[
Affine(
estimate_transform(ttype="affine", src=src, dst=dst).params,
input_axes=("x", "y"),
output_axes=("x", "y"),
)
]
)
bins_transformations = get_transformation(element, get_all=True)

assert isinstance(bins_transformations, dict)
assert isinstance(bins_transformations, dict)

transformations = {cs: to_bins.compose_with(t) for cs, t in bins_transformations.items()}
transformations = {cs: to_bins.compose_with(t) for cs, t in bins_transformations.items()}

if return_region_as_labels:
dtype = _get_uint_dtype(table.obs[instance_key].max())
_min_value = table.obs[instance_key].min()
if _min_value == 0:
logger.info(
f"Minimum value of the instance key column ('table.obs[{instance_key}]') is 0. "
"Since the label 0 is reserved for the background, "
"both the instance key column in 'table.obs' "
f"and the index of the annotating element '{bins}' is incremented by 1."
)
table.obs[instance_key] += 1
element.index += 1
labels_element = np.zeros((n_rows, n_cols), dtype=dtype)
# make labels layer that can visualy represent the cells
labels_element[y, x] = table.obs[instance_key].values.T
Expand Down
151 changes: 131 additions & 20 deletions tests/core/operations/test_rasterize_bins.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import re

import numpy as np
import pytest
from anndata import AnnData
Expand All @@ -13,7 +15,7 @@
from spatialdata._core.operations.rasterize_bins import rasterize_bins
from spatialdata._core.spatialdata import SpatialData
from spatialdata._types import ArrayLike
from spatialdata.models.models import Labels2DModel, PointsModel, ShapesModel, TableModel
from spatialdata.models.models import Image2DModel, Labels2DModel, PointsModel, ShapesModel, TableModel
from spatialdata.transformations.transformations import Scale

RNG = default_rng(0)
Expand All @@ -30,7 +32,8 @@ def _get_bins_data(n: int) -> tuple[ArrayLike, ArrayLike, ArrayLike]:

@pytest.mark.parametrize("geometry", ["points", "circles", "squares"])
@pytest.mark.parametrize("value_key", [None, "instance_id", ["gene0", "gene1"]])
def test_rasterize_bins(geometry: str, value_key: str | list[str] | None):
@pytest.mark.parametrize("return_region_as_labels", [True, False])
def test_rasterize_bins(geometry: str, value_key: str | list[str] | None, return_region_as_labels: bool):
n = 10
data, x, y = _get_bins_data(n)
scale = Scale([2.0], axes=("x",))
Expand Down Expand Up @@ -64,28 +67,78 @@ def test_rasterize_bins(geometry: str, value_key: str | list[str] | None):
col_key="col_index",
row_key="row_index",
value_key=value_key,
return_region_as_labels=return_region_as_labels,
)
points_extent = get_extent(points)
raster_extent = get_extent(rasterized)
# atol can be set tighter when https://github.com/scverse/spatialdata/issues/165 is addressed
assert are_extents_equal(points_extent, raster_extent, atol=2)

# if regions are returned as labels, we can annotate the table with 'rasterized',
# which is a labels layer containing the bins, and then run rasterize_bins again
# but now with return_region_as_labels set to False to get a lazy image.
if return_region_as_labels:
labels_name = "labels"
sdata[labels_name] = rasterized
adata = sdata["table"]
adata.obs["region"] = labels_name
adata.obs["region"] = adata.obs["region"].astype("category")
del adata.uns[TableModel.ATTRS_KEY]
adata = TableModel.parse(
adata,
region=labels_name,
region_key="region",
instance_key="instance_id",
)
del sdata["table"]
sdata["table"] = adata
# this fails because table already annotated by labels layer
with pytest.raises(
ValueError,
match="bins is already a labels layer that annotates the table 'table'. "
"Consider setting 'return_region_as_labels' to 'False' to create a lazy spatial image.",
):
_ = rasterize_bins(
sdata,
bins="labels",
table_name="table",
col_key="col_index",
row_key="row_index",
value_key=value_key,
return_region_as_labels=True,
)

# but we want to be able to create the lazy raster even if the table is already annotated by a labels layer
rasterized = rasterize_bins(
sdata,
bins="labels",
table_name="table",
col_key="col_index",
row_key="row_index",
value_key=value_key,
return_region_as_labels=False,
)
raster_extent = get_extent(rasterized)
assert are_extents_equal(points_extent, raster_extent, atol=2)


def test_rasterize_bins_invalid():
n = 2
data, x, y = _get_bins_data(n)
points = PointsModel.parse(data)
obs = DataFrame(
data={"region": ["points"] * n * n, "instance_id": np.arange(n * n), "col_index": x, "row_index": y}
)
table = TableModel.parse(
AnnData(X=RNG.normal(size=(n * n, 2)), obs=obs),
region="points",
region_key="region",
instance_key="instance_id",
)
sdata = SpatialData.init_from_elements({"points": points}, tables={"table": table})
def _get_sdata(n: int):
data, x, y = _get_bins_data(n)
points = PointsModel.parse(data)
obs = DataFrame(
data={"region": ["points"] * n * n, "instance_id": np.arange(n * n), "col_index": x, "row_index": y}
)
table = TableModel.parse(
AnnData(X=RNG.normal(size=(n * n, 2)), obs=obs),
region="points",
region_key="region",
instance_key="instance_id",
)
return SpatialData.init_from_elements({"points": points}, tables={"table": table})

# sdata with not enough bins (2*2) to estimate transformation
sdata = _get_sdata(n=2)
# not enough points
with pytest.raises(ValueError, match="At least 6 bins are needed to estimate the transformation."):
_ = rasterize_bins(
Expand All @@ -98,6 +151,8 @@ def test_rasterize_bins_invalid():
)

# the matrix should be a csc_matrix or a full matrix; in particular not a csr_matrix
sdata = _get_sdata(n=3)
table = sdata.tables["table"]
table.X = csr_matrix(table.X)
with pytest.raises(
ValueError,
Expand All @@ -120,8 +175,24 @@ def test_rasterize_bins_invalid():
table.obs["region"] = regions
with pytest.raises(
ValueError,
match="The table should be associated with the specified bins. Found multiple regions annotated by the table: "
"points, shapes.",
match="Found multiple regions annotated by the table: " "points, shapes.",
):
_ = rasterize_bins(
sdata=sdata,
bins="points",
table_name="table",
col_key="col_index",
row_key="row_index",
value_key="instance_id",
)
# table annotating wrong element
sdata = _get_sdata(n=3)
table = sdata.tables["table"]
table.obs["region"] = "shapes"
table.obs["region"] = table.obs["region"].astype("category")
with pytest.raises(
ValueError,
match="The table should be associated with the specified bins.",
):
_ = rasterize_bins(
sdata=sdata,
Expand All @@ -133,6 +204,8 @@ def test_rasterize_bins_invalid():
)

# region_key should be categorical
sdata = _get_sdata(n=3)
table = sdata.tables["table"]
table.obs["region"] = table.obs["region"].astype(str)
with pytest.raises(ValueError, match="Please convert `table.obs.*` to a category series to improve performances"):
_ = rasterize_bins(
Expand All @@ -144,11 +217,49 @@ def test_rasterize_bins_invalid():
value_key="instance_id",
)

# the element to rasterize should be a GeoDataFrame or a DaskDataFrame
image = Labels2DModel.parse(RNG.normal(size=(n, n)))
# the element to rasterize should be a GeoDataFrame, a DaskDataFrame or a DataArray holding labels
sdata = _get_sdata(n=3)
with pytest.raises(
ValueError,
match="The bins should be a GeoDataFrame, a DaskDataFrame or a DataArray.",
):
_ = rasterize_bins(
sdata=sdata,
bins="table",
table_name="table",
col_key="col_index",
row_key="row_index",
value_key="instance_id",
)

# if bins is a DataArray it should contain labels
image = Image2DModel.parse(RNG.integers(low=0, high=10, size=(1, 3, 3)), dims=("c", "y", "x"))
del sdata["points"]
sdata["points"] = image
with pytest.raises(
ValueError,
match=re.escape(
f"If bins is a DataArray, it should hold labels. "
f"But found associated dimension containing 'c': {image.dims}."
),
):
_ = rasterize_bins(
sdata=sdata,
bins="points",
table_name="table",
col_key="col_index",
row_key="row_index",
value_key="instance_id",
)

# if bins is a DataArray, it should hold integers
image = Labels2DModel.parse(RNG.normal(size=(3, 3)), dims=("y", "x"))
del sdata["points"]
sdata["points"] = image
with pytest.raises(ValueError, match="The bins should be a GeoDataFrame or a DaskDataFrame."):
with pytest.raises(
ValueError,
match=f"If bins is a DataArray, it should hold integers. Found dtype {image.dtype}.",
):
_ = rasterize_bins(
sdata=sdata,
bins="points",
Expand Down

0 comments on commit 5c3edee

Please sign in to comment.