Skip to content

Commit

Permalink
started drafting midline alignment widget
Browse files Browse the repository at this point in the history
  • Loading branch information
niksirbi committed Oct 31, 2023
1 parent 4c4779d commit 2a2e685
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 2 deletions.
1 change: 1 addition & 0 deletions brainglobe_template_builder/napari/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from brainglobe_template_builder.napari._widget import (
mask_widget,
points_widget,
align_widget,
)

__all__ = (
Expand Down
65 changes: 63 additions & 2 deletions brainglobe_template_builder/napari/_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@

import numpy as np
from magicgui import magic_factory
from napari.layers import Image, Labels
from napari.layers import Image, Labels, Points
from napari.types import LayerDataTuple
from napari_plugin_engine import napari_hook_implementation

from brainglobe_template_builder.utils import (
align_vectors,
extract_largest_object,
fit_plane_to_points,
get_midline_points,
threshold_image,
)
Expand Down Expand Up @@ -100,7 +102,7 @@ def mask_widget(
)

# return the mask as a napari Labels layer
return (mask, {"name": f"Mask_{image.name}", "opacity": 0.5}, "labels")
return (mask, {"name": "mask", "opacity": 0.5}, "labels")


@magic_factory(
Expand Down Expand Up @@ -145,6 +147,65 @@ def points_widget(
return (points, point_attrs, "points")


@magic_factory(
call_button="Align midline",
image={"label": "Image"},
points={"label": "Midline points"},
)
def align_widget(image: Image, points: Points) -> LayerDataTuple:
"""Align image to midline points.
Parameters
----------
image : Image
A napari image layer to align.
points : Points
A napari points layer containing the midline points.
Returns
-------
napari.types.LayerDataTuple
A napari Image layer containing the aligned image.
"""

from scipy.ndimage import affine_transform

points_data = points.data
centroid = np.mean(points_data, axis=0)
normal_vector = fit_plane_to_points(points_data)

# 1. Translate so centroid is at origin
translate_to_origin = np.eye(4)
translate_to_origin[:3, 3] = -centroid

# 2. Rotate so normal vector aligns with x-axis
x_axis = np.array([0, 0, 1])
rotation_matrix = align_vectors(normal_vector, x_axis)
rotation = np.eye(4)
rotation[:3, :3] = rotation_matrix

# 3. Translate so plane is at middle slice along x-axis
translate_to_mid_x = np.eye(4)
mid_x = image.data.shape[2] // 2
translate_to_mid_x[2, 3] = mid_x - centroid[2]

# Combine transformations
transform = (
np.linalg.inv(translate_to_origin)
@ rotation
@ translate_to_origin
@ translate_to_mid_x
)
affine_matrix, offset = transform[:3, :3], transform[:3, 3]

# Apply the transformation to the image
transformed_image = affine_transform(
image.data, affine_matrix, offset=offset
)

return (transformed_image, {"name": "aligned image"}, "image")


@napari_hook_implementation
def napari_experimental_provide_dock_widget():
return [mask_widget, points_widget]
5 changes: 5 additions & 0 deletions brainglobe_template_builder/napari/napari.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ contributions:
- id: brainglobe-template-builder.make_points_widget
python_name: brainglobe_template_builder.napari._widget:points_widget
title: Annotate points
- id: brainglobe-template-builder.make_align_widget
python_name: brainglobe_template_builder.napari._widget:align_widget
title: Align midline
readers:
- command: brainglobe-template-builder.get_reader
accepts_directories: false
Expand All @@ -22,3 +25,5 @@ contributions:
display_name: Generate Mask
- command: brainglobe-template-builder.make_points_widget
display_name: Annotate Points
- command: brainglobe-template-builder.make_align_widget
display_name: Align Midline
55 changes: 55 additions & 0 deletions brainglobe_template_builder/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,58 @@ def get_midline_points(mask: np.ndarray):
points = list(product(z_slices, y_slices, [centroid[2]]))

return np.array(points)


def fit_plane_to_points(
points: np.ndarray,
) -> tuple[float, float, float, float]:
"""Fit a plane to a set of 3D points.
Parameters
----------
points : np.ndarray
An array of shape (n, 3) containing the points.
Returns
-------
np.ndarray
The normal vector to the plane, with shape (3,).
"""

# Ensure points are 3D
if points.shape[1] != 3:
raise ValueError("Points array must have 3 columns (z, y, x)")

centered_points = points - np.mean(points, axis=0)

# Use SVD to get the normal vector to the plane
_, _, vh = np.linalg.svd(centered_points)
normal_vector = vh[-1]

return normal_vector


def align_vectors(v1, v2):
"""Align two vectors using Rodrigues' rotation formula.
Parameters
----------
v1 : np.ndarray
The first vector.
v2 : np.ndarray
The second vector.
"""
v1 = v1 / np.linalg.norm(v1)
v2 = v2 / np.linalg.norm(v2)
cross_prod = np.cross(v1, v2)
dot_prod = np.dot(v1, v2)
s = np.linalg.norm(cross_prod)
K = np.array(
[
[0, -cross_prod[2], cross_prod[1]],
[cross_prod[2], 0, -cross_prod[0]],
[-cross_prod[1], cross_prod[0], 0],
]
)
rotation = np.eye(3) + K + K @ K * ((1 - dot_prod) / (s**2))
return rotation

0 comments on commit 2a2e685

Please sign in to comment.