diff --git a/docs/api.md b/docs/api.md index bbe05f8d..48632a6c 100644 --- a/docs/api.md +++ b/docs/api.md @@ -41,6 +41,7 @@ Operations on `SpatialData` objects. to_polygons aggregate map_raster + relabel_sequential ``` ### Operations Utilities diff --git a/src/spatialdata/__init__.py b/src/spatialdata/__init__.py index 8b0664ab..01f6c27f 100644 --- a/src/spatialdata/__init__.py +++ b/src/spatialdata/__init__.py @@ -48,6 +48,7 @@ "save_transformations", "get_dask_backing_files", "are_extents_equal", + "relabel_sequential", "map_raster", "deepcopy", ] @@ -58,7 +59,7 @@ from spatialdata._core.concatenate import concatenate from spatialdata._core.data_extent import are_extents_equal, get_extent from spatialdata._core.operations.aggregate import aggregate -from spatialdata._core.operations.map import map_raster +from spatialdata._core.operations.map import map_raster, relabel_sequential from spatialdata._core.operations.rasterize import rasterize from spatialdata._core.operations.rasterize_bins import rasterize_bins from spatialdata._core.operations.transform import transform diff --git a/src/spatialdata/_core/operations/map.py b/src/spatialdata/_core/operations/map.py index 2ba70401..0f5a380a 100644 --- a/src/spatialdata/_core/operations/map.py +++ b/src/spatialdata/_core/operations/map.py @@ -1,17 +1,22 @@ from __future__ import annotations +import math +import operator from collections.abc import Callable, Iterable, Mapping +from functools import reduce from types import MappingProxyType from typing import TYPE_CHECKING, Any import dask.array as da +import numpy as np from dask.array.overlap import coerce_depth from xarray import DataArray, DataTree +from spatialdata._types import IntArrayLike from spatialdata.models._utils import get_axes_names, get_channel_names, get_raster_model_from_data_dims from spatialdata.transformations import get_transformation -__all__ = ["map_raster"] +__all__ = ["map_raster", "relabel_sequential"] def map_raster( @@ -24,6 +29,7 @@ def map_raster( c_coords: Iterable[int] | Iterable[str] | None = None, dims: tuple[str, ...] | None = None, transformations: dict[str, Any] | None = None, + relabel: bool = True, **kwargs: Any, ) -> DataArray: """ @@ -68,6 +74,13 @@ def map_raster( transformations The transformations of the output data. If not provided, the transformations of the input data are copied to the output data. It should be specified if the callable changes the data transformations. + relabel + Whether to relabel the blocks of the output data. + This option is ignored when the output data is not a labels layer (i.e., when `dims` does not contain `c`). + It is recommended to enable relabeling if `func` returns labels that are not unique across chunks. + Relabeling will be done by performing a bit shift. When a cell or entity to be labeled is split between two + adjacent chunks, the current implementation does not assign the same label across blocks. + See https://github.com/scverse/spatialdata/pull/664 for discussion. kwargs Additional keyword arguments to pass to :func:`dask.array.map_overlap` or :func:`dask.array.map_blocks`. Ignored if `blockwise` is set to `False`. @@ -130,6 +143,9 @@ def map_raster( assert isinstance(d, dict) transformations = d + if "c" not in dims and relabel: + arr = _relabel(arr) + model_kwargs = { "chunks": arr.chunksize, "c_coords": c_coords, @@ -138,3 +154,100 @@ def map_raster( } model = get_raster_model_from_data_dims(dims) return model.parse(arr, **model_kwargs) + + +def _relabel(arr: da.Array) -> da.Array: + if not np.issubdtype(arr.dtype, np.integer): + raise ValueError(f"Relabeling is only supported for arrays of type {np.integer}.") + num_blocks = arr.numblocks + + shift = (math.prod(num_blocks) - 1).bit_length() + + meta = np.empty((0,) * arr.ndim, dtype=arr.dtype) + + def _relabel_block( + block: IntArrayLike, block_id: tuple[int, ...], num_blocks: tuple[int, ...], shift: int + ) -> IntArrayLike: + def _calculate_block_num(block_id: tuple[int, ...], num_blocks: tuple[int, ...]) -> int: + if len(num_blocks) != len(block_id): + raise ValueError("num_blocks and block_id must have the same length") + block_num = 0 + for i in range(len(num_blocks)): + multiplier = reduce(operator.mul, num_blocks[i + 1 :], 1) + block_num += block_id[i] * multiplier + return block_num + + available_bits = np.iinfo(block.dtype).max.bit_length() + max_bits_block = int(block.max()).bit_length() + + if max_bits_block + shift > available_bits: + # Note: because of no harmonization across blocks, adjusting number of chunks lowers the required bits. + raise ValueError( + f"Relabel was set to True, but " + f"the number of bits required to represent the labels in the block ({max_bits_block}) " + f"+ required shift ({shift}) exceeds the available_bits ({available_bits}). In other words" + f"the number of labels exceeds the number of integers that can be represented by the dtype" + "of the individual blocks." + "To solve this issue, please consider the following solutions:" + " 1. Rechunking using a larger chunk size, lowering the number of blocks and thereby" + " lowering the value of required shift." + " 2. Cast to a data type with a higher maximum value " + " 3. Perform sequential relabeling of the dask array using `relabel_sequential` in `spatialdata`," + " potentially lowering the maximum value of a label (though number of distinct labels values " + " stays the same). For example if the unique labels values are `[0, 1, 1000]`, after the " + " sequential relabeling the unique labels value will be `[0, 1, 2]`, thus requiring less bits " + " to store the labels." + ) + + block_num = _calculate_block_num(block_id=block_id, num_blocks=num_blocks) + + mask = block > 0 + block[mask] = (block[mask] << shift) | block_num + + return block + + return da.map_blocks( + _relabel_block, + arr, + dtype=arr.dtype, + num_blocks=num_blocks, + shift=shift, + meta=meta, + ) + + +def relabel_sequential(arr: da.Array) -> da.Array: + """ + Relabels integers in a Dask array sequentially. + + This function assigns sequential labels to the integers in a Dask array starting from 1. + For example, if the unique values in the input array are [0, 9, 5], + they will be relabeled to [0, 1, 2] respectively. + Note that currently if a cell or entity to be labeled is split across adjacent chunks the same label is not + assigned to the cell across blocks. See discussion https://github.com/scverse/spatialdata/pull/664. + + Parameters + ---------- + arr + input array. + + Returns + ------- + The relabeled array. + """ + if not np.issubdtype(arr.dtype, np.integer): + raise ValueError(f"Sequential relabeling is only supported for arrays of type {np.integer}.") + + unique_labels = da.unique(arr).compute() + if 0 not in unique_labels: + # otherwise first non zero label would be relabeled to 0 + unique_labels = np.insert(unique_labels, 0, 0) + + max_label = unique_labels[-1] + + new_labeling = da.full(max_label + 1, -1, dtype=arr.dtype) + + # Note that both sides are ordered as da.unique returns an ordered array. + new_labeling[unique_labels] = da.arange(len(unique_labels), dtype=arr.dtype) + + return da.map_blocks(operator.getitem, new_labeling, arr, dtype=arr.dtype, chunks=arr.chunks) diff --git a/src/spatialdata/_types.py b/src/spatialdata/_types.py index c0f22b32..3c527163 100644 --- a/src/spatialdata/_types.py +++ b/src/spatialdata/_types.py @@ -7,8 +7,11 @@ from numpy.typing import DTypeLike, NDArray ArrayLike = NDArray[np.float64] + IntArrayLike = NDArray[np.int64] # or any np.integer + except (ImportError, TypeError): ArrayLike = np.ndarray # type: ignore[misc] + IntArrayLike = np.ndarray # type: ignore[misc] DTypeLike = np.dtype # type: ignore[misc, assignment] Raster_T = DataArray | DataTree diff --git a/tests/core/operations/test_map.py b/tests/core/operations/test_map.py index 71b34d70..b3fd3165 100644 --- a/tests/core/operations/test_map.py +++ b/tests/core/operations/test_map.py @@ -1,10 +1,12 @@ +import math import re +import dask.array as da import numpy as np import pytest from xarray import DataArray -from spatialdata._core.operations.map import map_raster +from spatialdata._core.operations.map import map_raster, relabel_sequential from spatialdata.transformations import Translation, get_transformation, set_transformation @@ -28,6 +30,11 @@ def _multiply_to_labels(arr, parameter=10): return arr[0].astype(np.int32) +def _to_constant(arr, constant): + arr[arr > 0] = constant + return arr + + @pytest.mark.parametrize( "depth", [ @@ -47,6 +54,7 @@ def test_map_raster(sdata_blobs, depth, element_name): func_kwargs=func_kwargs, c_coords=None, depth=depth, + relabel=False, ) assert isinstance(se, DataArray) @@ -162,6 +170,7 @@ def test_map_to_labels_(sdata_blobs, blockwise, chunks, drop_axis): chunks=chunks, drop_axis=drop_axis, dims=("y", "x"), + relabel=False, ) data = sdata_blobs[img_layer].data.compute() @@ -249,3 +258,117 @@ def test_invalid_map_raster(sdata_blobs): c_coords=["c"], depth=(0, 60, 60), ) + + +def test_map_raster_relabel(sdata_blobs): + constant = 2047 + func_kwargs = {"constant": constant} + + element_name = "blobs_labels" + se = map_raster( + sdata_blobs[element_name].chunk((100, 100)), + func=_to_constant, + func_kwargs=func_kwargs, + c_coords=None, + depth=None, + relabel=True, + ) + + # check if labels in different blocks are all mapped to a different value + assert isinstance(se, DataArray) + se.data.compute() + a = set() + for chunk in se.data.to_delayed().flatten(): + chunk = chunk.compute() + b = set(np.unique(chunk)) + b.remove(0) + assert not b.intersection(a) + a.update(b) + # 9 blocks, each block contains 'constant' left shifted by (9-1).bit_length() + block_num. + shift = (math.prod(se.data.numblocks) - 1).bit_length() + assert a == set(range(constant << shift, (constant << shift) + math.prod(se.data.numblocks))) + + +def test_map_raster_relabel_fail(sdata_blobs): + constant = 2048 + func_kwargs = {"constant": constant} + + element_name = "blobs_labels" + + # Testing the case of having insufficient number of bits. + with pytest.raises( + ValueError, + match=re.escape("Relabel was set to True, but"), + ): + se = map_raster( + sdata_blobs[element_name].chunk((100, 100)), + func=_to_constant, + func_kwargs=func_kwargs, + c_coords=None, + depth=None, + relabel=True, + ) + + se.data.compute() + + constant = 2047 + func_kwargs = {"constant": constant} + + element_name = "blobs_labels" + with pytest.raises( + ValueError, + match=re.escape(f"Relabeling is only supported for arrays of type {np.integer}."), + ): + se = map_raster( + sdata_blobs[element_name].astype(float).chunk((100, 100)), + func=_to_constant, + func_kwargs=func_kwargs, + c_coords=None, + depth=None, + relabel=True, + ) + + +def test_relabel_sequential(sdata_blobs): + def _is_sequential(arr): + if arr.ndim != 1: + raise ValueError("Input array must be one-dimensional") + sorted_arr = np.sort(arr) + expected_sequence = np.arange(sorted_arr[0], sorted_arr[0] + len(sorted_arr)) + return np.array_equal(sorted_arr, expected_sequence) + + arr = sdata_blobs["blobs_labels"].data.rechunk(100) + + arr_relabeled = relabel_sequential(arr) + + labels_relabeled = da.unique(arr_relabeled).compute() + labels_original = da.unique(arr).compute() + + assert labels_relabeled.shape == labels_original.shape + assert _is_sequential(labels_relabeled) + + # test some edge cases + arr = da.asarray(np.array([0])) + assert np.array_equal(relabel_sequential(arr).compute(), np.array([0])) + + arr = da.asarray(np.array([1])) + assert np.array_equal(relabel_sequential(arr).compute(), np.array([1])) + + arr = da.asarray(np.array([2])) + assert np.array_equal(relabel_sequential(arr).compute(), np.array([1])) + + arr = da.asarray(np.array([2, 0])) + assert np.array_equal(relabel_sequential(arr).compute(), np.array([1, 0])) + + arr = da.asarray(np.array([0, 9, 5])) + assert np.array_equal(relabel_sequential(arr).compute(), np.array([0, 2, 1])) + + arr = da.asarray(np.array([4, 1, 3])) + assert np.array_equal(relabel_sequential(arr).compute(), np.array([3, 1, 2])) + + +def test_relabel_sequential_fails(sdata_blobs): + with pytest.raises( + ValueError, match=re.escape(f"Sequential relabeling is only supported for arrays of type {np.integer}.") + ): + relabel_sequential(sdata_blobs["blobs_labels"].data.astype(float))