Skip to content

Commit

Permalink
networking docs
Browse files Browse the repository at this point in the history
  • Loading branch information
aelefebv committed Oct 3, 2024
1 parent e000a75 commit 6b3c662
Showing 1 changed file with 264 additions and 0 deletions.
264 changes: 264 additions & 0 deletions nellie/segmentation/networking.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,105 @@


class Network:
"""
A class for analyzing and skeletonizing network-like structures in 3D or 4D microscopy images, such as cellular branches.
Attributes
----------
im_info : ImInfo
An object containing image metadata and memory-mapped image data.
num_t : int
Number of timepoints in the image.
min_radius_um : float
Minimum radius of detected objects in micrometers.
max_radius_um : float
Maximum radius of detected objects in micrometers.
min_radius_px : float
Minimum radius of detected objects in pixels.
max_radius_px : float
Maximum radius of detected objects in pixels.
scaling : tuple
Scaling factors for Z, Y, and X dimensions.
shape : tuple
Shape of the input image.
im_memmap : np.ndarray or None
Memory-mapped original image data.
im_frangi_memmap : np.ndarray or None
Memory-mapped Frangi-filtered image data.
label_memmap : np.ndarray or None
Memory-mapped label data from instance segmentation.
network_memmap : np.ndarray or None
Memory-mapped output for network analysis.
pixel_class_memmap : np.ndarray or None
Memory-mapped output for pixel classification.
skel_memmap : np.ndarray or None
Memory-mapped output for skeleton images.
skel_relabelled_memmap : np.ndarray or None
Memory-mapped output for relabeled skeletons.
clean_skel : bool
Whether to clean the skeletons by removing noisy parts (default is True).
sigmas : list or None
List of sigma values for multi-scale filtering.
debug : dict or None
Debugging information for tracking network analysis steps.
viewer : object or None
Viewer object for displaying status during processing.
Methods
-------
_remove_connected_label_pixels(skel_labels)
Removes skeleton pixels that are connected to multiple labeled regions.
_add_missing_skeleton_labels(skel_frame, label_frame, frangi_frame, thresh)
Adds missing labels to the skeleton where the intensity is highest within a labeled region.
_skeletonize(label_frame, frangi_frame)
Skeletonizes the labeled regions and cleans up the skeleton based on intensity thresholds.
_get_sigma_vec(sigma)
Computes the sigma vector for multi-scale filtering based on image dimensions.
_set_default_sigmas()
Sets the default sigma values for multi-scale filtering.
_relabel_objects(branch_skel_labels, label_frame)
Relabels skeleton pixels by propagating labels to nearby unlabeled pixels.
_local_max_peak(frame, mask)
Detects local maxima in the image using multi-scale Laplacian of Gaussian filtering.
_get_pixel_class(skel)
Classifies skeleton pixels into junctions, branches, and endpoints based on connectivity.
_get_t()
Determines the number of timepoints to process.
_allocate_memory()
Allocates memory for skeleton images, pixel classification, and relabeled skeletons.
_get_branch_skel_labels(pixel_class)
Gets the branch skeleton labels, excluding junctions and background pixels.
_run_frame(t)
Runs skeletonization and network analysis for a single timepoint.
_clean_junctions(pixel_class)
Cleans up junctions by removing closely spaced junction pixels.
_run_networking()
Runs the network analysis process for all timepoints in the image.
run()
Main method to execute the network analysis process over the image data.
"""

