diff --git a/.github/workflows/test_and_deploy.yml b/.github/workflows/test_and_deploy.yml index 0dfb7f4e..e60ac765 100644 --- a/.github/workflows/test_and_deploy.yml +++ b/.github/workflows/test_and_deploy.yml @@ -6,6 +6,7 @@ on: tags: - '*' pull_request: + branches: [ '*' ] jobs: linting: diff --git a/benchmarks/filter_2d.py b/benchmarks/filter_2d.py new file mode 100644 index 00000000..fd16172e --- /dev/null +++ b/benchmarks/filter_2d.py @@ -0,0 +1,27 @@ +import numpy as np +from pyinstrument import Profiler + +from cellfinder_core.detect.filters.plane import TileProcessor +from cellfinder_core.detect.filters.setup_filters import setup_tile_filtering + +# Use random 16-bit integer data for signal plane +shape = (10000, 10000) + +signal_array_plane = np.random.randint( + low=0, high=65536, size=shape, dtype=np.uint16 +) + +clipping_value, threshold_value = setup_tile_filtering(signal_array_plane) +tile_processor = TileProcessor( + clipping_value=clipping_value, + threshold_value=threshold_value, + soma_diameter=16, + log_sigma_size=0.2, + n_sds_above_mean_thresh=10, +) +if __name__ == "__main__": + profiler = Profiler() + profiler.start() + plane, tiles = tile_processor.get_tile_mask(signal_array_plane) + profiler.stop() + profiler.print(show_all=True) diff --git a/benchmarks/filter_3d.py b/benchmarks/filter_3d.py new file mode 100644 index 00000000..83626bee --- /dev/null +++ b/benchmarks/filter_3d.py @@ -0,0 +1,50 @@ +import numpy as np +from pyinstrument import Profiler + +from cellfinder_core.detect.filters.volume.volume_filter import VolumeFilter + +# Use random data for signal data +ball_z_size = 3 + + +def gen_signal_array(ny, nx): + shape = (ball_z_size, ny, nx) + return np.random.randint(low=0, high=65536, size=shape, dtype=np.uint16) + + +signal_array = gen_signal_array(667, 510) + +soma_diameter = 8 +setup_params = ( + signal_array[0, :, :].T, + soma_diameter, + 3, # ball_xy_size, + ball_z_size, + 0.6, # ball_overlap_fraction, + 0, # start_plane, +) + +mp_3d_filter = VolumeFilter( + soma_diameter=soma_diameter, + setup_params=setup_params, + planes_paths_range=signal_array, +) + +# Use random data for mask data +mask = np.random.randint(low=0, high=2, size=(42, 32), dtype=np.uint8) + +# Fill up the 3D filter with planes +for plane in signal_array: + mp_3d_filter.ball_filter.append(plane, mask) + if mp_3d_filter.ball_filter.ready: + break + +# Run the 3D filter +profiler = Profiler() +profiler.start() +for i in range(10): + # Repeat same filter 10 times to increase runtime + mp_3d_filter._run_filter() + +profiler.stop() +profiler.print(show_all=True) diff --git a/pyproject.toml b/pyproject.toml index d38f3564..bd12ed9d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,7 @@ dependencies = [ "fancylog>=0.0.7", "imlib>=0.0.26", "natsort", + "numba", "numpy", "scikit-image", "scikit-learn", @@ -44,6 +45,7 @@ dev = [ "black", "gitpython", "pre-commit", + "pyinstrument", "pytest", "pytest-cov", "pytest-timeout", @@ -59,7 +61,6 @@ requires = [ "setuptools>=45", "wheel", "setuptools_scm[toml]>=6.2", - "cython", ] build-backend = 'setuptools.build_meta' @@ -80,13 +81,15 @@ filterwarnings = [ "error", # Raised by tensorflow; should be removed when tensorflow 2.12.0 is released. Fix is: # https://github.com/tensorflow/tensorflow/commit/b23c5750c9f35a87872793eef7c56e74ec55d4a7 - "ignore:`np.bool8` is a deprecated alias for `np.bool_`" + "ignore:`np.bool8` is a deprecated alias for `np.bool_`", + # See https://github.com/numba/numba/issues/8676 + "ignore:.*:numba.core.errors.NumbaTypeSafetyWarning" ] markers = [ "slow: marks tests as slow (deselect with '-m \"not slow\"')", "serial", ] -log_level = "DEBUG" +log_level = "WARNING" [tool.black] target-version = ['py38', 'py39', 'py310'] diff --git a/setup.py b/setup.py deleted file mode 100644 index d2e73d19..00000000 --- a/setup.py +++ /dev/null @@ -1,32 +0,0 @@ -import Cython.Build -from setuptools import Extension, setup - -base_tile_filter_extension = Extension( - name="cellfinder_core.detect.filters.plane.base_tile_filter", - sources=["src/cellfinder_core/detect/filters/plane/base_tile_filter.pyx"], - language="c++", -) - -ball_filter_extension = Extension( - name="cellfinder_core.detect.filters.volume.ball_filter", - sources=["src/cellfinder_core/detect/filters/volume/ball_filter.pyx"], -) - -structure_detection_extension = Extension( - name="cellfinder_core.detect.filters.volume.structure_detection", - sources=[ - "src/cellfinder_core/detect/filters/volume/structure_detection.pyx" - ], - language="c++", -) - -extensions = [ - base_tile_filter_extension, - ball_filter_extension, - structure_detection_extension, -] - - -setup( - ext_modules=Cython.Build.cythonize(extensions), -) diff --git a/src/cellfinder_core/detect/filters/plane/base_tile_filter.py b/src/cellfinder_core/detect/filters/plane/base_tile_filter.py new file mode 100644 index 00000000..73a0521d --- /dev/null +++ b/src/cellfinder_core/detect/filters/plane/base_tile_filter.py @@ -0,0 +1,130 @@ +from typing import Dict + +import numpy as np +from numba import jit + + +def get_biggest_structure(sizes): + result = 0 + for val in sizes: + if val > result: + result = val + return result + + +class BaseTileFilter: + def __init__(self, out_of_brain_intensity_threshold=100): + """ + + :param int out_of_brain_intensity_threshold: Set to 0 to disable + """ + self.out_of_brain_intensity_threshold = ( + out_of_brain_intensity_threshold + ) + self.current_threshold = -1 + self.keep = True + self.size_analyser = SizeAnalyser() + + def set_tile(self, tile): + raise NotImplementedError + + def get_tile(self): + raise NotImplementedError + + def get_structures(self): + struct_sizes = [] + self.size_analyser.process(self._tile, self.current_threshold) + struct_sizes = self.size_analyser.get_sizes() + return get_biggest_structure(struct_sizes), struct_sizes.size() + + +@jit +def is_low_average(tile: np.ndarray, threshold: float) -> bool: + """ + Return `True` if the average value of *tile* is below *threshold*. + """ + avg = np.mean(tile) + return avg < threshold + + +class OutOfBrainTileFilter(BaseTileFilter): + def set_tile(self, tile): + self._tile = tile + + def get_tile(self): + return self._tile + + +class SizeAnalyser: + obsolete_ids: Dict[int, int] = {} + struct_sizes: Dict[int, int] = {} + + def process(self, tile, threshold): + tile = tile.copy() + self.clear_maps() + + last_structure_id = 1 + + for y in range(tile.shape[1]): + for x in range(tile.shape[0]): + # default struct_id to 0 so that it is not counted as + # structure in next iterations + id_west = id_north = struct_id = 0 + if tile[x, y] >= threshold: + # If in bounds look to neighbours + if x > 0: + id_west = tile[x - 1, y] + if y > 0: + id_north = tile[x, y - 1] + + id_west = self.sanitise_id(id_west) + id_north = self.sanitise_id(id_north) + + if id_west != 0: + if id_north != 0 and id_north != id_west: + struct_id = self.merge_structures( + id_west, id_north + ) + else: + struct_id = id_west + elif id_north != 0: + struct_id = id_north + else: # no neighbours, create new structure + struct_id = last_structure_id + self.struct_sizes[last_structure_id] = 0 + last_structure_id += 1 + + self.struct_sizes[struct_id] += 1 + + tile[x, y] = struct_id + + def get_sizes(self): + for iterator_pair in self.struct_sizes: + self.struct_sizes.push_back(iterator_pair.second) + return self.struct_sizes + + def clear_maps(self): + self.obsolete_ids.clear() + self.struct_sizes.clear() + + def sanitise_id(self, s_id): + while self.obsolete_ids.count( + s_id + ): # walk up the chain of obsolescence + s_id = self.obsolete_ids[s_id] + return s_id + + # id1 and id2 must be valid struct IDs (>0)! + def merge_structures(self, id1, id2): + # ensure id1 is the smaller of the two values + if id2 < id1: + tmp_id = id1 # swap + id1 = id2 + id2 = tmp_id + + self.struct_sizes[id1] += self.struct_sizes[id2] + self.struct_sizes.erase(id2) + + self.obsolete_ids[id2] = id1 + + return id1 diff --git a/src/cellfinder_core/detect/filters/plane/base_tile_filter.pyx b/src/cellfinder_core/detect/filters/plane/base_tile_filter.pyx deleted file mode 100644 index 3ec0024f..00000000 --- a/src/cellfinder_core/detect/filters/plane/base_tile_filter.pyx +++ /dev/null @@ -1,151 +0,0 @@ -# cython: language_level=3 -# distutils: language = c++ - -from libcpp.map cimport map as CppMap -from libcpp.pair cimport pair as CppPair -from libcpp.vector cimport vector as CppVector - -from cellfinder_core.detect.filters.typedefs cimport uint, ull, ulong, ushort - - -cdef get_biggest_structure(CppVector[ulong] sizes): - cdef uint result = 0 - cdef uint val - for val in sizes: - if val > result: - result = val - return result - - -cdef class BaseTileFilter: - - cdef: - readonly uint out_of_brain_intensity_threshold - public ushort[:,:] _tile - public ushort current_threshold - public bint keep - SizeAnalyser size_analyser - - def __init__(self, out_of_brain_intensity_threshold=100): - """ - - :param int out_of_brain_intensity_threshold: Set to 0 to disable - """ - self.out_of_brain_intensity_threshold = out_of_brain_intensity_threshold - self.current_threshold = -1 - self.keep = True - self.size_analyser = SizeAnalyser() - - cpdef set_tile(self, tile): - raise NotImplementedError - - cpdef get_tile(self): - raise NotImplementedError - - cpdef get_structures(self): - cdef CppVector[ulong] struct_sizes - self.size_analyser.process(self._tile, self.current_threshold) - struct_sizes = self.size_analyser.get_sizes() - return get_biggest_structure(struct_sizes), struct_sizes.size() - - cpdef is_low_average(self): # TODO: move to OutOfBrainTileFilter - cdef bint is_low - cdef double avg = 0 - cdef uint x, y - for x in range(self._tile.shape[0]): - for y in range(self._tile.shape[1]): - avg += self._tile[x, y] - avg /= self._tile.shape[0] * self._tile.shape[1] - is_low = avg < self.out_of_brain_intensity_threshold - self.keep = not is_low - return is_low - - -cdef class OutOfBrainTileFilter(BaseTileFilter): - - cpdef set_tile(self, tile): - self._tile = tile - - cpdef get_tile(self): - return self._tile - - -cdef class SizeAnalyser: - - cdef: - CppMap[ull, ull] obsolete_ids - CppMap[ull, ulong] struct_sizes - - cpdef process(self, ushort[:,:] tile, ushort threshold): - tile = tile.copy() - self.clear_maps() - - cdef ull struct_id, id_west, id_north, last_structure_id - last_structure_id = 1 - - cdef uint y, x - for y in range(tile.shape[1]): - for x in range(tile.shape[0]): - # default struct_id to 0 so that it is not counted as structure in next iterations - id_west = id_north = struct_id = 0 - if tile[x, y] >= threshold: - # If in bounds look to neighbours - if x > 0: - id_west = tile[x-1, y] - if y > 0: - id_north = tile[x, y-1] - - id_west = self.sanitise_id(id_west) - id_north = self.sanitise_id(id_north) - - if id_west != 0: - if id_north != 0 and id_north != id_west: - struct_id = self.merge_structures(id_west, id_north) - else: - struct_id = id_west - elif id_north != 0: - struct_id = id_north - else: # no neighbours, create new structure - struct_id = last_structure_id - self.struct_sizes[last_structure_id] = 0 - last_structure_id += 1 - - self.struct_sizes[struct_id] += 1 - - tile[x, y] = struct_id - - cpdef get_sizes(self): - cdef CppVector[ulong] struct_sizes - cdef ulong size - cdef ull s_id - cdef CppPair[ull, ulong] iterator_pair - for iterator_pair in self.struct_sizes: - struct_sizes.push_back(iterator_pair.second) - return struct_sizes - - cdef clear_maps(self): - self.obsolete_ids.clear() - self.struct_sizes.clear() - - cdef sanitise_id(self, ull s_id): - while self.obsolete_ids.count(s_id): # walk up the chain of obsolescence - s_id = self.obsolete_ids[s_id] - return s_id - - # id1 and id2 must be valid struct IDs (>0)! - cdef merge_structures(self, ull id1, ull id2): - cdef ulong new_size - cdef ull tmp_id - - # ensure id1 is the smaller of the two values - if id2 < id1: - tmp_id = id1 # swap - id1 = id2 - id2 = tmp_id - - self.struct_sizes[id1] += self.struct_sizes[id2] - self.struct_sizes.erase(id2) - - self.obsolete_ids[id2] = id1 - - return id1 diff --git a/src/cellfinder_core/detect/filters/plane/tile_walker.py b/src/cellfinder_core/detect/filters/plane/tile_walker.py index 2f895db7..0a5df05b 100644 --- a/src/cellfinder_core/detect/filters/plane/tile_walker.py +++ b/src/cellfinder_core/detect/filters/plane/tile_walker.py @@ -4,6 +4,7 @@ from cellfinder_core.detect.filters.plane.base_tile_filter import ( OutOfBrainTileFilter, + is_low_average, ) @@ -59,7 +60,10 @@ def walk_out_of_brain_only(self): self.y = y self.ftf.set_tile(tile) if self.ftf.out_of_brain_intensity_threshold: - if not self.ftf.is_low_average(): + self.ftf.keep = not is_low_average( + tile, self.ftf.out_of_brain_intensity_threshold + ) + if self.ftf.keep: mask_x = self.x // self.tile_width mask_y = self.y // self.tile_height self.good_tiles_mask[mask_x, mask_y] = True diff --git a/src/cellfinder_core/detect/filters/volume/ball_filter.py b/src/cellfinder_core/detect/filters/volume/ball_filter.py new file mode 100644 index 00000000..2e1cba9b --- /dev/null +++ b/src/cellfinder_core/detect/filters/volume/ball_filter.py @@ -0,0 +1,242 @@ +import numpy as np +from numba import jit + +from cellfinder_core.tools.array_operations import bin_mean_3d +from cellfinder_core.tools.geometry import make_sphere + +DEBUG = False + + +class BallFilter: + def __init__( + self, + layer_width, + layer_height, + ball_xy_size, + ball_z_size, + overlap_fraction=0.8, + tile_step_width=None, + tile_step_height=None, + threshold_value=None, + soma_centre_value=None, + ): + self.ball_xy_size = ball_xy_size + self.ball_z_size = ball_z_size + self.overlap_fraction = overlap_fraction + self.tile_step_width = tile_step_width + self.tile_step_height = tile_step_height + + self.THRESHOLD_VALUE = threshold_value + self.SOMA_CENTRE_VALUE = soma_centre_value + + # temporary kernel of scaling_factor*ball_x_y size to be then scaled + # to final ball size + x_upscale_factor = ( + y_upscale_factor + ) = z_upscale_factor = 7 # WARNING: needs to be integer + temp_kernel_shape = [ + x_upscale_factor * ball_xy_size, + y_upscale_factor * ball_xy_size, + z_upscale_factor * ball_z_size, + ] + tmp_ball_centre_position = [ + np.floor(d / 2) for d in temp_kernel_shape + ] # z_centre is xy_centre before resize + tmp_ball_radius = temp_kernel_shape[0] / 2.0 + tmp_kernel = make_sphere( + temp_kernel_shape, tmp_ball_radius, tmp_ball_centre_position + ) + tmp_kernel = tmp_kernel.astype(np.float64) + self.kernel = bin_mean_3d( + tmp_kernel, x_upscale_factor, y_upscale_factor, z_upscale_factor + ) + + assert ( + self.kernel.shape[2] == ball_z_size + ), "Kernel z dimension should be {}, got {}".format( + ball_z_size, self.kernel.shape[2] + ) + + self.overlap_threshold = ( + self.overlap_fraction + * np.array(self.kernel, dtype=np.float64).sum() + ) + + # Stores the current planes that are being filtered + self.volume = np.empty( + (layer_width, layer_height, ball_z_size), dtype=np.uint16 + ) + # Index of the middle plane in the volume + self.middle_z_idx = int(np.floor(ball_z_size / 2)) + + self.good_tiles_mask = np.empty( + ( + int( + np.ceil(layer_width / tile_step_width) + ), # TODO: lazy initialisation + int(np.ceil(layer_height / tile_step_height)), + ball_z_size, + ), + dtype=np.uint8, + ) + # Stores the z-index in volume at which new layers are inserted when + # append() is called + self.__current_z = -1 + + @property + def ready(self): + """ + Return `True` if enough layers have been appended to run the filter. + """ + return self.__current_z == self.ball_z_size - 1 + + def append(self, layer, mask): + """ + Add a new 2D layer to the filter. + """ + if DEBUG: + assert [e for e in layer.shape[:2]] == [ + e for e in self.volume.shape[:2] + ], 'layer shape mismatch, expected "{}", got "{}"'.format( + [e for e in self.volume.shape[:2]], + [e for e in layer.shape[:2]], + ) + assert [e for e in mask.shape[:2]] == [ + e for e in self.good_tiles_mask.shape[2] + ], 'mask shape mismatch, expected"{}", got {}"'.format( + [e for e in self.good_tiles_mask.shape[:2]], + [e for e in mask.shape[:2]], + ) + if not self.ready: + self.__current_z += 1 + else: + # Shift everything down by one to make way for the new layer + self.volume = np.roll( + self.volume, -1, axis=2 + ) # WARNING: not in place + self.good_tiles_mask = np.roll(self.good_tiles_mask, -1, axis=2) + # Add the new layer to the top of volume and good_tiles_mask + self.volume[:, :, self.__current_z] = layer[:, :] + self.good_tiles_mask[:, :, self.__current_z] = mask[:, :] + + def get_middle_plane(self): + """ + Get the plane in the middle of self.volume. + """ + z = self.middle_z_idx + return np.array(self.volume[:, :, z], dtype=np.uint16) + + def walk(self): # Highly optimised because most time critical + ball_radius = self.ball_xy_size // 2 + tile_mask_covered_img_width = ( + self.good_tiles_mask.shape[0] * self.tile_step_width + ) + tile_mask_covered_img_height = ( + self.good_tiles_mask.shape[1] * self.tile_step_height + ) + # whole ball size because the cube is extracted with x + whole ball + # width + max_width = tile_mask_covered_img_width - self.ball_xy_size + # whole ball size because the cube is extracted with y + whole ball + # height + max_height = tile_mask_covered_img_height - self.ball_xy_size + _walk( + max_height, + max_width, + self.tile_step_width, + self.tile_step_height, + self.good_tiles_mask, + self.volume, + self.kernel, + self.ball_z_size, + ball_radius, + self.middle_z_idx, + self.overlap_threshold, + self.THRESHOLD_VALUE, + self.SOMA_CENTRE_VALUE, + ) + + +@jit(nopython=True, cache=True) +def _cube_overlaps( + cube, ball_z_size, overlap_threshold, THRESHOLD_VALUE, kernel +): # Highly optimised because most time critical + """ + + :param np.ndarray cube: The thresholded array to check for ball fit. + values at CellDetector.THRESHOLD_VALUE are threshold + :return: True if the overlap exceeds self.overlap_fraction + """ + current_overlap_value = 0 + + middle = np.floor(ball_z_size / 2) + 1 + overlap_thresh = overlap_threshold * 0.4 # FIXME: do not hard code value + + for z in range(cube.shape[2]): + # TODO: OPTIMISE: step from middle to outer boundaries to check + # more data first + if z == middle and current_overlap_value < overlap_thresh: + return False # DEBUG: optimisation attempt + for y in range(cube.shape[1]): + for x in range(cube.shape[0]): + # includes self.SOMA_CENTRE_VALUE + if cube[x, y, z] >= THRESHOLD_VALUE: + current_overlap_value += kernel[x, y, z] + return current_overlap_value > overlap_threshold + + +@jit(nopython=True) +def _is_tile_to_check( + x, y, middle_z, tile_step_width, tile_step_height, good_tiles_mask +): # Highly optimised because most time critical + x_in_mask = x // tile_step_width # TEST: test bounds (-1 range) + y_in_mask = y // tile_step_height # TEST: test bounds (-1 range) + return good_tiles_mask[x_in_mask, y_in_mask, middle_z] + + +@jit(nopython=True) +def _walk( + max_height, + max_width, + tile_step_width, + tile_step_height, + good_tiles_mask, + volume, + kernel, + ball_z_size, + ball_radius, + middle_z, + overlap_threshold, + THRESHOLD_VALUE, + SOMA_CENTRE_VALUE, +): + """ + Warning: modifies volume in place! + """ + for y in range(max_height): + for x in range(max_width): + ball_centre_x = x + ball_radius + ball_centre_y = y + ball_radius + if _is_tile_to_check( + ball_centre_x, + ball_centre_y, + middle_z, + tile_step_width, + tile_step_height, + good_tiles_mask, + ): + cube = volume[ + x : x + kernel.shape[0], + y : y + kernel.shape[1], + :, + ] + if _cube_overlaps( + cube, + ball_z_size, + overlap_threshold, + THRESHOLD_VALUE, + kernel, + ): + volume[ + ball_centre_x, ball_centre_y, middle_z + ] = SOMA_CENTRE_VALUE diff --git a/src/cellfinder_core/detect/filters/volume/ball_filter.pyx b/src/cellfinder_core/detect/filters/volume/ball_filter.pyx deleted file mode 100644 index c9d77dba..00000000 --- a/src/cellfinder_core/detect/filters/volume/ball_filter.pyx +++ /dev/null @@ -1,167 +0,0 @@ -# cython: language_level=3 - -cimport cython -cimport libc.math as cmath - -import numpy as np - -from cellfinder_core.detect.filters.typedefs cimport uint, ushort - -# only for __init__ - -from cellfinder_core.tools.array_operations import bin_mean_3d -from cellfinder_core.tools.geometry import make_sphere - -DEBUG = False - - - -cdef class BallFilter: - - cdef: - uint THRESHOLD_VALUE, SOMA_CENTRE_VALUE - uint ball_xy_size, ball_z_size, tile_step_width, tile_step_height - uint middle_z_idx - int __current_z - double overlap_fraction, overlap_threshold - - # Numpy arrays - double[:,:,:] kernel - ushort[:,:,:] volume - unsigned char[:,:,:] good_tiles_mask - - - def __init__(self, layer_width, layer_height, - ball_xy_size, ball_z_size, overlap_fraction=0.8, - tile_step_width=None, tile_step_height=None, - threshold_value=None, soma_centre_value=None): - self.ball_xy_size = ball_xy_size - self.ball_z_size = ball_z_size - self.overlap_fraction = overlap_fraction - self.tile_step_width = tile_step_width - self.tile_step_height = tile_step_height - - self.THRESHOLD_VALUE = threshold_value - self.SOMA_CENTRE_VALUE = soma_centre_value - - # temporary kernel of scaling_factor*ball_x_y size to be then scaled to final ball size - scaling_factor = 2 - x_upscale_factor = y_upscale_factor = z_upscale_factor = 7 # WARNING: needs to be integer - temp_kernel_shape = [x_upscale_factor * ball_xy_size, y_upscale_factor * ball_xy_size, z_upscale_factor * ball_z_size] - tmp_ball_centre_position = [cmath.floor(d / 2) for d in temp_kernel_shape] # z_centre is xy_centre before resize - tmp_ball_radius = temp_kernel_shape[0] / 2.0 - tmp_kernel = make_sphere(temp_kernel_shape, tmp_ball_radius, tmp_ball_centre_position) - tmp_kernel = tmp_kernel.astype(np.float64) - self.kernel = bin_mean_3d(tmp_kernel, x_upscale_factor, y_upscale_factor, z_upscale_factor) - - assert self.kernel.shape[2] == ball_z_size, 'Kernel z dimension should be {}, got {}'\ - .format(ball_z_size, self.kernel.shape[2]) - - self.overlap_threshold = self.overlap_fraction * np.array(self.kernel, dtype=np.float64).sum() - - # Stores the current planes that are being filtered - self.volume = np.empty((layer_width, layer_height, ball_z_size), dtype=np.uint16) - # Index of the middle plane in the volume - self.middle_z_idx = cmath.floor(ball_z_size / 2) - - self.good_tiles_mask = np.empty((int(cmath.ceil(layer_width / tile_step_width)), # TODO: lazy initialisation - int(cmath.ceil(layer_height / tile_step_height)), - ball_z_size), dtype=np.uint8) - # Stores the z-index in volume at which new layers are inserted when - # append() is called - self.__current_z = -1 - - @property - def ready(self): - """ - Return `True` if enough layers have been appended to run the filter. - """ - return self.__current_z == self.ball_z_size - 1 - - cpdef append(self, ushort[:,:] layer, unsigned char[:,:] mask): - """ - Add a new 2D layer to the filter. - """ - if DEBUG: - assert [e for e in layer.shape[:2]] == [e for e in self.volume.shape[:2]],\ - 'layer shape mismatch, expected "{}", got "{}"'\ - .format([e for e in self.volume.shape[:2]], [e for e in layer.shape[:2]]) - assert [e for e in mask.shape[:2]] == [e for e in self.good_tiles_mask.shape[2]], \ - 'mask shape mismatch, expected"{}", got {}"'\ - .format([e for e in self.good_tiles_mask.shape[:2]], [e for e in mask.shape[:2]]) - if not self.ready: - self.__current_z += 1 - else: - # Shift everything down by one to make way for the new layer - self.volume = np.roll(self.volume, -1, axis=2) # WARNING: not in place - self.good_tiles_mask = np.roll(self.good_tiles_mask, -1, axis=2) - # Add the new layer to the top of volume and good_tiles_mask - self.volume[:, :, self.__current_z] = layer[:,:] - self.good_tiles_mask[:, :, self.__current_z] = mask[:,:] - - def get_middle_plane(self): - """ - Get the plane in the middle of self.volume. - """ - cdef uint z = self.middle_z_idx - return np.array(self.volume[:, :, z], dtype=np.uint16) - - @cython.initializedcheck(False) - @cython.cdivision(True) - @cython.boundscheck(False) - cpdef walk(self): # Highly optimised because most time critical - cdef uint ball_centre_x, ball_centre_y - cdef uint ball_radius = self.ball_xy_size // 2 - cdef ushort[:,:,:] cube - cdef uint middle_z = self.middle_z_idx - - cdef uint max_width, max_height - tile_mask_covered_img_width = self.good_tiles_mask.shape[0] * self.tile_step_width - tile_mask_covered_img_height = self.good_tiles_mask.shape[1] * self.tile_step_height - max_width = tile_mask_covered_img_width - self.ball_xy_size # whole ball size because the cube is extracted with x + whole ball width - max_height = tile_mask_covered_img_height - self.ball_xy_size # whole ball size because the cube is extracted with y + whole ball height - cdef uint x, y - for y in range(max_height): - for x in range(max_width): - ball_centre_x = x + ball_radius - ball_centre_y = y + ball_radius - if self.__is_tile_to_check(ball_centre_x, ball_centre_y): - cube = self.volume[x:x + self.kernel.shape[0], y:y + self.kernel.shape[1], :] - if self.__cube_overlaps(cube): - self.volume[ball_centre_x, ball_centre_y, middle_z] = self.SOMA_CENTRE_VALUE - - @cython.initializedcheck(False) - @cython.cdivision(True) - @cython.boundscheck(False) - cdef __cube_overlaps(self, ushort[:,:,:] cube): # Highly optimised because most time critical - """ - - :param np.ndarray cube: The thresholded array to check for ball fit. values at CellDetector.THRESHOLD_VALUE are threshold - :return: True if the overlap exceeds self.overlap_fraction - """ - if DEBUG: - assert cube.max() <= 1 - assert cube.shape == self.kernel.shape - - cdef double current_overlap_value = 0 - - cdef uint x, y, z - for z in range(cube.shape[2]): # TODO: OPTIMISE: step from middle to outer boundaries to check more data first - if z == cmath.floor(self.ball_z_size / 2) + 1 and current_overlap_value < self.overlap_threshold * 0.4: # FIXME: do not hard code value - return False # DEBUG: optimisation attempt - for y in range(cube.shape[1]): - for x in range(cube.shape[0]): - if cube[x, y, z] >= self.THRESHOLD_VALUE: # includes self.SOMA_CENTRE_VALUE - current_overlap_value += self.kernel[x, y, z] - return (current_overlap_value > self.overlap_threshold) - - @cython.initializedcheck(False) - @cython.cdivision(True) - @cython.boundscheck(False) - cdef __is_tile_to_check(self, uint x, uint y): # Highly optimised because most time critical - cdef uint x_in_mask, y_in_mask, middle_plane_idx - cdef uint middle_z = self.middle_z_idx - - x_in_mask = x // self.tile_step_width # TEST: test bounds (-1 range) - y_in_mask = y // self.tile_step_height # TEST: test bounds (-1 range) - return self.good_tiles_mask[x_in_mask, y_in_mask, middle_z] diff --git a/src/cellfinder_core/detect/filters/volume/structure_detection.py b/src/cellfinder_core/detect/filters/volume/structure_detection.py new file mode 100644 index 00000000..7de5cded --- /dev/null +++ b/src/cellfinder_core/detect/filters/volume/structure_detection.py @@ -0,0 +1,347 @@ +from dataclasses import dataclass +from typing import List + +import numba +import numpy as np +from numba import jit +from numba.core import types +from numba.experimental import jitclass +from numba.typed import Dict +from numba.types import DictType + + +@dataclass +class Point: + x: int + y: int + z: int + + +UINT64_MAX = np.iinfo(np.uint64).max +N_NEIGHBOURS_4_CONNECTED = 3 # below, left and behind +N_NEIGHBOURS_8_CONNECTED = 13 # all the 9 below + the 4 before on same plane + + +@jit(nopython=True) +def get_non_zero_ull_min(values): + min_val = UINT64_MAX + for v in values: + if v != 0 and v < min_val: + min_val = v + return min_val + + +@jit(nopython=True) +def traverse_dict(d: dict, a): + """ + Traverse d, until a is not present as a key. + """ + if a in d: + return traverse_dict(d, d[a]) + else: + return a + + +def get_structure_centre(structure): + mean_x = 0 + mean_y = 0 + mean_z = 0 + s_len = len(structure) + + for p in structure: + mean_x += p.x / s_len + mean_y += p.y / s_len + mean_z += p.z / s_len + + return Point(round(mean_x), round(mean_y), round(mean_z)) + + +def get_structure_centre_wrapper(structure): # wrapper for testing purposes + s = [] + for p in structure: + if type(p) == dict: + s.append(Point(p["x"], p["y"], p["z"])) + elif isinstance(p, Point): + s.append(Point(p.x, p.y, p.z)) + else: + s.append(Point(p[0], p[1], p[2])) + return get_structure_centre(s) + + +# Type declaration has to come outside of the class, +# see https://github.com/numba/numba/issues/8808 +uint_2d_type = types.uint64[:, :] + + +spec = [ + ("connect_type", types.uint8), + ("SOMA_CENTRE_VALUE", types.uint64), + ("z", types.uint64), + ("relative_z", types.uint64), + ("next_structure_id", types.uint64), + ("shape", types.UniTuple(types.int64, 2)), + ("obsolete_ids", DictType(types.int64, types.int64)), + ("coords_maps", DictType(types.uint64, uint_2d_type)), + ("previous_layer", types.uint64[:, :]), +] + + +@jitclass(spec=spec) +class CellDetector: + def __init__(self, width: int, height: int, start_z: int, connect_type=4): + self.shape = width, height + self.z = start_z + + if connect_type not in (4, 8): + raise ValueError("Connection type must be one of [4, 8]") + self.connect_type = connect_type + + self.SOMA_CENTRE_VALUE = UINT64_MAX + + # position to append in stack + # FIXME: replace by keeping start_z and self.z > self.start_Z + self.relative_z = 0 + self.next_structure_id = 1 + + # Mapping from obsolete IDs to the IDs that they have been + # made obsolete by + self.obsolete_ids = Dict.empty( + key_type=types.int64, value_type=types.int64 + ) + # Mapping from IDs to list of points in that structure + self.coords_maps = Dict.empty( + key_type=types.int64, value_type=uint_2d_type + ) + + def get_previous_layer(self): + return np.array(self.previous_layer, dtype=np.uint64) + + def process( + self, layer + ): # WARNING: inplace # WARNING: ull may be overkill but ulong required + if [e for e in layer.shape[:2]] != [e for e in self.shape]: + raise ValueError("layer does not have correct shape") + + source_dtype = layer.dtype + # Have to cast layer to a concrete data type in order to save it + # in the .previous_layer class attribute + layer = layer.astype(np.uint64) + + # The 'magic numbers' below are chosen so that the maximum number + # representable in each data type is converted to 2**64 - 1, the + # maximum representable number in uint64. + nbits = np.iinfo(source_dtype).bits + if nbits == 8: + layer *= numba.uint64(72340172838076673) + elif nbits == 16: + layer *= numba.uint64(281479271743489) + elif nbits == 32: + layer *= numba.uint64(4294967297) + + if self.connect_type == 4: + layer = self.connect_four(layer) + self.previous_layer = layer + else: + self.previous_layer = self.connect_eight(layer) + + if self.relative_z == 0: + self.relative_z += 1 + + self.z += 1 + + def connect_four(self, layer): + """ + For all the pixels in the current layer, finds all structures touching + this pixel using the + four connected (plus shape) rule and also looks at the pixel at the + same location in the previous layer. + If structures are found, they are added to the structure manager and + the pixel labeled accordingly. + + :param layer: + :return: + """ + for y in range(layer.shape[1]): + for x in range(layer.shape[0]): + if layer[x, y] == self.SOMA_CENTRE_VALUE: + # Labels of structures at left, top, below + neighbour_ids = np.zeros( + N_NEIGHBOURS_4_CONNECTED, dtype=np.uint64 + ) + # If in bounds look at neighbours + if x > 0: + neighbour_ids[0] = layer[x - 1, y] + if y > 0: + neighbour_ids[1] = layer[x, y - 1] + if self.relative_z > 0: + neighbour_ids[2] = self.previous_layer[x, y] + + if is_new_structure(neighbour_ids): + neighbour_ids[0] = self.next_structure_id + self.next_structure_id += 1 + struct_id = self.add(x, y, self.z, neighbour_ids) + else: + # reset so that grayscale value does not count as + # structure in next iterations + struct_id = 0 + + layer[x, y] = struct_id + + return layer + + def connect_eight(self, layer): + """ + For all the pixels in the current layer, finds all structures touching + this pixel using the + eight connected (connected by edges or corners) rule and also looks at + the pixel at the same + location in the previous layer. + If structures are found, they are added to the structure manager and + the pixel labeled accordingly. + + :param layer: + :return: + """ + neighbour_ids = [0] * N_NEIGHBOURS_8_CONNECTED + + for y in range(layer.shape[1]): + for x in range(layer.shape[0]): + if layer[x, y] == self.SOMA_CENTRE_VALUE: + # If in bounds look at neighbours + if x > 0 and y > 0: + neighbour_ids[0] = layer[x - 1, y - 1] + if x > 0: + neighbour_ids[1] = layer[x - 1, y] + if y > 0: + neighbour_ids[2] = layer[x, y - 1] + neighbour_ids[3] = layer[x + 1, y - 1] + if self.relative_z > 0: + if x > 0 and y > 0: + neighbour_ids[4] = self.previous_layer[ + x - 1, y - 1 + ] + if x > 0: + neighbour_ids[5] = self.previous_layer[x - 1, y] + if y < layer.shape[1] - 1: + neighbour_ids[6] = self.previous_layer[ + x - 1, y + 1 + ] + if y > 0: + neighbour_ids[7] = self.previous_layer[x, y - 1] + if x < layer.shape[0] - 1: + neighbour_ids[8] = self.previous_layer[ + x + 1, y - 1 + ] + neighbour_ids[9] = self.previous_layer[x, y] + if y < layer.shape[1] - 1: + neighbour_ids[10] = self.previous_layer[x, y + 1] + if x < layer.shape[0] - 1: + neighbour_ids[11] = self.previous_layer[x + 1, y] + if y < layer.shape[1] - 1: + neighbour_ids[12] = self.previous_layer[ + x + 1, y + 1 + ] + + if is_new_structure(neighbour_ids): + neighbour_ids[0] = self.next_structure_id + self.next_structure_id += 1 + struct_id = self.add(x, y, self.z, neighbour_ids) + else: + # reset so that grayscale value does not count as + # structure in next iterations + struct_id = 0 + + layer[x, y] = struct_id + return layer + + def get_cell_centres(self): + cell_centres = self.structures_to_cells() + return cell_centres + + def get_coords_dict(self): + return self.coords_maps + + def add_point(self, sid: int, point: np.ndarray) -> None: + """ + Add *point* to the structure with the given *sid*. + """ + self.coords_maps[sid] = np.row_stack((self.coords_maps[sid], point)) + + def add(self, x: int, y: int, z: int, neighbour_ids: List[int]) -> int: + """ + For the current coordinates takes all the neighbours and find the + minimum structure including obsolete structures mapping to any of + the neighbours recursively. + + Once the correct structure id is found, append a point with the + current coordinates to the coordinates map entry for the correct + structure. Hence each entry of the map will be a vector of all the + pertaining points. + """ + updated_id = self.sanitise_ids(neighbour_ids) + if updated_id not in self.coords_maps: + self.coords_maps[updated_id] = np.zeros( + shape=(0, 3), dtype=np.uint64 + ) + self.merge_structures(updated_id, neighbour_ids) + + # Add point for that structure + point = np.array([[x, y, z]], dtype=np.uint64) + self.add_point(updated_id, point) + return updated_id + + def sanitise_ids(self, neighbour_ids: List[int]) -> int: + """ + Get the smallest ID of all the structures that are connected to IDs + in `neighbour_ids`. + + For all the neighbour ids, walk up the chain of obsolescence (self. + obsolete_ids) to reassign the corresponding most obsolete structure + to the current neighbour. + + Has no side effects on this class. + """ + for i, neighbour_id in enumerate(neighbour_ids): + # walk up the chain of obsolescence + neighbour_id = int(traverse_dict(self.obsolete_ids, neighbour_id)) + neighbour_ids[i] = neighbour_id + + # Get minimum of all non-obsolete IDs + updated_id = get_non_zero_ull_min(neighbour_ids) + return int(updated_id) + + def merge_structures( + self, updated_id: int, neighbour_ids: List[int] + ) -> None: + """ + For all the neighbours, reassign all the points of neighbour to + updated_id. Then deletes the now obsolete entry from the points + map and add that entry to the obsolete_ids. + + Updates: + - self.coords_maps + - self.obsolete_ids + """ + for i, neighbour_id in enumerate(neighbour_ids): + # minimise ID so if neighbour with higher ID, reassign its points + # to current + if neighbour_id > updated_id: + self.add_point(updated_id, self.coords_maps[neighbour_id]) + self.coords_maps.pop(neighbour_id) + self.obsolete_ids[neighbour_id] = updated_id + + def structures_to_cells(self): + cell_centres = [] + for iterator_pair in self.coords_maps: + structure = iterator_pair.second + p = get_structure_centre(structure) + cell_centres.append(p) + return cell_centres + + +@jit +def is_new_structure(neighbour_ids): + for i in range(len(neighbour_ids)): + if neighbour_ids[i] != 0: + return False + return True diff --git a/src/cellfinder_core/detect/filters/volume/structure_detection.pyi b/src/cellfinder_core/detect/filters/volume/structure_detection.pyi deleted file mode 100644 index 070f1470..00000000 --- a/src/cellfinder_core/detect/filters/volume/structure_detection.pyi +++ /dev/null @@ -1,7 +0,0 @@ -class CellDetector: - def __init__(self, width, height, start_z, connect_type=4): ... - def process(self, layer) -> None: ... - def get_coords_list(self): ... - -def get_structure_centre_wrapper(structure): ... -def get_non_zero_ull_min_wrapper(get_non_zero_ull_min) -> int: ... diff --git a/src/cellfinder_core/detect/filters/volume/structure_detection.pyx b/src/cellfinder_core/detect/filters/volume/structure_detection.pyx deleted file mode 100644 index 462169af..00000000 --- a/src/cellfinder_core/detect/filters/volume/structure_detection.pyx +++ /dev/null @@ -1,339 +0,0 @@ -# distutils: language = c++ -# cython: language_level=3 - -import cython - -cimport libc.math as cmath -from libcpp.map cimport map as CppMap -from libcpp.pair cimport pair as CppPair -from libcpp.vector cimport vector as CppVector - -import numpy as np - -from cellfinder_core.detect.filters.typedefs cimport Point, uint, ull - -from imlib.cells.cells import Cell - -DEF ULLONG_MAX = 18446744073709551615 # (2**64) -1 -DEF N_NEIGHBOURS_4_CONNECTED = 3 # top left, below -DEF N_NEIGHBOURS_8_CONNECTED = 13 # all the 9 below + the 4 before on same plane - - - -cdef get_non_zero_ull_min(ull[:] values): - cdef ull min_val = ULLONG_MAX - cdef ull s_id - cdef uint i - for i in range(len(values)): - s_id = values[i] - if s_id != 0: - if s_id < min_val: - min_val = s_id - return min_val - -cpdef get_non_zero_ull_min_wrapper(values): # wrapper for testing purposes - assert len(values) == 10 - cdef ull c_values[10] - for i, v in enumerate(values): - c_values[i] = v - return get_non_zero_ull_min(c_values) - -cdef get_structure_centre(CppVector[Point] structure): - cdef double mean_z, mean_y, mean_x - mean_x = 0; mean_y = 0; mean_z = 0 - cdef double s_len = len(structure) - - cdef Point p - for p in structure: - mean_x += p.x / s_len - mean_y += p.y / s_len - mean_z += p.z / s_len - - return Point( cmath.round(mean_x), cmath.round(mean_y), cmath.round(mean_z)) - - -cpdef get_structure_centre_wrapper(structure): # wrapper for testing purposes - cdef CppVector[Point] s - for p in structure: - if type(p) == dict: - s.push_back(Point(p['x'], p['y'], p['z'])) - else: - s.push_back(Point(p.x, p.y, p.z)) - return get_structure_centre(s) - - - -cdef class CellDetector: - cdef: - public ull SOMA_CENTRE_VALUE # = range - 1 (e.g. 2**16 - 1) - - ull z - int relative_z - int connect_type - - ull[:,:] previous_layer - tuple shape - - StructureManager structure_manager - - ull next_structure_id - - def __init__(self, uint width, uint height, uint start_z, connect_type=4): - self.shape = width, height - self.z = start_z - - assert connect_type in (4, 8), 'Connection type must be one of 4,8 got "{}"'.format(connect_type) - self.connect_type = connect_type - - self.SOMA_CENTRE_VALUE = ULLONG_MAX - - self.relative_z = 0 # position to append in stack # FIXME: replace by keeping start_z and self.z > self.start_Z - self.next_structure_id = 1 - - self.structure_manager = StructureManager() - - cpdef get_previous_layer(self): - return np.array(self.previous_layer, dtype=np.uint64) - - cpdef process(self, layer): # WARNING: inplace # WARNING: ull may be overkill but ulong required - assert [e for e in layer.shape[:2]] == [e for e in self.shape], \ - 'CellDetector layer error, expected shape "{}", got "{}"'\ - .format(self.shape, [e for e in layer.shape[:2]]) - - source_dtype = layer.dtype - layer = layer.astype(np.uint64) - - cdef ull[:,:] c_layer - - # The 'magic numbers' below are chosen so that the maximum number - # representable in each data type is converted to 2**64 - 1, the - # maximum representable number in uint64. - if source_dtype == np.uint8: - # 2**56 + 2**48 + 2**40 + 2**32 + 2**24 + 2**16 + 2**8 + 1 - layer *= 72340172838076673 - elif source_dtype == np.uint16: - # 2**48 + 2**32 + 2**16 + 1 - layer *= 281479271743489 - elif source_dtype == np.uint32: - # 2**32 + 1 - layer *= 4294967297 - elif source_dtype == np.uint64: - pass - else: - raise ValueError('Expected layer of any type from np.uint8, np.uint16, np.uint32, np.uint64,' - 'got: {}'.format(source_dtype)) - c_layer = layer - - if self.connect_type == 4: - self.previous_layer = self.connect_four(c_layer) - else: - self.previous_layer = self.connect_eight(c_layer) - - if self.relative_z == 0: - self.relative_z += 1 - - self.z += 1 - - @cython.boundscheck(False) - cdef connect_four(self, ull[:,:] layer): - """ - For all the pixels in the current layer, finds all structures touching this pixel using the - four connected (plus shape) rule and also looks at the pixel at the same location in the previous layer. - If structures are found, they are added to the structure manager and the pixel labeled accordingly. - - :param layer: - :return: - """ - cdef ull struct_id - cdef ull neighbour_ids[N_NEIGHBOURS_4_CONNECTED] - cdef uint i - for i in range(N_NEIGHBOURS_4_CONNECTED): # reset - neighbour_ids[i] = 0 # Labels of structures at left, top, below - - cdef uint y, x - for y in range(layer.shape[1]): - for x in range(layer.shape[0]): - if layer[x, y] == self.SOMA_CENTRE_VALUE: - for i in range(N_NEIGHBOURS_4_CONNECTED): # reset - neighbour_ids[i] = 0 # Labels of structures at left, top, below - # If in bounds look at neighbours - if x > 0: - neighbour_ids[0] = layer[x-1, y] - if y > 0: - neighbour_ids[1] = layer[x, y-1] - if self.relative_z > 0: - neighbour_ids[2] = self.previous_layer[x, y] - - if self.is_new_structure(neighbour_ids): - neighbour_ids[0] = self.next_structure_id - self.next_structure_id += 1 - struct_id = self.structure_manager.add(x, y, self.z, neighbour_ids) - else: - struct_id = 0 # reset so that grayscale value does not count as structure in next iterations - - layer[x, y] = struct_id - return layer - - cdef connect_eight(self, ull[:,:] layer): - """ - For all the pixels in the current layer, finds all structures touching this pixel using the - eight connected (connected by edges or corners) rule and also looks at the pixel at the same - location in the previous layer. - If structures are found, they are added to the structure manager and the pixel labeled accordingly. - - :param layer: - :return: - """ - cdef ull struct_id - cdef ull neighbour_ids[N_NEIGHBOURS_8_CONNECTED] - cdef uint i - for i in range(N_NEIGHBOURS_8_CONNECTED): # reset - neighbour_ids[i] = 0 # Labels of neighbour structures touching before - - cdef uint y, x - for y in range(layer.shape[1]): - for x in range(layer.shape[0]): - if layer[x, y] == self.SOMA_CENTRE_VALUE: - for i in range(N_NEIGHBOURS_8_CONNECTED): # reset - neighbour_ids[i] = 0 - - # If in bounds look at neighbours - if x > 0 and y > 0: - neighbour_ids[0] = layer[x-1, y-1] - if x > 0: - neighbour_ids[1] = layer[x-1, y] - if y > 0: - neighbour_ids[2] = layer[x, y-1] - neighbour_ids[3] = layer[x+1, y-1] - if self.relative_z > 0: - if x > 0 and y > 0: - neighbour_ids[4] = self.previous_layer[x-1, y-1] - if x > 0: - neighbour_ids[5] = self.previous_layer[x-1, y] - if y < layer.shape[1] - 1: - neighbour_ids[6] = self.previous_layer[x-1, y+1] - if y > 0: - neighbour_ids[7] = self.previous_layer[x, y-1] - if x < layer.shape[0] - 1: - neighbour_ids[8] = self.previous_layer[x+1, y-1] - neighbour_ids[9] = self.previous_layer[x, y] - if y < layer.shape[1] - 1: - neighbour_ids[10] = self.previous_layer[x, y+1] - if x < layer.shape[0] - 1: - neighbour_ids[11] = self.previous_layer[x+1, y] - if y < layer.shape[1] - 1: - neighbour_ids[12] = self.previous_layer[x+1, y+1] - - if self.is_new_structure(neighbour_ids): - neighbour_ids[0] = self.next_structure_id - self.next_structure_id += 1 - struct_id = self.structure_manager.add(x, y, self.z, neighbour_ids) - else: - struct_id = 0 # reset so that grayscale value does not count as structure in next iterations - - layer[x, y] = struct_id - return layer - - @cython.boundscheck(False) - cdef is_new_structure(self, ull[:] neighbour_ids): # TEST: - cdef uint i - for i in range(len(neighbour_ids)): - if neighbour_ids[i] != 0: - return False - return True - - cpdef get_cell_centres(self): - cdef CppVector[Point] cell_centres - cell_centres = self.structure_manager.structures_to_cells() - return cell_centres - - cpdef get_coords_list(self): - coords = self.structure_manager.get_coords_dict() # TODO: cache (attribute) - return coords - -cdef class StructureManager: - cdef: - CppMap[ull, ull] obsolete_ids - CppMap[ull, CppVector[Point]] coords_maps - int default_cell_type - - def __init__(self): - self.default_cell_type = Cell.UNKNOWN - - cpdef get_coords_dict(self): - return self.coords_maps - - @cython.boundscheck(False) - cdef add(self, uint x, uint y, uint z, ull[:] neighbour_ids): - """ - For the current coordinates takes all the neighbours and find the minimum structure - including obsolete structures mapping to any of the neighbours recursively. - Once the correct structure id is found, append a point with the current coordinates to the coordinates map - entry for the correct structure. Hence each entry of the map will be a vector of all the pertaining points. - - :param x: - :param y: - :param z: - :param neighbour_ids: - :return: - """ - cdef ull updated_id - - updated_id = self.sanitise_ids(neighbour_ids) - self.merge_structures(updated_id, neighbour_ids) - - cdef Point p = Point(x, y, z) # Necessary to split definition on some machines - self.coords_maps[updated_id].push_back(p) # Add point for that structure - - return updated_id - - @cython.boundscheck(False) - cdef sanitise_ids(self, ull[:] neighbour_ids): - """ - For all the neighbour ids, walk up the chain of obsolescence (self.obsolete_ids) - to reassign the corresponding most obsolete structure to the current neighbour - - :param neighbour_ids: - :return: updated_id - """ - cdef ull updated_id, neighbour_id - cdef uint i - for i in range(len(neighbour_ids)): - neighbour_id = neighbour_ids[i] - while self.obsolete_ids.count(neighbour_id): # walk up the chain of obsolescence - neighbour_id = self.obsolete_ids[neighbour_id] - neighbour_ids[i] = neighbour_id - - updated_id = get_non_zero_ull_min(neighbour_ids) # FIXME: what happens if all neighbour_ids are 0 (raise) - return updated_id - - @cython.boundscheck(False) - cdef merge_structures(self, ull updated_id, ull[:] neighbour_ids): - """ - For all the neighbours, reassign all the points of neighbour to updated_id - Then deletes the now obsolete entry from the points map and add that entry to the obsolete_ids - - :param updated_id: - :param neighbour_ids: - """ - cdef ull neighbour_id - cdef Point p - cdef uint i - for i in range(len(neighbour_ids)): - neighbour_id = neighbour_ids[i] - if neighbour_id > updated_id: # minimise ID so if neighbour with higher ID, reassign its points to current - for p in self.coords_maps[neighbour_id]: - self.coords_maps[updated_id].push_back(p) - self.coords_maps.erase(neighbour_id) - self.obsolete_ids[neighbour_id] = updated_id - - cdef structures_to_cells(self): - cdef CppVector[Point] structure, cell_centres - cdef Point p - - cdef CppPair[ull, CppVector[Point]] iterator_pair - for iterator_pair in self.coords_maps: - structure = iterator_pair.second - p = get_structure_centre(structure) - cell_centres.push_back(p) - return cell_centres diff --git a/src/cellfinder_core/detect/filters/volume/volume_filter.py b/src/cellfinder_core/detect/filters/volume/volume_filter.py index 159f3a4d..0703da79 100644 --- a/src/cellfinder_core/detect/filters/volume/volume_filter.py +++ b/src/cellfinder_core/detect/filters/volume/volume_filter.py @@ -91,20 +91,7 @@ def process( self.ball_filter.append(plane, mask) if self.ball_filter.ready: - logger.debug(f"Ball filtering plane {self.z}") - self.ball_filter.walk() - - middle_plane = self.ball_filter.get_middle_plane() - if self.save_planes: - self.save_plane(middle_plane) - - logger.debug(f"Detecting structures for plane {self.z}") - self.cell_detector.process(middle_plane) - - logger.debug(f"Structures done for plane {self.z}") - logger.debug( - f"Skipping plane {self.z} for 3D filter" " (out of bounds)" - ) + self._run_filter() callback(self.z) self.z += 1 @@ -114,6 +101,22 @@ def process( logger.debug("3D filter done") return self.get_results() + def _run_filter(self): + logger.debug(f"Ball filtering plane {self.z}") + self.ball_filter.walk() + + middle_plane = self.ball_filter.get_middle_plane() + if self.save_planes: + self.save_plane(middle_plane) + + logger.debug(f"Detecting structures for plane {self.z}") + self.cell_detector.process(middle_plane) + + logger.debug(f"Structures done for plane {self.z}") + logger.debug( + f"Skipping plane {self.z} for 3D filter" " (out of bounds)" + ) + def save_plane(self, plane): plane_name = f"plane_{str(self.z).zfill(4)}.tif" f_path = os.path.join(self.plane_directory, plane_name) @@ -127,17 +130,14 @@ def get_results(self) -> List[Cell]: ) cells = [] - for ( - cell_id, - cell_points, - ) in self.cell_detector.get_coords_list().items(): + for cell_id, cell_points in self.cell_detector.coords_maps.items(): cell_volume = len(cell_points) if cell_volume < max_cell_volume: cell_centre = get_structure_centre_wrapper(cell_points) cells.append( Cell( - (cell_centre["x"], cell_centre["y"], cell_centre["z"]), + (cell_centre.x, cell_centre.y, cell_centre.z), Cell.UNKNOWN, ) ) @@ -155,9 +155,9 @@ def get_results(self) -> List[Cell]: cells.append( Cell( ( - cell_centre["x"], - cell_centre["y"], - cell_centre["z"], + cell_centre.x, + cell_centre.y, + cell_centre.z, ), Cell.UNKNOWN, ) @@ -167,9 +167,9 @@ def get_results(self) -> List[Cell]: cells.append( Cell( ( - cell_centre["x"], - cell_centre["y"], - cell_centre["z"], + cell_centre.x, + cell_centre.y, + cell_centre.z, ), Cell.ARTIFACT, ) diff --git a/tests/tests/test_unit/test_detect/test_filters/test_volume_filters/test_structure_detection.py b/tests/tests/test_unit/test_detect/test_filters/test_volume_filters/test_structure_detection.py index 25d43f03..60e5c47e 100644 --- a/tests/tests/test_unit/test_detect/test_filters/test_volume_filters/test_structure_detection.py +++ b/tests/tests/test_unit/test_detect/test_filters/test_volume_filters/test_structure_detection.py @@ -3,24 +3,25 @@ from cellfinder_core.detect.filters.volume.structure_detection import ( CellDetector, - get_non_zero_ull_min_wrapper, + Point, + get_non_zero_ull_min, get_structure_centre_wrapper, ) -def test_get_non_zero_ull_min(): - assert get_non_zero_ull_min_wrapper(list(range(10))) == 1 - assert get_non_zero_ull_min_wrapper([0] * 10) == (2**64) - 1 - +def coords_to_points(coords_arrays): + # Convert from arrays to dicts + coords = {} + for sid in coords_arrays: + coords[sid] = [] + for row in coords_arrays[sid]: + coords[sid].append(Point(row[0], row[1], row[2])) + return coords -class Point: - def __init__(self, x, y, z): - self.x = x - self.y = y - self.z = z - def __str__(self): - return "x: {}, y: {}, z: {}".format(self.x, self.y, self.z) +def test_get_non_zero_ull_min(): + assert get_non_zero_ull_min(np.arange(10, dtype=np.uint64)) == 1 + assert get_non_zero_ull_min(np.zeros(10, dtype=np.uint64)) == (2**64) - 1 @pytest.fixture() @@ -43,7 +44,7 @@ def structure(three_d_cross): def test_get_structure_centre(structure): result_point = get_structure_centre_wrapper(structure) - assert (result_point["x"], result_point["y"], result_point["z"]) == ( + assert (result_point.x, result_point.y, result_point.z) == ( 1, 1, 1, @@ -67,63 +68,49 @@ def test_get_structure_centre(structure): ( # Two pixels connected in a single structure along x [(0, 0, 0), (0, 1, 0)], - {1: [{"x": 0, "y": 0, "z": 0}, {"x": 1, "y": 0, "z": 0}]}, + {1: [Point(0, 0, 0), Point(1, 0, 0)]}, ), ( # Two pixels connected in a single structure along y [(0, 0, 0), (0, 0, 1)], - {1: [{"x": 0, "y": 0, "z": 0}, {"x": 0, "y": 1, "z": 0}]}, + {1: [Point(0, 0, 0), Point(0, 1, 0)]}, ), ( # Two pixels connected in a single structure along z [(0, 0, 0), (1, 0, 0)], - {1: [{"x": 0, "y": 0, "z": 0}, {"x": 0, "y": 0, "z": 1}]}, + {1: [Point(0, 0, 0), Point(0, 0, 1)]}, ), ( # Four pixels all connected and spread across x-y-z [(0, 0, 0), (1, 0, 0), (1, 1, 0), (1, 0, 1)], - { - 1: [ - {"x": 0, "y": 0, "z": 0}, - {"x": 0, "y": 0, "z": 1}, - {"x": 1, "y": 0, "z": 1}, - {"x": 0, "y": 1, "z": 1}, - ] - }, + {1: [Point(0, 0, 0), Point(0, 0, 1), Point(1, 0, 1), Point(0, 1, 1)]}, ), ( # three initially disconnected pixels that then get merged # by a fourth pixel [(1, 1, 0), (0, 1, 1), (1, 0, 1), (1, 1, 1)], - { - 1: [ - {"x": 1, "y": 1, "z": 0}, - {"x": 0, "y": 1, "z": 1}, - {"x": 1, "y": 0, "z": 1}, - {"x": 1, "y": 1, "z": 1}, - ] - }, + {1: [Point(1, 1, 0), Point(0, 1, 1), Point(1, 0, 1), Point(1, 1, 1)]}, ), ( # Three pixels in x-y plane that require structure merging [(1, 0, 0), (0, 1, 0), (1, 1, 0)], { 1: [ - {"x": 1, "y": 0, "z": 0}, - {"x": 0, "y": 0, "z": 1}, - {"x": 1, "y": 0, "z": 1}, + Point(1, 0, 0), + Point(0, 0, 1), + Point(1, 0, 1), ] }, ), ( # Two disconnected single-pixel structures [(0, 0, 0), (0, 2, 0)], - {1: [{"x": 0, "y": 0, "z": 0}], 2: [{"x": 2, "y": 0, "z": 0}]}, + {1: [Point(0, 0, 0)], 2: [Point(2, 0, 0)]}, ), ( # Two disconnected single-pixel structures along a diagonal [(0, 0, 0), (1, 1, 1)], - {1: [{"x": 0, "y": 0, "z": 0}], 2: [{"x": 1, "y": 1, "z": 1}]}, + {1: [Point(0, 0, 0)], 2: [Point(1, 1, 1)]}, ), ] @@ -142,5 +129,5 @@ def test_detection(dtype, pixels, expected_coords): for plane in data: detector.process(plane) - coords = detector.get_coords_list() - assert coords == expected_coords + coords = detector.get_coords_dict() + assert coords_to_points(coords) == expected_coords