From 5c3edee642bad1eaedbddf2097430315b964bb1a Mon Sep 17 00:00:00 2001 From: ArneDefauw Date: Tue, 17 Dec 2024 13:17:01 +0100 Subject: [PATCH] rasterize bins labels --- .../_core/operations/rasterize_bins.py | 111 ++++++++----- tests/core/operations/test_rasterize_bins.py | 151 +++++++++++++++--- 2 files changed, 200 insertions(+), 62 deletions(-) diff --git a/src/spatialdata/_core/operations/rasterize_bins.py b/src/spatialdata/_core/operations/rasterize_bins.py index 87c11740..f2a034fb 100644 --- a/src/spatialdata/_core/operations/rasterize_bins.py +++ b/src/spatialdata/_core/operations/rasterize_bins.py @@ -14,6 +14,7 @@ from xarray import DataArray from spatialdata._core.query.relational_query import get_values +from spatialdata._logging import logger from spatialdata._types import ArrayLike from spatialdata.models import Image2DModel, Labels2DModel, get_table_keys from spatialdata.transformations import Affine, Sequence, get_transformation @@ -54,10 +55,10 @@ def rasterize_bins( If `None`, all the var names will be used, and the returned object will be lazily constructed. Ignored if `return_region_as_labels` is `True`. return_regions_as_labels - If `True` this function returns a lazy spatial image of shape `(c, y, x)` with dimension of `c` equal to + If `False` this function returns a lazy spatial image of shape `(c, y, x)` with dimension of `c` equal to the number of key(s) specified in `value_key`, - or the number of var names in `table_name` if `value_key` is None. - If `False`, will return labels of shape `(y,x)`, + or the number of var names in `table_name` if `value_key` is `None`. + If `True`, will return labels of shape `(y,x)`, which will be the raster equivalent of bins specified in `bins`. Returns @@ -81,17 +82,30 @@ def rasterize_bins( """ element = sdata[bins] table = sdata.tables[table_name] - if not isinstance(element, GeoDataFrame | DaskDataFrame): - raise ValueError("The bins should be a GeoDataFrame or a DaskDataFrame.") + if not isinstance(element, GeoDataFrame | DaskDataFrame | DataArray): + raise ValueError("The bins should be a GeoDataFrame, a DaskDataFrame or a DataArray.") + if isinstance(element, DataArray): + if "c" in element.dims: + raise ValueError( + "If bins is a DataArray, it should hold labels. " + f"But found associated dimension containing 'c': {element.dims}." + ) + if not np.issubdtype(element.dtype, np.integer): + raise ValueError(f"If bins is a DataArray, it should hold integers. Found dtype {element.dtype}.") _, region_key, instance_key = get_table_keys(table) if not table.obs[region_key].dtype == "category": raise ValueError(f"Please convert `table.obs['{region_key}']` to a category series to improve performances") unique_regions = table.obs[region_key].cat.categories - if len(unique_regions) > 1 or unique_regions[0] != bins: + if len(unique_regions) > 1: + raise ValueError(f"Found multiple regions annotated by the table: {', '.join(list(unique_regions))}.") + if unique_regions[0] != bins: + raise ValueError("The table should be associated with the specified bins.") + + if isinstance(element, DataArray) and return_region_as_labels: raise ValueError( - "The table should be associated with the specified bins. " - f"Found multiple regions annotated by the table: {', '.join(list(unique_regions))}." + f"bins is already a labels layer that annotates the table '{table_name}'. " + "Consider setting 'return_region_as_labels' to 'False' to create a lazy spatial image." ) min_row, min_col = table.obs[row_key].min(), table.obs[col_key].min() @@ -99,46 +113,59 @@ def rasterize_bins( y = (table.obs[row_key] - min_row).values x = (table.obs[col_key] - min_col).values - # get the transformation - if table.n_obs < 6: - raise ValueError("At least 6 bins are needed to estimate the transformation.") - - random_indices = RNG.choice(table.n_obs, min(20, table.n_obs), replace=True) - location_ids = table.obs[instance_key].iloc[random_indices].values - sub_df, sub_table = element.loc[location_ids], table[random_indices] - - src = np.stack([sub_table.obs[col_key] - min_col, sub_table.obs[row_key] - min_row], axis=1) - if isinstance(sub_df, GeoDataFrame): - if isinstance(sub_df.iloc[0].geometry, Point): - sub_x = sub_df.geometry.x.values - sub_y = sub_df.geometry.y.values - else: - assert isinstance(sub_df.iloc[0].geometry, Polygon | MultiPolygon) - sub_x = sub_df.centroid.x - sub_y = sub_df.centroid.y + if isinstance(element, DataArray): + transformations = get_transformation(element, get_all=True) else: - assert isinstance(sub_df, DaskDataFrame) - sub_x = sub_df.x.compute().values - sub_y = sub_df.y.compute().values - dst = np.stack([sub_x, sub_y], axis=1) - - to_bins = Sequence( - [ - Affine( - estimate_transform(ttype="affine", src=src, dst=dst).params, - input_axes=("x", "y"), - output_axes=("x", "y"), - ) - ] - ) - bins_transformations = get_transformation(element, get_all=True) + # get the transformation + if table.n_obs < 6: + raise ValueError("At least 6 bins are needed to estimate the transformation.") + + random_indices = RNG.choice(table.n_obs, min(20, table.n_obs), replace=True) + location_ids = table.obs[instance_key].iloc[random_indices].values + sub_df, sub_table = element.loc[location_ids], table[random_indices] + + src = np.stack([sub_table.obs[col_key] - min_col, sub_table.obs[row_key] - min_row], axis=1) + if isinstance(sub_df, GeoDataFrame): + if isinstance(sub_df.iloc[0].geometry, Point): + sub_x = sub_df.geometry.x.values + sub_y = sub_df.geometry.y.values + else: + assert isinstance(sub_df.iloc[0].geometry, Polygon | MultiPolygon) + sub_x = sub_df.centroid.x + sub_y = sub_df.centroid.y + else: + assert isinstance(sub_df, DaskDataFrame) + sub_x = sub_df.x.compute().values + sub_y = sub_df.y.compute().values + dst = np.stack([sub_x, sub_y], axis=1) + + to_bins = Sequence( + [ + Affine( + estimate_transform(ttype="affine", src=src, dst=dst).params, + input_axes=("x", "y"), + output_axes=("x", "y"), + ) + ] + ) + bins_transformations = get_transformation(element, get_all=True) - assert isinstance(bins_transformations, dict) + assert isinstance(bins_transformations, dict) - transformations = {cs: to_bins.compose_with(t) for cs, t in bins_transformations.items()} + transformations = {cs: to_bins.compose_with(t) for cs, t in bins_transformations.items()} if return_region_as_labels: dtype = _get_uint_dtype(table.obs[instance_key].max()) + _min_value = table.obs[instance_key].min() + if _min_value == 0: + logger.info( + f"Minimum value of the instance key column ('table.obs[{instance_key}]') is 0. " + "Since the label 0 is reserved for the background, " + "both the instance key column in 'table.obs' " + f"and the index of the annotating element '{bins}' is incremented by 1." + ) + table.obs[instance_key] += 1 + element.index += 1 labels_element = np.zeros((n_rows, n_cols), dtype=dtype) # make labels layer that can visualy represent the cells labels_element[y, x] = table.obs[instance_key].values.T diff --git a/tests/core/operations/test_rasterize_bins.py b/tests/core/operations/test_rasterize_bins.py index a8eb0b03..913d02cf 100644 --- a/tests/core/operations/test_rasterize_bins.py +++ b/tests/core/operations/test_rasterize_bins.py @@ -1,5 +1,7 @@ from __future__ import annotations +import re + import numpy as np import pytest from anndata import AnnData @@ -13,7 +15,7 @@ from spatialdata._core.operations.rasterize_bins import rasterize_bins from spatialdata._core.spatialdata import SpatialData from spatialdata._types import ArrayLike -from spatialdata.models.models import Labels2DModel, PointsModel, ShapesModel, TableModel +from spatialdata.models.models import Image2DModel, Labels2DModel, PointsModel, ShapesModel, TableModel from spatialdata.transformations.transformations import Scale RNG = default_rng(0) @@ -30,7 +32,8 @@ def _get_bins_data(n: int) -> tuple[ArrayLike, ArrayLike, ArrayLike]: @pytest.mark.parametrize("geometry", ["points", "circles", "squares"]) @pytest.mark.parametrize("value_key", [None, "instance_id", ["gene0", "gene1"]]) -def test_rasterize_bins(geometry: str, value_key: str | list[str] | None): +@pytest.mark.parametrize("return_region_as_labels", [True, False]) +def test_rasterize_bins(geometry: str, value_key: str | list[str] | None, return_region_as_labels: bool): n = 10 data, x, y = _get_bins_data(n) scale = Scale([2.0], axes=("x",)) @@ -64,28 +67,78 @@ def test_rasterize_bins(geometry: str, value_key: str | list[str] | None): col_key="col_index", row_key="row_index", value_key=value_key, + return_region_as_labels=return_region_as_labels, ) points_extent = get_extent(points) raster_extent = get_extent(rasterized) # atol can be set tighter when https://github.com/scverse/spatialdata/issues/165 is addressed assert are_extents_equal(points_extent, raster_extent, atol=2) + # if regions are returned as labels, we can annotate the table with 'rasterized', + # which is a labels layer containing the bins, and then run rasterize_bins again + # but now with return_region_as_labels set to False to get a lazy image. + if return_region_as_labels: + labels_name = "labels" + sdata[labels_name] = rasterized + adata = sdata["table"] + adata.obs["region"] = labels_name + adata.obs["region"] = adata.obs["region"].astype("category") + del adata.uns[TableModel.ATTRS_KEY] + adata = TableModel.parse( + adata, + region=labels_name, + region_key="region", + instance_key="instance_id", + ) + del sdata["table"] + sdata["table"] = adata + # this fails because table already annotated by labels layer + with pytest.raises( + ValueError, + match="bins is already a labels layer that annotates the table 'table'. " + "Consider setting 'return_region_as_labels' to 'False' to create a lazy spatial image.", + ): + _ = rasterize_bins( + sdata, + bins="labels", + table_name="table", + col_key="col_index", + row_key="row_index", + value_key=value_key, + return_region_as_labels=True, + ) + + # but we want to be able to create the lazy raster even if the table is already annotated by a labels layer + rasterized = rasterize_bins( + sdata, + bins="labels", + table_name="table", + col_key="col_index", + row_key="row_index", + value_key=value_key, + return_region_as_labels=False, + ) + raster_extent = get_extent(rasterized) + assert are_extents_equal(points_extent, raster_extent, atol=2) + def test_rasterize_bins_invalid(): - n = 2 - data, x, y = _get_bins_data(n) - points = PointsModel.parse(data) - obs = DataFrame( - data={"region": ["points"] * n * n, "instance_id": np.arange(n * n), "col_index": x, "row_index": y} - ) - table = TableModel.parse( - AnnData(X=RNG.normal(size=(n * n, 2)), obs=obs), - region="points", - region_key="region", - instance_key="instance_id", - ) - sdata = SpatialData.init_from_elements({"points": points}, tables={"table": table}) + def _get_sdata(n: int): + data, x, y = _get_bins_data(n) + points = PointsModel.parse(data) + obs = DataFrame( + data={"region": ["points"] * n * n, "instance_id": np.arange(n * n), "col_index": x, "row_index": y} + ) + table = TableModel.parse( + AnnData(X=RNG.normal(size=(n * n, 2)), obs=obs), + region="points", + region_key="region", + instance_key="instance_id", + ) + return SpatialData.init_from_elements({"points": points}, tables={"table": table}) + # sdata with not enough bins (2*2) to estimate transformation + sdata = _get_sdata(n=2) # not enough points with pytest.raises(ValueError, match="At least 6 bins are needed to estimate the transformation."): _ = rasterize_bins( @@ -98,6 +151,8 @@ def test_rasterize_bins_invalid(): ) # the matrix should be a csc_matrix or a full matrix; in particular not a csr_matrix + sdata = _get_sdata(n=3) + table = sdata.tables["table"] table.X = csr_matrix(table.X) with pytest.raises( ValueError, @@ -120,8 +175,24 @@ def test_rasterize_bins_invalid(): table.obs["region"] = regions with pytest.raises( ValueError, - match="The table should be associated with the specified bins. Found multiple regions annotated by the table: " - "points, shapes.", + match="Found multiple regions annotated by the table: " "points, shapes.", + ): + _ = rasterize_bins( + sdata=sdata, + bins="points", + table_name="table", + col_key="col_index", + row_key="row_index", + value_key="instance_id", + ) + # table annotating wrong element + sdata = _get_sdata(n=3) + table = sdata.tables["table"] + table.obs["region"] = "shapes" + table.obs["region"] = table.obs["region"].astype("category") + with pytest.raises( + ValueError, + match="The table should be associated with the specified bins.", ): _ = rasterize_bins( sdata=sdata, @@ -133,6 +204,8 @@ def test_rasterize_bins_invalid(): ) # region_key should be categorical + sdata = _get_sdata(n=3) + table = sdata.tables["table"] table.obs["region"] = table.obs["region"].astype(str) with pytest.raises(ValueError, match="Please convert `table.obs.*` to a category series to improve performances"): _ = rasterize_bins( @@ -144,11 +217,49 @@ def test_rasterize_bins_invalid(): value_key="instance_id", ) - # the element to rasterize should be a GeoDataFrame or a DaskDataFrame - image = Labels2DModel.parse(RNG.normal(size=(n, n))) + # the element to rasterize should be a GeoDataFrame, a DaskDataFrame or a DataArray holding labels + sdata = _get_sdata(n=3) + with pytest.raises( + ValueError, + match="The bins should be a GeoDataFrame, a DaskDataFrame or a DataArray.", + ): + _ = rasterize_bins( + sdata=sdata, + bins="table", + table_name="table", + col_key="col_index", + row_key="row_index", + value_key="instance_id", + ) + + # if bins is a DataArray it should contain labels + image = Image2DModel.parse(RNG.integers(low=0, high=10, size=(1, 3, 3)), dims=("c", "y", "x")) + del sdata["points"] + sdata["points"] = image + with pytest.raises( + ValueError, + match=re.escape( + f"If bins is a DataArray, it should hold labels. " + f"But found associated dimension containing 'c': {image.dims}." + ), + ): + _ = rasterize_bins( + sdata=sdata, + bins="points", + table_name="table", + col_key="col_index", + row_key="row_index", + value_key="instance_id", + ) + + # if bins is a DataArray, it should hold integers + image = Labels2DModel.parse(RNG.normal(size=(3, 3)), dims=("y", "x")) del sdata["points"] sdata["points"] = image - with pytest.raises(ValueError, match="The bins should be a GeoDataFrame or a DaskDataFrame."): + with pytest.raises( + ValueError, + match=f"If bins is a DataArray, it should hold integers. Found dtype {image.dtype}.", + ): _ = rasterize_bins( sdata=sdata, bins="points",