Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Displacement field bridge ITK-MONAI #8

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions src/bridge_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import itk
import torch
import numpy as np
from monai.data import ITKReader
from monai.data.meta_tensor import MetaTensor
from monai.transforms import EnsureChannelFirst
from monai.utils import convert_to_dst_type


def metatensor_to_array(metatensor):
metatensor = metatensor.squeeze()
metatensor = metatensor.permute(*torch.arange(metatensor.ndim - 1, -1, -1))

return metatensor.get_array()


def image_to_metatensor(image):
"""
Converts an ITK image to a MetaTensor object.

Args:
image: The ITK image to be converted.

Returns:
A MetaTensor object containing the array data and metadata.
"""
reader = ITKReader(affine_lps_to_ras=False)
image_array, meta_data = reader.get_data(image)
image_array = convert_to_dst_type(image_array, dst=image_array, dtype=itk.D)[0]
metatensor = MetaTensor.ensure_torch_and_prune_meta(image_array, meta_data)
metatensor = EnsureChannelFirst()(metatensor)

return metatensor



def remove_border(image):
"""
MONAI seems to have different behavior in the borders of the image than ITK.
This helper function sets the border of the ITK image as 0 (padding but keeping
the same image size) in order to allow numerical comparison between the
result from resampling with ITK/Elastix and resampling with MONAI.
To use: image[:] = remove_border(image)
Args:
image: The ITK image to be padded.

Returns:
The padded array of data.
"""
return np.pad(image[1:-1, 1:-1, 1:-1] if image.ndim==3 else image[1:-1, 1:-1],
pad_width=1)
82 changes: 82 additions & 0 deletions src/itk_torch_ddf_bridge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import itk
import torch
import monai
import numpy as np
import matplotlib.pyplot as plt

def monai_to_itk_ddf(image, ddf):
"""
converting the dense displacement field from the MONAI space to the ITK
Args:
image: itk image of array shape 2D: (H, W) or 3D: (D, H, W)
ddf: numpy array of shape 2D: (2, H, W) or 3D: (3, D, H, W)
Returns:
displacement_field: itk image of the corresponding displacement field

"""
# 3, D, H, W -> D, H, W, 3
ndim = image.ndim
ddf = ddf.transpose(tuple(list(range(1, ndim+1)) + [0]))
# x, y, z -> z, x, y
ddf = ddf[..., ::-1]

# Correct for spacing
spacing = np.asarray(image.GetSpacing(), dtype=np.float64)
ddf *= np.array(spacing, ndmin=ndim+1)

# Correct for direction
direction = np.asarray(image.GetDirection(), dtype=np.float64)
ddf = np.einsum('ij,...j->...i', direction, ddf, dtype=np.float64).astype(np.float32)

# initialise displacement field -
vector_component_type = itk.F
vector_pixel_type = itk.Vector[vector_component_type, ndim]
displacement_field_type = itk.Image[vector_pixel_type, ndim]
displacement_field = itk.GetImageFromArray(ddf, ttype=displacement_field_type)

# Set image metadata
displacement_field.SetSpacing(image.GetSpacing())
displacement_field.SetOrigin(image.GetOrigin())
displacement_field.SetDirection(image.GetDirection())

return displacement_field


def itk_warp(image, ddf):
"""
warping with python itk
Args:
image: itk image of array shape 2D: (H, W) or 3D: (D, H, W)
ddf: numpy array of shape 2D: (2, H, W) or 3D: (3, D, H, W)
Returns:
warped_image: numpy array of shape (H, W) or (D, H, W)
"""
# MONAI->ITK ddf
displacement_field = monai_to_itk_ddf(image, ddf)

# Resample using the ddf
interpolator = itk.LinearInterpolateImageFunction.New(image)
warped_image = itk.warp_image_filter(image,
interpolator=interpolator,
displacement_field=displacement_field,
output_parameters_from_image=image)

return np.asarray(warped_image)


def monai_wrap(image_tensor, ddf_tensor):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo wrap -> warp ?

