Skip to content

Commit

Permalink
[python] Add MultiscaleImage level SpatialData exporter (#3342)
Browse files Browse the repository at this point in the history
Add support for exporting a single resolution level of `MultiscaleImage` to a SpatialData Image2DModel or Image3DModel.
  • Loading branch information
jp-dark authored Nov 20, 2024
1 parent b6e3660 commit b6af5d7
Show file tree
Hide file tree
Showing 2 changed files with 209 additions and 2 deletions.
76 changes: 74 additions & 2 deletions apis/python/src/tiledbsoma/experimental/outgest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@
# Copyright (c) 2024 TileDB, Inc
#
# Licensed under the MIT License.
from typing import Dict, Optional, Tuple
from typing import Dict, Optional, Tuple, Union

import geopandas as gpd
import somacore
import spatialdata as sd
import xarray as xr

from .. import PointCloudDataFrame
from .. import MultiscaleImage, PointCloudDataFrame
from .._constants import SOMA_JOINID
from ._xarray_backend import dense_nd_array_to_data_array


def _convert_axis_names(
Expand Down Expand Up @@ -144,3 +146,73 @@ def to_spatial_data_shapes(
df = gpd.GeoDataFrame(data, geometry=geometry)
df.attrs["transform"] = transforms
return df


def to_spatial_data_image(
image: MultiscaleImage,
level: Optional[Union[str, int]] = None,
*,
scene_id: str,
scene_dim_map: Dict[str, str],
transform: somacore.CoordinateTransform,
) -> xr.DataArray:
"""Export a level of a :class:`MultiscaleImage` to a
:class:`spatialdata.Image2DModel` or :class:`spatialdata.Image3DModel`.
"""
if not image.has_channel_axis:
raise NotImplementedError(
"Support for exporting a MultiscaleImage to without a channel axis to "
"SpatialData is not yet implemented."
)

# Convert from SOMA axis names to SpatialData axis names.
orig_axis_names = image.coordinate_space.axis_names
if len(orig_axis_names) not in {2, 3}:
raise NotImplementedError(
f"Support for converting a '{len(orig_axis_names)}'D is not yet implemented."
)
new_axis_names, image_dim_map = _convert_axis_names(
orig_axis_names, image.data_axis_order
)

# Get the URI of the requested level.
if level is None:
if image.level_count != 1:
raise ValueError(
"The level must be specified for a multiscale image with more than one "
"resolution level."
)
level = 0
level_uri = image.level_uri(level)

# Get the transformtion from the image level to the scene:
# If the result is a single scale transform (or identity transform), output a
# single transformation. Otherwise, convert to a SpatialData sequence of
# transformations.
inv_transform = transform.inverse_transform()
scale_transform = image.get_transform_from_level(level)
if isinstance(transform, somacore.ScaleTransform) or isinstance(
scale_transform, somacore.IdentityTransform
):
# inv_transform @ scale_transform -> applies scale_transform first
sd_transform = _transform_to_spatial_data(
inv_transform @ scale_transform, image_dim_map, scene_dim_map
)
else:
sd_transform1 = _transform_to_spatial_data(
scale_transform, image_dim_map, image_dim_map
)
sd_transform2 = _transform_to_spatial_data(
inv_transform, image_dim_map, scene_dim_map
)
# Sequence([sd_transform1, sd_transform2]) -> applies sd_transform1 first
sd_transform = sd.transformations.Sequence([sd_transform1, sd_transform2])
transformations = {scene_id: sd_transform}

# Return array accessor as a dask array.
return dense_nd_array_to_data_array(
level_uri,
dim_names=new_axis_names,
attrs={"transform": transformations},
context=image.context,
)
135 changes: 135 additions & 0 deletions apis/python/tests/test_export_multiscale_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
from urllib.parse import urljoin

import numpy as np
import pyarrow as pa
import pytest
import somacore

import tiledbsoma as soma

soma_outgest = pytest.importorskip("tiledbsoma.experimental.outgest")
sd = pytest.importorskip("spatialdata")
xr = pytest.importorskip("xarray")


@pytest.fixture(scope="module")
def sample_2d_data():
return [
np.random.randint(0, 255, size=(3, 32, 32), dtype=np.uint8),
np.random.randint(0, 255, size=(3, 16, 16), dtype=np.uint8),
np.random.randint(0, 255, size=(3, 8, 8), dtype=np.uint8),
]


@pytest.fixture(scope="module")
def sample_multiscale_image_2d(tmp_path_factory, sample_2d_data):
# Create the multiscale image.
baseuri = tmp_path_factory.mktemp("export_multiscale_image").as_uri()
image_uri = urljoin(baseuri, "default")
with soma.MultiscaleImage.create(
image_uri,
type=pa.uint8(),
coordinate_space=("x_image", "y_image"),
level_shape=(3, 32, 32),
) as image:
coords = (slice(None), slice(None), slice(None))
# Create levels.
l0 = image["level0"]
l0.write(coords, pa.Tensor.from_numpy(sample_2d_data[0]))

# Create medium sized downsample.
l1 = image.add_new_level("level1", shape=(3, 16, 16))
l1.write(coords, pa.Tensor.from_numpy(sample_2d_data[1]))

# Create very small downsample and write to it.
l2 = image.add_new_level("level2", shape=(3, 8, 8))
l2.write(coords, pa.Tensor.from_numpy(sample_2d_data[2]))
image2d = soma.MultiscaleImage.open(image_uri)
return image2d


@pytest.mark.parametrize(
"level,transform,expected_transformation",
[
(
0,
somacore.IdentityTransform(("x_scene", "y_scene"), ("x_image", "y_image")),
sd.transformations.Identity(),
),
(
2,
somacore.IdentityTransform(("x_scene", "y_scene"), ("x_image", "y_image")),
sd.transformations.Scale([4, 4], ("x", "y")),
),
(
0,
somacore.ScaleTransform(
("x_scene", "y_scene"), ("x_image", "y_image"), [0.25, 0.5]
),
sd.transformations.Scale([4, 2], ("x", "y")),
),
(
2,
somacore.ScaleTransform(
("x_scene", "y_scene"), ("x_image", "y_image"), [0.25, 0.5]
),
sd.transformations.Scale([16, 8], ("x", "y")),
),
(
0,
somacore.AffineTransform(
("x_scene", "y_scene"), ("x_image", "y_image"), [[1, 0, 1], [0, 1, 2]]
),
sd.transformations.Affine(
np.array([[1, 0, -1], [0, 1, -2], [0, 0, 1]]),
("x", "y"),
("x", "y"),
),
),
(
2,
somacore.AffineTransform(
("x_scene", "y_scene"), ("x_image", "y_image"), [[1, 0, 1], [0, 1, 2]]
),
sd.transformations.Sequence(
[
sd.transformations.Scale([4, 4], ("x", "y")),
sd.transformations.Affine(
np.array([[1, 0, -1], [0, 1, -2], [0, 0, 1]]),
("x", "y"),
("x", "y"),
),
]
),
),
],
)
def test_export_image_level_to_spatial_data(
sample_multiscale_image_2d,
sample_2d_data,
level,
transform,
expected_transformation,
):
image2d = soma_outgest.to_spatial_data_image(
sample_multiscale_image_2d,
level=level,
scene_id="scene0",
scene_dim_map={"x_scene": "x", "y_scene": "y"},
transform=transform,
)

assert isinstance(image2d, xr.DataArray)

# Validate the model.
schema = sd.models.get_model(image2d)
assert schema == sd.models.Image2DModel

# Check the correct data exists.
result = image2d.data.compute()
np.testing.assert_equal(result, sample_2d_data[level])

# Check the metadata.
metadata = dict(image2d.attrs)
assert len(metadata) == 1
assert metadata["transform"] == {"scene0": expected_transformation}

0 comments on commit b6af5d7

Please sign in to comment.