diff --git a/brainglobe_template_builder/napari/align_widget.py b/brainglobe_template_builder/napari/align_widget.py index b58fd48..3d367fd 100644 --- a/brainglobe_template_builder/napari/align_widget.py +++ b/brainglobe_template_builder/napari/align_widget.py @@ -95,6 +95,14 @@ def _create_align_group(self): self.align_image_button.clicked.connect(self._on_align_button_click) self.align_groupbox.layout().addRow(self.align_image_button) + # Add button to split image down the midplane + self.split_image_button = QPushButton( + "Split and refect image", parent=self.align_groupbox + ) + self.split_image_button.setEnabled(False) + self.split_image_button.clicked.connect(self._on_split_button_click) + self.align_groupbox.layout().addRow(self.split_image_button) + def _get_layers_by_type(self, layer_type: Layer) -> list: """Return a list of napari layers of a given type.""" return [ @@ -164,10 +172,32 @@ def _on_align_button_click(self): self.viewer.add_image(aligned_image, name="aligned_image") aligned_mask = aligner.transform_image(mask_data) self.viewer.add_labels(aligned_mask, name="aligned_mask", opacity=0.5) + split_mask = aligner.split_mask(aligned_mask) + self.viewer.add_labels(split_mask, name="split_mask", opacity=0.5) # Hide original image, mask, and points layers self.viewer.layers[image_name].visible = False self.viewer.layers[mask_name].visible = False self.viewer.layers[points_name].visible = False + # Enable the split button + self.split_image_button.setEnabled(True) + # Make aligner object accessible to other methods + self.aligner = aligner + + def _on_split_button_click(self): + """Split the aligned image and its mask down the midplane and reflect + each half.""" + aligned_image = self.viewer.layers["aligned_image"].data + split_mask = self.viewer.layers["split_mask"].data + hemi1_image, hemi2_image = self.aligner.split_and_reflect_image( + aligned_image + ) + hemi1_mask, hemi2_mask = self.aligner.split_and_reflect_image( + split_mask + ) + self.viewer.add_image(hemi1_image, name="hemi1_sym_image") + self.viewer.add_labels(hemi1_mask, name="hemi1_sym_mask", opacity=0.5) + self.viewer.add_image(hemi2_image, name="hemi2_sym_image") + self.viewer.add_labels(hemi2_mask, name="hemi2_sym_mask", opacity=0.5) def _on_dropdown_selection_change(self): # Enable estimate button if mask dropdown has a value diff --git a/brainglobe_template_builder/preproc/alignment.py b/brainglobe_template_builder/preproc/alignment.py index 8adcc12..b8d3cc7 100644 --- a/brainglobe_template_builder/preproc/alignment.py +++ b/brainglobe_template_builder/preproc/alignment.py @@ -195,3 +195,60 @@ def transform_image(self, image: np.ndarray = None): image = self.image self.transformed_image = apply_transform(image, self.transform) return self.transformed_image + + def split_mask(self, mask: np.ndarray) -> np.ndarray: + """Label each half of the mask along the symmetry axis with different + integer values, to help diagnose issues with the splitting process. + + Parameters + ---------- + mask : np.ndarray + The mask to split. Must contain only 1 label + + Returns + ------- + np.ndarray + An array of the same shape as the input mask, with each half + labelled with a different integer value. + """ + # Ensure mask is binary + mask = mask.astype(bool) + axi = self.symmetry_axis_idx + axis_len = mask.shape[axi] + half_len = axis_len // 2 + split_mask = np.zeros_like(mask, dtype=int) + + # Create slicing objects for each half + slicer_half1 = [slice(None)] * mask.ndim + slicer_half1[axi] = slice(0, half_len) + slicer_half2 = [slice(None)] * mask.ndim + slicer_half2[axi] = slice(half_len, axis_len) + + # Apply new labels only to the regions labeled in the input mask + split_mask[tuple(slicer_half1)] = mask[tuple(slicer_half1)] * 2 + split_mask[tuple(slicer_half2)] = mask[tuple(slicer_half2)] * 3 + + return split_mask + + def split_and_reflect_image( + self, image: np.ndarray + ) -> tuple[np.ndarray, np.ndarray]: + """Split the transformed image into two halves along the symmetry + axis and reflect each half to produce two full images. + + Parameters + ---------- + image : np.ndarray + The image to split and reflect. + """ + axi = self.symmetry_axis_idx # axis index + axis_len = image.shape[axi] + half_len = axis_len // 2 + # take first half_len slices along the symmetry axis (first hald) + hemi1 = image.take(range(half_len), axis=axi) + # take last half_len slices along the symmetry axis (second half) + hemi2 = image.take(range(half_len, axis_len), axis=axi) + # reflect each half to produce two full images + full1 = np.concatenate([hemi1, np.flip(hemi1, axis=axi)], axis=axi) + full2 = np.concatenate([np.flip(hemi2, axis=axi), hemi2], axis=axi) + return full1, full2