diff --git a/brainglobe_template_builder/find_midline.py b/brainglobe_template_builder/find_midline.py deleted file mode 100644 index 1a70d04..0000000 --- a/brainglobe_template_builder/find_midline.py +++ /dev/null @@ -1,112 +0,0 @@ -""" -This module is for finding the midline plane of a 3D image stack. - -The image is loaded into napari, and the user is asked to annotate -at least 3 points on the midline. The midline plane is then fitted -to these points using a least squares regression. Finally, the -transformation matrix needed to rotate the image so that the midline -plane is centered is calculated and saved. -""" - -# %% -from itertools import product -from pathlib import Path - -import imio -import napari -import numpy as np -import pandas as pd -import skimage - -# %% -# load tiff image into napari -data_dir = Path( - "/Users/nsirmpilatze/Data/BlackCap/ants/template_v2/results_tiff" -) -image_path = data_dir / "template.tif" -image = imio.load_any(image_path.as_posix()) - -viewer = napari.view_image(image, name="image") - -# %% -# Segment the brain from the background via thresholding -# and add brain mask to napari viewer - -# first apply gaussian filter to image -image_smoothed = skimage.filters.gaussian(image, sigma=3) -viewer.add_image( - image_smoothed, - name="smoothed_image", - visible=False, -) - -# %% -# then apply thresholding to smoothed image -thresholded = skimage.filters.threshold_triangle(image_smoothed) -binary_smoothed = image_smoothed > thresholded -# Remove small objects -min_object_size = 500 # Define a minimum object size to keep (in pixels) -binary_smoothed = skimage.morphology.remove_small_objects( - binary_smoothed, min_size=min_object_size -) -viewer.add_image( - binary_smoothed, - name="initial_binary", - colormap="green", - blending="additive", - opacity=0.5, -) - -# %% -# Erode the binary mask to remove the edges of the brain -# and the resulting mask to napari viewer as a label layer -eroded = skimage.morphology.binary_erosion( - binary_smoothed, footprint=np.ones((5, 5, 5)) -) -viewer.add_labels( - eroded, - name="brain_mask", - opacity=0.5, -) -viewer.layers["initial_binary"].visible = False -viewer.layers["brain_mask"].selected_label = 1 - - -# %% -# Define initial set of 9 midline points - -# Find slices at 1/4, 2/4, and 3/4 of each dimension -slices = [int(dim / 4) * i for dim in image.shape for i in [1, 2, 3]] -z_slices, y_slices, x_slices = slices[0:3], slices[3:6], slices[6:9] -# Find points at the intersection the middle x slice with the y and z slices -grid_points = np.array(list(product(z_slices, y_slices, [x_slices[1]]))) -print(grid_points) - -# %% -# Add points to napari viewer -points_layer = viewer.add_points( - grid_points, - name="midline", - face_color="#ffaa00", - size=5, - opacity=0.5, - edge_color="#ff0000", - edge_width=1.5, - edge_width_is_relative=False, -) -# %% -# Go to the first z slice - -# activate point layer -viewer.layers.selection.active_layer = points_layer -# change selection mode to select -viewer.layers.selection.mode = "select" -# go to the first z slice -viewer.dims.set_point(0, z_slices[0]) - -# %% -# save points to csv -midline_points = points_layer.data -midline_points = pd.DataFrame(midline_points, columns=["z", "y", "x"]) -points_path = data_dir / "midline_points.csv" -midline_points.to_csv(points_path, index=False) diff --git a/brainglobe_template_builder/io.py b/brainglobe_template_builder/io.py index 7397b81..7a06b17 100644 --- a/brainglobe_template_builder/io.py +++ b/brainglobe_template_builder/io.py @@ -2,6 +2,7 @@ import imio import numpy as np +import pandas as pd def load_image_to_napari(tiff_path: Path): @@ -27,6 +28,25 @@ def load_image_to_napari(tiff_path: Path): return image +def save_3d_points_to_csv(points: np.ndarray, file_path: Path): + """ + Save 3D points to a csv file + """ + + if points.shape[1] != 3: + raise ValueError( + f"Points must be of shape (n, 3). Got shape {points.shape}" + ) + if file_path.suffix != ".csv": + raise ValueError( + f"File extension {file_path.suffix} is not valid. " + f"Expected file path to end in .csv" + ) + + points_df = pd.DataFrame(points, columns=["z", "y", "x"]) + points_df.to_csv(file_path, index=False) + + def save_nii(stack: np.ndarray, pix_sizes: list, dest_path: Path): """ Save 3D image stack to dest_path as a nifti image. diff --git a/brainglobe_template_builder/napari/__init__.py b/brainglobe_template_builder/napari/__init__.py index 3bd5bae..9774bc0 100644 --- a/brainglobe_template_builder/napari/__init__.py +++ b/brainglobe_template_builder/napari/__init__.py @@ -7,9 +7,13 @@ __version__ = "unknown" from brainglobe_template_builder.napari._reader import napari_get_reader -from brainglobe_template_builder.napari._widget import mask_widget +from brainglobe_template_builder.napari._widget import ( + mask_widget, + points_widget, +) __all__ = ( "napari_get_reader", "mask_widget", + "points_widget", ) diff --git a/brainglobe_template_builder/napari/_widget.py b/brainglobe_template_builder/napari/_widget.py index 1279fb7..9ebae3e 100644 --- a/brainglobe_template_builder/napari/_widget.py +++ b/brainglobe_template_builder/napari/_widget.py @@ -10,15 +10,31 @@ import numpy as np from magicgui import magic_factory -from napari.layers import Image +from magicgui.widgets import ComboBox, Container +from napari import Viewer +from napari.layers import Image, Labels, Points from napari.types import LayerDataTuple from napari_plugin_engine import napari_hook_implementation from brainglobe_template_builder.utils import ( extract_largest_object, + get_midline_points, threshold_image, ) +# 9 colors taken from ColorBrewer2.org Set3 palette +POINTS_COLOR_CYCLE = [ + "#8dd3c7", + "#ffffb3", + "#bebada", + "#fb8072", + "#80b1d3", + "#fdb462", + "#b3de69", + "#fccde5", + "#d9d9d9", +] + @magic_factory( call_button="generate mask", @@ -55,10 +71,8 @@ def mask_widget( Returns ------- - layers : list[LayerDataTuple] - A list of napari layers to add to the viewer. - The first layer is the mask, and the second layer is the smoothed - image (if smoothing was applied). + napari.types.LayerDataTuple + A napari Labels layer containing the mask. """ if image is not None: @@ -91,6 +105,90 @@ def mask_widget( return (mask, {"name": f"Mask_{image.name}", "opacity": 0.5}, "labels") +def create_point_label_menu( + points_layer: Points, point_labels: list[str] +) -> Container: + """Create a point label menu widget for a napari points layer. + + Parameters: + ----------- + points_layer : napari.layers.Points + a napari points layer + point_labels : list[str] + a list of point labels + + Returns: + -------- + label_menu : Container + the magicgui Container with a dropdown menu widget + """ + # Create the label selection menu + label_menu = ComboBox(label="point_label", choices=point_labels) + label_widget = Container(widgets=[label_menu]) + + def update_label_menu(event): + """Update the label menu when the point selection changes""" + new_label = str(points_layer.current_properties["label"][0]) + if new_label != label_menu.value: + label_menu.value = new_label + + points_layer.events.current_properties.connect(update_label_menu) + + def label_changed(event): + """Update the Points layer when the label menu selection changes""" + selected_label = event.value + current_properties = points_layer.current_properties + current_properties["label"] = np.asarray([selected_label]) + points_layer.current_properties = current_properties + + label_menu.changed.connect(label_changed) + + return label_widget + + +@magic_factory( + call_button="Estimate midline points", +) +def points_widget( + mask: Labels, +) -> LayerDataTuple: + """Create a points layer with 9 midline points. + + Parameters + ---------- + mask : Labels + A napari labels layer to use as a reference for the points. + + Returns + ------- + napari.types.LayerDataTuple + A napari Points layer containing the estimated midline points. + """ + + # Estimate 9 midline points + points = get_midline_points(mask.data) + + point_labels = np.arange(1, points.shape[0] + 1) + + point_attrs = { + "properties": {"label": point_labels}, + "edge_color": "label", + "edge_color_cycle": POINTS_COLOR_CYCLE, + "symbol": "o", + "face_color": "transparent", + "edge_width": 0.3, + "size": 8, + "ndim": mask.ndim, + "name": "midline points", + } + + return (points, point_attrs, "points") + + @napari_hook_implementation def napari_experimental_provide_dock_widget(): - return mask_widget + return [mask_widget, points_widget] + + +viewer = Viewer() +viewer.add_points() diff --git a/brainglobe_template_builder/napari/napari.yaml b/brainglobe_template_builder/napari/napari.yaml index a993f6e..db86ce4 100644 --- a/brainglobe_template_builder/napari/napari.yaml +++ b/brainglobe_template_builder/napari/napari.yaml @@ -7,7 +7,10 @@ contributions: title: Open data with brainglobe-template-builder - id: brainglobe-template-builder.make_mask_widget python_name: brainglobe_template_builder.napari._widget:mask_widget - title: Generate organ mask + title: Generate mask + - id: brainglobe-template-builder.make_points_widget + python_name: brainglobe_template_builder.napari._widget:points_widget + title: Annotate points readers: - command: brainglobe-template-builder.get_reader accepts_directories: false @@ -17,3 +20,5 @@ contributions: widgets: - command: brainglobe-template-builder.make_mask_widget display_name: Generate Mask + - command: brainglobe-template-builder.make_points_widget + display_name: Annotate Points diff --git a/brainglobe_template_builder/utils.py b/brainglobe_template_builder/utils.py index d6653fc..65ac624 100644 --- a/brainglobe_template_builder/utils.py +++ b/brainglobe_template_builder/utils.py @@ -1,3 +1,4 @@ +from itertools import product from typing import Literal, Union import numpy as np @@ -54,3 +55,46 @@ def threshold_image( return image > thresholded else: raise ValueError(f"Unknown thresholding method {method}") + + +def get_midline_points(mask: np.ndarray): + """Get a set of 9 points roughly on the x axis midline of a 3D binary mask. + + Parameters + ---------- + mask : np.ndarray + A binary mask of shape (z, y, x). + + Returns + ------- + np.ndarray + An array of shape (9, 3) containing the midline points. + """ + + # Ensure mask is 3D + if mask.ndim != 3: + raise ValueError("Mask must be 3D") + + # Ensure mask is binary + try: + mask = mask.astype(bool) + except ValueError: + raise ValueError("Mask must be binary") + + # Derive mask properties + props = measure.regionprops(measure.label(mask))[0] + # bbox in shape (3, 2): for each dim (row) the min and max (col) + bbox = np.array(props.bbox).reshape(2, 3).T + bbox_ranges = bbox[:, 1] - bbox[:, 0] + # mask centroid in shape (3,) + centroid = np.array(props.centroid) + + # Find slices at 1/4, 2/4, and 3/4 of the z and y dimensions + z_slices = [bbox_ranges[0] / 4 * i for i in [1, 2, 3]] + y_slices = [bbox_ranges[1] / 4 * i for i in [1, 2, 3]] + # Find points at the intersection the centroid's x slice + # with the above y and z slices. + # This produces a set of 9 points roughly on the midline + points = list(product(z_slices, y_slices, [centroid[2]])) + + return np.array(points)