Skip to content

Commit

Permalink
Merge pull request #477 from ANTsX/InitVF
Browse files Browse the repository at this point in the history
ENH:  Add initial velocity field capabilities.
  • Loading branch information
ntustison authored Jun 17, 2023
2 parents 27dc5a2 + ad103e3 commit f9c799b
Showing 1 changed file with 33 additions and 17 deletions.
50 changes: 33 additions & 17 deletions ants/registration/landmark_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def fit_transform_to_paired_points(moving_points,
sigma=0.0,
convergence_threshold=0.0,
number_of_integration_points=2,
number_of_integration_steps=100,
rasterize_points=False,
verbose=False
):
Expand Down Expand Up @@ -106,6 +107,9 @@ def fit_transform_to_paired_points(moving_points,
number_of_integration_points : integer
Time-varying velocity field parameter.
number_of_integration_steps : scalar
Number of steps used for integrating the velocity field.
rasterize_points : boolean
Use nearest neighbor rasterization of points for estimating the update
field (potential speed-up). Default = False.
Expand Down Expand Up @@ -455,15 +459,15 @@ def create_zero_velocity_field(domain_image, number_of_time_points=2):
t = n / (number_of_integration_points - 1.0)

if n > 0:
integrated_forward_field = integrate_velocity_field(velocity_field, 0.0, t, 100)
integrated_forward_field = integrate_velocity_field(velocity_field, 0.0, t, number_of_integration_steps)
integrated_forward_field_xfrm = txio.transform_from_displacement_field(integrated_forward_field)
for j in range(updated_fixed_points.shape[0]):
updated_fixed_points[j,:] = integrated_forward_field_xfrm.apply_to_point(tuple(fixed_points[j,:]))
else:
updated_fixed_points[:] = fixed_points

if n < number_of_integration_points - 1:
integrated_inverse_field = integrate_velocity_field(velocity_field, 1.0, t, 100)
integrated_inverse_field = integrate_velocity_field(velocity_field, 1.0, t, number_of_integration_steps)
integrated_inverse_field_xfrm = txio.transform_from_displacement_field(integrated_inverse_field)
for j in range(updated_moving_points.shape[0]):
updated_moving_points[j,:] = integrated_inverse_field_xfrm.apply_to_point(tuple(moving_points[j,:]))
Expand Down Expand Up @@ -515,8 +519,8 @@ def create_zero_velocity_field(domain_image, number_of_time_points=2):
if not convergence_value is None and convergence_value < convergence_threshold:
break

forward_xfrm = txio.transform_from_displacement_field(integrate_velocity_field(velocity_field, 0.0, 1.0, 100))
inverse_xfrm = txio.transform_from_displacement_field(integrate_velocity_field(velocity_field, 1.0, 0.0, 100))
forward_xfrm = txio.transform_from_displacement_field(integrate_velocity_field(velocity_field, 0.0, 1.0, number_of_integration_steps))
inverse_xfrm = txio.transform_from_displacement_field(integrate_velocity_field(velocity_field, 1.0, 0.0, number_of_integration_steps))

if verbose:
end_total_time = time.time()
Expand All @@ -534,6 +538,7 @@ def create_zero_velocity_field(domain_image, number_of_time_points=2):

def fit_time_varying_transform_to_point_sets(point_sets,
time_points=None,
initial_velocity_field=None,
number_of_integration_points=None,
domain_image=None,
number_of_fitting_levels=4,
Expand All @@ -542,6 +547,7 @@ def fit_time_varying_transform_to_point_sets(point_sets,
displacement_weights=None,
number_of_compositions=10,
composition_step_size=0.5,
number_of_integration_steps=100,
sigma=0.0,
convergence_threshold=0.0,
rasterize_points=False,
Expand All @@ -563,6 +569,10 @@ def fit_time_varying_transform_to_point_sets(point_sets,
Set of scalar values, one for each point-set, designating its time position in the velocity
flow. If not set, it defaults to equal spacing between 0 and 1.
initial_velocity_field : initial ANTs velocity field
Optional velocity field for initializing optimization. Overrides the number of integration
points.
number_of_integration_points : integer
Time-varying velocity field parameter. Needs to be equal to or greater than the number of
point sets. If not specified, it defaults to the number of point sets.
Expand Down Expand Up @@ -591,6 +601,9 @@ def fit_time_varying_transform_to_point_sets(point_sets,
composition_step_size : scalar
Scalar multiplication factor of the weighting of the update field.
number_of_integration_steps : scalar
Number of steps used for integrating the velocity field.
sigma : scalar
Gaussian smoothing standard deviation of the update field (in mm).
Expand Down Expand Up @@ -640,12 +653,6 @@ def create_zero_velocity_field(domain_image, number_of_time_points=2):
if np.any(time_points < 0.0) or np.any(time_points > 1.0):
raise ValueError("time point values should be between 0 and 1.")

if number_of_integration_points is None:
number_of_integration_points = len(time_points)

if number_of_integration_points < number_of_point_sets:
raise ValueError("The number of integration points should be at least as great as the number of point sets.")

if number_of_point_sets < 3:
raise ValueError("Expecting three or greater point sets.")

Expand All @@ -666,7 +673,16 @@ def create_zero_velocity_field(domain_image, number_of_time_points=2):
updated_fixed_points = np.zeros(point_sets[0].shape)
updated_moving_points = np.zeros(point_sets[0].shape)

velocity_field = create_zero_velocity_field(domain_image, number_of_integration_points)
velocity_field = None
if initial_velocity_field is None:
velocity_field = create_zero_velocity_field(domain_image, number_of_integration_points)
if number_of_integration_points is None:
number_of_integration_points = len(time_points)
if number_of_integration_points < number_of_point_sets:
raise ValueError("The number of integration points should be at least as great as the number of point sets.")
else:
velocity_field = iio2.image_clone(initial_velocity_field)
number_of_integration_points = initial_velocity_field.shape[-1]
velocity_field_array = velocity_field.numpy()

last_update_derivative_field = create_zero_velocity_field(domain_image, number_of_integration_points)
Expand All @@ -693,7 +709,7 @@ def create_zero_velocity_field(domain_image, number_of_time_points=2):

if n > 0 and n < number_of_integration_points - 1 and time_points[t_index-1] == t:
updated_fixed_points[:] = point_sets[t_index-1]
integrated_inverse_field = integrate_velocity_field(velocity_field, time_points[t_index], t, 100)
integrated_inverse_field = integrate_velocity_field(velocity_field, time_points[t_index], t, number_of_integration_steps)
integrated_inverse_field_xfrm = txio.transform_from_displacement_field(integrated_inverse_field)
for j in range(updated_moving_points.shape[0]):
updated_moving_points[j,:] = integrated_inverse_field_xfrm.apply_to_point(tuple(point_sets[t_index][j,:]))
Expand All @@ -714,7 +730,7 @@ def create_zero_velocity_field(domain_image, number_of_time_points=2):
)

updated_moving_points[:] = point_sets[t_index-1]
integrated_forward_field = integrate_velocity_field(velocity_field, time_points[t_index-2], t, 100)
integrated_forward_field = integrate_velocity_field(velocity_field, time_points[t_index-2], t, number_of_integration_steps)
integrated_forward_field_xfrm = txio.transform_from_displacement_field(integrated_forward_field)
for j in range(updated_fixed_points.shape[0]):
updated_fixed_points[j,:] = integrated_forward_field_xfrm.apply_to_point(tuple(point_sets[t_index-2][j,:]))
Expand All @@ -741,15 +757,15 @@ def create_zero_velocity_field(domain_image, number_of_time_points=2):
if t == 0.0 and time_points[t_index-1] == 0.0:
updated_fixed_points[:] = point_sets[0]
else:
integrated_forward_field = integrate_velocity_field(velocity_field, time_points[t_index-1], t, 100)
integrated_forward_field = integrate_velocity_field(velocity_field, time_points[t_index-1], t, number_of_integration_steps)
integrated_forward_field_xfrm = txio.transform_from_displacement_field(integrated_forward_field)
for j in range(updated_fixed_points.shape[0]):
updated_fixed_points[j,:] = integrated_forward_field_xfrm.apply_to_point(tuple(point_sets[t_index-1][j,:]))

if t == 1.0 and time_points[t_index] == 1.0:
updated_moving_points[:] = point_sets[-1]
else:
integrated_inverse_field = integrate_velocity_field(velocity_field, time_points[t_index], t, 100)
integrated_inverse_field = integrate_velocity_field(velocity_field, time_points[t_index], t, number_of_integration_steps)
integrated_inverse_field_xfrm = txio.transform_from_displacement_field(integrated_inverse_field)
for j in range(updated_moving_points.shape[0]):
updated_moving_points[j,:] = integrated_inverse_field_xfrm.apply_to_point(tuple(point_sets[t_index][j,:]))
Expand Down Expand Up @@ -801,8 +817,8 @@ def create_zero_velocity_field(domain_image, number_of_time_points=2):
if not convergence_value is None and convergence_value < convergence_threshold:
break

forward_xfrm = txio.transform_from_displacement_field(integrate_velocity_field(velocity_field, 0.0, 1.0, 100))
inverse_xfrm = txio.transform_from_displacement_field(integrate_velocity_field(velocity_field, 1.0, 0.0, 100))
forward_xfrm = txio.transform_from_displacement_field(integrate_velocity_field(velocity_field, 0.0, 1.0, number_of_integration_steps))
inverse_xfrm = txio.transform_from_displacement_field(integrate_velocity_field(velocity_field, 1.0, 0.0, number_of_integration_steps))

if verbose:
end_total_time = time.time()
Expand Down

0 comments on commit f9c799b

Please sign in to comment.