Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
…and automatically download test data

Signed-off-by: Felix Schnabel <[email protected]>
  • Loading branch information
Shadow-Devil committed Feb 7, 2023
1 parent b080bc3 commit 5af9094
Show file tree
Hide file tree
Showing 3 changed files with 174 additions and 21 deletions.
77 changes: 76 additions & 1 deletion monai/data/itk_torch_affine_matrix_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from monai.config import NdarrayOrTensor
from monai.data import ITKReader
from monai.data.meta_tensor import MetaTensor
from monai.networks.blocks import Warp
from monai.transforms import Affine, EnsureChannelFirst
from monai.utils import convert_to_dst_type, optional_import

Expand Down Expand Up @@ -272,4 +273,78 @@ def monai_affine_resample(metatensor: MetaTensor, affine_matrix: NdarrayOrTensor
affine = Affine(affine=affine_matrix, padding_mode="zeros", mode="bilinear", dtype=torch.float64, image_only=True)
output_tensor = cast(MetaTensor, affine(metatensor))

return output_tensor.squeeze().permute(*torch.arange(output_tensor.ndim - 2, -1, -1)).array
return cast(MetaTensor, output_tensor.squeeze().permute(*torch.arange(output_tensor.ndim - 2, -1, -1))).array


def monai_to_itk_ddf(image, ddf):
"""
converting the dense displacement field from the MONAI space to the ITK
Args:
image: itk image of array shape 2D: (H, W) or 3D: (D, H, W)
ddf: numpy array of shape 2D: (2, H, W) or 3D: (3, D, H, W)
Returns:
displacement_field: itk image of the corresponding displacement field
"""
# 3, D, H, W -> D, H, W, 3
ndim = image.ndim
ddf = ddf.transpose(tuple(list(range(1, ndim + 1)) + [0]))
# x, y, z -> z, x, y
ddf = ddf[..., ::-1]

# Correct for spacing
spacing = np.asarray(image.GetSpacing(), dtype=np.float64)
ddf *= np.array(spacing, ndmin=ndim + 1)

# Correct for direction
direction = np.asarray(image.GetDirection(), dtype=np.float64)
ddf = np.einsum("ij,...j->...i", direction, ddf, dtype=np.float64).astype(np.float32)

# initialise displacement field -
vector_component_type = itk.F
vector_pixel_type = itk.Vector[vector_component_type, ndim]
displacement_field_type = itk.Image[vector_pixel_type, ndim]
displacement_field = itk.GetImageFromArray(ddf, ttype=displacement_field_type)

# Set image metadata
displacement_field.SetSpacing(image.GetSpacing())
displacement_field.SetOrigin(image.GetOrigin())
displacement_field.SetDirection(image.GetDirection())

return displacement_field


def itk_warp(image, ddf):
"""
warping with python itk
Args:
image: itk image of array shape 2D: (H, W) or 3D: (D, H, W)
ddf: numpy array of shape 2D: (2, H, W) or 3D: (3, D, H, W)
Returns:
warped_image: numpy array of shape (H, W) or (D, H, W)
"""
# MONAI -> ITK ddf
displacement_field = monai_to_itk_ddf(image, ddf)

# Resample using the ddf
interpolator = itk.LinearInterpolateImageFunction.New(image)
warped_image = itk.warp_image_filter(
image, interpolator=interpolator, displacement_field=displacement_field, output_parameters_from_image=image
)

return np.asarray(warped_image)


def monai_wrap(image_tensor, ddf_tensor):
"""
warping with MONAI
Args:
image_tensor: torch tensor of shape 2D: (1, 1, H, W) and 3D: (1, 1, D, H, W)
ddf_tensor: torch tensor of shape 2D: (1, 2, H, W) and 3D: (1, 3, D, H, W)
Returns:
warped_image: numpy array of shape (H, W) or (D, H, W)
"""
warp = Warp(mode="bilinear", padding_mode="zeros")
warped_image = warp(image_tensor.to(torch.float64), ddf_tensor.to(torch.float64))

return warped_image.to(torch.float32).squeeze().numpy()
113 changes: 93 additions & 20 deletions tests/test_itk_torch_affine_matrix_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,27 +18,42 @@
import torch
from parameterized import parameterized

from monai.apps import download_url
from monai.data import ITKReader
from monai.data.itk_torch_affine_matrix_bridge import (
create_itk_affine_from_parameters,
itk_image_to_metatensor,
itk_affine_resample,
itk_image_to_metatensor,
itk_to_monai_affine,
itk_warp,
monai_affine_resample,
monai_to_itk_affine,
monai_wrap,
)
from monai.utils import optional_import
from tests.utils import skip_if_downloading_fails, testing_data_config

itk, has_itk = optional_import("itk")

TESTS = ["CT_2D_head_fixed.mha", "CT_2D_head_moving.mha", "copd1_highres_INSP_STD_COPD_img.nii.gz"]
# Download URL:
# SHA-521: 60193cd6ef0cf055c623046446b74f969a2be838444801bd32ad5bedc8a7eeec
# b343e8a1208769c9c7a711e101c806a3133eccdda7790c551a69a64b9b3701e9
# TEST_CASE_3D_1 = [
# "copd1_highres_INSP_STD_COPD_img.nii.gz" # https://data.kitware.com/api/v1/file/62a0f067bddec9d0c4175c5a/download
# "copd1_highres_INSP_STD_COPD_img.nii.gz" # https://data.kitware.com/api/v1/item/62a0f045bddec9d0c4175c44/download
# ]

key = "copd1_highres_INSP_STD_COPD_img"
FILE_PATH = os.path.join(os.path.dirname(__file__), "testing_data", f"{key}.nii.gz")

def remove_border(image):
"""
MONAI seems to have different behavior in the borders of the image than ITK.
This helper function sets the border of the ITK image as 0 (padding but keeping
the same image size) in order to allow numerical comparison between the
result from resampling with ITK/Elastix and resampling with MONAI.
To use: image[:] = remove_border(image)
Args:
image: The ITK image to be padded.
Returns:
The padded array of data.
"""
return np.pad(image[1:-1, 1:-1, 1:-1] if image.ndim == 3 else image[1:-1, 1:-1], pad_width=1)


@unittest.skipUnless(has_itk, "Requires `itk` package.")
Expand All @@ -48,17 +63,16 @@ def setUp(self):
self.data_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "testing_data")
self.reader = ITKReader()

