Skip to content

Commit

Permalink
Merge branch 'main' into cleanup_spatialimage
Browse files Browse the repository at this point in the history
  • Loading branch information
LucaMarconato committed Jun 21, 2024
2 parents 1310b47 + eec6ab3 commit 77fd273
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 52 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ and this project adheres to [Semantic Versioning][].
- Added operation: `to_polygons()` @quentinblampey #560
- Extended `rasterize()` to support all the data types @quentinblampey #566
- Added operation: `rasterize_bins()` @quentinblampey #578
- Added operation: `map_raster()` to apply functions block-wise to raster data @ArneDefauw #588

### Minor

Expand Down
1 change: 1 addition & 0 deletions docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ Operations on `SpatialData` objects.
to_circles
to_polygons
aggregate
map_raster
```

### Operations Utilities
Expand Down
2 changes: 2 additions & 0 deletions src/spatialdata/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
"save_transformations",
"get_dask_backing_files",
"are_extents_equal",
"map_raster",
"deepcopy",
]

Expand All @@ -40,6 +41,7 @@
from spatialdata._core.concatenate import concatenate
from spatialdata._core.data_extent import are_extents_equal, get_extent
from spatialdata._core.operations.aggregate import aggregate
from spatialdata._core.operations.map import map_raster
from spatialdata._core.operations.rasterize import rasterize
from spatialdata._core.operations.rasterize_bins import rasterize_bins
from spatialdata._core.operations.transform import transform
Expand Down
68 changes: 39 additions & 29 deletions src/spatialdata/_core/operations/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from collections.abc import Iterable, Mapping
from types import MappingProxyType
from typing import Any, Callable
from typing import TYPE_CHECKING, Any, Callable

import dask.array as da
from dask.array.overlap import coerce_depth
Expand All @@ -19,7 +19,7 @@
def map_raster(
data: DataArray | DataTree,
func: Callable[[da.Array], da.Array],
fn_kwargs: Mapping[str, Any] = MappingProxyType({}),
func_kwargs: Mapping[str, Any] = MappingProxyType({}),
blockwise: bool = True,
depth: int | tuple[int, ...] | dict[int, int] | None = None,
chunks: tuple[tuple[int, ...], ...] | None = None,
Expand All @@ -29,46 +29,54 @@ def map_raster(
**kwargs: Any,
) -> DataArray:
"""
Apply a function to raster data.
Apply a callable to raster data.
Applies a `func` callable to raster data. If `blockwise` is set to `True`,
distributed processing will be achieved with:
- :func:`dask.array.map_overlap` if `depth` is not `None`
- :func:`dask.array.map_blocks`, if `depth` is `None`
otherwise `func` is applied to the full data.
Parameters
----------
data
The data to process. It can be a `DataArray` or `DataTree`. If it's a `DataTree`,
the function is applied to the first scale (full-resolution data).
The data to process. It can be a :class:`xarray.DataArray` or :class:`datatree.DataTree`.
If it's a `DataTree`, the callable is applied to the first scale (`scale0`, the full-resolution data).
func
The function to apply to the data.
fn_kwargs
Additional keyword arguments to pass to the function `func`.
The callable that is applied to the data.
func_kwargs
Additional keyword arguments to pass to the callable `func`.
blockwise
If `True`, distributed processing will be achieved with `dask.array.map_overlap`/`dask.array.map_blocks`,
otherwise the function is applied to the full data. If `False`, `depth` and `chunks` are ignored.
If `True`, `func` will be distributed with :func:`dask.array.map_overlap` or :func:`dask.array.map_blocks`,
otherwise `func` is applied to the full data. If `False`, `depth`, `chunks` and `kwargs` are ignored.
depth
If not `None`, distributed processing will be achieved with `dask.array.map_overlap`, otherwise with
`dask.array.map_blocks`. Specifies the overlap between chunks, i.e. the number of elements that each chunk
should share with its neighbor chunks. Please see `dask.array.map_overlap` for more information on the accepted
values.
Specifies the overlap between chunks, i.e. the number of elements that each chunk
should share with its neighboring chunks. If not `None`, distributed processing will be achieved with
:func:`dask.array.map_overlap`, otherwise with :func:`dask.array.map_blocks`.
chunks
Passed to `dask.array.map_overlap`/`dask.array.map_blocks` as `chunks`. Ignored if `blockwise` is `False`.
Chunk shape of resulting blocks if the function does not preserve the data shape. If not provided, the resulting
array is assumed to have the same chunk structure as the first input array.
E.g. ( (3,), (100,100), (100,100) ).
Chunk shape of resulting blocks if the callable does not preserve the data shape.
For example, if the input block has `shape: (3,100,100)` and the resulting block after the `map_raster`
call has `shape: (1, 100,100)`, the argument `chunks` should be passed accordingly.
Passed to :func:`dask.array.map_overlap` or :func:`dask.array.map_blocks`. Ignored if `blockwise` is `False`.
c_coords
The channel coordinates for the output data. If not provided, the channel coordinates of the input data are
used. It should be specified if the function changes the number of channels.
used. If the callable `func` is expected to change the number of channel coordinates,
this argument should be provided, otherwise will default to `range(len(output_coords))`.
dims
The dimensions of the output data. If not provided, the dimensions of the input data are used. It must be
specified if the function changes the data dimensions.
E.g. ('c', 'y', 'x').
specified if the callable changes the data dimensions, e.g. `('c', 'y', 'x') -> ('y', 'x')`.
transformations
The transformations of the output data. If not provided, the transformations of the input data are copied to the
output data. It should be specified if the function changes the data transformations.
output data. It should be specified if the callable changes the data transformations.
kwargs
Additional keyword arguments to pass to `dask.array.map_overlap` or `dask.array.map_blocks`.
Additional keyword arguments to pass to :func:`dask.array.map_overlap` or :func:`dask.array.map_blocks`.
Ignored if `blockwise` is set to `False`.
Returns
-------
The processed data as a `DataArray`.
The processed data as a :class:`xarray.DataArray`.
"""
if isinstance(data, DataArray):
arr = data.data
Expand All @@ -85,31 +93,33 @@ def map_raster(
kwargs["chunks"] = chunks

if not blockwise:
arr = func(arr, **fn_kwargs)
arr = func(arr, **func_kwargs)
else:
if depth is not None:
kwargs.setdefault("boundary", "reflect")

if not isinstance(depth, int) and len(depth) != arr.ndim:
raise ValueError(
f"Depth {depth} is provided for {len(depth)} dimensions. "
f"Please (only) provide depth for {arr.ndim} dimensions."
f"Please provide depth for {arr.ndim} dimensions."
)
kwargs["depth"] = coerce_depth(arr.ndim, depth)
map_func = da.map_overlap
else:
map_func = da.map_blocks

arr = map_func(func, arr, **fn_kwargs, **kwargs, dtype=arr.dtype)
arr = map_func(func, arr, **func_kwargs, **kwargs, dtype=arr.dtype)

dims = dims if dims is not None else get_axes_names(data)
if model not in (Labels2DModel, Labels3DModel):
c_coords = c_coords if c_coords is not None else get_channels(data)
if c_coords is None:
c_coords = range(arr.shape[0]) if arr.shape[0] != len(get_channels(data)) else get_channels(data)
else:
c_coords = None
if transformations is None:
d = get_transformation(data, get_all=True)
assert isinstance(d, dict)
if TYPE_CHECKING:
assert isinstance(d, dict)
transformations = d

model_kwargs = {
Expand Down
46 changes: 23 additions & 23 deletions tests/core/operations/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,19 @@ def test_map_raster(sdata_blobs, depth, element_name):
if element_name == "blobs_labels" and depth is not None:
depth = (60, 60)

fn_kwargs = {"parameter": 20}
func_kwargs = {"parameter": 20}
se = map_raster(
sdata_blobs[element_name],
func=_multiply,
fn_kwargs=fn_kwargs,
func_kwargs=func_kwargs,
c_coords=None,
depth=depth,
)

assert isinstance(se, DataArray)
data = sdata_blobs[element_name].data.compute()
res = se.data.compute()
assert np.array_equal(data * fn_kwargs["parameter"], res)
assert np.array_equal(data * func_kwargs["parameter"], res)


@pytest.mark.parametrize(
Expand All @@ -58,27 +58,27 @@ def test_map_raster(sdata_blobs, depth, element_name):
)
def test_map_raster_multiscale(sdata_blobs, depth):
img_layer = "blobs_multiscale_image"
fn_kwargs = {"parameter": 20}
func_kwargs = {"parameter": 20}
se = map_raster(
sdata_blobs[img_layer],
func=_multiply,
fn_kwargs=fn_kwargs,
func_kwargs=func_kwargs,
c_coords=None,
depth=depth,
)

data = sdata_blobs[img_layer]["scale0"]["image"].data.compute()
res = se.data.compute()
assert np.array_equal(data * fn_kwargs["parameter"], res)
assert np.array_equal(data * func_kwargs["parameter"], res)


def test_map_raster_no_blockwise(sdata_blobs):
img_layer = "blobs_image"
fn_kwargs = {"parameter": 20}
func_kwargs = {"parameter": 20}
se = map_raster(
sdata_blobs[img_layer],
func=_multiply,
fn_kwargs=fn_kwargs,
func_kwargs=func_kwargs,
blockwise=False,
c_coords=None,
depth=None,
Expand All @@ -87,17 +87,17 @@ def test_map_raster_no_blockwise(sdata_blobs):
assert isinstance(se, DataArray)
data = sdata_blobs[img_layer].data.compute()
res = se.data.compute()
assert np.array_equal(data * fn_kwargs["parameter"], res)
assert np.array_equal(data * func_kwargs["parameter"], res)


def test_map_raster_output_chunks(sdata_blobs):
depth = 60
fn_kwargs = {"parameter": 20}
func_kwargs = {"parameter": 20}
output_channels = ["test"]
se = map_raster(
sdata_blobs["blobs_image"].chunk((3, 100, 100)),
func=_multiply_alter_c,
fn_kwargs=fn_kwargs,
func_kwargs=func_kwargs,
chunks=(
(1,),
(100 + 2 * depth, 96 + 2 * depth, 60 + 2 * depth),
Expand All @@ -111,12 +111,12 @@ def test_map_raster_output_chunks(sdata_blobs):
assert np.array_equal(np.array(output_channels), se.c.data)
data = sdata_blobs["blobs_image"].data.compute()
res = se.data.compute()
assert np.array_equal(data[0] * fn_kwargs["parameter"], res[0])
assert np.array_equal(data[0] * func_kwargs["parameter"], res[0])


@pytest.mark.parametrize("img_layer", ["blobs_image", "blobs_multiscale_image"])
def test_map_transformation(sdata_blobs, img_layer):
fn_kwargs = {"parameter": 20}
func_kwargs = {"parameter": 20}
target_coordinate_system = "my_other_space0"
transformation = Translation(translation=[10, 12], axes=["y", "x"])

Expand All @@ -126,7 +126,7 @@ def test_map_transformation(sdata_blobs, img_layer):
se = map_raster(
se,
func=_multiply,
fn_kwargs=fn_kwargs,
func_kwargs=func_kwargs,
blockwise=False,
c_coords=None,
depth=None,
Expand All @@ -136,12 +136,12 @@ def test_map_transformation(sdata_blobs, img_layer):

def test_map_squeeze_z(full_sdata):
img_layer = "image3d_numpy"
fn_kwargs = {"parameter": 20}
func_kwargs = {"parameter": 20}

se = map_raster(
full_sdata[img_layer].chunk((3, 2, 64, 64)),
func=_multiply_squeeze_z,
fn_kwargs=fn_kwargs,
func_kwargs=func_kwargs,
chunks=((3,), (64,), (64,)),
drop_axis=1,
c_coords=None,
Expand All @@ -152,18 +152,18 @@ def test_map_squeeze_z(full_sdata):
assert isinstance(se, DataArray)
data = full_sdata[img_layer].data.compute()
res = se.data.compute()
assert np.array_equal(data[:, 0, ...] * fn_kwargs["parameter"], res)
assert np.array_equal(data[:, 0, ...] * func_kwargs["parameter"], res)


def test_map_squeeze_z_fails(full_sdata):
img_layer = "image3d_numpy"
fn_kwargs = {"parameter": 20}
func_kwargs = {"parameter": 20}

with pytest.raises(IndexError):
map_raster(
full_sdata[img_layer].chunk((3, 2, 64, 64)),
func=_multiply_squeeze_z,
fn_kwargs=fn_kwargs,
func_kwargs=func_kwargs,
chunks=((3,), (64,), (64,)),
drop_axis=1,
c_coords=None,
Expand All @@ -176,19 +176,19 @@ def test_invalid_map_raster(sdata_blobs):
map_raster(
sdata_blobs["blobs_points"],
func=_multiply,
fn_kwargs={"parameter": 20},
func_kwargs={"parameter": 20},
c_coords=None,
depth=(0, 60),
)

with pytest.raises(
ValueError,
match=re.escape("Depth (0, 60) is provided for 2 dimensions. Please (only) provide depth for 3 dimensions."),
match=re.escape("Depth (0, 60) is provided for 2 dimensions. Please provide depth for 3 dimensions."),
):
map_raster(
sdata_blobs["blobs_image"],
func=_multiply,
fn_kwargs={"parameter": 20},
func_kwargs={"parameter": 20},
c_coords=None,
depth=(0, 60),
)
Expand All @@ -197,7 +197,7 @@ def test_invalid_map_raster(sdata_blobs):
map_raster(
sdata_blobs["blobs_labels"],
func=_multiply,
fn_kwargs={"parameter": 20},
func_kwargs={"parameter": 20},
c_coords=["c"],
depth=(0, 60, 60),
)

0 comments on commit 77fd273

Please sign in to comment.