-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Displacement field bridge ITK-MONAI #8
Draft
ntatsisk
wants to merge
3
commits into
InsightSoftwareConsortium:master
Choose a base branch
from
ntatsisk:ddf_bridge
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
import itk | ||
import torch | ||
import numpy as np | ||
from monai.data import ITKReader | ||
from monai.data.meta_tensor import MetaTensor | ||
from monai.transforms import EnsureChannelFirst | ||
from monai.utils import convert_to_dst_type | ||
|
||
|
||
def metatensor_to_array(metatensor): | ||
metatensor = metatensor.squeeze() | ||
metatensor = metatensor.permute(*torch.arange(metatensor.ndim - 1, -1, -1)) | ||
|
||
return metatensor.get_array() | ||
|
||
|
||
def image_to_metatensor(image): | ||
""" | ||
Converts an ITK image to a MetaTensor object. | ||
|
||
Args: | ||
image: The ITK image to be converted. | ||
|
||
Returns: | ||
A MetaTensor object containing the array data and metadata. | ||
""" | ||
reader = ITKReader(affine_lps_to_ras=False) | ||
image_array, meta_data = reader.get_data(image) | ||
image_array = convert_to_dst_type(image_array, dst=image_array, dtype=itk.D)[0] | ||
metatensor = MetaTensor.ensure_torch_and_prune_meta(image_array, meta_data) | ||
metatensor = EnsureChannelFirst()(metatensor) | ||
|
||
return metatensor | ||
|
||
|
||
|
||
def remove_border(image): | ||
""" | ||
MONAI seems to have different behavior in the borders of the image than ITK. | ||
This helper function sets the border of the ITK image as 0 (padding but keeping | ||
the same image size) in order to allow numerical comparison between the | ||
result from resampling with ITK/Elastix and resampling with MONAI. | ||
To use: image[:] = remove_border(image) | ||
Args: | ||
image: The ITK image to be padded. | ||
|
||
Returns: | ||
The padded array of data. | ||
""" | ||
return np.pad(image[1:-1, 1:-1, 1:-1] if image.ndim==3 else image[1:-1, 1:-1], | ||
pad_width=1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
import itk | ||
import torch | ||
import monai | ||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
|
||
def monai_to_itk_ddf(image, ddf): | ||
""" | ||
converting the dense displacement field from the MONAI space to the ITK | ||
Args: | ||
image: itk image of array shape 2D: (H, W) or 3D: (D, H, W) | ||
ddf: numpy array of shape 2D: (2, H, W) or 3D: (3, D, H, W) | ||
Returns: | ||
displacement_field: itk image of the corresponding displacement field | ||
|
||
""" | ||
# 3, D, H, W -> D, H, W, 3 | ||
ndim = image.ndim | ||
ddf = ddf.transpose(tuple(list(range(1, ndim+1)) + [0])) | ||
# x, y, z -> z, x, y | ||
ddf = ddf[..., ::-1] | ||
|
||
# Correct for spacing | ||
spacing = np.asarray(image.GetSpacing(), dtype=np.float64) | ||
ddf *= np.array(spacing, ndmin=ndim+1) | ||
|
||
# Correct for direction | ||
direction = np.asarray(image.GetDirection(), dtype=np.float64) | ||
ddf = np.einsum('ij,...j->...i', direction, ddf, dtype=np.float64).astype(np.float32) | ||
|
||
# initialise displacement field - | ||
vector_component_type = itk.F | ||
vector_pixel_type = itk.Vector[vector_component_type, ndim] | ||
displacement_field_type = itk.Image[vector_pixel_type, ndim] | ||
displacement_field = itk.GetImageFromArray(ddf, ttype=displacement_field_type) | ||
|
||
# Set image metadata | ||
displacement_field.SetSpacing(image.GetSpacing()) | ||
displacement_field.SetOrigin(image.GetOrigin()) | ||
displacement_field.SetDirection(image.GetDirection()) | ||
|
||
return displacement_field | ||
|
||
|
||
def itk_warp(image, ddf): | ||
""" | ||
warping with python itk | ||
Args: | ||
image: itk image of array shape 2D: (H, W) or 3D: (D, H, W) | ||
ddf: numpy array of shape 2D: (2, H, W) or 3D: (3, D, H, W) | ||
Returns: | ||
warped_image: numpy array of shape (H, W) or (D, H, W) | ||
""" | ||
# MONAI->ITK ddf | ||
displacement_field = monai_to_itk_ddf(image, ddf) | ||
|
||
# Resample using the ddf | ||
interpolator = itk.LinearInterpolateImageFunction.New(image) | ||
warped_image = itk.warp_image_filter(image, | ||
interpolator=interpolator, | ||
displacement_field=displacement_field, | ||
output_parameters_from_image=image) | ||
|
||
return np.asarray(warped_image) | ||
|
||
|
||
def monai_wrap(image_tensor, ddf_tensor): | ||
""" | ||
warping with MONAI | ||
Args: | ||
image_tensor: torch tensor of shape 2D: (1, 1, H, W) and 3D: (1, 1, D, H, W) | ||
ddf_tensor: torch tensor of shape 2D: (1, 2, H, W) and 3D: (1, 3, D, H, W) | ||
Returns: | ||
warped_image: numpy array of shape (H, W) or (D, H, W) | ||
""" | ||
warp = monai.networks.blocks.Warp(mode='bilinear', padding_mode="zeros") | ||
warped_image = warp(image_tensor.to(torch.float64), ddf_tensor.to(torch.float64)) | ||
|
||
return warped_image.to(torch.float32).squeeze().numpy() | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
from test_cases import * | ||
import test_utils | ||
|
||
test_utils.download_test_data() | ||
|
||
filepath_2D = str(test_utils.TEST_DATA_DIR / 'CT_2D_head_fixed.mha') | ||
filepath_3D = str(test_utils.TEST_DATA_DIR / 'copd1_highres_INSP_STD_COPD_img.nii.gz') | ||
|
||
# 2D cases | ||
test_random_array(ndim=2) | ||
test_real_data(filepath=filepath_2D) | ||
|
||
# 3D cases | ||
test_random_array(ndim=3) | ||
test_real_data(filepath=filepath_3D) | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
from itk_torch_ddf_bridge import * | ||
from bridge_utils import remove_border | ||
|
||
def test_random_array(ndim): | ||
print("\nTest: Random array with random spacing, direction and origin, ndim={}".format(ndim)) | ||
|
||
# Create image/array with random size and pixel intensities | ||
s = torch.randint(low=2, high=20, size=(ndim,)) | ||
img = 100 * torch.rand((1, 1, *s.tolist()), dtype=torch.float32) | ||
|
||
# Pad at the edges because ITK and MONAI have different behavior there | ||
# during resampling | ||
img = torch.nn.functional.pad(img, pad=ndim*(1, 1)) | ||
ddf = 5 * torch.rand((1, ndim, *img.shape[-ndim:]), dtype=torch.float32) - 2.5 | ||
|
||
# Warp with MONAI | ||
img_resampled = monai_wrap(img, ddf) | ||
|
||
# Create ITK image | ||
itk_img = itk.GetImageFromArray(img.squeeze().numpy()) | ||
|
||
# Set random spacing | ||
spacing = 3 * np.random.rand(ndim) | ||
itk_img.SetSpacing(spacing) | ||
|
||
# Set random direction | ||
direction = 5 * np.random.rand(ndim, ndim) - 5 | ||
direction = itk.matrix_from_array(direction) | ||
itk_img.SetDirection(direction) | ||
|
||
# Set random origin | ||
origin = 100 * np.random.rand(ndim) - 100 | ||
itk_img.SetOrigin(origin) | ||
|
||
# Warp with ITK | ||
itk_img_resampled = itk_warp(itk_img, ddf.squeeze().numpy()) | ||
|
||
# Compare | ||
print("All close: ", np.allclose(img_resampled, itk_img_resampled, rtol=1e-3, atol=1e-3)) | ||
diff = img_resampled - itk_img_resampled | ||
print(diff.min(), diff.max()) | ||
|
||
|
||
def test_real_data(filepath): | ||
print("\nTEST: Real data with random deformation field") | ||
# Read image | ||
image = itk.imread(filepath, itk.F) | ||
image[:] = remove_border(image) | ||
ndim = image.ndim | ||
|
||
# Random ddf | ||
ddf = 10 * torch.rand((1, ndim, *image.shape), dtype=torch.float32) - 10 | ||
|
||
# Warp with MONAI | ||
image_tensor = torch.tensor(itk.GetArrayFromImage(image), dtype=torch.float32).unsqueeze(0).unsqueeze(0) | ||
img_resampled = monai_wrap(image_tensor, ddf) | ||
|
||
# Warp with ITK | ||
itk_img_resampled = itk_warp(image, ddf.squeeze().numpy()) | ||
|
||
# Compare | ||
print("All close: ", np.allclose(img_resampled, itk_img_resampled)) | ||
diff = img_resampled - itk_img_resampled | ||
print(diff.min(), diff.max()) | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import pathlib | ||
import subprocess | ||
# import sys | ||
|
||
TEST_DATA_DIR = pathlib.Path(__file__).parent.parent / "test_files" | ||
|
||
def download_test_data(): | ||
subprocess.run( | ||
[ | ||
"girder-client", | ||
"--api-url", | ||
"https://data.kitware.com/api/v1", | ||
"localsync", | ||
"62a0efe5bddec9d0c4175c1f", | ||
str(TEST_DATA_DIR), | ||
], | ||
#stdout=sys.stdout, | ||
) |
Binary file not shown.
Binary file not shown.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo wrap -> warp ?