Skip to content
This repository has been archived by the owner on Jan 3, 2024. It is now read-only.

Commit

Permalink
Merge pull request #115 from dstansby/clean-3d-filter
Browse files Browse the repository at this point in the history
Clean and document ball filter
  • Loading branch information
dstansby authored Apr 3, 2023
2 parents 46d4be6 + 121435c commit 8a1e9c2
Show file tree
Hide file tree
Showing 3 changed files with 243 additions and 91 deletions.
234 changes: 163 additions & 71 deletions src/cellfinder_core/detect/filters/volume/ball_filter.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

import numpy as np
from numba import jit

Expand All @@ -8,18 +10,51 @@


class BallFilter:
"""
A 3D ball filter.
This runs a spherical kernel across the (x, y) dimensions
of a *ball_z_size* stack of planes, and marks pixels in the middle
plane of the stack that have a high enough intensity within the
spherical kernel.
"""

def __init__(
self,
layer_width,
layer_height,
ball_xy_size,
ball_z_size,
overlap_fraction=0.8,
layer_width: int,
layer_height: int,
ball_xy_size: int,
ball_z_size: int,
overlap_fraction: float = 0.8,
tile_step_width=None,
tile_step_height=None,
threshold_value=None,
threshold_value: Optional[int] = None,
soma_centre_value=None,
):
"""
Parameters
----------
layer_width, layer_height :
Width/height of the layers.
ball_xy_size :
Diameter of the spherical kernel in the x/y dimensions.
ball_z_size :
Diameter of the spherical kernel in the z dimension.
Equal to the number of planes that stacked to filter
the central plane of the stack.
overlap_fraction :
The fraction of pixels within the spherical kernel that
have to be over *threshold_value* for a pixel to be marked
as having a high intensity.
tile_step_width, tile_step_height :
Width/height of individual tiles in the mask generated by
2D filtering.
threshold_value :
Value above which an individual pixel is considered to have
a high intensity.
soma_centre_value :
Value used to mark pixels with a high enough intensity.
"""
self.ball_xy_size = ball_xy_size
self.ball_z_size = ball_z_size
self.overlap_fraction = overlap_fraction
Expand All @@ -29,26 +64,36 @@ def __init__(
self.THRESHOLD_VALUE = threshold_value
self.SOMA_CENTRE_VALUE = soma_centre_value

# temporary kernel of scaling_factor*ball_x_y size to be then scaled
# to final ball size
x_upscale_factor = (
y_upscale_factor
) = z_upscale_factor = 7 # WARNING: needs to be integer
temp_kernel_shape = [
x_upscale_factor * ball_xy_size,
y_upscale_factor * ball_xy_size,
z_upscale_factor * ball_z_size,
]
tmp_ball_centre_position = [
np.floor(d / 2) for d in temp_kernel_shape
] # z_centre is xy_centre before resize
tmp_ball_radius = temp_kernel_shape[0] / 2.0
tmp_kernel = make_sphere(
temp_kernel_shape, tmp_ball_radius, tmp_ball_centre_position
# Create a spherical kernel.
#
# This is done by:
# 1. Generating a binary sphere at a resolution *upscale_factor* larger
# than desired.
# 2. Downscaling the binary sphere to get a 'fuzzy' sphere at the
# original intended scale
upscale_factor: int = 7
upscaled_kernel_shape = (
upscale_factor * ball_xy_size,
upscale_factor * ball_xy_size,
upscale_factor * ball_z_size,
)
upscaled_ball_centre_position = (
np.floor(upscaled_kernel_shape[0] / 2),
np.floor(upscaled_kernel_shape[1] / 2),
np.floor(upscaled_kernel_shape[2] / 2),
)
tmp_kernel = tmp_kernel.astype(np.float64)
upscaled_ball_radius = upscaled_kernel_shape[0] / 2.0
sphere_kernel = make_sphere(
upscaled_kernel_shape,
upscaled_ball_radius,
upscaled_ball_centre_position,
)
sphere_kernel = sphere_kernel.astype(np.float64)
self.kernel = bin_mean_3d(
tmp_kernel, x_upscale_factor, y_upscale_factor, z_upscale_factor
sphere_kernel,
bin_height=upscale_factor,
bin_width=upscale_factor,
bin_depth=upscale_factor,
)

assert (
Expand All @@ -57,10 +102,7 @@ def __init__(
ball_z_size, self.kernel.shape[2]
)

self.overlap_threshold = (
self.overlap_fraction
* np.array(self.kernel, dtype=np.float64).sum()
)
self.overlap_threshold = np.sum(self.overlap_fraction * self.kernel)

# Stores the current planes that are being filtered
self.volume = np.empty(
Expand All @@ -69,11 +111,10 @@ def __init__(
# Index of the middle plane in the volume
self.middle_z_idx = int(np.floor(ball_z_size / 2))

self.good_tiles_mask = np.empty(
# TODO: lazy initialisation
self.inside_brain_tiles = np.empty(
(
int(
np.ceil(layer_width / tile_step_width)
), # TODO: lazy initialisation
int(np.ceil(layer_width / tile_step_width)),
int(np.ceil(layer_height / tile_step_height)),
ball_z_size,
),
Expand Down Expand Up @@ -102,9 +143,9 @@ def append(self, layer, mask):
[e for e in layer.shape[:2]],
)
assert [e for e in mask.shape[:2]] == [
e for e in self.good_tiles_mask.shape[2]
e for e in self.inside_brain_tiles.shape[2]
], 'mask shape mismatch, expected"{}", got {}"'.format(
[e for e in self.good_tiles_mask.shape[:2]],
[e for e in self.inside_brain_tiles.shape[:2]],
[e for e in mask.shape[:2]],
)
if not self.ready:
Expand All @@ -114,10 +155,12 @@ def append(self, layer, mask):
self.volume = np.roll(
self.volume, -1, axis=2
) # WARNING: not in place
self.good_tiles_mask = np.roll(self.good_tiles_mask, -1, axis=2)
# Add the new layer to the top of volume and good_tiles_mask
self.inside_brain_tiles = np.roll(
self.inside_brain_tiles, -1, axis=2
)
# Add the new layer to the top of volume and inside_brain_tiles
self.volume[:, :, self.__current_z] = layer[:, :]
self.good_tiles_mask[:, :, self.__current_z] = mask[:, :]
self.inside_brain_tiles[:, :, self.__current_z] = mask[:, :]

def get_middle_plane(self):
"""
Expand All @@ -128,27 +171,25 @@ def get_middle_plane(self):

def walk(self): # Highly optimised because most time critical
ball_radius = self.ball_xy_size // 2
# Get extents of image that are covered by tiles
tile_mask_covered_img_width = (
self.good_tiles_mask.shape[0] * self.tile_step_width
self.inside_brain_tiles.shape[0] * self.tile_step_width
)
tile_mask_covered_img_height = (
self.good_tiles_mask.shape[1] * self.tile_step_height
self.inside_brain_tiles.shape[1] * self.tile_step_height
)
# whole ball size because the cube is extracted with x + whole ball
# width
# Get maximum offsets for the ball
max_width = tile_mask_covered_img_width - self.ball_xy_size
# whole ball size because the cube is extracted with y + whole ball
# height
max_height = tile_mask_covered_img_height - self.ball_xy_size

_walk(
max_height,
max_width,
self.tile_step_width,
self.tile_step_height,
self.good_tiles_mask,
self.inside_brain_tiles,
self.volume,
self.kernel,
self.ball_z_size,
ball_radius,
self.middle_z_idx,
self.overlap_threshold,
Expand All @@ -159,23 +200,45 @@ def walk(self): # Highly optimised because most time critical

@jit(nopython=True, cache=True)
def _cube_overlaps(
cube, ball_z_size, overlap_threshold, THRESHOLD_VALUE, kernel
): # Highly optimised because most time critical
cube: np.ndarray,
overlap_threshold: float,
THRESHOLD_VALUE: int,
kernel: np.ndarray,
) -> bool: # Highly optimised because most time critical
"""
For each pixel in cube that is greater than THRESHOLD_VALUE, sum
up the corresponding pixels in *kernel*. If the total is less than
overlap_threshold, return False, otherwise return True.
:param np.ndarray cube: The thresholded array to check for ball fit.
values at CellDetector.THRESHOLD_VALUE are threshold
:return: True if the overlap exceeds self.overlap_fraction
Halfway through scanning the z-planes, if the total overlap is
less than 0.4 * overlap_threshold, this will return False early
without scanning the second half of the z-planes.
Parameters
----------
cube :
3D array.
overlap_threshold :
Threshold above which to return True.
THRESHOLD_VALUE :
Value above which a pixel is marked as being part of a cell.
kernel :
3D array, with the same shape as *cube*.
"""
current_overlap_value = 0

middle = np.floor(ball_z_size / 2) + 1
overlap_thresh = overlap_threshold * 0.4 # FIXME: do not hard code value
middle = np.floor(cube.shape[2] / 2) + 1
halfway_overlap_thresh = (
overlap_threshold * 0.4
) # FIXME: do not hard code value

for z in range(cube.shape[2]):
# TODO: OPTIMISE: step from middle to outer boundaries to check
# more data first
if z == middle and current_overlap_value < overlap_thresh:
#
# If halfway through the array, and the overlap value isn't more than
# 0.4 * the overlap threshold, return
if z == middle and current_overlap_value < halfway_overlap_thresh:
return False # DEBUG: optimisation attempt
for y in range(cube.shape[1]):
for x in range(cube.shape[0]):
Expand All @@ -187,30 +250,60 @@ def _cube_overlaps(

@jit(nopython=True)
def _is_tile_to_check(
x, y, middle_z, tile_step_width, tile_step_height, good_tiles_mask
x: int,
y: int,
middle_z: int,
tile_step_width: int,
tile_step_height: int,
inside_brain_tiles: np.ndarray,
): # Highly optimised because most time critical
"""
Check if the tile containing pixel (x, y) is a tile that needs checking.
"""
x_in_mask = x // tile_step_width # TEST: test bounds (-1 range)
y_in_mask = y // tile_step_height # TEST: test bounds (-1 range)
return good_tiles_mask[x_in_mask, y_in_mask, middle_z]
return inside_brain_tiles[x_in_mask, y_in_mask, middle_z]


@jit(nopython=True)
def _walk(
max_height,
max_width,
tile_step_width,
tile_step_height,
good_tiles_mask,
volume,
kernel,
ball_z_size,
ball_radius,
middle_z,
overlap_threshold,
THRESHOLD_VALUE,
SOMA_CENTRE_VALUE,
):
max_height: int,
max_width: int,
tile_step_width: int,
tile_step_height: int,
inside_brain_tiles: np.ndarray,
volume: np.ndarray,
kernel: np.ndarray,
ball_radius: int,
middle_z: int,
overlap_threshold: float,
THRESHOLD_VALUE: int,
SOMA_CENTRE_VALUE: int,
) -> None:
"""
Scan through *volume*, and mark pixels where there are enough surrounding
pixels with high enough intensity.
The surrounding area is defined by the *kernel*.
Parameters
----------
max_height, max_width :
Maximum offsets for the ball filter.
inside_brain_tiles :
Array containing information on whether a tile is inside the brain
or not. Tiles outside the brain are skipped.
volume :
3D array containing the plane-filtered data.
kernel :
3D array
ball_radius :
Radius of the ball in the xy plane.
SOMA_CENTRE_VALUE :
Value that is used to mark pixels in *volume*.
Notes
-----
Warning: modifies volume in place!
"""
for y in range(max_height):
Expand All @@ -223,7 +316,7 @@ def _walk(
middle_z,
tile_step_width,
tile_step_height,
good_tiles_mask,
inside_brain_tiles,
):
cube = volume[
x : x + kernel.shape[0],
Expand All @@ -232,7 +325,6 @@ def _walk(
]
if _cube_overlaps(
cube,
ball_z_size,
overlap_threshold,
THRESHOLD_VALUE,
kernel,
Expand Down
Loading

0 comments on commit 8a1e9c2

Please sign in to comment.