# for filepath in TEST_CASES:
# if not os.path.exists(os.path.join(self.data_dir, filepath)):
# with skip_if_downloading_fails():
# data_spec = testing_data_config("images", filepath)
# download_and_extract(
# data_spec["url"],
# filepath,
# self.data_dir,
# #hash_val=data_spec["hash_val"],
# #hash_type=data_spec["hash_type"],
# )
for k, n in ((key, FILE_PATH),):
if not os.path.exists(n):
with skip_if_downloading_fails():
data_spec = testing_data_config("images", f"{k}")
download_url(
data_spec["url"],
n,
hash_val=data_spec["hash_val"],
hash_type=data_spec["hash_type"],
)

@parameterized.expand(TESTS)
def test_setting_affine_parameters(self, filepath):
Expand Down Expand Up @@ -160,7 +174,6 @@ def test_arbitary_center_of_rotation(self, filepath):

@parameterized.expand(TESTS)
def test_monai_to_itk(self, filepath):
print("\nTEST: MONAI affine matrix -> ITK matrix + translation vector -> transform")
# Read image
image = self.reader.read(os.path.join(self.data_dir, filepath))

Expand Down Expand Up @@ -244,6 +257,66 @@ def test_cyclic_conversion(self, filepath):
np.testing.assert_allclose(matrix, matrix_result)
np.testing.assert_allclose(translation, translation_result)

@parameterized.expand([(2,), (3,)])
def test_random_array(self, ndim):
# Create image/array with random size and pixel intensities
s = torch.randint(low=2, high=20, size=(ndim,))
img = 100 * torch.rand((1, 1, *s.tolist()), dtype=torch.float32)

# Pad at the edges because ITK and MONAI have different behavior there
# during resampling
img = torch.nn.functional.pad(img, pad=ndim * (1, 1))
ddf = 5 * torch.rand((1, ndim, *img.shape[-ndim:]), dtype=torch.float32) - 2.5

# Warp with MONAI
img_resampled = monai_wrap(img, ddf)

# Create ITK image
itk_img = itk.GetImageFromArray(img.squeeze().numpy())

# Set random spacing
spacing = 3 * np.random.rand(ndim)
itk_img.SetSpacing(spacing)

# Set random direction
direction = 5 * np.random.rand(ndim, ndim) - 5
direction = itk.matrix_from_array(direction)
itk_img.SetDirection(direction)

# Set random origin
origin = 100 * np.random.rand(ndim) - 100
itk_img.SetOrigin(origin)

# Warp with ITK
itk_img_resampled = itk_warp(itk_img, ddf.squeeze().numpy())

# Compare
np.testing.assert_allclose(img_resampled, itk_img_resampled, rtol=1e-3, atol=1e-3)
diff_output = img_resampled - itk_img_resampled
print(f"[Min, Max] diff: [{diff_output.min()}, {diff_output.max()}]")

@parameterized.expand(TESTS)
def test_real_data(self, filepath):
# Read image
image = itk.imread(os.path.join(self.data_dir, filepath), itk.F)
image[:] = remove_border(image)
ndim = image.ndim

# Random ddf
ddf = 10 * torch.rand((1, ndim, *image.shape), dtype=torch.float32) - 10

# Warp with MONAI
image_tensor = torch.tensor(itk.GetArrayFromImage(image), dtype=torch.float32).unsqueeze(0).unsqueeze(0)
img_resampled = monai_wrap(image_tensor, ddf)

# Warp with ITK
itk_img_resampled = itk_warp(image, ddf.squeeze().numpy())

# Compare
np.testing.assert_allclose(img_resampled, itk_img_resampled, rtol=1e-3, atol=1e-3)
diff_output = img_resampled - itk_img_resampled
print(f"[Min, Max] diff: [{diff_output.min()}, {diff_output.max()}]")


if __name__ == "__main__":
unittest.main()
5 changes: 5 additions & 0 deletions tests/testing_data/data_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@
"url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/MNI152_T1_2mm_strucseg.nii.gz",
"hash_type": "sha256",
"hash_val": "eb4f1e596ca85aadaefc359d409fb9a3e27d733e6def04b996953b7c54bc26d4"
},
"copd1_highres_INSP_STD_COPD_img": {
"url": "https://data.kitware.com/api/v1/file/62a0f067bddec9d0c4175c5a/download",
"hash_type": "sha512",
"hash_val": "60193cd6ef0cf055c623046446b74f969a2be838444801bd32ad5bedc8a7eeecb343e8a1208769c9c7a711e101c806a3133eccdda7790c551a69a64b9b3701e9"
}
},
"videos": {
Expand Down

0 comments on commit 5af9094

Please sign in to comment.