Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] torch DataSet + utils #145

Merged
merged 23 commits into from
Mar 14, 2023
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
a3a1a52
add sdata-> data dict transform
kevinyamauchi Feb 20, 2023
b5fa7c6
add initial dataset
kevinyamauchi Feb 20, 2023
479d7e2
fix typos
kevinyamauchi Feb 20, 2023
a95ab3c
Merge branch 'main' into torch-dataloader
kevinyamauchi Feb 20, 2023
42706ff
add shapes to dataset
kevinyamauchi Feb 20, 2023
e0bb5d8
start multislide
kevinyamauchi Feb 22, 2023
3f45b3b
Merge branch 'main' into torch-dataloader
LucaMarconato Mar 3, 2023
9337549
wip, need to merge with rasterize branch
LucaMarconato Mar 3, 2023
3ab9398
Merge branch 'feature/rasterize' into torch-dataloader
LucaMarconato Mar 3, 2023
8902390
wip tiling
LucaMarconato Mar 6, 2023
c6bee89
added __set_item__() and merge branch 'main' into torch-dataloader
LucaMarconato Mar 7, 2023
a307f1b
tiling still wip, but usable
LucaMarconato Mar 8, 2023
354fa3f
Merge branch 'main' into torch-dataloader
LucaMarconato Mar 8, 2023
2125bec
fixed mypy
LucaMarconato Mar 8, 2023
9c2cf75
type fix
LucaMarconato Mar 9, 2023
656616b
fixed bug with xarray coordinates in multiscale, fixed wrong centroids
LucaMarconato Mar 9, 2023
db62b71
Apply suggestions from code review
LucaMarconato Mar 14, 2023
edf71bd
implemented suggestions from code review
LucaMarconato Mar 14, 2023
b37691d
Merge branch 'torch-dataloader' of https://github.com/kevinyamauchi/s…
LucaMarconato Mar 14, 2023
a7228c2
Merge branch 'main' into torch-dataloader
LucaMarconato Mar 14, 2023
7a27ef0
fixed test
LucaMarconato Mar 14, 2023
6f4aa7c
removed numpy=1.22 contraint for mypy
LucaMarconato Mar 14, 2023
41d818a
mypy now using numpy==1.24
LucaMarconato Mar 14, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
LucaMarconato marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ default_stages:
- push
minimum_pre_commit_version: 2.16.0
ci:
skip: [mypy]
skip: []
repos:
- repo: https://github.com/psf/black
rev: 23.1.0
Expand Down
3 changes: 0 additions & 3 deletions examples/dev-examples/README.md

This file was deleted.

115 changes: 0 additions & 115 deletions examples/dev-examples/spatial_query_and_rasterization.py

This file was deleted.

3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ test = [
"pytest",
"pytest-cov",
]
optional = [
"torch"
]

[tool.coverage.run]
source = ["spatialdata"]
Expand Down
8 changes: 8 additions & 0 deletions spatialdata/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from importlib.metadata import version
from typing import Union

__version__ = version("spatialdata")

Expand Down Expand Up @@ -40,3 +41,10 @@
TableModel,
)
from spatialdata._io.read import read_zarr

