Skip to content

Commit

Permalink
implemented widget for estimating midline points
Browse files Browse the repository at this point in the history
  • Loading branch information
niksirbi committed Oct 31, 2023
1 parent f9a747e commit d4f88a0
Show file tree
Hide file tree
Showing 6 changed files with 179 additions and 120 deletions.
112 changes: 0 additions & 112 deletions brainglobe_template_builder/find_midline.py

This file was deleted.

20 changes: 20 additions & 0 deletions brainglobe_template_builder/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import imio
import numpy as np
import pandas as pd


def load_image_to_napari(tiff_path: Path):
Expand All @@ -27,6 +28,25 @@ def load_image_to_napari(tiff_path: Path):
return image


def save_3d_points_to_csv(points: np.ndarray, file_path: Path):
"""
Save 3D points to a csv file
"""

if points.shape[1] != 3:
raise ValueError(
f"Points must be of shape (n, 3). Got shape {points.shape}"
)
if file_path.suffix != ".csv":
raise ValueError(
f"File extension {file_path.suffix} is not valid. "
f"Expected file path to end in .csv"
)

points_df = pd.DataFrame(points, columns=["z", "y", "x"])
points_df.to_csv(file_path, index=False)


def save_nii(stack: np.ndarray, pix_sizes: list, dest_path: Path):
"""
Save 3D image stack to dest_path as a nifti image.
Expand Down
6 changes: 5 additions & 1 deletion brainglobe_template_builder/napari/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,13 @@
__version__ = "unknown"

from brainglobe_template_builder.napari._reader import napari_get_reader
from brainglobe_template_builder.napari._widget import mask_widget
from brainglobe_template_builder.napari._widget import (
mask_widget,
points_widget,
)

__all__ = (
"napari_get_reader",
"mask_widget",
"points_widget",
)
110 changes: 104 additions & 6 deletions brainglobe_template_builder/napari/_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,31 @@

import numpy as np
from magicgui import magic_factory
from napari.layers import Image
from magicgui.widgets import ComboBox, Container
from napari import Viewer
from napari.layers import Image, Labels, Points
from napari.types import LayerDataTuple
from napari_plugin_engine import napari_hook_implementation

from brainglobe_template_builder.utils import (
extract_largest_object,
get_midline_points,
threshold_image,
)

# 9 colors taken from ColorBrewer2.org Set3 palette
POINTS_COLOR_CYCLE = [
"#8dd3c7",
"#ffffb3",
"#bebada",
"#fb8072",
"#80b1d3",
"#fdb462",
"#b3de69",
"#fccde5",
"#d9d9d9",
]


@magic_factory(
call_button="generate mask",
Expand Down Expand Up @@ -55,10 +71,8 @@ def mask_widget(
Returns
-------
layers : list[LayerDataTuple]
A list of napari layers to add to the viewer.
The first layer is the mask, and the second layer is the smoothed
image (if smoothing was applied).
napari.types.LayerDataTuple
A napari Labels layer containing the mask.
"""

if image is not None:
Expand Down Expand Up @@ -91,6 +105,90 @@ def mask_widget(
return (mask, {"name": f"Mask_{image.name}", "opacity": 0.5}, "labels")


def create_point_label_menu(
points_layer: Points, point_labels: list[str]
) -> Container:
"""Create a point label menu widget for a napari points layer.
Parameters:
-----------
points_layer : napari.layers.Points
a napari points layer
point_labels : list[str]
a list of point labels
Returns:
--------
label_menu : Container
the magicgui Container with a dropdown menu widget
"""
# Create the label selection menu
label_menu = ComboBox(label="point_label", choices=point_labels)
label_widget = Container(widgets=[label_menu])

def update_label_menu(event):
"""Update the label menu when the point selection changes"""
new_label = str(points_layer.current_properties["label"][0])
if new_label != label_menu.value:
label_menu.value = new_label

points_layer.events.current_properties.connect(update_label_menu)

def label_changed(event):
"""Update the Points layer when the label menu selection changes"""
selected_label = event.value
current_properties = points_layer.current_properties
current_properties["label"] = np.asarray([selected_label])
points_layer.current_properties = current_properties

label_menu.changed.connect(label_changed)

return label_widget


@magic_factory(
call_button="Estimate midline points",
)
def points_widget(
mask: Labels,
) -> LayerDataTuple:
"""Create a points layer with 9 midline points.
Parameters
----------
mask : Labels
A napari labels layer to use as a reference for the points.
Returns
-------
napari.types.LayerDataTuple
A napari Points layer containing the estimated midline points.
"""

# Estimate 9 midline points
points = get_midline_points(mask.data)

point_labels = np.arange(1, points.shape[0] + 1)

point_attrs = {
"properties": {"label": point_labels},
"edge_color": "label",
"edge_color_cycle": POINTS_COLOR_CYCLE,
"symbol": "o",
"face_color": "transparent",
"edge_width": 0.3,
"size": 8,
"ndim": mask.ndim,
"name": "midline points",
}

return (points, point_attrs, "points")


@napari_hook_implementation
def napari_experimental_provide_dock_widget():
return mask_widget
return [mask_widget, points_widget]


viewer = Viewer()
viewer.add_points()
7 changes: 6 additions & 1 deletion brainglobe_template_builder/napari/napari.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@ contributions:
title: Open data with brainglobe-template-builder
- id: brainglobe-template-builder.make_mask_widget
python_name: brainglobe_template_builder.napari._widget:mask_widget
title: Generate organ mask
title: Generate mask
- id: brainglobe-template-builder.make_points_widget
python_name: brainglobe_template_builder.napari._widget:points_widget
title: Annotate points
readers:
- command: brainglobe-template-builder.get_reader
accepts_directories: false
Expand All @@ -17,3 +20,5 @@ contributions:
widgets:
- command: brainglobe-template-builder.make_mask_widget
display_name: Generate Mask
- command: brainglobe-template-builder.make_points_widget
display_name: Annotate Points
44 changes: 44 additions & 0 deletions brainglobe_template_builder/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from itertools import product
from typing import Literal, Union

import numpy as np
Expand Down Expand Up @@ -54,3 +55,46 @@ def threshold_image(
return image > thresholded
else:
raise ValueError(f"Unknown thresholding method {method}")


def get_midline_points(mask: np.ndarray):
"""Get a set of 9 points roughly on the x axis midline of a 3D binary mask.
Parameters
----------
mask : np.ndarray
A binary mask of shape (z, y, x).
Returns
-------
np.ndarray
An array of shape (9, 3) containing the midline points.
"""

# Ensure mask is 3D
if mask.ndim != 3:
raise ValueError("Mask must be 3D")

# Ensure mask is binary
try:
mask = mask.astype(bool)
except ValueError:
raise ValueError("Mask must be binary")

# Derive mask properties
props = measure.regionprops(measure.label(mask))[0]
# bbox in shape (3, 2): for each dim (row) the min and max (col)
bbox = np.array(props.bbox).reshape(2, 3).T
bbox_ranges = bbox[:, 1] - bbox[:, 0]
# mask centroid in shape (3,)
centroid = np.array(props.centroid)

# Find slices at 1/4, 2/4, and 3/4 of the z and y dimensions
z_slices = [bbox_ranges[0] / 4 * i for i in [1, 2, 3]]
y_slices = [bbox_ranges[1] / 4 * i for i in [1, 2, 3]]
# Find points at the intersection the centroid's x slice
# with the above y and z slices.
# This produces a set of 9 points roughly on the midline
points = list(product(z_slices, y_slices, [centroid[2]]))

return np.array(points)

0 comments on commit d4f88a0

Please sign in to comment.