Skip to content

Commit

Permalink
fixed midline alignment
Browse files Browse the repository at this point in the history
  • Loading branch information
niksirbi committed Jan 11, 2024
1 parent 83546a7 commit c73853b
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 61 deletions.
11 changes: 7 additions & 4 deletions brainglobe_template_builder/napari/midline_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
)

from brainglobe_template_builder.preproc import (
align_to_midline,
apply_transform,
get_alignment_transform,
get_midline_points,
)

Expand Down Expand Up @@ -157,18 +158,20 @@ def _on_estimate_button_click(self):

def _on_align_button_click(self):
"""Align image and add the transformed image to the viewer."""
# Get values from dropdowns
image_name = self.select_image_dropdown.currentText()
points_name = self.select_points_dropdown.currentText()
axis = self.select_axis_dropdown.currentText()

# Call align_to_midline function
aligned_image = align_to_midline(
transform = get_alignment_transform(
self.viewer.layers[image_name].data,
self.viewer.layers[points_name].data,
axis=axis,
)

aligned_image = apply_transform(
self.viewer.layers[image_name].data, transform
)

self.viewer.add_image(aligned_image, name="aligned image")

def _on_dropdown_selection_change(self):
Expand Down
162 changes: 105 additions & 57 deletions brainglobe_template_builder/preproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import numpy as np
from scipy.ndimage import affine_transform
from scipy.spatial.transform import Rotation
from skimage import filters, measure, morphology


Expand Down Expand Up @@ -156,43 +155,67 @@ def get_midline_points(mask: np.ndarray):

def _fit_plane_to_points(
points: np.ndarray,
) -> np.ndarray:
) -> tuple[np.ndarray, np.ndarray]:
"""Fit a plane to a set of 3D points.
Parameters
----------
points : np.ndarray
An array of shape (n, 3) containing the points.
An array of shape (n_points, 3) containing the points.
Returns
-------
np.ndarray
The normal vector to the plane, with shape (3,).
centroid : np.ndarray
The centroid of the points.
normal_vector : np.ndarray
A vector normal to the fitted plane.
"""

# Ensure points are 3D
if points.shape[1] != 3:
raise ValueError("Points array must have 3 columns (z, y, x)")

centered_points = points - np.mean(points, axis=0)

# Find the centroid of the points
centroid = np.mean(points, axis=0)
# Use SVD to get the normal vector to the plane
_, _, vh = np.linalg.svd(centered_points)
_, _, vh = np.linalg.svd(points - centroid)
normal_vector = vh[-1]

return normal_vector
return centroid, normal_vector


def _rotation_matrix_from_vectors(vec1: np.ndarray, vec2: np.ndarray):
"""Find the rotation matrix that aligns vec1 to vec2. Implementation
adapted from StackOverflow [1]_.
Parameters
----------
vec1 : np.ndarray
The 3D "source" vector
vec2 : np.ndarray
The 3D "target" vector
Returns
-------
A rotation matrix (3x3) that, when applied to vec1, aligns it with vec2.
References
----------
.. [1] https://stackoverflow.com/questions/45142959
"""
a = (vec1 / np.linalg.norm(vec1)).reshape(3)
b = (vec2 / np.linalg.norm(vec2)).reshape(3)
v = np.cross(a, b)
c = np.dot(a, b)
s = np.linalg.norm(v)
kmat = np.array([[0, -v[2], v[1]], [v[2], 0, -v[0]], [-v[1], v[0], 0]])
rotation_matrix = np.eye(3) + kmat + kmat.dot(kmat) * ((1 - c) / (s**2))
return rotation_matrix

