diff --git a/monai/data/itk_torch_affine_matrix_bridge.py b/monai/data/itk_torch_affine_matrix_bridge.py index 65853ffde39..fe5723b67b7 100644 --- a/monai/data/itk_torch_affine_matrix_bridge.py +++ b/monai/data/itk_torch_affine_matrix_bridge.py @@ -19,6 +19,7 @@ from monai.config import NdarrayOrTensor from monai.data import ITKReader from monai.data.meta_tensor import MetaTensor +from monai.networks.blocks import Warp from monai.transforms import Affine, EnsureChannelFirst from monai.utils import convert_to_dst_type, optional_import @@ -272,4 +273,78 @@ def monai_affine_resample(metatensor: MetaTensor, affine_matrix: NdarrayOrTensor affine = Affine(affine=affine_matrix, padding_mode="zeros", mode="bilinear", dtype=torch.float64, image_only=True) output_tensor = cast(MetaTensor, affine(metatensor)) - return output_tensor.squeeze().permute(*torch.arange(output_tensor.ndim - 2, -1, -1)).array + return cast(MetaTensor, output_tensor.squeeze().permute(*torch.arange(output_tensor.ndim - 2, -1, -1))).array + + +def monai_to_itk_ddf(image, ddf): + """ + converting the dense displacement field from the MONAI space to the ITK + Args: + image: itk image of array shape 2D: (H, W) or 3D: (D, H, W) + ddf: numpy array of shape 2D: (2, H, W) or 3D: (3, D, H, W) + Returns: + displacement_field: itk image of the corresponding displacement field + + """ + # 3, D, H, W -> D, H, W, 3 + ndim = image.ndim + ddf = ddf.transpose(tuple(list(range(1, ndim + 1)) + [0])) + # x, y, z -> z, x, y + ddf = ddf[..., ::-1] + + # Correct for spacing + spacing = np.asarray(image.GetSpacing(), dtype=np.float64) + ddf *= np.array(spacing, ndmin=ndim + 1) + + # Correct for direction + direction = np.asarray(image.GetDirection(), dtype=np.float64) + ddf = np.einsum("ij,...j->...i", direction, ddf, dtype=np.float64).astype(np.float32) + + # initialise displacement field - + vector_component_type = itk.F + vector_pixel_type = itk.Vector[vector_component_type, ndim] + displacement_field_type = itk.Image[vector_pixel_type, ndim] + displacement_field = itk.GetImageFromArray(ddf, ttype=displacement_field_type) + + # Set image metadata + displacement_field.SetSpacing(image.GetSpacing()) + displacement_field.SetOrigin(image.GetOrigin()) + displacement_field.SetDirection(image.GetDirection()) + + return displacement_field + + +def itk_warp(image, ddf): + """ + warping with python itk + Args: + image: itk image of array shape 2D: (H, W) or 3D: (D, H, W) + ddf: numpy array of shape 2D: (2, H, W) or 3D: (3, D, H, W) + Returns: + warped_image: numpy array of shape (H, W) or (D, H, W) + """ + # MONAI -> ITK ddf + displacement_field = monai_to_itk_ddf(image, ddf) + + # Resample using the ddf + interpolator = itk.LinearInterpolateImageFunction.New(image) + warped_image = itk.warp_image_filter( + image, interpolator=interpolator, displacement_field=displacement_field, output_parameters_from_image=image + ) + + return np.asarray(warped_image) + + +def monai_wrap(image_tensor, ddf_tensor): + """ + warping with MONAI + Args: + image_tensor: torch tensor of shape 2D: (1, 1, H, W) and 3D: (1, 1, D, H, W) + ddf_tensor: torch tensor of shape 2D: (1, 2, H, W) and 3D: (1, 3, D, H, W) + Returns: + warped_image: numpy array of shape (H, W) or (D, H, W) + """ + warp = Warp(mode="bilinear", padding_mode="zeros") + warped_image = warp(image_tensor.to(torch.float64), ddf_tensor.to(torch.float64)) + + return warped_image.to(torch.float32).squeeze().numpy() diff --git a/tests/test_itk_torch_affine_matrix_bridge.py b/tests/test_itk_torch_affine_matrix_bridge.py index f2cad928fad..5127732f5c6 100644 --- a/tests/test_itk_torch_affine_matrix_bridge.py +++ b/tests/test_itk_torch_affine_matrix_bridge.py @@ -18,27 +18,42 @@ import torch from parameterized import parameterized +from monai.apps import download_url from monai.data import ITKReader from monai.data.itk_torch_affine_matrix_bridge import ( create_itk_affine_from_parameters, - itk_image_to_metatensor, itk_affine_resample, + itk_image_to_metatensor, itk_to_monai_affine, + itk_warp, monai_affine_resample, monai_to_itk_affine, + monai_wrap, ) from monai.utils import optional_import +from tests.utils import skip_if_downloading_fails, testing_data_config itk, has_itk = optional_import("itk") TESTS = ["CT_2D_head_fixed.mha", "CT_2D_head_moving.mha", "copd1_highres_INSP_STD_COPD_img.nii.gz"] -# Download URL: -# SHA-521: 60193cd6ef0cf055c623046446b74f969a2be838444801bd32ad5bedc8a7eeec -# b343e8a1208769c9c7a711e101c806a3133eccdda7790c551a69a64b9b3701e9 -# TEST_CASE_3D_1 = [ -# "copd1_highres_INSP_STD_COPD_img.nii.gz" # https://data.kitware.com/api/v1/file/62a0f067bddec9d0c4175c5a/download -# "copd1_highres_INSP_STD_COPD_img.nii.gz" # https://data.kitware.com/api/v1/item/62a0f045bddec9d0c4175c44/download -# ] + +key = "copd1_highres_INSP_STD_COPD_img" +FILE_PATH = os.path.join(os.path.dirname(__file__), "testing_data", f"{key}.nii.gz") + +def remove_border(image): + """ + MONAI seems to have different behavior in the borders of the image than ITK. + This helper function sets the border of the ITK image as 0 (padding but keeping + the same image size) in order to allow numerical comparison between the + result from resampling with ITK/Elastix and resampling with MONAI. + To use: image[:] = remove_border(image) + Args: + image: The ITK image to be padded. + + Returns: + The padded array of data. + """ + return np.pad(image[1:-1, 1:-1, 1:-1] if image.ndim == 3 else image[1:-1, 1:-1], pad_width=1) @unittest.skipUnless(has_itk, "Requires `itk` package.") @@ -48,17 +63,16 @@ def setUp(self): self.data_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "testing_data") self.reader = ITKReader() - # for filepath in TEST_CASES: - # if not os.path.exists(os.path.join(self.data_dir, filepath)): - # with skip_if_downloading_fails(): - # data_spec = testing_data_config("images", filepath) - # download_and_extract( - # data_spec["url"], - # filepath, - # self.data_dir, - # #hash_val=data_spec["hash_val"], - # #hash_type=data_spec["hash_type"], - # ) + for k, n in ((key, FILE_PATH),): + if not os.path.exists(n): + with skip_if_downloading_fails(): + data_spec = testing_data_config("images", f"{k}") + download_url( + data_spec["url"], + n, + hash_val=data_spec["hash_val"], + hash_type=data_spec["hash_type"], + ) @parameterized.expand(TESTS) def test_setting_affine_parameters(self, filepath): @@ -160,7 +174,6 @@ def test_arbitary_center_of_rotation(self, filepath): @parameterized.expand(TESTS) def test_monai_to_itk(self, filepath): - print("\nTEST: MONAI affine matrix -> ITK matrix + translation vector -> transform") # Read image image = self.reader.read(os.path.join(self.data_dir, filepath)) @@ -244,6 +257,66 @@ def test_cyclic_conversion(self, filepath): np.testing.assert_allclose(matrix, matrix_result) np.testing.assert_allclose(translation, translation_result) + @parameterized.expand([(2,), (3,)]) + def test_random_array(self, ndim): + # Create image/array with random size and pixel intensities + s = torch.randint(low=2, high=20, size=(ndim,)) + img = 100 * torch.rand((1, 1, *s.tolist()), dtype=torch.float32) + + # Pad at the edges because ITK and MONAI have different behavior there + # during resampling + img = torch.nn.functional.pad(img, pad=ndim * (1, 1)) + ddf = 5 * torch.rand((1, ndim, *img.shape[-ndim:]), dtype=torch.float32) - 2.5 + + # Warp with MONAI + img_resampled = monai_wrap(img, ddf) + + # Create ITK image + itk_img = itk.GetImageFromArray(img.squeeze().numpy()) + + # Set random spacing + spacing = 3 * np.random.rand(ndim) + itk_img.SetSpacing(spacing) + + # Set random direction + direction = 5 * np.random.rand(ndim, ndim) - 5 + direction = itk.matrix_from_array(direction) + itk_img.SetDirection(direction) + + # Set random origin + origin = 100 * np.random.rand(ndim) - 100 + itk_img.SetOrigin(origin) + + # Warp with ITK + itk_img_resampled = itk_warp(itk_img, ddf.squeeze().numpy()) + + # Compare + np.testing.assert_allclose(img_resampled, itk_img_resampled, rtol=1e-3, atol=1e-3) + diff_output = img_resampled - itk_img_resampled + print(f"[Min, Max] diff: [{diff_output.min()}, {diff_output.max()}]") + + @parameterized.expand(TESTS) + def test_real_data(self, filepath): + # Read image + image = itk.imread(os.path.join(self.data_dir, filepath), itk.F) + image[:] = remove_border(image) + ndim = image.ndim + + # Random ddf + ddf = 10 * torch.rand((1, ndim, *image.shape), dtype=torch.float32) - 10 + + # Warp with MONAI + image_tensor = torch.tensor(itk.GetArrayFromImage(image), dtype=torch.float32).unsqueeze(0).unsqueeze(0) + img_resampled = monai_wrap(image_tensor, ddf) + + # Warp with ITK + itk_img_resampled = itk_warp(image, ddf.squeeze().numpy()) + + # Compare + np.testing.assert_allclose(img_resampled, itk_img_resampled, rtol=1e-3, atol=1e-3) + diff_output = img_resampled - itk_img_resampled + print(f"[Min, Max] diff: [{diff_output.min()}, {diff_output.max()}]") + if __name__ == "__main__": unittest.main() diff --git a/tests/testing_data/data_config.json b/tests/testing_data/data_config.json index 788d6644397..f1559d54399 100644 --- a/tests/testing_data/data_config.json +++ b/tests/testing_data/data_config.json @@ -54,6 +54,11 @@ "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/MNI152_T1_2mm_strucseg.nii.gz", "hash_type": "sha256", "hash_val": "eb4f1e596ca85aadaefc359d409fb9a3e27d733e6def04b996953b7c54bc26d4" + }, + "copd1_highres_INSP_STD_COPD_img": { + "url": "https://data.kitware.com/api/v1/file/62a0f067bddec9d0c4175c5a/download", + "hash_type": "sha512", + "hash_val": "60193cd6ef0cf055c623046446b74f969a2be838444801bd32ad5bedc8a7eeecb343e8a1208769c9c7a711e101c806a3133eccdda7790c551a69a64b9b3701e9" } }, "videos": {