Skip to content

Commit

Permalink
return layers instead of LayerDataTuples
Browse files Browse the repository at this point in the history
  • Loading branch information
niksirbi committed Nov 17, 2023
1 parent c24ec82 commit 5016bbf
Showing 1 changed file with 13 additions and 16 deletions.
29 changes: 13 additions & 16 deletions brainglobe_template_builder/napari/_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,11 @@
Replace code below according to your needs.
"""
from typing import Literal, Union
from typing import Literal

import numpy as np
from magicgui import magic_factory
from napari.layers import Image, Labels, Points
from napari.types import LayerDataTuple
from napari_plugin_engine import napari_hook_implementation

from brainglobe_template_builder.utils import (
Expand Down Expand Up @@ -46,7 +45,7 @@ def mask_widget(
gauss_sigma: float = 3,
threshold_method: Literal["triangle", "otsu", "isodata"] = "triangle",
erosion_size: int = 5,
) -> Union[LayerDataTuple, None]:
) -> Labels:
"""Threshold image and create a mask for the largest object.
The mask is generated by applying a Gaussian filter to the image,
Expand All @@ -70,8 +69,8 @@ def mask_widget(
Returns
-------
napari.types.LayerDataTuple
A napari Labels layer containing the mask.
napari.layers.Labels
A napari labels layer containing the mask.
"""

if image is not None:
Expand Down Expand Up @@ -99,17 +98,15 @@ def mask_widget(
mask = morphology.binary_erosion(
mask, footprint=np.ones((erosion_size,) * image.ndim)
)

# return the mask as a napari Labels layer
return (mask, {"name": "mask", "opacity": 0.5}, "labels")
return Labels(mask, opacity=0.5, name="mask")


@magic_factory(
call_button="Estimate midline points",
)
def points_widget(
mask: Labels,
) -> LayerDataTuple:
) -> Points:
"""Create a points layer with 9 midline points.
Parameters
Expand All @@ -119,8 +116,8 @@ def points_widget(
Returns
-------
napari.types.LayerDataTuple
A napari Points layer containing the estimated midline points.
napari.layers.Points
A napari points layer containing the midline points.
"""

# Estimate 9 midline points
Expand All @@ -143,7 +140,7 @@ def points_widget(
# Make mask layer invisible
mask.visible = False

return (points, point_attrs, "points")
return Points(points, **point_attrs)


@magic_factory(
Expand All @@ -156,7 +153,7 @@ def transform_widget(
image: Image,
points: Points,
axis: Literal["x", "y", "z"] = "x",
) -> LayerDataTuple:
) -> Image:
"""Transform image to align points with midline of the specified axis.
It first fits a plane to the points, then rigidly transforms the image
Expand All @@ -174,8 +171,8 @@ def transform_widget(
Returns
-------
napari.types.LayerDataTuple
A napari Image layer containing the transformed image.
napari.layers.Image
A napari image layer containing the transformed image.
"""

from scipy.ndimage import affine_transform
Expand Down Expand Up @@ -224,7 +221,7 @@ def transform_widget(
transformation_matrix[:3, :3],
offset=transformation_matrix[:3, 3],
)
return (transformed_image, {"name": "transformed image"}, "image")
return Image(transformed_image, name="aligned image")


@napari_hook_implementation
Expand Down

0 comments on commit 5016bbf

Please sign in to comment.