try:
from spatialdata._dataloader.datasets import ImageTilesDataset
except ImportError as e:
_error: Union[str, None] = str(e)
else:
_error = None
108 changes: 76 additions & 32 deletions spatialdata/_core/_rasterize.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def _get_xarray_data_to_rasterize(
min_coordinate: Union[list[Number], ArrayLike],
max_coordinate: Union[list[Number], ArrayLike],
target_sizes: dict[str, Optional[float]],
corrected_affine: Affine,
target_coordinate_system: str,
) -> tuple[DataArray, Optional[Scale]]:
"""
Returns the DataArray to rasterize along with its eventual scale factor (if from a pyramid level) from either a
Expand Down Expand Up @@ -289,6 +289,11 @@ def _get_xarray_data_to_rasterize(
xdata = next(iter(v))
assert set(get_spatial_axes(tuple(xdata.sizes.keys()))) == set(axes)

corrected_affine, _ = _get_corrected_affine_matrix(
data=SpatialImage(xdata),
axes=axes,
target_coordinate_system=target_coordinate_system,
)
m = corrected_affine.inverse().matrix # type: ignore[attr-defined]
m_linear = m[:-1, :-1]
m_translation = m[:-1, -1]
Expand All @@ -300,7 +305,7 @@ def _get_xarray_data_to_rasterize(
assert tuple(bb_corners.axis.data.tolist()) == axes
bb_in_xdata = bb_corners.data @ m_linear + m_translation
bb_in_xdata_sizes = {
ax: bb_in_xdata[axes.index(ax)].max() - bb_in_xdata[axes.index(ax)].min() for ax in axes
ax: bb_in_xdata[:, axes.index(ax)].max() - bb_in_xdata[:, axes.index(ax)].min() for ax in axes
}
for ax in axes:
# TLDR; the sqrt selects a pyramid level in which the requested bounding box is a bit larger than the
Expand All @@ -311,9 +316,10 @@ def _get_xarray_data_to_rasterize(
# inverse-transformed bounding box. The sqrt comes from the ratio of the side of a square,
# and the maximum diagonal of a square containing the original square, if the original square is
# rotated.
if bb_in_xdata_sizes[ax] * np.sqrt(len(axes)) < target_sizes[ax]:
if bb_in_xdata_sizes[ax] < target_sizes[ax] * np.sqrt(len(axes)):
break
else:
# when this code is reached, latest_scale is selected
break
assert latest_scale is not None
xdata = next(iter(data[latest_scale].values()))
Expand All @@ -327,6 +333,45 @@ def _get_xarray_data_to_rasterize(
return xdata, pyramid_scale


def _get_corrected_affine_matrix(
data: Union[SpatialImage, MultiscaleSpatialImage],
axes: tuple[str, ...],
target_coordinate_system: str,
) -> tuple[Affine, tuple[str, ...]]:
"""
Get the affine matrix that maps the intrinsic coordinates of the data to the target_coordinate_system,
with in addition:
- restricting the domain to the axes specified in axes (i.e. the axes for which the bounding box is specified), in
particular axes never contains c;
- restricting the codomain to the spatial axes of the target coordinate system (i.e. excluding c).

