Skip to content

Commit

Permalink
Merge pull request #145 from kevinyamauchi/torch-dataloader
Browse files Browse the repository at this point in the history
[WIP] torch DataSet + utils
  • Loading branch information
LucaMarconato authored Mar 14, 2023
2 parents 9975a5c + 41d818a commit ff458c8
Show file tree
Hide file tree
Showing 17 changed files with 412 additions and 191 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ default_stages:
- push
minimum_pre_commit_version: 2.16.0
ci:
skip: [mypy]
skip: []
repos:
- repo: https://github.com/psf/black
rev: 23.1.0
Expand All @@ -28,7 +28,7 @@ repos:
rev: v1.1.1
hooks:
- id: mypy
additional_dependencies: [numpy==1.22.0, types-requests]
additional_dependencies: [numpy==1.24, types-requests]
exclude: tests/|docs/|temp/|spatialdata/_core/reader.py
- repo: https://github.com/asottile/yesqa
rev: v1.4.0
Expand Down
3 changes: 0 additions & 3 deletions examples/dev-examples/README.md

This file was deleted.

115 changes: 0 additions & 115 deletions examples/dev-examples/spatial_query_and_rasterization.py

This file was deleted.

3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ test = [
"pytest",
"pytest-cov",
]
optional = [
"torch"
]

[tool.coverage.run]
source = ["spatialdata"]
Expand Down
8 changes: 8 additions & 0 deletions spatialdata/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from importlib.metadata import version
from typing import Union

__version__ = version("spatialdata")

Expand Down Expand Up @@ -40,3 +41,10 @@
TableModel,
)
from spatialdata._io.read import read_zarr

