Skip to content

Commit

Permalink
ENH: Expose number of int. steps.
Browse files Browse the repository at this point in the history
  • Loading branch information
ntustison committed Jun 17, 2023
1 parent 447666b commit ad103e3
Showing 1 changed file with 18 additions and 10 deletions.
28 changes: 18 additions & 10 deletions ants/registration/landmark_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def fit_transform_to_paired_points(moving_points,
sigma=0.0,
convergence_threshold=0.0,
number_of_integration_points=2,
number_of_integration_steps=100,
rasterize_points=False,
verbose=False
):
Expand Down Expand Up @@ -106,6 +107,9 @@ def fit_transform_to_paired_points(moving_points,
number_of_integration_points : integer
Time-varying velocity field parameter.
number_of_integration_steps : scalar
Number of steps used for integrating the velocity field.
rasterize_points : boolean
Use nearest neighbor rasterization of points for estimating the update
field (potential speed-up). Default = False.
Expand Down Expand Up @@ -455,15 +459,15 @@ def create_zero_velocity_field(domain_image, number_of_time_points=2):
t = n / (number_of_integration_points - 1.0)

if n > 0:
integrated_forward_field = integrate_velocity_field(velocity_field, 0.0, t, 100)
integrated_forward_field = integrate_velocity_field(velocity_field, 0.0, t, number_of_integration_steps)
integrated_forward_field_xfrm = txio.transform_from_displacement_field(integrated_forward_field)
for j in range(updated_fixed_points.shape[0]):
updated_fixed_points[j,:] = integrated_forward_field_xfrm.apply_to_point(tuple(fixed_points[j,:]))
else:
updated_fixed_points[:] = fixed_points

if n < number_of_integration_points - 1:
integrated_inverse_field = integrate_velocity_field(velocity_field, 1.0, t, 100)
integrated_inverse_field = integrate_velocity_field(velocity_field, 1.0, t, number_of_integration_steps)
integrated_inverse_field_xfrm = txio.transform_from_displacement_field(integrated_inverse_field)
for j in range(updated_moving_points.shape[0]):
updated_moving_points[j,:] = integrated_inverse_field_xfrm.apply_to_point(tuple(moving_points[j,:]))
Expand Down Expand Up @@ -515,8 +519,8 @@ def create_zero_velocity_field(domain_image, number_of_time_points=2):
if not convergence_value is None and convergence_value < convergence_threshold:
break

forward_xfrm = txio.transform_from_displacement_field(integrate_velocity_field(velocity_field, 0.0, 1.0, 100))
inverse_xfrm = txio.transform_from_displacement_field(integrate_velocity_field(velocity_field, 1.0, 0.0, 100))
forward_xfrm = txio.transform_from_displacement_field(integrate_velocity_field(velocity_field, 0.0, 1.0, number_of_integration_steps))
inverse_xfrm = txio.transform_from_displacement_field(integrate_velocity_field(velocity_field, 1.0, 0.0, number_of_integration_steps))

if verbose:
end_total_time = time.time()
Expand All @@ -543,6 +547,7 @@ def fit_time_varying_transform_to_point_sets(point_sets,
displacement_weights=None,
number_of_compositions=10,
composition_step_size=0.5,
number_of_integration_steps=100,
sigma=0.0,
convergence_threshold=0.0,
rasterize_points=False,
Expand Down Expand Up @@ -596,6 +601,9 @@ def fit_time_varying_transform_to_point_sets(point_sets,
composition_step_size : scalar
Scalar multiplication factor of the weighting of the update field.
number_of_integration_steps : scalar
Number of steps used for integrating the velocity field.
sigma : scalar
Gaussian smoothing standard deviation of the update field (in mm).
Expand Down Expand Up @@ -701,7 +709,7 @@ def create_zero_velocity_field(domain_image, number_of_time_points=2):

if n > 0 and n < number_of_integration_points - 1 and time_points[t_index-1] == t:
updated_fixed_points[:] = point_sets[t_index-1]
integrated_inverse_field = integrate_velocity_field(velocity_field, time_points[t_index], t, 100)
integrated_inverse_field = integrate_velocity_field(velocity_field, time_points[t_index], t, number_of_integration_steps)
integrated_inverse_field_xfrm = txio.transform_from_displacement_field(integrated_inverse_field)
for j in range(updated_moving_points.shape[0]):
updated_moving_points[j,:] = integrated_inverse_field_xfrm.apply_to_point(tuple(point_sets[t_index][j,:]))
Expand All @@ -722,7 +730,7 @@ def create_zero_velocity_field(domain_image, number_of_time_points=2):
)

updated_moving_points[:] = point_sets[t_index-1]
integrated_forward_field = integrate_velocity_field(velocity_field, time_points[t_index-2], t, 100)
integrated_forward_field = integrate_velocity_field(velocity_field, time_points[t_index-2], t, number_of_integration_steps)
integrated_forward_field_xfrm = txio.transform_from_displacement_field(integrated_forward_field)
for j in range(updated_fixed_points.shape[0]):
updated_fixed_points[j,:] = integrated_forward_field_xfrm.apply_to_point(tuple(point_sets[t_index-2][j,:]))
Expand All @@ -749,15 +757,15 @@ def create_zero_velocity_field(domain_image, number_of_time_points=2):
if t == 0.0 and time_points[t_index-1] == 0.0:
updated_fixed_points[:] = point_sets[0]
else:
integrated_forward_field = integrate_velocity_field(velocity_field, time_points[t_index-1], t, 100)
integrated_forward_field = integrate_velocity_field(velocity_field, time_points[t_index-1], t, number_of_integration_steps)
integrated_forward_field_xfrm = txio.transform_from_displacement_field(integrated_forward_field)
for j in range(updated_fixed_points.shape[0]):
updated_fixed_points[j,:] = integrated_forward_field_xfrm.apply_to_point(tuple(point_sets[t_index-1][j,:]))

if t == 1.0 and time_points[t_index] == 1.0:
updated_moving_points[:] = point_sets[-1]
else:
integrated_inverse_field = integrate_velocity_field(velocity_field, time_points[t_index], t, 100)
integrated_inverse_field = integrate_velocity_field(velocity_field, time_points[t_index], t, number_of_integration_steps)
integrated_inverse_field_xfrm = txio.transform_from_displacement_field(integrated_inverse_field)
for j in range(updated_moving_points.shape[0]):
updated_moving_points[j,:] = integrated_inverse_field_xfrm.apply_to_point(tuple(point_sets[t_index][j,:]))
Expand Down Expand Up @@ -809,8 +817,8 @@ def create_zero_velocity_field(domain_image, number_of_time_points=2):
if not convergence_value is None and convergence_value < convergence_threshold:
break

forward_xfrm = txio.transform_from_displacement_field(integrate_velocity_field(velocity_field, 0.0, 1.0, 100))
inverse_xfrm = txio.transform_from_displacement_field(integrate_velocity_field(velocity_field, 1.0, 0.0, 100))
forward_xfrm = txio.transform_from_displacement_field(integrate_velocity_field(velocity_field, 0.0, 1.0, number_of_integration_steps))
inverse_xfrm = txio.transform_from_displacement_field(integrate_velocity_field(velocity_field, 1.0, 0.0, number_of_integration_steps))

if verbose:
end_total_time = time.time()
Expand Down

0 comments on commit ad103e3

Please sign in to comment.