"""
warping with MONAI
Args:
image_tensor: torch tensor of shape 2D: (1, 1, H, W) and 3D: (1, 1, D, H, W)
ddf_tensor: torch tensor of shape 2D: (1, 2, H, W) and 3D: (1, 3, D, H, W)
Returns:
warped_image: numpy array of shape (H, W) or (D, H, W)
"""
warp = monai.networks.blocks.Warp(mode='bilinear', padding_mode="zeros")
warped_image = warp(image_tensor.to(torch.float64), ddf_tensor.to(torch.float64))

return warped_image.to(torch.float32).squeeze().numpy()



17 changes: 17 additions & 0 deletions src/run_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from test_cases import *
import test_utils

test_utils.download_test_data()

filepath_2D = str(test_utils.TEST_DATA_DIR / 'CT_2D_head_fixed.mha')
filepath_3D = str(test_utils.TEST_DATA_DIR / 'copd1_highres_INSP_STD_COPD_img.nii.gz')

# 2D cases
test_random_array(ndim=2)
test_real_data(filepath=filepath_2D)

# 3D cases
test_random_array(ndim=3)
test_real_data(filepath=filepath_3D)


66 changes: 66 additions & 0 deletions src/test_cases.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from itk_torch_ddf_bridge import *
from bridge_utils import remove_border

def test_random_array(ndim):
print("\nTest: Random array with random spacing, direction and origin, ndim={}".format(ndim))

# Create image/array with random size and pixel intensities
s = torch.randint(low=2, high=20, size=(ndim,))
img = 100 * torch.rand((1, 1, *s.tolist()), dtype=torch.float32)

# Pad at the edges because ITK and MONAI have different behavior there
# during resampling
img = torch.nn.functional.pad(img, pad=ndim*(1, 1))
ddf = 5 * torch.rand((1, ndim, *img.shape[-ndim:]), dtype=torch.float32) - 2.5

# Warp with MONAI
img_resampled = monai_wrap(img, ddf)

# Create ITK image
itk_img = itk.GetImageFromArray(img.squeeze().numpy())

# Set random spacing
spacing = 3 * np.random.rand(ndim)
itk_img.SetSpacing(spacing)

# Set random direction
direction = 5 * np.random.rand(ndim, ndim) - 5
direction = itk.matrix_from_array(direction)
itk_img.SetDirection(direction)

# Set random origin
origin = 100 * np.random.rand(ndim) - 100
itk_img.SetOrigin(origin)

# Warp with ITK
itk_img_resampled = itk_warp(itk_img, ddf.squeeze().numpy())

# Compare
print("All close: ", np.allclose(img_resampled, itk_img_resampled, rtol=1e-3, atol=1e-3))
diff = img_resampled - itk_img_resampled
print(diff.min(), diff.max())


def test_real_data(filepath):
print("\nTEST: Real data with random deformation field")
# Read image
image = itk.imread(filepath, itk.F)
image[:] = remove_border(image)
ndim = image.ndim

# Random ddf
ddf = 10 * torch.rand((1, ndim, *image.shape), dtype=torch.float32) - 10

# Warp with MONAI
image_tensor = torch.tensor(itk.GetArrayFromImage(image), dtype=torch.float32).unsqueeze(0).unsqueeze(0)
img_resampled = monai_wrap(image_tensor, ddf)

# Warp with ITK
itk_img_resampled = itk_warp(image, ddf.squeeze().numpy())

# Compare
print("All close: ", np.allclose(img_resampled, itk_img_resampled))
diff = img_resampled - itk_img_resampled
print(diff.min(), diff.max())


18 changes: 18 additions & 0 deletions src/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import pathlib
import subprocess
# import sys

TEST_DATA_DIR = pathlib.Path(__file__).parent.parent / "test_files"

def download_test_data():
subprocess.run(
[
"girder-client",
"--api-url",
"https://data.kitware.com/api/v1",
"localsync",
"62a0efe5bddec9d0c4175c1f",
str(TEST_DATA_DIR),
],
#stdout=sys.stdout,
)
Binary file added test_files/CT_2D_head_fixed.mha
Binary file not shown.
Binary file added test_files/CT_2D_head_moving.mha
Binary file not shown.