def __init__(self, im_info: ImInfo, num_t=None,
min_radius_um=0.20, max_radius_um=1, clean_skel=None,
viewer=None):
"""
Initializes the Network object with image metadata and network analysis parameters.
Parameters
----------
im_info : ImInfo
An instance of the ImInfo class, containing metadata and paths for the image file.
num_t : int, optional
Number of timepoints to process. If None, defaults to the number of timepoints in the image.
min_radius_um : float, optional
Minimum radius of detected objects in micrometers (default is 0.20).
max_radius_um : float, optional
Maximum radius of detected objects in micrometers (default is 1).
clean_skel : bool, optional
Whether to clean the skeleton by removing noisy parts (default is None, which means True for 3D images).
viewer : object or None, optional
Viewer object for displaying status during processing (default is None).
"""
self.im_info = im_info
self.num_t = num_t
if num_t is None and not self.im_info.no_t:
Expand Down Expand Up @@ -51,6 +147,21 @@ def __init__(self, im_info: ImInfo, num_t=None,
self.viewer = viewer

def _remove_connected_label_pixels(self, skel_labels):
"""
Removes skeleton pixels that are connected to multiple labeled regions.
This method identifies pixels that are connected to more than one labeled region in the neighborhood and removes them.
Parameters
----------
skel_labels : np.ndarray
Skeletonized label data.
Returns
-------
np.ndarray
Cleaned skeleton data with conflicting pixels removed.
"""
if device_type == 'cuda':
skel_labels = skel_labels.get()

Expand Down Expand Up @@ -99,6 +210,25 @@ def _remove_connected_label_pixels(self, skel_labels):
return xp.array(skel_labels)

def _add_missing_skeleton_labels(self, skel_frame, label_frame, frangi_frame, thresh):
"""
Adds missing labels to the skeleton where the intensity is highest within a labeled region.
Parameters
----------
skel_frame : np.ndarray
Skeleton data.
label_frame : np.ndarray
Labeled regions in the image.
frangi_frame : np.ndarray
Frangi-filtered image.
thresh : float
Threshold value used during skeleton cleaning.
Returns
-------
np.ndarray
Updated skeleton with missing labels added.
"""
logger.debug('Adding missing skeleton labels.')
gpu_frame = xp.array(label_frame)
# identify unique labels and find missing ones
Expand All @@ -122,6 +252,23 @@ def _add_missing_skeleton_labels(self, skel_frame, label_frame, frangi_frame, th
return skel_frame

def _skeletonize(self, label_frame, frangi_frame):
"""
Skeletonizes the labeled regions and cleans up the skeleton based on intensity thresholds.
This method applies skeletonization and optionally cleans the skeleton using intensity information.
Parameters
----------
label_frame : np.ndarray
Labeled regions in the image.
frangi_frame : np.ndarray
Frangi-filtered image used for intensity-based cleaning.
Returns
-------
tuple
Skeleton data and the intensity threshold used for cleaning.
"""
cpu_frame = np.array(label_frame)
gpu_frame = xp.array(label_frame)

Expand Down Expand Up @@ -153,13 +300,29 @@ def _skeletonize(self, label_frame, frangi_frame):
return skel_labels, thresh

def _get_sigma_vec(self, sigma):
"""
Computes the sigma vector for multi-scale filtering based on image dimensions.
Parameters
----------
sigma : float
The sigma value to use for filtering.
Returns
-------
tuple
Sigma vector for Gaussian filtering in (Z, Y, X).
"""
if self.im_info.no_z:
sigma_vec = (sigma, sigma)
else:
sigma_vec = (sigma / self.z_ratio, sigma, sigma)
return sigma_vec

def _set_default_sigmas(self):
"""
Sets the default sigma values for multi-scale filtering based on the minimum and maximum radius in pixels.
"""
logger.debug('Setting to sigma values.')
min_sigma_step_size = 0.2
num_sigma = 5
Expand All @@ -174,6 +337,23 @@ def _set_default_sigmas(self):
logger.debug(f'Calculated sigma step size = {sigma_step_size_calculated}. Sigmas = {self.sigmas}')

def _relabel_objects(self, branch_skel_labels, label_frame):
"""
Relabels skeleton pixels by propagating labels to nearby unlabeled pixels.
This method uses a nearest-neighbor approach to propagate labels to unlabeled pixels in the skeleton.
Parameters
----------
branch_skel_labels : np.ndarray
Branch skeleton labels.
label_frame : np.ndarray
Labeled regions in the image.
Returns
-------
np.ndarray
Relabeled skeleton.
"""
if self.im_info.no_z:
structure = xp.ones((3, 3))
else:
Expand Down Expand Up @@ -241,6 +421,21 @@ def _relabel_objects(self, branch_skel_labels, label_frame):
return relabelled_labels

def _local_max_peak(self, frame, mask):
"""
Detects local maxima in the image using multi-scale Laplacian of Gaussian filtering.
Parameters
----------
frame : np.ndarray
The input image.
mask : np.ndarray
Binary mask of regions to process.
Returns
-------
np.ndarray
Coordinates of detected local maxima.
"""
lapofg = xp.empty(((len(self.sigmas),) + frame.shape), dtype=float)
for i, s in enumerate(self.sigmas):
sigma_vec = self._get_sigma_vec(s)
Expand All @@ -263,6 +458,19 @@ def _local_max_peak(self, frame, mask):
return coords_3d

def _get_pixel_class(self, skel):
"""
Classifies skeleton pixels into junctions, branches, and endpoints based on connectivity.
Parameters
----------
skel : np.ndarray
Skeleton data.
Returns
-------
np.ndarray
Pixel classification of skeleton points (junctions, branches, or endpoints).
"""
skel_mask = xp.array(skel > 0).astype('uint8')
if self.im_info.no_z:
weights = xp.ones((3, 3))
Expand All @@ -273,6 +481,11 @@ def _get_pixel_class(self, skel):
return skel_mask_sum

def _get_t(self):
"""
Determines the number of timepoints to process.
If `num_t` is not set and the image contains a temporal dimension, it sets `num_t` to the number of timepoints.
"""
if self.num_t is None:
if self.im_info.no_t:
self.num_t = 1
Expand All @@ -282,6 +495,11 @@ def _get_t(self):
return

def _allocate_memory(self):
"""
Allocates memory for skeleton images, pixel classification, and relabeled skeletons.
This method creates memory-mapped arrays for the instance label data, skeleton, pixel classification, and relabeled skeletons.
"""
logger.debug('Allocating memory for skeletonization.')
self.label_memmap = self.im_info.get_memmap(self.im_info.pipeline_paths['im_instance_label']) # , read_type='r+')
self.im_memmap = self.im_info.get_memmap(self.im_info.im_path)
Expand All @@ -307,6 +525,19 @@ def _allocate_memory(self):
return_memmap=True)

def _get_branch_skel_labels(self, pixel_class):
"""
Gets the branch skeleton labels, excluding junctions and background pixels.
Parameters
----------
pixel_class : np.ndarray
Classified skeleton pixels.
Returns
-------
np.ndarray
Branch skeleton labels.
"""
# get the labels of the skeleton pixels that are not junctions or background
non_junctions = pixel_class > 0
non_junctions = non_junctions * (pixel_class != 4)
Expand All @@ -318,6 +549,19 @@ def _get_branch_skel_labels(self, pixel_class):
return non_junction_labels

def _run_frame(self, t):
"""
Runs skeletonization and network analysis for a single timepoint.
Parameters
----------
t : int
Timepoint index.
Returns
-------
tuple
Branch skeleton labels, pixel classification, and relabeled skeletons.
"""
logger.info(f'Running network analysis, volume {t}/{self.num_t - 1}')
label_frame = self.label_memmap[t]
frangi_frame = xp.array(self.im_frangi_memmap[t])
Expand All @@ -334,6 +578,21 @@ def _run_frame(self, t):
return branch_skel_labels, pixel_class, branch_labels

def _clean_junctions(self, pixel_class):
"""
Cleans up junctions by removing closely spaced junction pixels.
This method uses a KD-tree to remove redundant junction points, leaving only the centroid.
Parameters
----------
pixel_class : np.ndarray
Pixel classification of skeleton points.
Returns
-------
np.ndarray
Cleaned pixel classification with redundant junctions removed.
"""
junctions = pixel_class == 4
junction_labels = skimage.measure.label(junctions)
junction_objects = skimage.measure.regionprops(junction_labels)
Expand All @@ -351,6 +610,11 @@ def _clean_junctions(self, pixel_class):
return pixel_class

def _run_networking(self):
"""
Runs the network analysis process for all timepoints in the image.
This method processes each timepoint sequentially and applies network analysis.
"""
for t in range(self.num_t):
if self.viewer is not None:
self.viewer.status = f'Extracting branches. Frame: {t + 1} of {self.num_t}.'
Expand Down

0 comments on commit 6b3c662

Please sign in to comment.