try:
from spatialdata._dataloader.datasets import ImageTilesDataset
except ImportError as e:
_error: Union[str, None] = str(e)
else:
_error = None
108 changes: 76 additions & 32 deletions spatialdata/_core/_rasterize.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def _get_xarray_data_to_rasterize(
min_coordinate: Union[list[Number], ArrayLike],
max_coordinate: Union[list[Number], ArrayLike],
target_sizes: dict[str, Optional[float]],
corrected_affine: Affine,
target_coordinate_system: str,
) -> tuple[DataArray, Optional[Scale]]:
"""
Returns the DataArray to rasterize along with its eventual scale factor (if from a pyramid level) from either a
Expand Down Expand Up @@ -289,6 +289,11 @@ def _get_xarray_data_to_rasterize(
xdata = next(iter(v))
assert set(get_spatial_axes(tuple(xdata.sizes.keys()))) == set(axes)

corrected_affine, _ = _get_corrected_affine_matrix(
data=SpatialImage(xdata),
axes=axes,
target_coordinate_system=target_coordinate_system,
)
m = corrected_affine.inverse().matrix # type: ignore[attr-defined]
m_linear = m[:-1, :-1]
m_translation = m[:-1, -1]
Expand All @@ -300,7 +305,7 @@ def _get_xarray_data_to_rasterize(
assert tuple(bb_corners.axis.data.tolist()) == axes
bb_in_xdata = bb_corners.data @ m_linear + m_translation
bb_in_xdata_sizes = {
ax: bb_in_xdata[axes.index(ax)].max() - bb_in_xdata[axes.index(ax)].min() for ax in axes
ax: bb_in_xdata[:, axes.index(ax)].max() - bb_in_xdata[:, axes.index(ax)].min() for ax in axes
}
for ax in axes:
# TLDR; the sqrt selects a pyramid level in which the requested bounding box is a bit larger than the
Expand All @@ -311,9 +316,10 @@ def _get_xarray_data_to_rasterize(
# inverse-transformed bounding box. The sqrt comes from the ratio of the side of a square,
# and the maximum diagonal of a square containing the original square, if the original square is
# rotated.
if bb_in_xdata_sizes[ax] * np.sqrt(len(axes)) < target_sizes[ax]:
if bb_in_xdata_sizes[ax] < target_sizes[ax] * np.sqrt(len(axes)):
break
else:
# when this code is reached, latest_scale is selected
break
assert latest_scale is not None
xdata = next(iter(data[latest_scale].values()))
Expand All @@ -327,6 +333,45 @@ def _get_xarray_data_to_rasterize(
return xdata, pyramid_scale


def _get_corrected_affine_matrix(
data: Union[SpatialImage, MultiscaleSpatialImage],
axes: tuple[str, ...],
target_coordinate_system: str,
) -> tuple[Affine, tuple[str, ...]]:
"""
Get the affine matrix that maps the intrinsic coordinates of the data to the target_coordinate_system,
with in addition:
- restricting the domain to the axes specified in axes (i.e. the axes for which the bounding box is specified), in
particular axes never contains c;
- restricting the codomain to the spatial axes of the target coordinate system (i.e. excluding c).
We do this because:
- we don't need to consider c
- when we create the target rasterized object, we need to have axes in the order that is requires by the schema
"""
transformation = get_transformation(data, target_coordinate_system)
assert isinstance(transformation, BaseTransformation)
affine = _get_affine_for_element(data, transformation)
target_axes_unordered = affine.output_axes
assert set(target_axes_unordered) in [{"x", "y", "z"}, {"x", "y"}, {"c", "x", "y", "z"}, {"c", "x", "y"}]
target_axes: tuple[str, ...]
if "z" in target_axes_unordered:
if "c" in target_axes_unordered:
target_axes = ("c", "z", "y", "x")
else:
target_axes = ("z", "y", "x")
else:
if "c" in target_axes_unordered:
target_axes = ("c", "y", "x")
else:
target_axes = ("y", "x")
target_spatial_axes = get_spatial_axes(target_axes)
assert len(target_spatial_axes) == len(axes)
assert len(target_spatial_axes) == len(axes)
corrected_affine = affine.to_affine(input_axes=axes, output_axes=target_spatial_axes)
return corrected_affine, target_axes


@rasterize.register(SpatialImage)
@rasterize.register(MultiscaleSpatialImage)
def _(
Expand Down Expand Up @@ -359,29 +404,6 @@ def _(
"z": target_depth,
}

# get inverse transformation
transformation = get_transformation(data, target_coordinate_system)
dims = get_dims(data)
assert isinstance(transformation, BaseTransformation)
affine = _get_affine_for_element(data, transformation)
target_axes_unordered = affine.output_axes
assert set(target_axes_unordered) in [{"x", "y", "z"}, {"x", "y"}, {"c", "x", "y", "z"}, {"c", "x", "y"}]
target_axes: tuple[str, ...]
if "z" in target_axes_unordered:
if "c" in target_axes_unordered:
target_axes = ("c", "z", "y", "x")
else:
target_axes = ("z", "y", "x")
else:
if "c" in target_axes_unordered:
target_axes = ("c", "y", "x")
else:
target_axes = ("y", "x")
target_spatial_axes = get_spatial_axes(target_axes)
assert len(target_spatial_axes) == len(min_coordinate)
assert len(target_spatial_axes) == len(max_coordinate)
corrected_affine = affine.to_affine(input_axes=axes, output_axes=target_spatial_axes)

bb_sizes = {ax: max_coordinate[axes.index(ax)] - min_coordinate[axes.index(ax)] for ax in axes}
scale_vector = [bb_sizes[ax] / target_sizes[ax] for ax in axes]
scale = Scale(scale_vector, axes=axes)
Expand All @@ -395,25 +417,33 @@ def _(
min_coordinate=min_coordinate,
max_coordinate=max_coordinate,
target_sizes=target_sizes,
corrected_affine=corrected_affine,
target_coordinate_system=target_coordinate_system,
)

if pyramid_scale is not None:
extra = [pyramid_scale.inverse()]
else:
extra = []

# get inverse transformation
corrected_affine, target_axes = _get_corrected_affine_matrix(
data=data,
axes=axes,
target_coordinate_system=target_coordinate_system,
)

half_pixel_offset = Translation([0.5, 0.5, 0.5], axes=("z", "y", "x"))
sequence = Sequence(
[
half_pixel_offset.inverse(),
# half_pixel_offset.inverse(),
scale,
translation,
corrected_affine.inverse(),
half_pixel_offset,
# half_pixel_offset,
]
+ extra
)
dims = get_dims(data)
matrix = sequence.to_affine_matrix(input_axes=target_axes, output_axes=dims)

# get output shape
Expand All @@ -437,8 +467,6 @@ def _(
else:
raise ValueError(f"Unsupported schema {schema}")

# TODO: adjust matrix
# TODO: add c
# resample the image
transformed_dask = dask_image.ndinterp.affine_transform(
xdata.data,
Expand All @@ -447,12 +475,28 @@ def _(
# output_chunks=xdata.data.chunks,
**kwargs,
)
# ##
# # debug code
# crop = xdata.sel(
# {
# "x": slice(min_coordinate[axes.index("x")], max_coordinate[axes.index("x")]),
# "y": slice(min_coordinate[axes.index("y")], max_coordinate[axes.index("y")]),
# }
# )
# import matplotlib.pyplot as plt
# plt.figure(figsize=(20, 10))
# plt.subplot(1, 2, 1)
# plt.imshow(crop.transpose("y", "x", "c").data)
# plt.subplot(1, 2, 2)
# plt.imshow(DataArray(transformed_dask, dims=xdata.dims).transpose("y", "x", "c").data)
# plt.show()
# ##
assert isinstance(transformed_dask, DaskArray)
transformed_data = schema.parse(transformed_dask, dims=xdata.dims) # type: ignore[call-arg,arg-type]
if target_coordinate_system != "global":
remove_transformation(transformed_data, "global")

sequence = Sequence([half_pixel_offset.inverse(), scale, translation])
sequence = Sequence([half_pixel_offset.inverse(), scale, translation, half_pixel_offset])
set_transformation(transformed_data, sequence, target_coordinate_system)

transformed_data = compute_coordinates(transformed_data)
Expand Down
4 changes: 2 additions & 2 deletions spatialdata/_core/_spatial_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def _(
)
new_elements[element_type] = queried_elements

if filter_table:
if filter_table and sdata.table is not None:
to_keep = np.array([False] * len(sdata.table))
region_key = sdata.table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY_KEY]
instance_key = sdata.table.uns[TableModel.ATTRS_KEY][TableModel.INSTANCE_KEY]
Expand Down Expand Up @@ -333,7 +333,7 @@ def _(
@bounding_box_query.register(SpatialImage)
@bounding_box_query.register(MultiscaleSpatialImage)
def _(
image: SpatialImage,
image: Union[SpatialImage, MultiscaleSpatialImage],
axes: tuple[str, ...],
min_coordinate: Union[list[Number], ArrayLike],
max_coordinate: Union[list[Number], ArrayLike],
Expand Down
Loading

0 comments on commit ff458c8

Please sign in to comment.