Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] torch DataSet + utils #145

Merged
merged 23 commits into from
Mar 14, 2023
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
a3a1a52
add sdata-> data dict transform
kevinyamauchi Feb 20, 2023
b5fa7c6
add initial dataset
kevinyamauchi Feb 20, 2023
479d7e2
fix typos
kevinyamauchi Feb 20, 2023
a95ab3c
Merge branch 'main' into torch-dataloader
kevinyamauchi Feb 20, 2023
42706ff
add shapes to dataset
kevinyamauchi Feb 20, 2023
e0bb5d8
start multislide
kevinyamauchi Feb 22, 2023
3f45b3b
Merge branch 'main' into torch-dataloader
LucaMarconato Mar 3, 2023
9337549
wip, need to merge with rasterize branch
LucaMarconato Mar 3, 2023
3ab9398
Merge branch 'feature/rasterize' into torch-dataloader
LucaMarconato Mar 3, 2023
8902390
wip tiling
LucaMarconato Mar 6, 2023
c6bee89
added __set_item__() and merge branch 'main' into torch-dataloader
LucaMarconato Mar 7, 2023
a307f1b
tiling still wip, but usable
LucaMarconato Mar 8, 2023
354fa3f
Merge branch 'main' into torch-dataloader
LucaMarconato Mar 8, 2023
2125bec
fixed mypy
LucaMarconato Mar 8, 2023
9c2cf75
type fix
LucaMarconato Mar 9, 2023
656616b
fixed bug with xarray coordinates in multiscale, fixed wrong centroids
LucaMarconato Mar 9, 2023
db62b71
Apply suggestions from code review
LucaMarconato Mar 14, 2023
edf71bd
implemented suggestions from code review
LucaMarconato Mar 14, 2023
b37691d
Merge branch 'torch-dataloader' of https://github.com/kevinyamauchi/s…
LucaMarconato Mar 14, 2023
a7228c2
Merge branch 'main' into torch-dataloader
LucaMarconato Mar 14, 2023
7a27ef0
fixed test
LucaMarconato Mar 14, 2023
6f4aa7c
removed numpy=1.22 contraint for mypy
LucaMarconato Mar 14, 2023
41d818a
mypy now using numpy==1.24
LucaMarconato Mar 14, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 104 additions & 0 deletions examples/dev-examples/image_tiles.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
##
# from https://gist.github.com/kevinyamauchi/77f986889b7626db4ab3c1075a3a3e5e

import numpy as np
from matplotlib import pyplot as plt
from skimage import draw

import spatialdata as sd
from spatialdata._dl.datasets import ImageTilesDataset

##
coordinates = np.array([[10, 10], [20, 20], [50, 30], [90, 70], [20, 80]])

radius = 5

colors = np.array(
[
[102, 194, 165],
[252, 141, 98],
[141, 160, 203],
[231, 138, 195],
[166, 216, 84],
]
)

##
# make an image with spots
image = np.zeros((100, 100, 3), dtype=np.uint8)

for spot_color, centroid in zip(colors, coordinates):
rr, cc = draw.disk(centroid, radius=radius)

for color_index in range(3):
channel_dims = color_index * np.ones((len(rr),), dtype=int)
image[rr, cc, channel_dims] = spot_color[color_index]

# plt.imshow(image)
# plt.show()

##
sd_image = sd.Image2DModel.parse(image, dims=("y", "x", "c"))

# circles coordinates are xy, so we flip them here.
circles = sd.ShapesModel.parse(coordinates[:, [1, 0]], radius=radius, geometry=0)
sdata = sd.SpatialData(images={"image": sd_image}, shapes={"spots": circles})
sdata

##
ds = ImageTilesDataset(
sdata=sdata,
regions_to_images={"spots": "image"},
tile_dim_in_units=10,
tile_dim_in_pixels=32,
target_coordinate_system="global",
)

print(f"this dataset as {len(ds)} items")

##
# we can use the __getitem__ interface to get one of the sample crops
print(ds[0])


##
# now we plot all of the crops
def plot_sdata_dataset(ds: ImageTilesDataset) -> None:
n_samples = len(ds)
fig, axs = plt.subplots(1, n_samples)

for i, (image, region, index) in enumerate(ds):
axs[i].imshow(image.transpose("y", "x", "c"))
axs[i].set_title(f"{region}, {index}")
plt.show()


plot_sdata_dataset(ds)

# TODO: code to be restored when the transforms will use the bounding box query
# ##
# # we can also use transforms to automatically extract the relevant data
# # into a datadictionary
#
# # map the SpatialData path to a data dict key
# data_mapping = {"images/image": "image"}
#
# # make the transform
# ds_transform = ImageTilesDataset(
# sdata=sdata,
# spots_element_key="spots",
# transform=SpatialDataToDataDict(data_mapping=data_mapping),
# )
#
# print(f"this dataset as {len(ds_transform)} items")
#
# ##
# # now the samples are a dictionary with key "image" and the item is the
# # image array
# # this is useful because it is the expected format for many of the
# #
# ds_transform[0]
#
# ##
# # plot of each sample in the dataset
# plot_sdata_dataset(ds_transform)
71 changes: 50 additions & 21 deletions examples/dev-examples/spatial_query_and_rasterization.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import numpy as np
from multiscale_spatial_image import MultiscaleSpatialImage
from spatial_image import SpatialImage

