From 6b3c66229dfefc279072b4556f416fccb9108dd2 Mon Sep 17 00:00:00 2001 From: Austin Epiphane Yann Tung-Shan Lefebvre Date: Wed, 2 Oct 2024 20:22:10 -0700 Subject: [PATCH] networking docs --- nellie/segmentation/networking.py | 264 ++++++++++++++++++++++++++++++ 1 file changed, 264 insertions(+) diff --git a/nellie/segmentation/networking.py b/nellie/segmentation/networking.py index 84aca82..d23f2c1 100644 --- a/nellie/segmentation/networking.py +++ b/nellie/segmentation/networking.py @@ -9,9 +9,105 @@ class Network: + """ + A class for analyzing and skeletonizing network-like structures in 3D or 4D microscopy images, such as cellular branches. + + Attributes + ---------- + im_info : ImInfo + An object containing image metadata and memory-mapped image data. + num_t : int + Number of timepoints in the image. + min_radius_um : float + Minimum radius of detected objects in micrometers. + max_radius_um : float + Maximum radius of detected objects in micrometers. + min_radius_px : float + Minimum radius of detected objects in pixels. + max_radius_px : float + Maximum radius of detected objects in pixels. + scaling : tuple + Scaling factors for Z, Y, and X dimensions. + shape : tuple + Shape of the input image. + im_memmap : np.ndarray or None + Memory-mapped original image data. + im_frangi_memmap : np.ndarray or None + Memory-mapped Frangi-filtered image data. + label_memmap : np.ndarray or None + Memory-mapped label data from instance segmentation. + network_memmap : np.ndarray or None + Memory-mapped output for network analysis. + pixel_class_memmap : np.ndarray or None + Memory-mapped output for pixel classification. + skel_memmap : np.ndarray or None + Memory-mapped output for skeleton images. + skel_relabelled_memmap : np.ndarray or None + Memory-mapped output for relabeled skeletons. + clean_skel : bool + Whether to clean the skeletons by removing noisy parts (default is True). + sigmas : list or None + List of sigma values for multi-scale filtering. + debug : dict or None + Debugging information for tracking network analysis steps. + viewer : object or None + Viewer object for displaying status during processing. + + Methods + ------- + _remove_connected_label_pixels(skel_labels) + Removes skeleton pixels that are connected to multiple labeled regions. + _add_missing_skeleton_labels(skel_frame, label_frame, frangi_frame, thresh) + Adds missing labels to the skeleton where the intensity is highest within a labeled region. + _skeletonize(label_frame, frangi_frame) + Skeletonizes the labeled regions and cleans up the skeleton based on intensity thresholds. + _get_sigma_vec(sigma) + Computes the sigma vector for multi-scale filtering based on image dimensions. + _set_default_sigmas() + Sets the default sigma values for multi-scale filtering. + _relabel_objects(branch_skel_labels, label_frame) + Relabels skeleton pixels by propagating labels to nearby unlabeled pixels. + _local_max_peak(frame, mask) + Detects local maxima in the image using multi-scale Laplacian of Gaussian filtering. + _get_pixel_class(skel) + Classifies skeleton pixels into junctions, branches, and endpoints based on connectivity. + _get_t() + Determines the number of timepoints to process. + _allocate_memory() + Allocates memory for skeleton images, pixel classification, and relabeled skeletons. + _get_branch_skel_labels(pixel_class) + Gets the branch skeleton labels, excluding junctions and background pixels. + _run_frame(t) + Runs skeletonization and network analysis for a single timepoint. + _clean_junctions(pixel_class) + Cleans up junctions by removing closely spaced junction pixels. + _run_networking() + Runs the network analysis process for all timepoints in the image. + run() + Main method to execute the network analysis process over the image data. + """ + def __init__(self, im_info: ImInfo, num_t=None, min_radius_um=0.20, max_radius_um=1, clean_skel=None, viewer=None): + """ + Initializes the Network object with image metadata and network analysis parameters. + + Parameters + ---------- + im_info : ImInfo + An instance of the ImInfo class, containing metadata and paths for the image file. + num_t : int, optional + Number of timepoints to process. If None, defaults to the number of timepoints in the image. + min_radius_um : float, optional + Minimum radius of detected objects in micrometers (default is 0.20). + max_radius_um : float, optional + Maximum radius of detected objects in micrometers (default is 1). + clean_skel : bool, optional + Whether to clean the skeleton by removing noisy parts (default is None, which means True for 3D images). + viewer : object or None, optional + Viewer object for displaying status during processing (default is None). + """ self.im_info = im_info self.num_t = num_t if num_t is None and not self.im_info.no_t: @@ -51,6 +147,21 @@ def __init__(self, im_info: ImInfo, num_t=None, self.viewer = viewer def _remove_connected_label_pixels(self, skel_labels): + """ + Removes skeleton pixels that are connected to multiple labeled regions. + + This method identifies pixels that are connected to more than one labeled region in the neighborhood and removes them. + + Parameters + ---------- + skel_labels : np.ndarray + Skeletonized label data. + + Returns + ------- + np.ndarray + Cleaned skeleton data with conflicting pixels removed. + """ if device_type == 'cuda': skel_labels = skel_labels.get() @@ -99,6 +210,25 @@ def _remove_connected_label_pixels(self, skel_labels): return xp.array(skel_labels) def _add_missing_skeleton_labels(self, skel_frame, label_frame, frangi_frame, thresh): + """ + Adds missing labels to the skeleton where the intensity is highest within a labeled region. + + Parameters + ---------- + skel_frame : np.ndarray + Skeleton data. + label_frame : np.ndarray + Labeled regions in the image. + frangi_frame : np.ndarray + Frangi-filtered image. + thresh : float + Threshold value used during skeleton cleaning. + + Returns + ------- + np.ndarray + Updated skeleton with missing labels added. + """ logger.debug('Adding missing skeleton labels.') gpu_frame = xp.array(label_frame) # identify unique labels and find missing ones @@ -122,6 +252,23 @@ def _add_missing_skeleton_labels(self, skel_frame, label_frame, frangi_frame, th return skel_frame def _skeletonize(self, label_frame, frangi_frame): + """ + Skeletonizes the labeled regions and cleans up the skeleton based on intensity thresholds. + + This method applies skeletonization and optionally cleans the skeleton using intensity information. + + Parameters + ---------- + label_frame : np.ndarray + Labeled regions in the image. + frangi_frame : np.ndarray + Frangi-filtered image used for intensity-based cleaning. + + Returns + ------- + tuple + Skeleton data and the intensity threshold used for cleaning. + """ cpu_frame = np.array(label_frame) gpu_frame = xp.array(label_frame) @@ -153,6 +300,19 @@ def _skeletonize(self, label_frame, frangi_frame): return skel_labels, thresh def _get_sigma_vec(self, sigma): + """ + Computes the sigma vector for multi-scale filtering based on image dimensions. + + Parameters + ---------- + sigma : float + The sigma value to use for filtering. + + Returns + ------- + tuple + Sigma vector for Gaussian filtering in (Z, Y, X). + """ if self.im_info.no_z: sigma_vec = (sigma, sigma) else: @@ -160,6 +320,9 @@ def _get_sigma_vec(self, sigma): return sigma_vec def _set_default_sigmas(self): + """ + Sets the default sigma values for multi-scale filtering based on the minimum and maximum radius in pixels. + """ logger.debug('Setting to sigma values.') min_sigma_step_size = 0.2 num_sigma = 5 @@ -174,6 +337,23 @@ def _set_default_sigmas(self): logger.debug(f'Calculated sigma step size = {sigma_step_size_calculated}. Sigmas = {self.sigmas}') def _relabel_objects(self, branch_skel_labels, label_frame): + """ + Relabels skeleton pixels by propagating labels to nearby unlabeled pixels. + + This method uses a nearest-neighbor approach to propagate labels to unlabeled pixels in the skeleton. + + Parameters + ---------- + branch_skel_labels : np.ndarray + Branch skeleton labels. + label_frame : np.ndarray + Labeled regions in the image. + + Returns + ------- + np.ndarray + Relabeled skeleton. + """ if self.im_info.no_z: structure = xp.ones((3, 3)) else: @@ -241,6 +421,21 @@ def _relabel_objects(self, branch_skel_labels, label_frame): return relabelled_labels def _local_max_peak(self, frame, mask): + """ + Detects local maxima in the image using multi-scale Laplacian of Gaussian filtering. + + Parameters + ---------- + frame : np.ndarray + The input image. + mask : np.ndarray + Binary mask of regions to process. + + Returns + ------- + np.ndarray + Coordinates of detected local maxima. + """ lapofg = xp.empty(((len(self.sigmas),) + frame.shape), dtype=float) for i, s in enumerate(self.sigmas): sigma_vec = self._get_sigma_vec(s) @@ -263,6 +458,19 @@ def _local_max_peak(self, frame, mask): return coords_3d def _get_pixel_class(self, skel): + """ + Classifies skeleton pixels into junctions, branches, and endpoints based on connectivity. + + Parameters + ---------- + skel : np.ndarray + Skeleton data. + + Returns + ------- + np.ndarray + Pixel classification of skeleton points (junctions, branches, or endpoints). + """ skel_mask = xp.array(skel > 0).astype('uint8') if self.im_info.no_z: weights = xp.ones((3, 3)) @@ -273,6 +481,11 @@ def _get_pixel_class(self, skel): return skel_mask_sum def _get_t(self): + """ + Determines the number of timepoints to process. + + If `num_t` is not set and the image contains a temporal dimension, it sets `num_t` to the number of timepoints. + """ if self.num_t is None: if self.im_info.no_t: self.num_t = 1 @@ -282,6 +495,11 @@ def _get_t(self): return def _allocate_memory(self): + """ + Allocates memory for skeleton images, pixel classification, and relabeled skeletons. + + This method creates memory-mapped arrays for the instance label data, skeleton, pixel classification, and relabeled skeletons. + """ logger.debug('Allocating memory for skeletonization.') self.label_memmap = self.im_info.get_memmap(self.im_info.pipeline_paths['im_instance_label']) # , read_type='r+') self.im_memmap = self.im_info.get_memmap(self.im_info.im_path) @@ -307,6 +525,19 @@ def _allocate_memory(self): return_memmap=True) def _get_branch_skel_labels(self, pixel_class): + """ + Gets the branch skeleton labels, excluding junctions and background pixels. + + Parameters + ---------- + pixel_class : np.ndarray + Classified skeleton pixels. + + Returns + ------- + np.ndarray + Branch skeleton labels. + """ # get the labels of the skeleton pixels that are not junctions or background non_junctions = pixel_class > 0 non_junctions = non_junctions * (pixel_class != 4) @@ -318,6 +549,19 @@ def _get_branch_skel_labels(self, pixel_class): return non_junction_labels def _run_frame(self, t): + """ + Runs skeletonization and network analysis for a single timepoint. + + Parameters + ---------- + t : int + Timepoint index. + + Returns + ------- + tuple + Branch skeleton labels, pixel classification, and relabeled skeletons. + """ logger.info(f'Running network analysis, volume {t}/{self.num_t - 1}') label_frame = self.label_memmap[t] frangi_frame = xp.array(self.im_frangi_memmap[t]) @@ -334,6 +578,21 @@ def _run_frame(self, t): return branch_skel_labels, pixel_class, branch_labels def _clean_junctions(self, pixel_class): + """ + Cleans up junctions by removing closely spaced junction pixels. + + This method uses a KD-tree to remove redundant junction points, leaving only the centroid. + + Parameters + ---------- + pixel_class : np.ndarray + Pixel classification of skeleton points. + + Returns + ------- + np.ndarray + Cleaned pixel classification with redundant junctions removed. + """ junctions = pixel_class == 4 junction_labels = skimage.measure.label(junctions) junction_objects = skimage.measure.regionprops(junction_labels) @@ -351,6 +610,11 @@ def _clean_junctions(self, pixel_class): return pixel_class def _run_networking(self): + """ + Runs the network analysis process for all timepoints in the image. + + This method processes each timepoint sequentially and applies network analysis. + """ for t in range(self.num_t): if self.viewer is not None: self.viewer.status = f'Extracting branches. Frame: {t + 1} of {self.num_t}.'