We do this because:
- we don't need to consider c
- when we create the target rasterized object, we need to have axes in the order that is requires by the schema
"""
transformation = get_transformation(data, target_coordinate_system)
assert isinstance(transformation, BaseTransformation)
affine = _get_affine_for_element(data, transformation)
target_axes_unordered = affine.output_axes
assert set(target_axes_unordered) in [{"x", "y", "z"}, {"x", "y"}, {"c", "x", "y", "z"}, {"c", "x", "y"}]
target_axes: tuple[str, ...]
if "z" in target_axes_unordered:
if "c" in target_axes_unordered:
target_axes = ("c", "z", "y", "x")
else:
target_axes = ("z", "y", "x")
else:
if "c" in target_axes_unordered:
target_axes = ("c", "y", "x")
else:
target_axes = ("y", "x")
target_spatial_axes = get_spatial_axes(target_axes)
assert len(target_spatial_axes) == len(axes)
assert len(target_spatial_axes) == len(axes)
corrected_affine = affine.to_affine(input_axes=axes, output_axes=target_spatial_axes)
return corrected_affine, target_axes


@rasterize.register(SpatialImage)
@rasterize.register(MultiscaleSpatialImage)
def _(
Expand Down Expand Up @@ -359,29 +404,6 @@ def _(
"z": target_depth,
}

# get inverse transformation
transformation = get_transformation(data, target_coordinate_system)
dims = get_dims(data)
assert isinstance(transformation, BaseTransformation)
affine = _get_affine_for_element(data, transformation)
target_axes_unordered = affine.output_axes
assert set(target_axes_unordered) in [{"x", "y", "z"}, {"x", "y"}, {"c", "x", "y", "z"}, {"c", "x", "y"}]
target_axes: tuple[str, ...]
if "z" in target_axes_unordered:
if "c" in target_axes_unordered:
target_axes = ("c", "z", "y", "x")
else:
target_axes = ("z", "y", "x")
else:
if "c" in target_axes_unordered:
target_axes = ("c", "y", "x")
else:
target_axes = ("y", "x")
target_spatial_axes = get_spatial_axes(target_axes)
assert len(target_spatial_axes) == len(min_coordinate)
assert len(target_spatial_axes) == len(max_coordinate)
corrected_affine = affine.to_affine(input_axes=axes, output_axes=target_spatial_axes)

bb_sizes = {ax: max_coordinate[axes.index(ax)] - min_coordinate[axes.index(ax)] for ax in axes}
scale_vector = [bb_sizes[ax] / target_sizes[ax] for ax in axes]
scale = Scale(scale_vector, axes=axes)
Expand All @@ -395,25 +417,33 @@ def _(
min_coordinate=min_coordinate,
max_coordinate=max_coordinate,
target_sizes=target_sizes,
corrected_affine=corrected_affine,
target_coordinate_system=target_coordinate_system,
)

if pyramid_scale is not None:
extra = [pyramid_scale.inverse()]
else:
extra = []

# get inverse transformation
corrected_affine, target_axes = _get_corrected_affine_matrix(
data=data,
axes=axes,
target_coordinate_system=target_coordinate_system,
)

half_pixel_offset = Translation([0.5, 0.5, 0.5], axes=("z", "y", "x"))
sequence = Sequence(
[
half_pixel_offset.inverse(),
# half_pixel_offset.inverse(),
scale,
translation,
corrected_affine.inverse(),
half_pixel_offset,
# half_pixel_offset,
LucaMarconato marked this conversation as resolved.
Show resolved Hide resolved
]
+ extra
)
dims = get_dims(data)
matrix = sequence.to_affine_matrix(input_axes=target_axes, output_axes=dims)

# get output shape
Expand All @@ -437,8 +467,6 @@ def _(
else:
raise ValueError(f"Unsupported schema {schema}")

# TODO: adjust matrix
# TODO: add c
# resample the image
transformed_dask = dask_image.ndinterp.affine_transform(
xdata.data,
Expand All @@ -447,12 +475,28 @@ def _(
# output_chunks=xdata.data.chunks,
**kwargs,
)
# ##
# # debug code
# crop = xdata.sel(
# {
# "x": slice(min_coordinate[axes.index("x")], max_coordinate[axes.index("x")]),
# "y": slice(min_coordinate[axes.index("y")], max_coordinate[axes.index("y")]),
# }
# )
# import matplotlib.pyplot as plt
# plt.figure(figsize=(20, 10))
# plt.subplot(1, 2, 1)
# plt.imshow(crop.transpose("y", "x", "c").data)
# plt.subplot(1, 2, 2)
# plt.imshow(DataArray(transformed_dask, dims=xdata.dims).transpose("y", "x", "c").data)
# plt.show()
# ##
assert isinstance(transformed_dask, DaskArray)
transformed_data = schema.parse(transformed_dask, dims=xdata.dims) # type: ignore[call-arg,arg-type]
if target_coordinate_system != "global":
remove_transformation(transformed_data, "global")

sequence = Sequence([half_pixel_offset.inverse(), scale, translation])
sequence = Sequence([half_pixel_offset.inverse(), scale, translation, half_pixel_offset])
set_transformation(transformed_data, sequence, target_coordinate_system)

transformed_data = compute_coordinates(transformed_data)
Expand Down
4 changes: 2 additions & 2 deletions spatialdata/_core/_spatial_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def _(
)
new_elements[element_type] = queried_elements

if filter_table:
if filter_table and sdata.table is not None:
to_keep = np.array([False] * len(sdata.table))
region_key = sdata.table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY_KEY]
instance_key = sdata.table.uns[TableModel.ATTRS_KEY][TableModel.INSTANCE_KEY]
Expand Down Expand Up @@ -333,7 +333,7 @@ def _(
@bounding_box_query.register(SpatialImage)
@bounding_box_query.register(MultiscaleSpatialImage)
def _(
image: SpatialImage,
image: Union[SpatialImage, MultiscaleSpatialImage],
axes: tuple[str, ...],
min_coordinate: Union[list[Number], ArrayLike],
max_coordinate: Union[list[Number], ArrayLike],
Expand Down
Loading