diff --git a/ants/registration/__init__.py b/ants/registration/__init__.py index e1ef0012..b44ef04c 100644 --- a/ants/registration/__init__.py +++ b/ants/registration/__init__.py @@ -11,5 +11,5 @@ from .integrate_velocity_field import integrate_velocity_field from .invert_displacement_field import invert_displacement_field from .landmark_transforms import fit_transform_to_paired_points, fit_time_varying_transform_to_point_sets -from .registration import registration, motion_correction -from .simulate_displacement_field import simulate_displacement_field \ No newline at end of file +from .registration import registration, motion_correction, label_image_registration +from .simulate_displacement_field import simulate_displacement_field diff --git a/ants/registration/registration.py b/ants/registration/registration.py index 9ee0fece..876997ca 100644 --- a/ants/registration/registration.py +++ b/ants/registration/registration.py @@ -2,9 +2,9 @@ ANTsPy Registration """ __all__ = ["registration", - "motion_correction"] + "motion_correction", + "label_image_registration"] -import os import numpy as np from tempfile import mktemp import glob @@ -1565,3 +1565,325 @@ def motion_correction( "motion_parameters": motion_parameters, "FD": FD, } + +def label_image_registration(fixed_label_images, + moving_label_images, + fixed_intensity_images=None, + moving_intensity_images=None, + fixed_mask=None, + moving_mask=None, + type_of_linear_transform='affine', + type_of_transform='antsRegistrationSyNQuick[so]', + label_image_weighting=1.0, + output_prefix='', + random_seed=None, + verbose=False): + + """ + Perform pairwise registration using fixed and moving sets of label + images (and, optionally, sets of corresponding intensity images). + + Arguments + --------- + fixed_label_images : single or list of ANTsImage + A single (or set of) fixed label image(s). + + moving_label_images : single or list of ANTsImage + A single (or set of) moving label image(s). + + fixed_intensity_images : single or list of ANTsImage + Optional---a single (or set of) fixed intensity image(s). + + moving_intensity_images : single or list of ANTsImage + Optional---a single (or set of) moving intensity image(s). + + fixed_mask : ANTsImage + Defines region for similarity metric calculation in the space + of the fixed image. + + moving_mask : ANTsImage + Defines region for similarity metric calculation in the space + of the moving image. + + type_of_linear_transform : string + Use label images with the centers of mass to a calculate linear + transform of type 'rigid', 'similarity', or 'affine'. + + type_of_transform : string + Only works with deformable-only transforms, specifically the family + of antsRegistrationSyN*[so] or antsRegistrationSyN*[bo] transforms. + See 'type_of_transform' in ants.registration. + + label_image_weighting : float or list of floats + Relative weighting for the label images. + + output_prefix : string + Define the output prefix for the filenames of the output transform + files. + + verbose : boolean + Print progress to the screen. + + Returns + ------- + Set of transforms definining the mapping to/from the fixed image domain + to the moving image domain. + + Example + ------- + >>> + >>> + """ + + # Perform validation check on the input + + if isinstance(fixed_label_images, ants.ANTsImage): + fixed_label_images = [ants.image_clone(fixed_label_images)] + if isinstance(moving_label_images, ants.ANTsImage): + moving_label_images = [ants.image_clone(moving_label_images)] + + if len(fixed_label_images) != len(moving_label_images): + raise ValueError("The number of fixed and moving label images do not match.") + + if fixed_intensity_images is not None or moving_intensity_images is not None: + if isinstance(fixed_intensity_images, ants.ANTsImage): + fixed_intensity_images = [ants.image_clone(fixed_intensity_images)] + if isinstance(moving_intensity_images, ants.ANTsImage): + moving_intensity_images = [ants.image_clone(moving_intensity_images)] + if len(fixed_intensity_images) != len(moving_intensity_images): + raise ValueError("The number of fixed and moving intensity images do not match.") + + label_image_weights = list() + if isinstance(label_image_weighting, (int, float)): + label_image_weights = [label_image_weighting] * len(fixed_label_images) + else: + label_image_weights = tuple(label_image_weighting) + if len(fixed_label_images) != len(label_image_weights): + raise ValueError("The length of label_image_weights must" + + "match the number of label image pairs.") + + image_dimension = fixed_label_images[0].dimension + + if output_prefix == "" or output_prefix is None or len(output_prefix) == 0: + output_prefix = mktemp() + + allowable_linear_transforms = ['rigid', 'similarity', 'affine'] + if not type_of_linear_transform in allowable_linear_transforms: + raise ValueError("Unrecognized linear transform.") + + do_deformable = False + if type_of_transform is not None or len(type_of_transform) > 0: + do_deformable = True + + common_label_ids = list() + total_number_of_labels = 0 + for i in range(len(fixed_label_images)): + fixed_label_geoms = ants.label_geometry_measures(fixed_label_images[i]) + fixed_label_ids = np.array(fixed_label_geoms['Label']) + moving_label_geoms = ants.label_geometry_measures(moving_label_images[i]) + moving_label_ids = np.array(moving_label_geoms['Label']) + common_label_ids.append(np.intersect1d(moving_label_ids, fixed_label_ids)) + total_number_of_labels = len(common_label_ids[i]) + if verbose: + print("Common label ids for image pair ", str(i), ": ", common_label_ids[i]) + if len(common_label_ids) == 0: + raise ValueError("No common labels for image pair " + str(i)) + + if verbose: + print("Total number of labels: " + str(total_number_of_labels)) + + ############################## + # + # Linear transform + # + ############################## + + linear_xfrm = None + if type_of_linear_transform is not None: + + if verbose: + print("\n\nComputing linear transform.\n") + + if total_number_of_labels < 3: + raise ValueError(" Number of labels must be >= 3.") + + fixed_centers_of_mass = np.zeros((total_number_of_labels, image_dimension)) + moving_centers_of_mass = np.zeros((total_number_of_labels, image_dimension)) + deformable_multivariate_extras = list() + + count = 0 + for i in range(len(common_label_ids)): + for j in range(len(common_label_ids[i])): + label = common_label_ids[i][j] + if verbose: + print(" Finding center of mass for label " + str(label)) + fixed_single_label_image = ants.threshold_image(fixed_label_images[i], label, label, 1, 0) + fixed_centers_of_mass[count, :] = ants.get_center_of_mass(fixed_single_label_image) + moving_single_label_image = ants.threshold_image(moving_label_images[i], label, label, 1, 0) + moving_centers_of_mass[count, :] = ants.get_center_of_mass(moving_single_label_image) + count += 1 + if do_deformable: + deformable_multivariate_extras.append(["MSQ", fixed_single_label_image, + moving_single_label_image, label_image_weighting, 0]) + + linear_xfrm = ants.fit_transform_to_paired_points(moving_centers_of_mass, + fixed_centers_of_mass, + transform_type=type_of_linear_transform, + verbose=verbose) + + linear_xfrm_file = output_prefix + "0GenericAffine.mat" + ants.write_transform(linear_xfrm, linear_xfrm_file) + + ############################## + # + # Deformable transform + # + ############################## + + if do_deformable: + + if verbose: + print("\n\nComputing deformable transform using images.\n") + + do_quick = False + do_repro = False + + if "Quick" in type_of_transform: + do_quick = True + elif "Repro" in type_of_transform: + do_repro = True + random_seed = str(1) + + intensity_metric_parameter = None + spline_distance = 26 + if "[" in type_of_transform and "]" in type_of_transform: + subtype_of_transform = type_of_transform.split("[")[1].split("]")[0] + if not ('bo' in subtype_of_transform or 'so' in subtype_of_transform): + raise ValueError("See only 'so' or 'bo' transforms are available.") + if "," in subtype_of_transform: + subtype_of_transform_args = subtype_of_transform.split(",") + subtype_of_transform = subtype_of_transform_args[0] + intensity_metric_parameter = subtype_of_transform_args[1] + if len(subtype_of_transform_args) > 2: + spline_distance = subtype_of_transform_args[2] + + syn_stage = list() + + intensity_metric = None + if fixed_intensity_images is not None and len(fixed_intensity_images) > 0: + if do_quick: + intensity_metric = "MI" + if intensity_metric_parameter is None: + intensity_metric_parameter = 32 + if not do_quick or do_repro: + intensity_metric = "CC" + if intensity_metric_parameter is None: + intensity_metric_parameter = 2 + for i in range(1, len(fixed_intensity_images)): + syn_stage.append("--metric") + metric_string = "%s[%s,%s,%s,%s]" % ( + intensity_metric, + get_pointer_string(fixed_intensity_images[i]), + get_pointer_string(moving_intensity_images[i]), + 1.0, intensity_metric_parameter) + syn_stage.append(metric_string) + + for kk in range(len(deformable_multivariate_extras)): + syn_stage.append("--metric") + metricString = "%s[%s,%s,%s,%s]" % ( + "MSQ", + get_pointer_string(deformable_multivariate_extras[kk][1]), + get_pointer_string(deformable_multivariate_extras[kk][2]), + 1.0, 0.0) + syn_stage.append(metricString) + + syn_shrink_factors = "8x4x2x1" + syn_smoothing_sigmas = "3x2x1x0vox" + + if do_quick: + syn_convergence = "[100x70x50x0,1e-6,10]" + else: + syn_convergence = "[100x70x50x20,1e-6,10]" + + syn_stage.append("--convergence") + syn_stage.append(syn_convergence) + syn_stage.append("--shrink-factors") + syn_stage.append(syn_shrink_factors) + syn_stage.append("--smoothing-sigmas") + syn_stage.append(syn_smoothing_sigmas) + + if 'b' in subtype_of_transform: + syn_stage.insert(0, "BSplineSyN[0.1," + str(spline_distance) + ",0,3]") + else: + syn_stage.insert(0, "SyN[0.1,3,0]") + syn_stage.insert(0, "--transform") + + args = ["-d", str(image_dimension), + "-r", linear_xfrm_file, + "-o", output_prefix] + args.append(syn_stage) + + fixed_mask_string = 'NA' + if fixed_mask is not None: + fixed_mask_binary = fixed_mask != 0 + fixed_mask_string = get_pointer_string(fixed_mask_binary) + + moving_mask_string = 'NA' + if moving_mask is not None: + moving_mask_binary = moving_mask != 0 + moving_mask_string = get_pointer_string(moving_mask_binary) + + mask_option = "[%s,%s]" % (fixed_mask_string, moving_mask_string) + + args.append("-x") + args.append(mask_option) + + args = list(itertools.chain.from_iterable( + itertools.repeat(x, 1) + if isinstance(x, str) + else x for x in args)) + + args.append("--float") + args.append("1") + + if random_seed is not None: + args.append("--random-seed") + args.append(random_seed) + + if verbose: + args.append("-v") + args.append("1") + + processed_args = process_arguments(args) + if verbose: + print("antsRegistration " + ' '.join(processed_args)) + + libfn = get_lib_fn("antsRegistration") + deformable_registration_exit_error = libfn(processed_args) + + if deformable_registration_exit_error != 0: + raise RuntimeError(f"Registration failed with error code {deformable_registration_exit_error}") + + all_xfrms = sorted(set(glob.glob(output_prefix + "*" + "[0-9]*"))) + + find_inverse_warps = np.where([re.search("[0-9]InverseWarp.nii.gz", ff) for ff in all_xfrms])[0] + find_forward_warps = np.where([re.search("[0-9]Warp.nii.gz", ff) for ff in all_xfrms])[0] + + if len(find_inverse_warps) > 0: + fwdtransforms = list(reversed([ff for idx, ff in enumerate(all_xfrms) if idx != find_inverse_warps[0]])) + invtransforms = [ff for idx, ff in enumerate(all_xfrms) if idx != find_forward_warps[0]] + else: + fwdtransforms = list(reversed(all_xfrms)) + invtransforms = all_xfrms + + if verbose: + print("\n\nResulting transforms:") + print(" fwdtransforms: ", fwdtransforms) + print(" invtransforms: ", invtransforms) + + return { + "fwdtransforms": fwdtransforms, + "invtransforms": invtransforms, + } + + diff --git a/tests/test_registration.py b/tests/test_registration.py index bc50a9cb..e222a3ab 100644 --- a/tests/test_registration.py +++ b/tests/test_registration.py @@ -161,6 +161,10 @@ def setUp(self): "QuickRigid", "DenseRigid", "BOLDRigid", + "antsRegistrationSyNQuick[b,32,26]", + "antsRegistrationSyNQuick[s]", + "antsRegistrationSyNRepro[s]", + "antsRegistrationSyN[s]" } def tearDown(self): @@ -451,5 +455,17 @@ def test_motion_correction(self): fi = ants.image_read(ants.get_ants_data('ch2')) mytx = ants.motion_correction( fi ) + def test_label_image_registration(self): + fi = ants.image_read(ants.get_ants_data('r16')) + mi = ants.image_read(ants.get_ants_data('r64')) + fi = ants.resample_image(fi, (60,60), 1, 0) + mi = ants.resample_image(mi, (60,60), 1, 0) + fi_seg = ants.threshold_image(fi, "Kmeans", 3)-1 + mi_seg = ants.threshold_image(mi, "Kmeans", 3)-1 + mytx = ants.label_image_registration([fi_seg], + [mi_seg], + fixed_intensity_images=fi, + moving_intensity_images=mi) + if __name__ == "__main__": run_tests()