diff --git a/brainglobe_template_builder/napari/midline_widget.py b/brainglobe_template_builder/napari/midline_widget.py index 81b3754..b5aa1ef 100644 --- a/brainglobe_template_builder/napari/midline_widget.py +++ b/brainglobe_template_builder/napari/midline_widget.py @@ -8,8 +8,10 @@ QWidget, ) -from brainglobe_template_builder.preproc import align_to_midline -from brainglobe_template_builder.utils import get_midline_points +from brainglobe_template_builder.preproc import ( + align_to_midline, + get_midline_points, +) class FindMidline(QWidget): diff --git a/brainglobe_template_builder/preproc.py b/brainglobe_template_builder/preproc.py index d86ccd3..92afed0 100644 --- a/brainglobe_template_builder/preproc.py +++ b/brainglobe_template_builder/preproc.py @@ -1,15 +1,62 @@ -from typing import Literal +from itertools import product +from typing import Literal, Union import numpy as np from scipy.ndimage import affine_transform from scipy.spatial.transform import Rotation -from skimage import filters, morphology +from skimage import filters, measure, morphology -from brainglobe_template_builder.utils import ( - extract_largest_object, - fit_plane_to_points, - threshold_image, -) + +def _extract_largest_object(binary_image): + """Keep only the largest object in a binary image. + + Parameters + ---------- + binary_image : np.ndarray + A binary image. + + Returns + ------- + np.ndarray + A binary image containing only the largest object. + """ + labeled_image = measure.label(binary_image) + regions = measure.regionprops(labeled_image) + largest_region = max(regions, key=lambda region: region.area) + return labeled_image == largest_region.label + + +def _threshold_image( + image: np.ndarray, + method: Literal["triangle", "otsu", "isodata"] = "triangle", +) -> Union[np.ndarray, None]: + """Threshold an image using the specified method to get a binary mask. + + Parameters + ---------- + image : np.ndarray + Image to threshold. + method : str + Thresholding method to use. One of 'triangle', 'otsu', and 'isodata' + (corresponding to methods from the skimage.filters module). + Defaults to 'triangle'. + + Returns + ------- + np.ndarray + A binary mask. + """ + + method_to_func = { + "triangle": filters.threshold_triangle, + "otsu": filters.threshold_otsu, + "isodata": filters.threshold_isodata, + } + if method in method_to_func.keys(): + thresholded = method_to_func[method](image) + return image > thresholded + else: + raise ValueError(f"Unknown thresholding method {method}") def create_mask( @@ -49,19 +96,14 @@ def create_mask( if image.ndim != 3: raise ValueError("Image must be 3D") - # Apply gaussian filter to image if gauss_sigma > 0: data_smoothed = filters.gaussian(image, sigma=gauss_sigma) else: data_smoothed = image - # Threshold the (smoothed) image - binary = threshold_image(data_smoothed, method=threshold_method) + binary = _threshold_image(data_smoothed, method=threshold_method) + mask = _extract_largest_object(binary) - # Keep only the largest object in the binary image - mask = extract_largest_object(binary) - - # Erode the mask if erosion_size > 0: mask = morphology.binary_erosion( mask, footprint=np.ones((erosion_size,) * image.ndim) @@ -69,6 +111,77 @@ def create_mask( return mask +def get_midline_points(mask: np.ndarray): + """Get a set of 9 points roughly on the x axis midline of a 3D binary mask. + + Parameters + ---------- + mask : np.ndarray + A binary mask of shape (z, y, x). + + Returns + ------- + np.ndarray + An array of shape (9, 3) containing the midline points. + """ + + # Check input + if mask.ndim != 3: + raise ValueError("Mask must be 3D") + + try: + mask = mask.astype(bool) + except ValueError: + raise ValueError("Mask must be binary") + + # Derive mask properties + props = measure.regionprops(measure.label(mask))[0] + # bbox in shape (3, 2): for each dim (row) the min and max (col) + bbox = np.array(props.bbox).reshape(2, 3).T + bbox_ranges = bbox[:, 1] - bbox[:, 0] + # mask centroid in shape (3,) + centroid = np.array(props.centroid) + + # Find slices at 1/4, 2/4, and 3/4 of the z and y dimensions + z_slices = [bbox_ranges[0] / 4 * i for i in [1, 2, 3]] + y_slices = [bbox_ranges[1] / 4 * i for i in [1, 2, 3]] + # Find points at the intersection the centroid's x slice + # with the above y and z slices. + # This produces a set of 9 points roughly on the midline + points = list(product(z_slices, y_slices, [centroid[2]])) + + return np.array(points) + + +def _fit_plane_to_points( + points: np.ndarray, +) -> np.ndarray: + """Fit a plane to a set of 3D points. + + Parameters + ---------- + points : np.ndarray + An array of shape (n, 3) containing the points. + + Returns + ------- + np.ndarray + The normal vector to the plane, with shape (3,). + """ + + # Ensure points are 3D + if points.shape[1] != 3: + raise ValueError("Points array must have 3 columns (z, y, x)") + + centered_points = points - np.mean(points, axis=0) + + # Use SVD to get the normal vector to the plane + _, _, vh = np.linalg.svd(centered_points) + normal_vector = vh[-1] + + return normal_vector + + def align_to_midline( image: np.ndarray, points: np.ndarray, @@ -105,7 +218,7 @@ def align_to_midline( raise ValueError("Axis must be one of 'x', 'y', or 'z'") # Fit a plane to the points - normal_vector = fit_plane_to_points(points) + normal_vector = _fit_plane_to_points(points) # Compute centroid of the midline points centroid = np.mean(points, axis=0) diff --git a/brainglobe_template_builder/utils.py b/brainglobe_template_builder/utils.py deleted file mode 100644 index 319b1a4..0000000 --- a/brainglobe_template_builder/utils.py +++ /dev/null @@ -1,129 +0,0 @@ -from itertools import product -from typing import Literal, Union - -import numpy as np -from skimage import filters, measure - - -def extract_largest_object(binary_image): - """Keep only the largest object in a binary image. - - Parameters - ---------- - binary_image : np.ndarray - A binary image. - - Returns - ------- - np.ndarray - A binary image containing only the largest object. - """ - labeled_image = measure.label(binary_image) - regions = measure.regionprops(labeled_image) - largest_region = max(regions, key=lambda region: region.area) - return labeled_image == largest_region.label - - -def threshold_image( - image: np.ndarray, - method: Literal["triangle", "otsu", "isodata"] = "triangle", -) -> Union[np.ndarray, None]: - """Threshold an image using the specified method to get a binary mask. - - Parameters - ---------- - image : np.ndarray - Image to threshold. - method : str - Thresholding method to use. One of 'triangle', 'otsu', and 'isodata' - (corresponding to methods from the skimage.filters module). - Defaults to 'triangle'. - - Returns - ------- - np.ndarray - A binary mask. - """ - - method_to_func = { - "triangle": filters.threshold_triangle, - "otsu": filters.threshold_otsu, - "isodata": filters.threshold_isodata, - } - if method in method_to_func.keys(): - thresholded = method_to_func[method](image) - return image > thresholded - else: - raise ValueError(f"Unknown thresholding method {method}") - - -def get_midline_points(mask: np.ndarray): - """Get a set of 9 points roughly on the x axis midline of a 3D binary mask. - - Parameters - ---------- - mask : np.ndarray - A binary mask of shape (z, y, x). - - Returns - ------- - np.ndarray - An array of shape (9, 3) containing the midline points. - """ - - # Ensure mask is 3D - if mask.ndim != 3: - raise ValueError("Mask must be 3D") - - # Ensure mask is binary - try: - mask = mask.astype(bool) - except ValueError: - raise ValueError("Mask must be binary") - - # Derive mask properties - props = measure.regionprops(measure.label(mask))[0] - # bbox in shape (3, 2): for each dim (row) the min and max (col) - bbox = np.array(props.bbox).reshape(2, 3).T - bbox_ranges = bbox[:, 1] - bbox[:, 0] - # mask centroid in shape (3,) - centroid = np.array(props.centroid) - - # Find slices at 1/4, 2/4, and 3/4 of the z and y dimensions - z_slices = [bbox_ranges[0] / 4 * i for i in [1, 2, 3]] - y_slices = [bbox_ranges[1] / 4 * i for i in [1, 2, 3]] - # Find points at the intersection the centroid's x slice - # with the above y and z slices. - # This produces a set of 9 points roughly on the midline - points = list(product(z_slices, y_slices, [centroid[2]])) - - return np.array(points) - - -def fit_plane_to_points( - points: np.ndarray, -) -> np.ndarray: - """Fit a plane to a set of 3D points. - - Parameters - ---------- - points : np.ndarray - An array of shape (n, 3) containing the points. - - Returns - ------- - np.ndarray - The normal vector to the plane, with shape (3,). - """ - - # Ensure points are 3D - if points.shape[1] != 3: - raise ValueError("Points array must have 3 columns (z, y, x)") - - centered_points = points - np.mean(points, axis=0) - - # Use SVD to get the normal vector to the plane - _, _, vh = np.linalg.svd(centered_points) - normal_vector = vh[-1] - - return normal_vector