Skip to content

Commit

Permalink
refactored thresholding and largest object extraction into separate f…
Browse files Browse the repository at this point in the history
…unctions
  • Loading branch information
niksirbi committed Oct 31, 2023
1 parent 571be9c commit a6a4cd5
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 20 deletions.
36 changes: 16 additions & 20 deletions brainglobe_template_builder/napari/_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,22 @@
from napari.types import LayerDataTuple
from napari_plugin_engine import napari_hook_implementation

from brainglobe_template_builder.utils import (
extract_largest_object,
threshold_image,
)


@magic_factory(
call_button="generate mask",
gauss_sigma={"widget_type": "FloatSlider", "max": 10, "min": 0, "step": 1},
threshold_method={"choices": ["triangle", "otsu"]},
erosion_size={"widget_type": "Slider", "max": 10, "min": 0, "step": 1},
gauss_sigma={"widget_type": "SpinBox", "max": 20, "min": 0},
threshold_method={"choices": ["triangle", "otsu", "isodata"]},
erosion_size={"widget_type": "SpinBox", "max": 20, "min": 0},
)
def mask_widget(
image: Image,
gauss_sigma: float = 3,
threshold_method: Literal["triangle", "otsu"] = "triangle",
threshold_method: Literal["triangle", "otsu", "isodata"] = "triangle",
erosion_size: int = 5,
) -> Union[LayerDataTuple, None]:
"""Threshold image and create a mask for the largest object.
Expand All @@ -41,7 +46,9 @@ def mask_widget(
Standard deviation for Gaussian kernel (in pixels) to smooth image
before thresholding. Set to 0 to skip smoothing.
threshold_method : str
Thresholding method to use. Options are 'triangle' and 'otsu'.
Thresholding method to use. One of 'triangle', 'otsu', and 'isodata'
(corresponding to methods from the skimage.filters module).
Defaults to 'triangle'.
erosion_size : int
Size of the erosion footprint (in pixels) to apply to the mask.
Set to 0 to skip erosion.
Expand All @@ -60,7 +67,7 @@ def mask_widget(
print("Please select an image layer")
return None

from skimage import filters, measure, morphology
from skimage import filters, morphology

# Apply gaussian filter to image
if gauss_sigma > 0:
Expand All @@ -69,21 +76,10 @@ def mask_widget(
data_smoothed = image.data

# Threshold the (smoothed) image
if threshold_method == "triangle":
thresholded = filters.threshold_triangle(data_smoothed)
elif threshold_method == "otsu":
thresholded = filters.threshold_otsu(data_smoothed)
else:
raise ValueError(f"Unknown thresholding method {threshold_method}")

binary = data_smoothed > thresholded
binary = threshold_image(data_smoothed, method=threshold_method)

# Keep only the largest object
labeled_image = measure.label(binary)
regions = measure.regionprops(labeled_image)
largest_region = max(regions, key=lambda region: region.area)
# Create a binary mask for the largest object
mask = labeled_image == largest_region.label
# Keep only the largest object in the binary image
mask = extract_largest_object(binary)

# Erode the mask
if erosion_size > 0:
Expand Down
56 changes: 56 additions & 0 deletions brainglobe_template_builder/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from typing import Literal, Union

import numpy as np
from skimage import filters, measure


def extract_largest_object(binary_image):
"""Keep only the largest object in a binary image.
Parameters
----------
binary_image : np.ndarray
A binary image.
Returns
-------
np.ndarray
A binary image containing only the largest object.
"""
labeled_image = measure.label(binary_image)
regions = measure.regionprops(labeled_image)
largest_region = max(regions, key=lambda region: region.area)
return labeled_image == largest_region.label


def threshold_image(
image: np.ndarray,
method: Literal["triangle", "otsu", "isodata"] = "triangle",
) -> Union[np.ndarray, None]:
"""Threshold an image using the specified method to get a binary mask.
Parameters
----------
image : np.ndarray
Image to threshold.
method : str
Thresholding method to use. One of 'triangle', 'otsu', and 'isodata'
(corresponding to methods from the skimage.filters module).
Defaults to 'triangle'.
Returns
-------
np.ndarray
A binary mask.
"""

method_to_func = {
"triangle": filters.threshold_triangle,
"otsu": filters.threshold_otsu,
"isodata": filters.threshold_isodata,
}
if method in method_to_func.keys():
thresholded = method_to_func[method](image)
return image > thresholded
else:
raise ValueError(f"Unknown thresholding method {method}")

0 comments on commit a6a4cd5

Please sign in to comment.