from spatialdata import Labels2DModel
Expand All @@ -9,7 +8,7 @@
remove_transformation,
set_transformation,
)
from spatialdata._core.transformations import Affine
from spatialdata._core.transformations import Affine, Scale, Sequence


def _visualize_crop_affine_labels_2d() -> None:
Expand All @@ -31,34 +30,41 @@ def _visualize_crop_affine_labels_2d() -> None:
requested crop, as exaplained above)
5) then enable "3 cropped rotated processed", this shows the data that we wanted to query in the first place,
in the target coordinate system ("rotated"). This is probaly the data you care about if for instance you want to
use tiles for deep learning. Note that for obtaning this answer there is also a better function (not available at
the time of this writing): rasterize(), which is faster and more accurate, so it should be used instead. The
function rasterize() transforms all the coordinates of the data into the target coordinate system, and it returns
only SpatialImage objects. So it has different use cases than the bounding box query.
6) finally switch to the "global" coordinate_system. This is, for how we constructed the example, showing the
use tiles for deep learning.
6) Note that for obtaning the previous answer there is also a better function rasterize().
This is what "4 rasterized" shows, which is faster and more accurate, so it should be used instead. The function
rasterize() transforms all the coordinates of the data into the target coordinate system, and it returns only
SpatialImage objects. So it has different use cases than the bounding box query. BUG: Note that it is not pixel
perfect. I think this is due to the difference between considering the origin of a pixel its center or its corner.
7) finally switch to the "global" coordinate_system. This is, for how we constructed the example, showing the
original image as it would appear its intrinsic coordinate system (since the transformation that maps the
original image to "global" is an identity. It then shows how the data showed at the point 5), localizes in the
original image.
"""
##
# in this test let's try some affine transformations, we could do that also for the other tests
# image = scipy.misc.face()[:100, :100, :].copy()
image = np.random.randint(low=10, high=100, size=(100, 100))
multiscale_image = np.repeat(np.repeat(image, 4, axis=0), 4, axis=1)

# y: [5, 9], x: [0, 4] has value 1
image[50:, :50] = 2
# labels_element = Image2DModel.parse(image, dims=('y', 'x', 'c'))
labels_element = Labels2DModel.parse(image)
affine = Affine(
np.array(
[
[np.cos(np.pi / 6), np.sin(-np.pi / 6), 0],
[np.sin(np.pi / 6), np.cos(np.pi / 6), 0],
[0, 0, 1],
]
),
input_axes=("x", "y"),
output_axes=("x", "y"),
)
set_transformation(
labels_element,
Affine(
np.array(
[
[np.cos(np.pi / 6), np.sin(-np.pi / 6), 20],
[np.sin(np.pi / 6), np.cos(np.pi / 6), 0],
[0, 0, 1],
]
),
input_axes=("x", "y"),
output_axes=("x", "y"),
),
affine,
"rotated",
)

Expand Down Expand Up @@ -91,9 +97,7 @@ def _visualize_crop_affine_labels_2d() -> None:
if labels_result_rotated is not None:
d["2 cropped_rotated"] = labels_result_rotated

assert isinstance(labels_result_rotated, SpatialImage) or isinstance(
labels_result_rotated, MultiscaleSpatialImage
)
assert isinstance(labels_result_rotated, SpatialImage)
transform = labels_result_rotated.attrs["transform"]["rotated"]
transform_rotated_processed = transform.transform(labels_result_rotated, maintain_positioning=True)
transform_rotated_processed_recropped = bounding_box_query(
Expand All @@ -106,7 +110,32 @@ def _visualize_crop_affine_labels_2d() -> None:
d["3 cropped_rotated_processed_recropped"] = transform_rotated_processed_recropped
remove_transformation(labels_result_rotated, "global")

multiscale_image[200:, :200] = 2
# multiscale_labels = Labels2DModel.parse(multiscale_image)
multiscale_labels = Labels2DModel.parse(multiscale_image, scale_factors=[2, 2, 2, 2])
sequence = Sequence([Scale([0.25, 0.25], axes=("x", "y")), affine])
set_transformation(multiscale_labels, sequence, "rotated")

from spatialdata._core._rasterize import rasterize

rasterized = rasterize(
multiscale_labels,
axes=("y", "x"),
min_coordinate=np.array([25, 25]),
max_coordinate=np.array([75, 100]),
target_coordinate_system="rotated",
target_width=300,
)
d["4 rasterized"] = rasterized

sdata = SpatialData(labels=d)

# to see only what matters when debugging https://github.com/scverse/spatialdata/issues/165
del sdata.labels["1 cropped_global"]
del sdata.labels["2 cropped_rotated"]
del sdata.labels["3 cropped_rotated_processed_recropped"]
del sdata.labels["0 original"].attrs["transform"]["global"]

Interactive(sdata)
##

Expand Down
Loading