diff --git a/nellie/tracking/voxel_reassignment.py b/nellie/tracking/voxel_reassignment.py index 2438e25..18eac7c 100644 --- a/nellie/tracking/voxel_reassignment.py +++ b/nellie/tracking/voxel_reassignment.py @@ -9,8 +9,73 @@ class VoxelReassigner: + """ + A class for voxel reassignment across time points using forward and backward flow interpolation. + + Attributes + ---------- + im_info : ImInfo + An object containing image metadata and memory-mapped image data. + num_t : int + Number of timepoints in the image. + flow_interpolator_fw : FlowInterpolator + Flow interpolator for forward timepoint matching. + flow_interpolator_bw : FlowInterpolator + Flow interpolator for backward timepoint matching. + running_matches : list + List of running matches for voxel reassignment between timepoints. + voxel_matches_path : str or None + Path to save the voxel matches array. + branch_label_memmap : np.ndarray or None + Memory-mapped data for relabeled branches. + obj_label_memmap : np.ndarray or None + Memory-mapped data for object labels. + reassigned_branch_memmap : np.ndarray or None + Memory-mapped data for reassigned branches. + reassigned_obj_memmap : np.ndarray or None + Memory-mapped data for reassigned object labels. + viewer : Any + Optional viewer (e.g., for visualization purposes). + + Methods + ------- + _match_forward(flow_interpolator, vox_prev, vox_next, t) + Matches voxels forward using flow interpolation. + _match_backward(flow_interpolator, vox_next, vox_prev, t) + Matches voxels backward using flow interpolation. + _match_voxels_to_centroids(coords_real, coords_interpx) + Matches voxels to centroids using nearest neighbor search. + _assign_unique_matches(vox_prev_matches, vox_next_matches, distances) + Assigns unique matches between timepoint voxels based on minimum distances. + _distance_threshold(vox_prev_matched, vox_next_matched) + Filters voxel matches by applying a distance threshold. + match_voxels(vox_prev, vox_next, t) + Matches voxels between two consecutive timepoints using forward and backward interpolation. + _get_t() + Gets the number of timepoints in the dataset. + _allocate_memory() + Allocates memory for voxel reassignment data, including memory-mapped arrays. + _run_frame(t, all_mask_coords, reassigned_memmap) + Runs the voxel reassignment process for a single timepoint. + _run_reassignment(label_type) + Runs the voxel reassignment process for all frames, for either branch or object labels. + run() + Main method to execute voxel reassignment for both branch and object labels. + """ def __init__(self, im_info: ImInfo, num_t=None, viewer=None): + """ + Initializes the VoxelReassigner class with image metadata and timepoints. + + Parameters + ---------- + im_info : ImInfo + Image metadata and memory-mapped data. + num_t : int, optional + Number of timepoints in the dataset. If None, it is inferred from the image metadata (default is None). + viewer : Any, optional + Optional viewer for visualization purposes (default is None). + """ self.im_info = im_info if self.im_info.no_t: @@ -35,6 +100,25 @@ def __init__(self, im_info: ImInfo, num_t=None, self.viewer = viewer def _match_forward(self, flow_interpolator, vox_prev, vox_next, t): + """ + Matches voxels forward in time using flow interpolation. + + Parameters + ---------- + flow_interpolator : FlowInterpolator + Flow interpolator for forward voxel matching. + vox_prev : np.ndarray + Voxels from the previous timepoint. + vox_next : np.ndarray + Voxels from the next timepoint. + t : int + Current timepoint index. + + Returns + ------- + tuple + Arrays of matched voxels from the previous and next timepoints and valid distances between them. + """ vectors_interpx_prev = flow_interpolator.interpolate_coord(vox_prev, t) if vectors_interpx_prev is None: return [], [], [] @@ -60,6 +144,25 @@ def _match_forward(self, flow_interpolator, vox_prev, vox_next, t): return vox_prev_matched_valid, vox_next_matched_valid, distances_valid def _match_backward(self, flow_interpolator, vox_next, vox_prev, t): + """ + Matches voxels backward in time using flow interpolation. + + Parameters + ---------- + flow_interpolator : FlowInterpolator + Flow interpolator for backward voxel matching. + vox_next : np.ndarray + Voxels from the next timepoint. + vox_prev : np.ndarray + Voxels from the previous timepoint. + t : int + Current timepoint index. + + Returns + ------- + tuple + Arrays of matched voxels from the previous and next timepoints and valid distances between them. + """ # interpolate flow vectors to all voxels in t1 from centroids derived from t0 centroids + t0 flow vectors vectors_interpx_prev = flow_interpolator.interpolate_coord(vox_next, t) if vectors_interpx_prev is None: @@ -85,6 +188,21 @@ def _match_backward(self, flow_interpolator, vox_next, vox_prev, t): return vox_prev_matched_valid, vox_next_matched_valid, distances_valid def _match_voxels_to_centroids(self, coords_real, coords_interpx): + """ + Matches real voxel coordinates to interpolated centroids using nearest neighbor search. + + Parameters + ---------- + coords_real : np.ndarray + Real voxel coordinates. + coords_interpx : np.ndarray + Interpolated centroid coordinates. + + Returns + ------- + tuple + Arrays of distances and indices of matched centroids. + """ coords_interpx = np.array(coords_interpx) * self.flow_interpolator_fw.scaling coords_real = np.array(coords_real) * self.flow_interpolator_fw.scaling tree = cKDTree(coords_real) @@ -92,6 +210,23 @@ def _match_voxels_to_centroids(self, coords_real, coords_interpx): return dist, idx def _assign_unique_matches(self, vox_prev_matches, vox_next_matches, distances): + """ + Assigns unique voxel matches based on the minimum distance criteria. + + Parameters + ---------- + vox_prev_matches : np.ndarray + Array of matched voxels from the previous timepoint. + vox_next_matches : np.ndarray + Array of matched voxels from the next timepoint. + distances : np.ndarray + Array of distances between matched voxels. + + Returns + ------- + tuple + Arrays of uniquely matched voxels for the previous and next timepoints. + """ # create a dict where the key is a voxel in t1, and the value is a list of distances and t0 voxels matched to it vox_next_dict = {} for match_idx, match_next in enumerate(vox_next_matches): @@ -137,6 +272,21 @@ def _assign_unique_matches(self, vox_prev_matches, vox_next_matches, distances): return vox_prev_matches_final, vox_next_matches_final def _distance_threshold(self, vox_prev_matched, vox_next_matched): + """ + Filters voxel matches by applying a distance threshold. + + Parameters + ---------- + vox_prev_matched : np.ndarray + Array of matched voxels from the previous timepoint. + vox_next_matched : np.ndarray + Array of matched voxels from the next timepoint. + + Returns + ------- + tuple + Arrays of valid voxel matches and corresponding distances. + """ distances = np.linalg.norm((vox_prev_matched - vox_next_matched) * self.flow_interpolator_fw.scaling, axis=1) distance_mask = distances < self.flow_interpolator_fw.max_distance_um vox_prev_matched_valid = vox_prev_matched[distance_mask] @@ -145,6 +295,23 @@ def _distance_threshold(self, vox_prev_matched, vox_next_matched): return vox_prev_matched_valid, vox_next_matched_valid, distances_valid def match_voxels(self, vox_prev, vox_next, t): + """ + Matches voxels between two consecutive timepoints using both forward and backward interpolation. + + Parameters + ---------- + vox_prev : np.ndarray + Voxels from the previous timepoint. + vox_next : np.ndarray + Voxels from the next timepoint. + t : int + Current timepoint index. + + Returns + ------- + tuple + Arrays of matched voxels from the previous and next timepoints. + """ # forward interpolation: # from t0 voxels and interpolated flow, get t1 centroids. # match nearby t1 voxels to t1 centroids, which are linked to t0 voxels. @@ -201,6 +368,9 @@ def match_voxels(self, vox_prev, vox_next, t): return np.array(vox_prev_matches_unique), np.array(vox_next_matches_unique) def _get_t(self): + """ + Gets the number of timepoints from the image metadata or sets it if not provided. + """ if self.num_t is None: if self.im_info.no_t: self.num_t = 1 @@ -210,6 +380,9 @@ def _get_t(self): return def _allocate_memory(self): + """ + Allocates memory for voxel reassignment, including initializing memory-mapped arrays for branch and object labels. + """ logger.debug('Allocating memory for voxel reassignment.') self.voxel_matches_path = self.im_info.pipeline_paths['voxel_matches'] @@ -230,6 +403,23 @@ def _allocate_memory(self): return_memmap=True) def _run_frame(self, t, all_mask_coords, reassigned_memmap): + """ + Reassigns voxels in a single timepoint based on voxel matches with the previous timepoint. + + Parameters + ---------- + t : int + Current timepoint index. + all_mask_coords : list + List of voxel coordinates for each timepoint. + reassigned_memmap : np.ndarray + Memory-mapped array for the reassigned labels. + + Returns + ------- + bool + Returns True if no matches are found, otherwise False. + """ logger.info(f'Reassigning pixels in frame {t + 1} of {self.num_t - 1}') vox_prev = all_mask_coords[t] @@ -250,6 +440,14 @@ def _run_frame(self, t, all_mask_coords, reassigned_memmap): return False def _run_reassignment(self, label_type): + """ + Runs voxel reassignment for all frames based on the specified label type (either 'branch' or 'obj'). + + Parameters + ---------- + label_type : str + The label type, either 'branch' or 'obj'. + """ # todo, be able to specify which frame to start at. if label_type == 'branch': label_memmap = self.branch_label_memmap @@ -272,6 +470,9 @@ def _run_reassignment(self, label_type): break def run(self): + """ + Main method to execute voxel reassignment for both branch and object labels. + """ if self.im_info.no_t: return self._get_t()