Skip to content

Commit

Permalink
also the mask is transformed during alignment
Browse files Browse the repository at this point in the history
  • Loading branch information
niksirbi committed Jan 12, 2024
1 parent f3c1b27 commit 74d25ac
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 11 deletions.
20 changes: 15 additions & 5 deletions brainglobe_template_builder/napari/align_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,18 +146,28 @@ def _on_estimate_button_click(self):
show_info("Please move all 9 estimated points exactly to the midplane")

def _on_align_button_click(self):
"""Align image and add the transformed image to the viewer."""
"""Align image and mask to midplane and add them to the viewer."""
image_name = self.select_image_dropdown.currentText()
image_data = self.viewer.layers[image_name].data
mask_name = self.select_mask_dropdown.currentText()
mask_data = self.viewer.layers[mask_name].data
points_name = self.select_points_dropdown.currentText()
points_data = self.viewer.layers[points_name].data
axis = self.select_axis_dropdown.currentText()

aligner = MidplaneAligner(
self.viewer.layers[image_name].data,
self.viewer.layers[points_name].data,
image_data,
points_data,
symmetry_axis=axis,
)
aligned_image = aligner.transform_image()
self.viewer.add_image(aligned_image, name="aligned image")
aligned_image = aligner.transform_image(image_data)
self.viewer.add_image(aligned_image, name="aligned_image")
aligned_mask = aligner.transform_image(mask_data)
self.viewer.add_labels(aligned_mask, name="aligned_mask", opacity=0.5)
# Hide original image, mask, and points layers
self.viewer.layers[image_name].visible = False
self.viewer.layers[mask_name].visible = False
self.viewer.layers[points_name].visible = False

def _on_dropdown_selection_change(self):
# Enable estimate button if mask dropdown has a value
Expand Down
20 changes: 14 additions & 6 deletions brainglobe_template_builder/preproc/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def _fit_plane_to_points(self):
self.normal_vector = normal_vector
self.symmetry_axis_vector = symmetry_axis_vector

def get_transform(self):
def _compute_transform(self):
"""Find the transformation matrix that aligns the plane defined by the
points to the midplane of the image along the symmetry axis.
"""
Expand All @@ -179,11 +179,19 @@ def get_transform(self):
self.transform = (
translation_to_mid_axis @ rotation @ translation_to_origin
)
return self.transform

def transform_image(self):
"""Transform the image using the transformation matrix."""
def transform_image(self, image: np.ndarray = None):
"""Transform the image using the transformation matrix.
Parameters
----------
image : np.ndarray
The image to transform. If None, the image passed to the
constructor is used.
"""
if not hasattr(self, "transform"):
self.get_transform()
self.transformed_image = apply_transform(self.image, self.transform)
self._compute_transform()
if image is None:
image = self.image
self.transformed_image = apply_transform(image, self.transform)
return self.transformed_image

0 comments on commit 74d25ac

Please sign in to comment.