Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

relabel block #664

Merged
merged 28 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
91518e0
relabel block
ArneDefauw Aug 7, 2024
206c445
dtype check array
ArneDefauw Aug 7, 2024
ef933d8
fix mypy
ArneDefauw Aug 7, 2024
4d0ea37
type IntArray
ArneDefauw Aug 8, 2024
633a132
add sequential relabeling helper function
ArneDefauw Aug 8, 2024
fc7c4ac
Merge branch 'main' into map_raster_relabeling
giovp Sep 2, 2024
a5ac4ad
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 2, 2024
cad5ce3
reintroduce DataArray
giovp Sep 2, 2024
5b0d3db
Merge branch 'main' into map_raster_relabeling
melonora Nov 27, 2024
fa95e0d
revert import removal
melonora Nov 27, 2024
e7cbc83
Update src/spatialdata/_core/operations/map.py
ArneDefauw Nov 27, 2024
599fddf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 27, 2024
57c2caa
adjusted error message
melonora Nov 27, 2024
b62c0da
Merge branch 'map_raster_relabeling' of https://github.com/ArneDefauw…
melonora Nov 27, 2024
fc31fe4
fix test
melonora Nov 27, 2024
a57e23b
make relabel_sequential public
melonora Nov 27, 2024
e81f66d
fix test
melonora Nov 27, 2024
23a7778
adjust docstring
melonora Nov 27, 2024
8a5b2e3
adjust docstring, add comment
melonora Nov 27, 2024
cf8f9fe
IntArray fix type
melonora Nov 27, 2024
24a9fbb
add comment
melonora Nov 27, 2024
4340d9c
add edge cases to test
melonora Nov 27, 2024
bfc4d74
Merge branch 'main' into map_raster_relabeling
melonora Nov 27, 2024
bbf3c63
remove default arg
melonora Nov 27, 2024
0dba194
Merge branch 'map_raster_relabeling' of https://github.com/ArneDefauw…
melonora Nov 27, 2024
afb7bd8
change docstring
LucaMarconato Nov 27, 2024
301655a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 27, 2024
4f06d59
minor renaming
LucaMarconato Nov 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ Operations on `SpatialData` objects.
to_polygons
aggregate
map_raster
relabel_sequential
```

### Operations Utilities
Expand Down
3 changes: 2 additions & 1 deletion src/spatialdata/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
"save_transformations",
"get_dask_backing_files",
"are_extents_equal",
"relabel_sequential",
"map_raster",
"deepcopy",
]
Expand All @@ -58,7 +59,7 @@
from spatialdata._core.concatenate import concatenate
from spatialdata._core.data_extent import are_extents_equal, get_extent
from spatialdata._core.operations.aggregate import aggregate
from spatialdata._core.operations.map import map_raster
from spatialdata._core.operations.map import map_raster, relabel_sequential
from spatialdata._core.operations.rasterize import rasterize
from spatialdata._core.operations.rasterize_bins import rasterize_bins
from spatialdata._core.operations.transform import transform
Expand Down
115 changes: 114 additions & 1 deletion src/spatialdata/_core/operations/map.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
from __future__ import annotations

import math
import operator
from collections.abc import Callable, Iterable, Mapping
from functools import reduce
from types import MappingProxyType
from typing import TYPE_CHECKING, Any

import dask.array as da
import numpy as np
from dask.array.overlap import coerce_depth
from xarray import DataArray, DataTree

from spatialdata._types import IntArrayLike
from spatialdata.models._utils import get_axes_names, get_channel_names, get_raster_model_from_data_dims
from spatialdata.transformations import get_transformation

__all__ = ["map_raster"]
__all__ = ["map_raster", "relabel_sequential"]


def map_raster(
Expand All @@ -24,6 +29,7 @@ def map_raster(
c_coords: Iterable[int] | Iterable[str] | None = None,
dims: tuple[str, ...] | None = None,
transformations: dict[str, Any] | None = None,
relabel: bool = True,
**kwargs: Any,
) -> DataArray:
"""
Expand Down Expand Up @@ -68,6 +74,13 @@ def map_raster(
transformations
The transformations of the output data. If not provided, the transformations of the input data are copied to the
output data. It should be specified if the callable changes the data transformations.
relabel
Whether to relabel the blocks of the output data.
This option is ignored when the output data is not a labels layer (i.e., when `dims` does not contain `c`).
It is recommended to enable relabeling if `func` returns labels that are not unique across chunks.
Relabeling will be done by performing a bit shift. When a cell or entity to be labeled is split between two
adjacent chunks, the current implementation does not assign the same label across blocks.
See https://github.com/scverse/spatialdata/pull/664 for discussion.
kwargs
Additional keyword arguments to pass to :func:`dask.array.map_overlap` or :func:`dask.array.map_blocks`.
Ignored if `blockwise` is set to `False`.
Expand Down Expand Up @@ -130,6 +143,9 @@ def map_raster(
assert isinstance(d, dict)
transformations = d

if "c" not in dims and relabel:
arr = _relabel(arr)

model_kwargs = {
"chunks": arr.chunksize,
"c_coords": c_coords,
Expand All @@ -138,3 +154,100 @@ def map_raster(
}
model = get_raster_model_from_data_dims(dims)
return model.parse(arr, **model_kwargs)


def _relabel(arr: da.Array) -> da.Array:
if not np.issubdtype(arr.dtype, np.integer):
raise ValueError(f"Relabeling is only supported for arrays of type {np.integer}.")
num_blocks = arr.numblocks

shift = (math.prod(num_blocks) - 1).bit_length()

meta = np.empty((0,) * arr.ndim, dtype=arr.dtype)

def _relabel_block(
block: IntArrayLike, block_id: tuple[int, ...], num_blocks: tuple[int, ...], shift: int
) -> IntArrayLike:
def _calculate_block_num(block_id: tuple[int, ...], num_blocks: tuple[int, ...]) -> int:
if len(num_blocks) != len(block_id):
raise ValueError("num_blocks and block_id must have the same length")
block_num = 0
for i in range(len(num_blocks)):
multiplier = reduce(operator.mul, num_blocks[i + 1 :], 1)
block_num += block_id[i] * multiplier
return block_num

available_bits = np.iinfo(block.dtype).max.bit_length()
max_bits_block = int(block.max()).bit_length()

if max_bits_block + shift > available_bits:
# Note: because of no harmonization across blocks, adjusting number of chunks lowers the required bits.
raise ValueError(
f"Relabel was set to True, but "
f"the number of bits required to represent the labels in the block ({max_bits_block}) "
f"+ required shift ({shift}) exceeds the available_bits ({available_bits}). In other words"
f"the number of labels exceeds the number of integers that can be represented by the dtype"
"of the individual blocks."
"To solve this issue, please consider the following solutions:"
" 1. Rechunking using a larger chunk size, lowering the number of blocks and thereby"
" lowering the value of required shift."
" 2. Cast to a data type with a higher maximum value "
" 3. Perform sequential relabeling of the dask array using `relabel_sequential` in `spatialdata`,"
" potentially lowering the maximum value of a label (though number of distinct labels values "
" stays the same). For example if the unique labels values are `[0, 1, 1000]`, after the "
" sequential relabeling the unique labels value will be `[0, 1, 2]`, thus requiring less bits "
" to store the labels."
)

block_num = _calculate_block_num(block_id=block_id, num_blocks=num_blocks)

mask = block > 0
block[mask] = (block[mask] << shift) | block_num

return block

return da.map_blocks(
_relabel_block,
arr,
dtype=arr.dtype,
num_blocks=num_blocks,
shift=shift,
meta=meta,
)


def relabel_sequential(arr: da.Array) -> da.Array:
"""
Relabels integers in a Dask array sequentially.

This function assigns sequential labels to the integers in a Dask array starting from 1.
For example, if the unique values in the input array are [0, 9, 5],
they will be relabeled to [0, 1, 2] respectively.
Note that currently if a cell or entity to be labeled is split across adjacent chunks the same label is not
assigned to the cell across blocks. See discussion https://github.com/scverse/spatialdata/pull/664.

Parameters
----------
arr
input array.

Returns
-------
The relabeled array.
"""
if not np.issubdtype(arr.dtype, np.integer):
raise ValueError(f"Sequential relabeling is only supported for arrays of type {np.integer}.")

unique_labels = da.unique(arr).compute()
melonora marked this conversation as resolved.
Show resolved Hide resolved
if 0 not in unique_labels:
# otherwise first non zero label would be relabeled to 0
unique_labels = np.insert(unique_labels, 0, 0)

max_label = unique_labels[-1]

new_labeling = da.full(max_label + 1, -1, dtype=arr.dtype)

# Note that both sides are ordered as da.unique returns an ordered array.
new_labeling[unique_labels] = da.arange(len(unique_labels), dtype=arr.dtype)

return da.map_blocks(operator.getitem, new_labeling, arr, dtype=arr.dtype, chunks=arr.chunks)
3 changes: 3 additions & 0 deletions src/spatialdata/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@
from numpy.typing import DTypeLike, NDArray

ArrayLike = NDArray[np.float64]
IntArrayLike = NDArray[np.int64] # or any np.integer

except (ImportError, TypeError):
ArrayLike = np.ndarray # type: ignore[misc]
IntArrayLike = np.ndarray # type: ignore[misc]
DTypeLike = np.dtype # type: ignore[misc, assignment]

Raster_T = DataArray | DataTree
Expand Down
125 changes: 124 additions & 1 deletion tests/core/operations/test_map.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import math
import re

import dask.array as da
import numpy as np
import pytest
from xarray import DataArray

from spatialdata._core.operations.map import map_raster
from spatialdata._core.operations.map import map_raster, relabel_sequential
from spatialdata.transformations import Translation, get_transformation, set_transformation


Expand All @@ -28,6 +30,11 @@ def _multiply_to_labels(arr, parameter=10):
return arr[0].astype(np.int32)


def _to_constant(arr, constant):
arr[arr > 0] = constant
return arr


@pytest.mark.parametrize(
"depth",
[
Expand All @@ -47,6 +54,7 @@ def test_map_raster(sdata_blobs, depth, element_name):
func_kwargs=func_kwargs,
c_coords=None,
depth=depth,
relabel=False,
)

assert isinstance(se, DataArray)
Expand Down Expand Up @@ -162,6 +170,7 @@ def test_map_to_labels_(sdata_blobs, blockwise, chunks, drop_axis):
chunks=chunks,
drop_axis=drop_axis,
dims=("y", "x"),
relabel=False,
)

data = sdata_blobs[img_layer].data.compute()
Expand Down Expand Up @@ -249,3 +258,117 @@ def test_invalid_map_raster(sdata_blobs):
c_coords=["c"],
depth=(0, 60, 60),
)


def test_map_raster_relabel(sdata_blobs):
constant = 2047
func_kwargs = {"constant": constant}

element_name = "blobs_labels"
se = map_raster(
sdata_blobs[element_name].chunk((100, 100)),
func=_to_constant,
func_kwargs=func_kwargs,
c_coords=None,
depth=None,
relabel=True,
)

# check if labels in different blocks are all mapped to a different value
assert isinstance(se, DataArray)
se.data.compute()
a = set()
for chunk in se.data.to_delayed().flatten():
chunk = chunk.compute()
b = set(np.unique(chunk))
b.remove(0)
assert not b.intersection(a)
a.update(b)
# 9 blocks, each block contains 'constant' left shifted by (9-1).bit_length() + block_num.
shift = (math.prod(se.data.numblocks) - 1).bit_length()
assert a == set(range(constant << shift, (constant << shift) + math.prod(se.data.numblocks)))


def test_map_raster_relabel_fail(sdata_blobs):
constant = 2048
func_kwargs = {"constant": constant}

element_name = "blobs_labels"

# Testing the case of having insufficient number of bits.
with pytest.raises(
melonora marked this conversation as resolved.
Show resolved Hide resolved
ValueError,
match=re.escape("Relabel was set to True, but"),
):
se = map_raster(
sdata_blobs[element_name].chunk((100, 100)),
func=_to_constant,
func_kwargs=func_kwargs,
c_coords=None,
depth=None,
relabel=True,
)

se.data.compute()

constant = 2047
func_kwargs = {"constant": constant}

element_name = "blobs_labels"
with pytest.raises(
ValueError,
match=re.escape(f"Relabeling is only supported for arrays of type {np.integer}."),
):
se = map_raster(
sdata_blobs[element_name].astype(float).chunk((100, 100)),
func=_to_constant,
func_kwargs=func_kwargs,
c_coords=None,
depth=None,
relabel=True,
)


def test_relabel_sequential(sdata_blobs):
def _is_sequential(arr):
if arr.ndim != 1:
raise ValueError("Input array must be one-dimensional")
sorted_arr = np.sort(arr)
expected_sequence = np.arange(sorted_arr[0], sorted_arr[0] + len(sorted_arr))
return np.array_equal(sorted_arr, expected_sequence)

arr = sdata_blobs["blobs_labels"].data.rechunk(100)

arr_relabeled = relabel_sequential(arr)

labels_relabeled = da.unique(arr_relabeled).compute()
labels_original = da.unique(arr).compute()

assert labels_relabeled.shape == labels_original.shape
assert _is_sequential(labels_relabeled)

# test some edge cases
arr = da.asarray(np.array([0]))
assert np.array_equal(relabel_sequential(arr).compute(), np.array([0]))

arr = da.asarray(np.array([1]))
assert np.array_equal(relabel_sequential(arr).compute(), np.array([1]))

arr = da.asarray(np.array([2]))
assert np.array_equal(relabel_sequential(arr).compute(), np.array([1]))

arr = da.asarray(np.array([2, 0]))
assert np.array_equal(relabel_sequential(arr).compute(), np.array([1, 0]))

arr = da.asarray(np.array([0, 9, 5]))
assert np.array_equal(relabel_sequential(arr).compute(), np.array([0, 2, 1]))

arr = da.asarray(np.array([4, 1, 3]))
assert np.array_equal(relabel_sequential(arr).compute(), np.array([3, 1, 2]))


def test_relabel_sequential_fails(sdata_blobs):
with pytest.raises(
ValueError, match=re.escape(f"Sequential relabeling is only supported for arrays of type {np.integer}.")
):
relabel_sequential(sdata_blobs["blobs_labels"].data.astype(float))
Loading