Skip to content

Commit

Permalink
Merge pull request #743 from ANTsX/LabelReg2
Browse files Browse the repository at this point in the history
ENH:  Add label iamge registration and tests.
  • Loading branch information
ntustison authored Nov 23, 2024
2 parents 96ba4f7 + 76857fc commit 6b220c5
Show file tree
Hide file tree
Showing 3 changed files with 342 additions and 4 deletions.
4 changes: 2 additions & 2 deletions ants/registration/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,5 @@
from .integrate_velocity_field import integrate_velocity_field
from .invert_displacement_field import invert_displacement_field
from .landmark_transforms import fit_transform_to_paired_points, fit_time_varying_transform_to_point_sets
from .registration import registration, motion_correction
from .simulate_displacement_field import simulate_displacement_field
from .registration import registration, motion_correction, label_image_registration
from .simulate_displacement_field import simulate_displacement_field
326 changes: 324 additions & 2 deletions ants/registration/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
ANTsPy Registration
"""
__all__ = ["registration",
"motion_correction"]
"motion_correction",
"label_image_registration"]

import os
import numpy as np
from tempfile import mktemp
import glob
Expand Down Expand Up @@ -1565,3 +1565,325 @@ def motion_correction(
"motion_parameters": motion_parameters,
"FD": FD,
}

def label_image_registration(fixed_label_images,
moving_label_images,
fixed_intensity_images=None,
moving_intensity_images=None,
fixed_mask=None,
moving_mask=None,
type_of_linear_transform='affine',
type_of_transform='antsRegistrationSyNQuick[so]',
label_image_weighting=1.0,
output_prefix='',
random_seed=None,
verbose=False):

"""
Perform pairwise registration using fixed and moving sets of label
images (and, optionally, sets of corresponding intensity images).
Arguments
---------
fixed_label_images : single or list of ANTsImage
A single (or set of) fixed label image(s).
moving_label_images : single or list of ANTsImage
A single (or set of) moving label image(s).
fixed_intensity_images : single or list of ANTsImage
Optional---a single (or set of) fixed intensity image(s).
moving_intensity_images : single or list of ANTsImage
Optional---a single (or set of) moving intensity image(s).
fixed_mask : ANTsImage
Defines region for similarity metric calculation in the space
of the fixed image.
moving_mask : ANTsImage
Defines region for similarity metric calculation in the space
of the moving image.
type_of_linear_transform : string
Use label images with the centers of mass to a calculate linear
transform of type 'rigid', 'similarity', or 'affine'.
type_of_transform : string
Only works with deformable-only transforms, specifically the family
of antsRegistrationSyN*[so] or antsRegistrationSyN*[bo] transforms.
See 'type_of_transform' in ants.registration.
label_image_weighting : float or list of floats
Relative weighting for the label images.
output_prefix : string
Define the output prefix for the filenames of the output transform
files.
verbose : boolean
Print progress to the screen.
Returns
-------
Set of transforms definining the mapping to/from the fixed image domain
to the moving image domain.
Example
-------
>>>
>>>
"""

# Perform validation check on the input

if isinstance(fixed_label_images, ants.ANTsImage):
fixed_label_images = [ants.image_clone(fixed_label_images)]
if isinstance(moving_label_images, ants.ANTsImage):
moving_label_images = [ants.image_clone(moving_label_images)]

if len(fixed_label_images) != len(moving_label_images):
raise ValueError("The number of fixed and moving label images do not match.")

if fixed_intensity_images is not None or moving_intensity_images is not None:
if isinstance(fixed_intensity_images, ants.ANTsImage):
fixed_intensity_images = [ants.image_clone(fixed_intensity_images)]
if isinstance(moving_intensity_images, ants.ANTsImage):
moving_intensity_images = [ants.image_clone(moving_intensity_images)]
if len(fixed_intensity_images) != len(moving_intensity_images):
raise ValueError("The number of fixed and moving intensity images do not match.")

label_image_weights = list()
if isinstance(label_image_weighting, (int, float)):
label_image_weights = [label_image_weighting] * len(fixed_label_images)
else:
label_image_weights = tuple(label_image_weighting)
if len(fixed_label_images) != len(label_image_weights):
raise ValueError("The length of label_image_weights must" +
"match the number of label image pairs.")

image_dimension = fixed_label_images[0].dimension

if output_prefix == "" or output_prefix is None or len(output_prefix) == 0:
output_prefix = mktemp()

allowable_linear_transforms = ['rigid', 'similarity', 'affine']
if not type_of_linear_transform in allowable_linear_transforms:
raise ValueError("Unrecognized linear transform.")

do_deformable = False
if type_of_transform is not None or len(type_of_transform) > 0:
do_deformable = True

common_label_ids = list()
total_number_of_labels = 0
for i in range(len(fixed_label_images)):
fixed_label_geoms = ants.label_geometry_measures(fixed_label_images[i])
fixed_label_ids = np.array(fixed_label_geoms['Label'])
moving_label_geoms = ants.label_geometry_measures(moving_label_images[i])
moving_label_ids = np.array(moving_label_geoms['Label'])
common_label_ids.append(np.intersect1d(moving_label_ids, fixed_label_ids))
total_number_of_labels = len(common_label_ids[i])
if verbose:
print("Common label ids for image pair ", str(i), ": ", common_label_ids[i])
if len(common_label_ids) == 0:
raise ValueError("No common labels for image pair " + str(i))

if verbose:
print("Total number of labels: " + str(total_number_of_labels))

##############################
#
# Linear transform
#
##############################

linear_xfrm = None
if type_of_linear_transform is not None:

if verbose:
print("\n\nComputing linear transform.\n")

if total_number_of_labels < 3:
raise ValueError(" Number of labels must be >= 3.")

fixed_centers_of_mass = np.zeros((total_number_of_labels, image_dimension))
moving_centers_of_mass = np.zeros((total_number_of_labels, image_dimension))
deformable_multivariate_extras = list()

count = 0
for i in range(len(common_label_ids)):
for j in range(len(common_label_ids[i])):
label = common_label_ids[i][j]
if verbose:
print(" Finding center of mass for label " + str(label))
fixed_single_label_image = ants.threshold_image(fixed_label_images[i], label, label, 1, 0)
fixed_centers_of_mass[count, :] = ants.get_center_of_mass(fixed_single_label_image)
moving_single_label_image = ants.threshold_image(moving_label_images[i], label, label, 1, 0)
moving_centers_of_mass[count, :] = ants.get_center_of_mass(moving_single_label_image)
count += 1
if do_deformable:
deformable_multivariate_extras.append(["MSQ", fixed_single_label_image,
moving_single_label_image, label_image_weighting, 0])

linear_xfrm = ants.fit_transform_to_paired_points(moving_centers_of_mass,
fixed_centers_of_mass,
transform_type=type_of_linear_transform,
verbose=verbose)

linear_xfrm_file = output_prefix + "0GenericAffine.mat"
ants.write_transform(linear_xfrm, linear_xfrm_file)

##############################
#
# Deformable transform
#
##############################

if do_deformable:

if verbose:
print("\n\nComputing deformable transform using images.\n")

do_quick = False
do_repro = False

if "Quick" in type_of_transform:
do_quick = True
elif "Repro" in type_of_transform:
do_repro = True
random_seed = str(1)

intensity_metric_parameter = None
spline_distance = 26
if "[" in type_of_transform and "]" in type_of_transform:
subtype_of_transform = type_of_transform.split("[")[1].split("]")[0]
if not ('bo' in subtype_of_transform or 'so' in subtype_of_transform):
raise ValueError("See only 'so' or 'bo' transforms are available.")
if "," in subtype_of_transform:
subtype_of_transform_args = subtype_of_transform.split(",")
subtype_of_transform = subtype_of_transform_args[0]
intensity_metric_parameter = subtype_of_transform_args[1]
if len(subtype_of_transform_args) > 2:
spline_distance = subtype_of_transform_args[2]

syn_stage = list()

intensity_metric = None
if fixed_intensity_images is not None and len(fixed_intensity_images) > 0:
if do_quick:
intensity_metric = "MI"
if intensity_metric_parameter is None:
intensity_metric_parameter = 32
if not do_quick or do_repro:
intensity_metric = "CC"
if intensity_metric_parameter is None:
intensity_metric_parameter = 2
for i in range(1, len(fixed_intensity_images)):
syn_stage.append("--metric")
metric_string = "%s[%s,%s,%s,%s]" % (
intensity_metric,
get_pointer_string(fixed_intensity_images[i]),
get_pointer_string(moving_intensity_images[i]),
1.0, intensity_metric_parameter)
syn_stage.append(metric_string)

for kk in range(len(deformable_multivariate_extras)):
syn_stage.append("--metric")
metricString = "%s[%s,%s,%s,%s]" % (
"MSQ",
get_pointer_string(deformable_multivariate_extras[kk][1]),
get_pointer_string(deformable_multivariate_extras[kk][2]),
1.0, 0.0)
syn_stage.append(metricString)

syn_shrink_factors = "8x4x2x1"
syn_smoothing_sigmas = "3x2x1x0vox"

if do_quick:
syn_convergence = "[100x70x50x0,1e-6,10]"
else:
syn_convergence = "[100x70x50x20,1e-6,10]"

syn_stage.append("--convergence")
syn_stage.append(syn_convergence)
syn_stage.append("--shrink-factors")
syn_stage.append(syn_shrink_factors)
syn_stage.append("--smoothing-sigmas")
syn_stage.append(syn_smoothing_sigmas)

if 'b' in subtype_of_transform:
syn_stage.insert(0, "BSplineSyN[0.1," + str(spline_distance) + ",0,3]")
else:
syn_stage.insert(0, "SyN[0.1,3,0]")
syn_stage.insert(0, "--transform")

args = ["-d", str(image_dimension),
"-r", linear_xfrm_file,
"-o", output_prefix]
args.append(syn_stage)

fixed_mask_string = 'NA'
if fixed_mask is not None:
fixed_mask_binary = fixed_mask != 0
fixed_mask_string = get_pointer_string(fixed_mask_binary)

moving_mask_string = 'NA'
if moving_mask is not None:
moving_mask_binary = moving_mask != 0
moving_mask_string = get_pointer_string(moving_mask_binary)

mask_option = "[%s,%s]" % (fixed_mask_string, moving_mask_string)

args.append("-x")
args.append(mask_option)

args = list(itertools.chain.from_iterable(
itertools.repeat(x, 1)
if isinstance(x, str)
else x for x in args))

args.append("--float")
args.append("1")

if random_seed is not None:
args.append("--random-seed")
args.append(random_seed)

if verbose:
args.append("-v")
args.append("1")

processed_args = process_arguments(args)
if verbose:
print("antsRegistration " + ' '.join(processed_args))

libfn = get_lib_fn("antsRegistration")
deformable_registration_exit_error = libfn(processed_args)

if deformable_registration_exit_error != 0:
raise RuntimeError(f"Registration failed with error code {deformable_registration_exit_error}")

all_xfrms = sorted(set(glob.glob(output_prefix + "*" + "[0-9]*")))

find_inverse_warps = np.where([re.search("[0-9]InverseWarp.nii.gz", ff) for ff in all_xfrms])[0]
find_forward_warps = np.where([re.search("[0-9]Warp.nii.gz", ff) for ff in all_xfrms])[0]

if len(find_inverse_warps) > 0:
fwdtransforms = list(reversed([ff for idx, ff in enumerate(all_xfrms) if idx != find_inverse_warps[0]]))
invtransforms = [ff for idx, ff in enumerate(all_xfrms) if idx != find_forward_warps[0]]
else:
fwdtransforms = list(reversed(all_xfrms))
invtransforms = all_xfrms

if verbose:
print("\n\nResulting transforms:")
print(" fwdtransforms: ", fwdtransforms)
print(" invtransforms: ", invtransforms)

return {
"fwdtransforms": fwdtransforms,
"invtransforms": invtransforms,
}


16 changes: 16 additions & 0 deletions tests/test_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,10 @@ def setUp(self):
"QuickRigid",
"DenseRigid",
"BOLDRigid",
"antsRegistrationSyNQuick[b,32,26]",
"antsRegistrationSyNQuick[s]",
"antsRegistrationSyNRepro[s]",
"antsRegistrationSyN[s]"
}

def tearDown(self):
Expand Down Expand Up @@ -451,5 +455,17 @@ def test_motion_correction(self):
fi = ants.image_read(ants.get_ants_data('ch2'))
mytx = ants.motion_correction( fi )

def test_label_image_registration(self):
fi = ants.image_read(ants.get_ants_data('r16'))
mi = ants.image_read(ants.get_ants_data('r64'))
fi = ants.resample_image(fi, (60,60), 1, 0)
mi = ants.resample_image(mi, (60,60), 1, 0)
fi_seg = ants.threshold_image(fi, "Kmeans", 3)-1
mi_seg = ants.threshold_image(mi, "Kmeans", 3)-1
mytx = ants.label_image_registration([fi_seg],
[mi_seg],
fixed_intensity_images=fi,
moving_intensity_images=mi)

if __name__ == "__main__":
run_tests()

0 comments on commit 6b220c5

Please sign in to comment.