def align_to_midline(

def get_alignment_transform(
image: np.ndarray,
points: np.ndarray,
axis: Literal["x", "y", "z"] = "x",
) -> np.ndarray:
"""Transform image such that the midline of the specified axis is aligned
with the plane fitted to the provided points.
This function first fits a plane to the points, then rigidly transforms
the image such that the fitted plane is aligned with the axis midline.
"""Find the transformation matrix that aligns the plane defined by the
given points to the midline of the specified axis.
Parameters
----------
Expand All @@ -202,12 +225,13 @@ def align_to_midline(
An array of shape (n_points, 3) containing points.
axis : str
Axis to align the midline with. One of 'x', 'y', and 'z'.
Defaults to 'x'.
Defaults to 'x'. The axis order is zyx in napari.
Returns
-------
aligned_image : np.ndarray
A 3D array containing the transformed image.
transform: np.ndarray
A 4x4 rigid transformation matrix (3x3 rotation matrix with a 3x1
translation vector appended to the right).
"""

# Check input
Expand All @@ -219,44 +243,68 @@ def align_to_midline(
raise ValueError("Axis must be one of 'x', 'y', or 'z'")

# Fit a plane to the points
normal_vector = _fit_plane_to_points(points)
centroid, normal_vector = _fit_plane_to_points(points)
# invert the normal vector if it points in the opposite direction of the
# specified axis
axis_index = {"z": 0, "y": 1, "x": 2}[axis] # axis order is zyx in napari
axis_vector = np.zeros(3)
axis_vector[axis_index] = 1
if np.dot(normal_vector, axis_vector) < 0:
normal_vector = -normal_vector

# Find rotation to align the fitted plane (i.e. its normal vector)
# with the specified axis (i.e. the axis unit vector)
rotation_matrix = _rotation_matrix_from_vectors(normal_vector, axis_vector)
# Find offset to bring the centroid to the middle of the specified axis
mid_axis = image.shape[axis_index] // 2
offset = mid_axis - centroid[axis_index]

# Construct the transformation matrix by combining rotation and offset
transform = np.zeros((4, 4))
transform[:3, :3] = rotation_matrix
transform[:3, 3] = offset * axis_vector
return transform


def apply_transform(
data: np.ndarray,
transform: np.ndarray,
) -> np.ndarray:
"""Apply a rigid transformation to an image.
# Compute centroid of the midline points
centroid = np.mean(points, axis=0)
Parameters
----------
data : np.ndarray
A 3D image to transform.
transform : np.ndarray
A 4x4 transformation matrix.
# Translation of the centroid to the origin
translation_to_origin = np.eye(4)
translation_to_origin[:3, 3] = -centroid
Returns
-------
np.ndarray
The transformed data.
Notes
-----
This function inverts the affine and flips the offset when passing the data
to `scipy.ndimage.affine_transform`. This is because the transforms are
given in the 'push' (or 'forward') direction, transforming input to output,
whereas `scipy.ndimage.affine_transform` does `pull` (or `backward`)
resampling, transforming the output space to the input.
"""

# Rotation to align normal vector with unit vector along the specified axis
axis_vec = np.zeros(3)
axis_index = {"z": 0, "y": 1, "x": 2}[axis] # axis order is zyx in napari
axis_vec[axis_index] = 1
rotation_to_axis = Rotation.align_vectors(
axis_vec.reshape(1, 3),
normal_vector.reshape(1, 3),
)[0].as_matrix()
rotation_4x4 = np.eye(4)
rotation_4x4[:3, :3] = rotation_to_axis

# Translation back, so that the plane is in the middle of axis
translation_to_mid_axis = np.eye(4)
translation_to_mid_axis[axis_index, 3] = (
image.data.shape[axis_index] // 2 - centroid[axis_index]
)
if data.ndim != 3:
raise ValueError("Data must be 3D")
if transform.shape != (4, 4):
raise ValueError("Transform must be a 4x4 matrix")

# Combine transformations into a single 4x4 matrix
transformation_matrix = (
np.linalg.inv(translation_to_origin)
@ rotation_4x4
@ translation_to_origin
@ translation_to_mid_axis
)
# use larger output shape (to avoid cropping of edges)
output_shape = [int(1.1 * s) for s in data.shape]

# Apply the transformation to the image
aligned_image = affine_transform(
image,
transformation_matrix[:3, :3],
offset=transformation_matrix[:3, 3],
transformed = affine_transform(
data,
np.linalg.inv(transform[:3, :3]),
offset=-transform[:3, 3],
output_shape=output_shape,
)
return aligned_image
return transformed

0 comments on commit c73853b

Please sign in to comment.