diff --git a/.codecov.yml b/.codecov.yml index 70f8cec9..0e323f2f 100644 --- a/.codecov.yml +++ b/.codecov.yml @@ -10,6 +10,8 @@ flags: numba: paths: - cellfinder/core/detect/filters/plane/tile_walker.py + - cellfinder/core/detect/filters/plane/classical_filter.py + - cellfinder/core/detect/filters/plane/plane_filter.py - cellfinder/core/detect/filters/volume/ball_filter.py - cellfinder/core/detect/filters/volume/structure_detection.py carryforward: true diff --git a/.github/workflows/test_and_deploy.yml b/.github/workflows/test_and_deploy.yml index 756b7d96..d071743d 100644 --- a/.github/workflows/test_and_deploy.yml +++ b/.github/workflows/test_and_deploy.yml @@ -35,11 +35,14 @@ jobs: test: needs: [linting, manifest] name: Run package tests - timeout-minutes: 60 + timeout-minutes: 120 runs-on: ${{ matrix.os }} env: KERAS_BACKEND: torch CELLFINDER_TEST_DEVICE: cpu + # pooch cache dir + BRAINGLOBE_TEST_DATA_DIR: "~/.pooch_cache" + strategy: matrix: # Run all supported Python versions on linux @@ -53,6 +56,14 @@ jobs: python-version: "3.12" steps: + - uses: actions/checkout@v4 + - name: Cache pooch data + uses: actions/cache@v4 + with: + path: "~/.pooch_cache" + # hash on conftest in case url changes + key: ${{ runner.os }}-${{ matrix.python-version }}-${{ hashFiles('**/pooch_registry.txt') }} + # Cache the tensorflow model so we don't have to remake it every time - name: Cache brainglobe directory uses: actions/cache@v3 with: @@ -75,12 +86,16 @@ jobs: test_numba_disabled: needs: [linting, manifest] name: Run tests with numba disabled - timeout-minutes: 60 + timeout-minutes: 120 runs-on: ubuntu-latest env: - NUMBA_DISABLE_JIT: "1" + NUMBA_DISABLE_JIT: "1" + PYTORCH_JIT: "0" + # pooch cache dir + BRAINGLOBE_TEST_DATA_DIR: "~/.pooch_cache" steps: + - uses: actions/checkout@v4 - name: Cache brainglobe directory uses: actions/cache@v3 with: @@ -88,6 +103,13 @@ jobs: ~/.brainglobe !~/.brainglobe/atlas.tar.gz key: brainglobe + + - name: Cache pooch data + uses: actions/cache@v4 + with: + path: "~/.pooch_cache" + key: ${{ runner.os }}-3.10-${{ hashFiles('**/pooch_registry.txt') }} + # Setup pyqt libraries - name: Setup qtpy libraries uses: tlambert03/setup-qt-libs@v1 @@ -105,7 +127,7 @@ jobs: test_brainmapper_cli: needs: [linting, manifest] name: Run brainmapper tests to check for breakages - timeout-minutes: 60 + timeout-minutes: 120 runs-on: ubuntu-latest env: KERAS_BACKEND: torch diff --git a/benchmarks/benchmark_tools.py b/benchmarks/benchmark_tools.py new file mode 100644 index 00000000..16c5d160 --- /dev/null +++ b/benchmarks/benchmark_tools.py @@ -0,0 +1,86 @@ +from pathlib import Path + +import pooch +import torch +from torch.profiler import ProfilerActivity, profile +from torch.utils.benchmark import Compare, Timer + +from cellfinder.core.tools.IO import fetch_pooch_directory + + +def get_test_data_path(path): + """ + Create a test data registry for BrainGlobe. + + Returns: + pooch.Pooch: The test data registry object. + + """ + registry = pooch.create( + path=pooch.os_cache("brainglobe_test_data"), + base_url="https://gin.g-node.org/BrainGlobe/test-data/raw/master/cellfinder/", + env="BRAINGLOBE_TEST_DATA_DIR", + ) + + registry.load_registry( + Path(__file__).parent.parent / "tests" / "data" / "pooch_registry.txt" + ) + + return fetch_pooch_directory(registry, path) + + +def time_filters(repeat, run, run_args, label): + timer = Timer( + stmt="run(*args)", + globals={"run": run, "args": run_args}, + label=label, + num_threads=4, + description="", # must be not None due to pytorch bug + ) + return timer.timeit(number=repeat) + + +def compare_results(*results): + # prints the results of all the timed tests + compare = Compare(results) + compare.trim_significant_figures() + compare.colorize() + compare.print() + + +def profile_cpu(repeat, run, run_args): + with profile( + activities=[ProfilerActivity.CPU], + record_shapes=True, + profile_memory=True, + with_stack=True, + with_modules=True, + ) as prof: + for _ in range(repeat): + run(*run_args) + + print( + prof.key_averages(group_by_stack_n=1).table( + sort_by="self_cpu_time_total", row_limit=20 + ) + ) + + +def profile_cuda(repeat, run, run_args): + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + record_shapes=True, + profile_memory=True, + with_stack=True, + with_modules=True, + ) as prof: + for _ in range(repeat): + run(*run_args) + # make sure it's fully done filtering + torch.cuda.synchronize("cuda") + + print( + prof.key_averages(group_by_stack_n=1).table( + sort_by="self_cuda_time_total", row_limit=20 + ) + ) diff --git a/benchmarks/filter_2d.py b/benchmarks/filter_2d.py index 1ae24d94..067fc5d8 100644 --- a/benchmarks/filter_2d.py +++ b/benchmarks/filter_2d.py @@ -1,27 +1,131 @@ +import os +import sys + +sys.path.append(os.path.dirname(__file__)) + import numpy as np -from pyinstrument import Profiler +import torch +from benchmark_tools import ( + compare_results, + get_test_data_path, + profile_cpu, + profile_cuda, + time_filters, +) +from brainglobe_utils.IO.image.load import read_with_dask from cellfinder.core.detect.filters.plane import TileProcessor -from cellfinder.core.detect.filters.setup_filters import setup_tile_filtering +from cellfinder.core.detect.filters.setup_filters import DetectionSettings -# Use random 16-bit integer data for signal plane -shape = (10000, 10000) -signal_array_plane = np.random.randint( - low=0, high=65536, size=shape, dtype=np.uint16 -) +def setup_filter( + signal_path, + batch_size=1, + num_z=None, + torch_device="cpu", + dtype=np.uint16, + use_scipy=False, +): + signal_array = read_with_dask(signal_path) + num_z = num_z or len(signal_array) + signal_array = np.asarray(signal_array[:num_z]).astype(dtype) + shape = signal_array.shape + + settings = DetectionSettings( + plane_original_np_dtype=dtype, + plane_shape=shape[1:], + torch_device=torch_device, + voxel_sizes=(5.06, 4.5, 4.5), + soma_diameter_um=30, + ball_xy_size_um=6, + ball_z_size_um=15, + ) + signal_array = settings.filter_data_converter_func(signal_array) + signal_array = torch.from_numpy(signal_array).to(torch_device) + + tile_processor = TileProcessor( + plane_shape=shape[1:], + clipping_value=settings.clipping_value, + threshold_value=settings.threshold_value, + soma_diameter=settings.soma_diameter, + log_sigma_size=settings.log_sigma_size, + n_sds_above_mean_thresh=settings.n_sds_above_mean_thresh, + torch_device=torch_device, + dtype=settings.filtering_dtype.__name__, + use_scipy=use_scipy, + ) + + return tile_processor, signal_array, batch_size + + +def run_filter(tile_processor, signal_array, batch_size): + for i in range(0, len(signal_array), batch_size): + tile_processor.get_tile_mask(signal_array[i : i + batch_size]) + -clipping_value, threshold_value = setup_tile_filtering(signal_array_plane) -tile_processor = TileProcessor( - clipping_value=clipping_value, - threshold_value=threshold_value, - soma_diameter=16, - log_sigma_size=0.2, - n_sds_above_mean_thresh=10, -) if __name__ == "__main__": - profiler = Profiler() - profiler.start() - plane, tiles = tile_processor.get_tile_mask(signal_array_plane) - profiler.stop() - profiler.print(show_all=True) + with torch.inference_mode(True): + n = 5 + batch_size = 2 + signal_path = get_test_data_path("bright_brain/signal") + + compare_results( + time_filters( + n, + run_filter, + setup_filter( + signal_path, + batch_size=batch_size, + torch_device="cpu", + use_scipy=False, + ), + "cpu-no_scipy", + ), + time_filters( + n, + run_filter, + setup_filter( + signal_path, + batch_size=batch_size, + torch_device="cpu", + use_scipy=True, + ), + "cpu-scipy", + ), + time_filters( + n, + run_filter, + setup_filter( + signal_path, batch_size=batch_size, torch_device="cuda" + ), + "cuda", + ), + ) + + profile_cpu( + n, + run_filter, + setup_filter( + signal_path, + batch_size=batch_size, + torch_device="cpu", + use_scipy=False, + ), + ) + profile_cpu( + n, + run_filter, + setup_filter( + signal_path, + batch_size=batch_size, + torch_device="cpu", + use_scipy=True, + ), + ) + profile_cuda( + n, + run_filter, + setup_filter( + signal_path, batch_size=batch_size, torch_device="cuda" + ), + ) diff --git a/benchmarks/filter_3d.py b/benchmarks/filter_3d.py index 5509dd52..0a8dd0af 100644 --- a/benchmarks/filter_3d.py +++ b/benchmarks/filter_3d.py @@ -1,51 +1,129 @@ +import os +import sys + +sys.path.append(os.path.dirname(__file__)) + import numpy as np -from pyinstrument import Profiler +import torch +from benchmark_tools import ( + compare_results, + get_test_data_path, + profile_cpu, + profile_cuda, + time_filters, +) +from brainglobe_utils.IO.image.load import read_with_dask -from cellfinder.core.detect.filters.volume.volume_filter import VolumeFilter +from cellfinder.core.detect.filters.setup_filters import DetectionSettings +from cellfinder.core.detect.filters.volume.ball_filter import BallFilter -# Use random data for signal data -ball_z_size = 3 +def setup_filter( + plane_path, + tiles_path, + batch_size=1, + num_z=None, + torch_device="cpu", + dtype=np.uint16, +): + filtered = read_with_dask(plane_path) + tiles = read_with_dask(tiles_path) + num_z = num_z or len(filtered) + filtered = np.asarray(filtered[:num_z]) + tiles = np.asarray(tiles[:num_z]) + shape = filtered.shape -def gen_signal_array(ny, nx): - shape = (ball_z_size, ny, nx) - return np.random.randint(low=0, high=65536, size=shape, dtype=np.uint16) + settings = DetectionSettings( + plane_original_np_dtype=dtype, + plane_shape=shape[1:], + torch_device=torch_device, + soma_diameter_um=30, + ball_xy_size_um=6, + ball_z_size_um=15, + ) + filtered = filtered.astype(settings.filtering_dtype) + filtered = torch.from_numpy(filtered).to(torch_device) + tiles = tiles.astype(np.bool_) + tiles = torch.from_numpy(tiles).to(torch_device) + ball_filter = BallFilter( + plane_height=settings.plane_height, + plane_width=settings.plane_width, + ball_xy_size=settings.ball_xy_size, + ball_z_size=settings.ball_z_size, + overlap_fraction=settings.ball_overlap_fraction, + threshold_value=settings.threshold_value, + soma_centre_value=settings.soma_centre_value, + tile_height=settings.tile_height, + tile_width=settings.tile_width, + dtype=settings.filtering_dtype.__name__, + batch_size=batch_size, + torch_device=torch_device, + use_mask=True, + ) -signal_array = gen_signal_array(667, 510) + return ball_filter, filtered, tiles, batch_size -soma_diameter = 8 -setup_params = ( - signal_array[0, :, :].T, - soma_diameter, - 3, # ball_xy_size, - ball_z_size, - 0.6, # ball_overlap_fraction, - 0, # start_plane, -) -mp_3d_filter = VolumeFilter( - soma_diameter=soma_diameter, - setup_params=setup_params, - n_planes=len(signal_array), - n_locks_release=1, -) +def run_filter(ball_filter, filtered, tiles, batch_size): + for i in range(0, len(filtered), batch_size): + ball_filter.append( + filtered[i : i + batch_size], tiles[i : i + batch_size] + ) + if ball_filter.ready: + ball_filter.walk() + ball_filter.get_processed_planes() -# Use random data for mask data -mask = np.random.randint(low=0, high=2, size=(42, 32), dtype=bool) -# Fill up the 3D filter with planes -for plane in signal_array: - mp_3d_filter.ball_filter.append(plane, mask) - if mp_3d_filter.ball_filter.ready: - break +if __name__ == "__main__": + with torch.inference_mode(True): + n = 5 + batch_size = 4 + plane_path = get_test_data_path("bright_brain/2d_filter") + tiles_path = get_test_data_path("bright_brain/tiles") -# Run the 3D filter -profiler = Profiler() -profiler.start() -for i in range(10): - # Repeat same filter 10 times to increase runtime - mp_3d_filter._run_filter() + compare_results( + time_filters( + n, + run_filter, + setup_filter( + plane_path, + tiles_path, + batch_size=batch_size, + torch_device="cpu", + ), + "cpu", + ), + time_filters( + n, + run_filter, + setup_filter( + plane_path, + tiles_path, + batch_size=batch_size, + torch_device="cuda", + ), + "cuda", + ), + ) -profiler.stop() -profiler.print(show_all=True) + profile_cpu( + n, + run_filter, + setup_filter( + plane_path, + tiles_path, + batch_size=batch_size, + torch_device="cpu", + ), + ) + profile_cuda( + n, + run_filter, + setup_filter( + plane_path, + tiles_path, + batch_size=batch_size, + torch_device="cuda", + ), + ) diff --git a/cellfinder/core/classify/classify.py b/cellfinder/core/classify/classify.py index ec77190f..8e211bba 100644 --- a/cellfinder/core/classify/classify.py +++ b/cellfinder/core/classify/classify.py @@ -30,7 +30,7 @@ def main( max_workers: int = 3, *, callback: Optional[Callable[[int], None]] = None, -) -> List: +) -> List[Cell]: """ Parameters ---------- diff --git a/cellfinder/core/detect/detect.py b/cellfinder/core/detect/detect.py index 9d70e541..7562bbcb 100644 --- a/cellfinder/core/detect/detect.py +++ b/cellfinder/core/detect/detect.py @@ -13,76 +13,47 @@ - (max_val) is used to mark bright points during 3D filtering """ -import multiprocessing +import dataclasses from datetime import datetime -from queue import Queue -from threading import Lock -from typing import Callable, List, Optional, Sequence, Tuple, TypeVar +from typing import Callable, List, Optional, Tuple import numpy as np +import torch from brainglobe_utils.cells.cells import Cell -from brainglobe_utils.general.system import get_num_processes -from numba import set_num_threads from cellfinder.core import logger, types from cellfinder.core.detect.filters.plane import TileProcessor -from cellfinder.core.detect.filters.setup_filters import setup_tile_filtering +from cellfinder.core.detect.filters.setup_filters import DetectionSettings from cellfinder.core.detect.filters.volume.volume_filter import VolumeFilter +from cellfinder.core.tools.tools import inference_wrapper -def calculate_parameters_in_pixels( - voxel_sizes: Tuple[float, float, float], - soma_diameter_um: float, - max_cluster_size_um3: float, - ball_xy_size_um: float, - ball_z_size_um: float, -) -> Tuple[int, int, int, int]: - """ - Convert the command-line arguments from real (um) units to pixels - """ - - mean_in_plane_pixel_size = 0.5 * ( - float(voxel_sizes[2]) + float(voxel_sizes[1]) - ) - voxel_volume = ( - float(voxel_sizes[2]) * float(voxel_sizes[1]) * float(voxel_sizes[0]) - ) - soma_diameter = int(round(soma_diameter_um / mean_in_plane_pixel_size)) - max_cluster_size = int(round(max_cluster_size_um3 / voxel_volume)) - ball_xy_size = int(round(ball_xy_size_um / mean_in_plane_pixel_size)) - ball_z_size = int(round(ball_z_size_um / float(voxel_sizes[0]))) - - if ball_z_size == 0: - raise ValueError( - "Ball z size has been calculated to be 0 voxels." - " This may be due to large axial spacing of your data or the " - "ball_z_size_um parameter being too small. " - "Please check input parameters are correct. " - "Note that cellfinder requires high resolution data in all " - "dimensions, so that cells can be detected in multiple " - "image planes." - ) - return soma_diameter, max_cluster_size, ball_xy_size, ball_z_size - - +@inference_wrapper def main( signal_array: types.array, - start_plane: int, - end_plane: int, - voxel_sizes: Tuple[float, float, float], - soma_diameter: float, - max_cluster_size: float, - ball_xy_size: float, - ball_z_size: float, - ball_overlap_fraction: float, - soma_spread_factor: float, - n_free_cpus: int, - log_sigma_size: float, - n_sds_above_mean_thresh: float, + start_plane: int = 0, + end_plane: int = -1, + voxel_sizes: Tuple[float, float, float] = (5, 2, 2), + soma_diameter: float = 16, + max_cluster_size: float = 100_000, + ball_xy_size: float = 6, + ball_z_size: float = 15, + ball_overlap_fraction: float = 0.6, + soma_spread_factor: float = 1.4, + n_free_cpus: int = 2, + log_sigma_size: float = 0.2, + n_sds_above_mean_thresh: float = 10, outlier_keep: bool = False, artifact_keep: bool = False, save_planes: bool = False, plane_directory: Optional[str] = None, + batch_size: Optional[int] = None, + torch_device: str = "cpu", + use_scipy: bool = True, + split_ball_xy_size: int = 3, + split_ball_z_size: int = 3, + split_ball_overlap_fraction: float = 0.8, + split_soma_diameter: int = 7, *, callback: Optional[Callable[[int], None]] = None, ) -> List[Cell]: @@ -101,7 +72,7 @@ def main( Index of the ending plane for detection. voxel_sizes : Tuple[float, float, float] - Tuple of voxel sizes in each dimension (x, y, z). + Tuple of voxel sizes in each dimension (z, y, x). soma_diameter : float Diameter of the soma in physical units. @@ -142,6 +113,18 @@ def main( plane_directory : str, optional Directory path to save the planes. Defaults to None. + batch_size : int, optional + The number of planes to process in each batch. Defaults to 1. + For CPU, there's no benefit for a larger batch size. Only a memory + usage increase. For CUDA, the larger the batch size the better the + performance. Until it fills up the GPU memory - after which it + becomes slower. + + torch_device : str, optional + The device on which to run the computation. By default, it's "cpu". + To run on a gpu, specify the PyTorch device name, such as "cuda" to + run on the first GPU. + callback : Callable[int], optional A callback function that is called every time a plane has finished being processed. Called with the plane number that has finished. @@ -151,151 +134,103 @@ def main( List[Cell] List of detected cells. """ - if not np.issubdtype(signal_array.dtype, np.integer): - raise ValueError( - "signal_array must be integer datatype, but has datatype " + start_time = datetime.now() + if batch_size is None: + if torch_device == "cpu": + batch_size = 4 + else: + batch_size = 1 + + if not np.issubdtype(signal_array.dtype, np.number): + raise TypeError( + "signal_array must be a numpy datatype, but has datatype " f"{signal_array.dtype}" ) - n_processes = get_num_processes(min_free_cpu_cores=n_free_cpus) - n_ball_procs = max(n_processes - 1, 1) - - # we parallelize 2d filtering, which typically lags behind the 3d - # processing so for n_ball_procs 2d filtering threads, ball_z_size will - # typically be in use while the others stall waiting for 3d processing - # so we can use those for other things, such as numba threading - set_num_threads(max(n_ball_procs - int(ball_z_size), 1)) - - start_time = datetime.now() - - ( - soma_diameter, - max_cluster_size, - ball_xy_size, - ball_z_size, - ) = calculate_parameters_in_pixels( - voxel_sizes, - soma_diameter, - max_cluster_size, - ball_xy_size, - ball_z_size, - ) - - if end_plane == -1: - end_plane = len(signal_array) - signal_array = signal_array[start_plane:end_plane] - signal_array = signal_array.astype(np.uint32) - - callback = callback or (lambda *args, **kwargs: None) if signal_array.ndim != 3: raise ValueError("Input data must be 3D") - setup_params = ( - signal_array[0, :, :], - soma_diameter, - ball_xy_size, - ball_z_size, - ball_overlap_fraction, - start_plane, - ) - - # Create 3D analysis filter - mp_3d_filter = VolumeFilter( - soma_diameter=soma_diameter, - setup_params=setup_params, - soma_size_spread_factor=soma_spread_factor, - n_planes=len(signal_array), - n_locks_release=n_ball_procs, - save_planes=save_planes, - plane_directory=plane_directory, + if end_plane < 0: + end_plane = len(signal_array) + end_plane = min(len(signal_array), end_plane) + + torch_device = torch_device.lower() + batch_size = max(batch_size, 1) + # brainmapper can pass them in as str + voxel_sizes = list(map(float, voxel_sizes)) + + settings = DetectionSettings( + plane_shape=signal_array.shape[1:], + plane_original_np_dtype=signal_array.dtype, + voxel_sizes=voxel_sizes, + soma_spread_factor=soma_spread_factor, + soma_diameter_um=soma_diameter, + max_cluster_size_um3=max_cluster_size, + ball_xy_size_um=ball_xy_size, + ball_z_size_um=ball_z_size, start_plane=start_plane, - max_cluster_size=max_cluster_size, + end_plane=end_plane, + n_free_cpus=n_free_cpus, + ball_overlap_fraction=ball_overlap_fraction, + log_sigma_size=log_sigma_size, + n_sds_above_mean_thresh=n_sds_above_mean_thresh, outlier_keep=outlier_keep, artifact_keep=artifact_keep, + save_planes=save_planes, + plane_directory=plane_directory, + batch_size=batch_size, + torch_device=torch_device, + ) + + # replicate the settings specific to splitting, before we access anything + # of the original settings, causing cached properties + kwargs = dataclasses.asdict(settings) + kwargs["ball_z_size_um"] = split_ball_z_size * settings.z_pixel_size + kwargs["ball_xy_size_um"] = ( + split_ball_xy_size * settings.in_plane_pixel_size ) + kwargs["ball_overlap_fraction"] = split_ball_overlap_fraction + kwargs["soma_diameter_um"] = ( + split_soma_diameter * settings.in_plane_pixel_size + ) + # always run on cpu because copying to gpu overhead is likely slower than + # any benefit for detection on smallish volumes + kwargs["torch_device"] = "cpu" + # for splitting, we only do 3d filtering. Its input is a zero volume + # with cell voxels marked with threshold_value. So just use float32 + # for input because the filters will also use float(32). So there will + # not be need to convert the input a different dtype before passing to + # the filters. + kwargs["plane_original_np_dtype"] = np.float32 + splitting_settings = DetectionSettings(**kwargs) + + # Create 3D analysis filter + mp_3d_filter = VolumeFilter(settings=settings) - clipping_val, threshold_value = setup_tile_filtering(signal_array[0, :, :]) # Create 2D analysis filter mp_tile_processor = TileProcessor( - clipping_val, - threshold_value, - soma_diameter, - log_sigma_size, - n_sds_above_mean_thresh, + plane_shape=settings.plane_shape, + clipping_value=settings.clipping_value, + threshold_value=settings.threshold_value, + n_sds_above_mean_thresh=n_sds_above_mean_thresh, + log_sigma_size=log_sigma_size, + soma_diameter=settings.soma_diameter, + torch_device=torch_device, + dtype=settings.filtering_dtype.__name__, + use_scipy=use_scipy, ) - # Force spawn context - mp_ctx = multiprocessing.get_context("spawn") - with mp_ctx.Pool(n_ball_procs) as worker_pool: - async_results, locks = _map_with_locks( - mp_tile_processor.get_tile_mask, - signal_array, # type: ignore - worker_pool, - ) - - # Release the first set of locks for the 2D filtering - for i in range(min(n_ball_procs + ball_z_size, len(locks))): - logger.debug(f"🔓 Releasing lock for plane {i}") - locks[i].release() + orig_n_threads = torch.get_num_threads() + torch.set_num_threads(settings.n_torch_comp_threads) - # Start 3D filter - # - # This runs in the main thread, and blocks until the all the 2D and - # then 3D filtering has finished. As batches of planes are filtered - # by the 3D filter, it releases the locks of subsequent 2D filter - # processes. - mp_3d_filter.process(async_results, locks, callback=callback) + # process the data + mp_3d_filter.process(mp_tile_processor, signal_array, callback=callback) + cells = mp_3d_filter.get_results(splitting_settings) - # it's now done filtering, get results with pool - cells = mp_3d_filter.get_results(worker_pool) + torch.set_num_threads(orig_n_threads) time_elapsed = datetime.now() - start_time - logger.debug( - f"All Planes done. Found {len(cells)} cells in {format(time_elapsed)}" - ) - print("Detection complete - all planes done in : {}".format(time_elapsed)) + s = f"Detection complete. Found {len(cells)} cells in {time_elapsed}" + logger.debug(s) + print(s) return cells - - -Tin = TypeVar("Tin") -Tout = TypeVar("Tout") - - -def _run_func_with_lock( - func: Callable[[Tin], Tout], arg: Tin, lock: Lock -) -> Tout: - """ - Run a function after acquiring a lock. - """ - lock.acquire(blocking=True) - return func(arg) - - -def _map_with_locks( - func: Callable[[Tin], Tout], - iterable: Sequence[Tin], - worker_pool: multiprocessing.pool.Pool, -) -> Tuple[Queue, List[Lock]]: - """ - Map a function to arguments, blocking execution. - - Maps *func* to args in *iterable*, but blocks all execution and - return a queue of asyncronous results and locks for each of the - results. Execution can be enabled by releasing the returned - locks in order. - """ - # Setup a manager to handle the locks - m = multiprocessing.Manager() - # Setup one lock per argument to be mapped - locks = [m.Lock() for _ in range(len(iterable))] - [lock.acquire(blocking=False) for lock in locks] - - async_results: Queue = Queue() - - for arg, lock in zip(iterable, locks): - async_result = worker_pool.apply_async( - _run_func_with_lock, args=(func, arg, lock) - ) - async_results.put(async_result) - - return async_results, locks diff --git a/cellfinder/core/detect/filters/plane/classical_filter.py b/cellfinder/core/detect/filters/plane/classical_filter.py index af331d52..6504d312 100644 --- a/cellfinder/core/detect/filters/plane/classical_filter.py +++ b/cellfinder/core/detect/filters/plane/classical_filter.py @@ -1,45 +1,347 @@ import numpy as np +import torch +import torch.nn.functional as F from scipy.ndimage import gaussian_filter, laplace from scipy.signal import medfilt2d -def enhance_peaks( - img: np.ndarray, clipping_value: float, gaussian_sigma: float = 2.5 -) -> np.ndarray: +@torch.jit.script +def normalize( + filtered_planes: torch.Tensor, + flip: bool, + max_value: float = 1.0, +) -> None: """ - Enhances the peaks (bright pixels) in an input image. + Normalizes the 3d tensor so each z-plane is independently scaled to be + in the [0, max_value] range. If `flip` is `True`, the sign of the tensor + values are flipped before any processing. - Parameters: + It is done to filtered_planes inplace. + """ + num_z = filtered_planes.shape[0] + filtered_planes_1d = filtered_planes.view(num_z, -1) + + if flip: + filtered_planes_1d.mul_(-1) + + planes_min = torch.min(filtered_planes_1d, dim=1, keepdim=True)[0] + filtered_planes_1d.sub_(planes_min) + # take max after subtraction + planes_max = torch.max(filtered_planes_1d, dim=1, keepdim=True)[0] + # if min = max = zero, divide by 1 - it'll stay zero + planes_max[planes_max == 0] = 1 + filtered_planes_1d.div_(planes_max) + + if max_value != 1.0: + # To leave room to label in the 3d detection. + filtered_planes_1d.mul_(max_value) + + +@torch.jit.script +def filter_for_peaks( + planes: torch.Tensor, + med_kernel: torch.Tensor, + gauss_kernel: torch.Tensor, + gauss_kernel_size: int, + lap_kernel: torch.Tensor, + device: str, + clipping_value: float, +) -> torch.Tensor: + """ + Takes the 3d z-stack and returns a new z-stack where the peaks are + highlighted. + + It applies a median filter -> gaussian filter -> laplacian filter. + """ + filtered_planes = planes.unsqueeze(1) # ZYX -> ZCYX input, C=channels + + # ------------------ median filter ------------------ + # extracts patches to compute median over for each pixel + # We go from ZCYX -> ZCYX, C=1 to C=9 with C containing the elements around + # each Z,X,Y voxel over which we compute the median + # Zero padding is ok here + filtered_planes = F.conv2d(filtered_planes, med_kernel, padding="same") + # we're going back to ZCYX=Z1YX by taking median of patches in C dim + filtered_planes = filtered_planes.median(dim=1, keepdim=True)[0] + + # ------------------ gaussian filter ------------------ + # normalize the input data to 0-1 range. Otherwise, if the values are + # large, we'd need a float64 so conv result is accurate + normalize(filtered_planes, flip=False) + + # we need to do reflection padding around the tensor for parity with scipy + # gaussian filtering. Scipy does reflection in a manner typically called + # symmetric: (dcba|abcd|dcba). Torch does it like this: (dcb|abcd|cba). So + # we manually do symmetric padding below + pad = gauss_kernel_size // 2 + padding_mode = "reflect" + # if data is too small for reflect, just use constant border value + if pad >= filtered_planes.shape[-1] or pad >= filtered_planes.shape[-2]: + padding_mode = "replicate" + filtered_planes = F.pad(filtered_planes, (pad,) * 4, padding_mode, 0.0) + # We reflected torch style, so copy/shift everything by one to be symmetric + filtered_planes[:, :, :pad, :] = filtered_planes[ + :, :, 1 : pad + 1, : + ].clone() + filtered_planes[:, :, -pad:, :] = filtered_planes[ + :, :, -pad - 1 : -1, : + ].clone() + filtered_planes[:, :, :, :pad] = filtered_planes[ + :, :, :, 1 : pad + 1 + ].clone() + filtered_planes[:, :, :, -pad:] = filtered_planes[ + :, :, :, -pad - 1 : -1 + ].clone() + + # We apply the 1D gaussian filter twice, once for Y and once for X. The + # filter shape passed in is 11K1 or 111K, depending on device. Where + # K=filter size + # see https://discuss.pytorch.org/t/performance-issue-for-conv2d-with-1d- + # filter-along-a-dim/201734/2 for the reason for the moveaxis depending + # on the device + if device == "cpu": + # kernel shape is 11K1. First do Y (second to last axis) + filtered_planes = F.conv2d( + filtered_planes, gauss_kernel, padding="valid" + ) + # To do X, exchange X,Y axis, filter, change back. On CPU, Y (second + # to last) axis is faster. + filtered_planes = F.conv2d( + filtered_planes.moveaxis(-1, -2), gauss_kernel, padding="valid" + ).moveaxis(-1, -2) + else: + # kernel shape is 111K + # First do Y (second to last axis). Exchange X,Y axis, filter, change + # back. On CUDA, X (last) axis is faster. + filtered_planes = F.conv2d( + filtered_planes.moveaxis(-1, -2), gauss_kernel, padding="valid" + ).moveaxis(-1, -2) + # now do X, last axis + filtered_planes = F.conv2d( + filtered_planes, gauss_kernel, padding="valid" + ) + + # ------------------ laplacian filter ------------------ + # it's a 2d filter. Need to pad using symmetric for scipy parity. But, + # torch doesn't have it, and we used a kernel of size 3, so for padding of + # 1, replicate == symmetric. That's enough for parity with past scipy. If + # we change kernel size in the future, we may have to do as above + padding = lap_kernel.shape[-1] // 2 + filtered_planes = F.pad(filtered_planes, (padding,) * 4, "replicate") + filtered_planes = F.conv2d(filtered_planes, lap_kernel, padding="valid") + + # we don't need the channel axis + filtered_planes = filtered_planes[:, 0, :, :] + + # scale back to full scale, filtered values are negative so flip + normalize(filtered_planes, flip=True, max_value=clipping_value) + return filtered_planes + + +class PeakEnhancer: + """ + A class that filters each plane in a z-stack such that peaks are + visualized. + + It uses a series of 2D filters of median -> gaussian -> + laplacian. Then normalizes each plane to be between [0, clipping_value]. + + Parameters ---------- - img : np.ndarray - Input image. - clipping_value : float - Maximum value for the enhanced image. - gaussian_sigma : float, optional - Standard deviation for the Gaussian filter. Default is 2.5. - - Returns: - ------- - np.ndarray - Enhanced image with peaks. - - Notes: - ------ - The enhancement process includes the following steps: - 1. Applying a 2D median filter. - 2. Applying a Laplacian of Gaussian filter (LoG). - 3. Multiplying by -1 (bright spots respond negative in a LoG). - 4. Rescaling image values to range from 0 to clipping value. - """ - type_in = img.dtype - filtered_img = medfilt2d(img.astype(np.float64)) - filtered_img = gaussian_filter(filtered_img, gaussian_sigma) - filtered_img = laplace(filtered_img) - filtered_img *= -1 - - filtered_img -= filtered_img.min() - filtered_img /= filtered_img.max() - - # To leave room to label in the 3d detection. - filtered_img *= clipping_value - return filtered_img.astype(type_in) + torch_device: str + The device on which the data and processing occurs on. Can be e.g. + "cpu", "cuda" etc. Any data passed to the filter must be on this + device. Returned data will also be on this device. + dtype : torch.dtype + The data-type of the input planes and the type to use internally. + E.g. `torch.float32`. + clipping_value : int + The value such that after normalizing, the max value will be this + clipping_value. + laplace_gaussian_sigma : float + Size of the sigma for the gaussian filter. + use_scipy : bool + If running on the CPU whether to use the scipy filters or the same + pytorch filters used on CUDA. Scipy filters can be faster. + """ + + # binary kernel that generates square patches for each pixel so we can find + # the median around the pixel + med_kernel: torch.Tensor + + # gaussian 1D kernel with kernel/weight shape 11K1 or 111K, depending + # on device. Where K=filter size + gauss_kernel: torch.Tensor + + # 2D laplacian kernel with kernel/weight shape KxK. Where + # K=filter size + lap_kernel: torch.Tensor + + # the value such that after normalizing, the max value will be this + # clipping_value + clipping_value: float + + # sigma value for gaussian filter + laplace_gaussian_sigma: float + + # the torch device to run on. E.g. cpu/cuda. + torch_device: str + + # when running on CPU whether to use pytorch or scipy for filters + use_scipy: bool + + median_filter_size: int = 3 + """ + The median filter size in x/y direction. + + **Must** be odd. + """ + + def __init__( + self, + torch_device: str, + dtype: torch.dtype, + clipping_value: float, + laplace_gaussian_sigma: float, + use_scipy: bool, + ): + super().__init__() + self.torch_device = torch_device.lower() + self.clipping_value = clipping_value + self.laplace_gaussian_sigma = laplace_gaussian_sigma + self.use_scipy = use_scipy + + # all these kernels are odd in size + self.med_kernel = self._get_median_kernel(torch_device, dtype) + self.gauss_kernel = self._get_gaussian_kernel( + torch_device, dtype, laplace_gaussian_sigma + ) + self.lap_kernel = self._get_laplacian_kernel(torch_device, dtype) + + @property + def gaussian_filter_size(self) -> int: + """ + The gaussian filter 1d size. + + It is odd. + """ + return 2 * int(round(4 * self.laplace_gaussian_sigma)) + 1 + + def _get_median_kernel( + self, torch_device: str, dtype: torch.dtype + ) -> torch.Tensor: + """ + Gets a median patch generator kernel, already on the correct + device. + + Based on how kornia does it for median filtering. + """ + # must be odd kernel + kernel_n = self.median_filter_size + if not (kernel_n % 2): + raise ValueError("The median filter size must be odd") + + # extract patches to compute median over for each pixel. When passing + # input we go from ZCYX -> ZCYX, C=1 to C=9 and containing the elements + # around each Z,X,Y over which we can then compute the median + window_range = kernel_n * kernel_n # e.g. 3x3 + kernel = torch.zeros( + (window_range, window_range), device=torch_device, dtype=dtype + ) + idx = torch.arange(window_range, device=torch_device) + # diagonal of e.g. 9x9 is 1 + kernel[idx, idx] = 1.0 + # out channels, in channels, n*y, n*x. The kernel collects all the 3x3 + # elements around a pixel, using a binary mask for each element, as a + # separate channel. So we go from 1 to 9 channels in the output + kernel = kernel.view(window_range, 1, kernel_n, kernel_n) + + return kernel + + def _get_gaussian_kernel( + self, + torch_device: str, + dtype: torch.dtype, + laplace_gaussian_sigma: float, + ) -> torch.Tensor: + """Gets the 1D gaussian kernel used to filter the data.""" + # we do 2 1D filters, once on each y, x dim. + # shape of kernel will be 11K1 with dims Z, C, Y, X. C=1, Z is expanded + # to number of z during filtering. + kernel_size = self.gaussian_filter_size + + # to get the values of a 1D gaussian kernel, we pass a single impulse + # data through the filter, which recovers the filter values. We do this + # because scipy doesn't make their kernel available in public API and + # we want parity with scipy filtering + impulse = np.zeros(kernel_size) + # the impulse needs to be to the left of center + impulse[kernel_size // 2] = 1 + kernel = gaussian_filter( + impulse, laplace_gaussian_sigma, mode="constant", cval=0 + ) + # kernel should be fully symmetric + assert kernel[0] == kernel[-1] + gauss_kernel = torch.from_numpy(kernel).type(dtype).to(torch_device) + + # default shape is (y, x) with y axis filtered only - we transpose + # input to filter on x + gauss_kernel = gauss_kernel.view(1, 1, -1, 1) + + # see https://discuss.pytorch.org/t/performance-issue-for-conv2d- + # with-1d-filter-along-a-dim/201734. Conv2d is faster on a specific dim + # for 1D filters depending on CPU/CUDA. See also filter_for_peaks + # on CPU, we only do conv2d on the (1st) dim + if torch_device != "cpu": + # on CUDA, we only filter on the x dim, flipping input to filter y + gauss_kernel = gauss_kernel.view(1, 1, 1, -1) + + return gauss_kernel + + def _get_laplacian_kernel( + self, torch_device: str, dtype: torch.dtype + ) -> torch.Tensor: + """Gets a 2d laplacian kernel, based on scipy's laplace.""" + # for parity with scipy, scipy computes the laplacian with default + # parameters and kernel size 3 using filter coefficients [1, -2, 1]. + # Each filtered pixel is the sum of the filter around the pixel + # vertically and horizontally. We can do it in 2d at once with + # coefficients below (faster than 2x1D for such small filter) + return torch.as_tensor( + [[0, 1, 0], [1, -4, 1], [0, 1, 0]], + dtype=dtype, + device=torch_device, + ).view(1, 1, 3, 3) + + def enhance_peaks(self, planes: torch.Tensor) -> torch.Tensor: + """ + Applies the filtering and normalization to the 3d z-stack (not inplace) + and returns the filtered z-stack. + """ + if self.torch_device == "cpu" and self.use_scipy: + filtered_planes = planes.clone() + for i in range(planes.shape[0]): + img = planes[i, :, :].numpy() + img = medfilt2d(img) + img = gaussian_filter(img, self.laplace_gaussian_sigma) + img = laplace(img) + filtered_planes[i, :, :] = torch.from_numpy(img) + + # laplace makes values negative so flip + normalize( + filtered_planes, + flip=True, + max_value=self.clipping_value, + ) + return filtered_planes + + filtered_planes = filter_for_peaks( + planes, + self.med_kernel, + self.gauss_kernel, + self.gaussian_filter_size, + self.lap_kernel, + self.torch_device, + self.clipping_value, + ) + return filtered_planes diff --git a/cellfinder/core/detect/filters/plane/plane_filter.py b/cellfinder/core/detect/filters/plane/plane_filter.py index f7e7868e..82434d16 100644 --- a/cellfinder/core/detect/filters/plane/plane_filter.py +++ b/cellfinder/core/detect/filters/plane/plane_filter.py @@ -1,87 +1,169 @@ -from dataclasses import dataclass -from threading import Lock -from typing import Optional, Tuple +from dataclasses import dataclass, field +from typing import Tuple -import dask.array as da -import numpy as np +import torch -from cellfinder.core import types -from cellfinder.core.detect.filters.plane.classical_filter import enhance_peaks +from cellfinder.core.detect.filters.plane.classical_filter import PeakEnhancer from cellfinder.core.detect.filters.plane.tile_walker import TileWalker @dataclass class TileProcessor: """ - Attributes + Processor that filters each plane to highlight the peaks and also + tiles and thresholds each plane returning a mask indicating which + tiles are inside the brain. + + Each input plane is: + + 1. Clipped to [0, clipping_value]. + 2. Tiled and compared to the corner tile. Any tile that is "bright" + according to `TileWalker` is marked as being in the brain. + 3. Filtered + 1. Run through the peak enhancement filter (see `PeakEnhancer`) + 2. Thresholded. Any values that are larger than + (mean + stddev * n_sds_above_mean_thresh) are set to + threshold_value. + + Parameters ---------- - clipping_value : - Upper value that the input plane is clipped to. - threshold_value : + plane_shape : tuple(int, int) + Height/width of the planes. + clipping_value : int + Upper value that the input planes are clipped to. Result is scaled so + max is this value. + threshold_value : int Value used to mark bright features in the input planes after they have been run through the 2D filter. + n_sds_above_mean_thresh : float + Number of standard deviations above the mean threshold to use for + determining whether a voxel is bright. + log_sigma_size : float + Size of the sigma for the gaussian filter. + soma_diameter : float + Diameter of the soma in voxels. + torch_device: str + The device on which the data and processing occurs on. Can be e.g. + "cpu", "cuda" etc. Any data passed to the filter must be on this + device. Returned data will also be on this device. + dtype : str + The data-type of the input planes and the type to use internally. + E.g. "float32". + use_scipy : bool + If running on the CPU whether to use the scipy filters or the same + pytorch filters used on CUDA. Scipy filters can be faster. """ + # Upper value that the input plane is clipped to. Result is scaled so + # max is this value clipping_value: int + # Value used to mark bright features in the input planes after they have + # been run through the 2D filter threshold_value: int - soma_diameter: int - log_sigma_size: float + # voxels who are this many std above mean or more are set to + # threshold_value n_sds_above_mean_thresh: float - def get_tile_mask( - self, plane: types.array, lock: Optional[Lock] = None - ) -> Tuple[np.ndarray, np.ndarray]: - """ - This thresholds the input plane, and returns a mask indicating which - tiles are inside the brain. + # filter that finds the peaks in the planes + peak_enhancer: PeakEnhancer = field(init=False) + # generates tiles of the planes, with each tile marked as being inside + # or outside the brain based on brightness + tile_walker: TileWalker = field(init=False) - The input plane is: + def __init__( + self, + plane_shape: Tuple[int, int], + clipping_value: int, + threshold_value: int, + n_sds_above_mean_thresh: float, + log_sigma_size: float, + soma_diameter: int, + torch_device: str, + dtype: str, + use_scipy: bool, + ): + self.clipping_value = clipping_value + self.threshold_value = threshold_value + self.n_sds_above_mean_thresh = n_sds_above_mean_thresh - 1. Clipped to [0, self.clipping_value] - 2. Run through a peak enhancement filter (see `classical_filter.py`) - 3. Thresholded. Any values that are larger than - (mean + stddev * self.n_sds_above_mean_thresh) are set to - self.threshold_value in-place. + laplace_gaussian_sigma = log_sigma_size * soma_diameter + self.peak_enhancer = PeakEnhancer( + torch_device=torch_device, + dtype=getattr(torch, dtype), + clipping_value=self.clipping_value, + laplace_gaussian_sigma=laplace_gaussian_sigma, + use_scipy=use_scipy, + ) + + self.tile_walker = TileWalker( + plane_shape=plane_shape, + soma_diameter=soma_diameter, + ) + + def get_tile_mask( + self, planes: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Applies the filtering listed in the class description. Parameters ---------- - plane : - Input plane. - lock : - If given, block reading the plane into memory until the lock - can be acquired. + planes : torch.Tensor + Input planes (z-stack). Note, the input data is modified. Returns ------- - plane : - Thresholded plane. - inside_brain_tiles : + planes : torch.Tensor + Filtered and thresholded planes (z-stack). + inside_brain_tiles : torch.Tensor Boolean mask indicating which tiles are inside (1) or outside (0) the brain. + It's a z-stack whose planes are the shape of the number of tiles + in each planar axis. """ - laplace_gaussian_sigma = self.log_sigma_size * self.soma_diameter - plane = plane.T - np.clip(plane, 0, self.clipping_value, out=plane) - if lock is not None: - lock.acquire(blocking=True) - # Read plane from a dask array into memory as a numpy array - if isinstance(plane, da.Array): - plane = np.array(plane) - + torch.clip_(planes, 0, self.clipping_value) # Get tiles that are within the brain - walker = TileWalker(plane, self.soma_diameter) - walker.mark_bright_tiles() - inside_brain_tiles = walker.bright_tiles_mask - + inside_brain_tiles = self.tile_walker.get_bright_tiles(planes) # Threshold the image - thresholded_img = enhance_peaks( - plane.copy(), - self.clipping_value, - gaussian_sigma=laplace_gaussian_sigma, + enhanced_planes = self.peak_enhancer.enhance_peaks(planes) + + _threshold_planes( + planes, + enhanced_planes, + self.n_sds_above_mean_thresh, + self.threshold_value, ) - avg = np.mean(thresholded_img) - sd = np.std(thresholded_img) - threshold = avg + self.n_sds_above_mean_thresh * sd - plane[thresholded_img > threshold] = self.threshold_value - return plane, inside_brain_tiles + return planes, inside_brain_tiles + + def get_tiled_buffer(self, depth: int, device: str): + return self.tile_walker.get_tiled_buffer(depth, device) + + +@torch.jit.script +def _threshold_planes( + planes: torch.Tensor, + enhanced_planes: torch.Tensor, + n_sds_above_mean_thresh: float, + threshold_value: int, +) -> None: + """ + Sets each plane (in-place) to threshold_value, where the corresponding + enhanced_plane > mean + n_sds_above_mean_thresh*std. Each plane will be + set to zero elsewhere. + """ + planes_1d = enhanced_planes.view(enhanced_planes.shape[0], -1) + + # add back last dim + avg = torch.mean(planes_1d, dim=1, keepdim=True).unsqueeze(2) + sd = torch.std(planes_1d, dim=1, keepdim=True).unsqueeze(2) + threshold = avg + n_sds_above_mean_thresh * sd + + above = enhanced_planes > threshold + planes[above] = threshold_value + # subsequent steps only care about the values that are set to threshold or + # above in planes. We set values in *planes* to threshold based on the + # value in *enhanced_planes*. So, there could be values in planes that are + # at threshold already, but in enhanced_planes they are not. So it's best + # to zero all other values, so voxels previously at threshold don't count + planes[torch.logical_not(above)] = 0 diff --git a/cellfinder/core/detect/filters/plane/tile_walker.py b/cellfinder/core/detect/filters/plane/tile_walker.py index e4abe5cf..e8b994b6 100644 --- a/cellfinder/core/detect/filters/plane/tile_walker.py +++ b/cellfinder/core/detect/filters/plane/tile_walker.py @@ -1,8 +1,8 @@ import math -from typing import Generator, Tuple +from typing import Tuple -import numpy as np -from numba import njit +import torch +import torch.nn.functional as F class TileWalker: @@ -15,74 +15,140 @@ class TileWalker: The mean and standard deviation of this tile is calculated, and the threshold set at 1 + mean + (2 * stddev). - Attributes + Parameters ---------- - bright_tiles_mask : - An boolean array whose entries correspond to whether each tile is - bright (1) or dark (0). The values are set in - self.mark_bright_tiles(). + plane_shape : tuple(int, int) + Height/width of the planes. + soma_diameter : float + Diameter of the soma in voxels. """ - def __init__(self, img: np.ndarray, soma_diameter: int) -> None: - self.img = img - self.img_width, self.img_height = img.shape - self.tile_width = soma_diameter * 2 - self.tile_height = soma_diameter * 2 + def __init__( + self, plane_shape: Tuple[int, int], soma_diameter: int + ) -> None: - n_tiles_width = math.ceil(self.img_width / self.tile_width) - n_tiles_height = math.ceil(self.img_height / self.tile_height) - self.bright_tiles_mask = np.zeros( - (n_tiles_width, n_tiles_height), dtype=bool - ) + self.img_height, self.img_width = plane_shape + self.tile_height = soma_diameter * 2 + self.tile_width = soma_diameter * 2 - corner_tile = img[0 : self.tile_width, 0 : self.tile_height] - corner_intensity = np.mean(corner_tile) - corner_sd = np.std(corner_tile) - # add 1 to ensure not 0, as disables - self.out_of_brain_threshold = (corner_intensity + (2 * corner_sd)) + 1 + self.n_tiles_height = math.ceil(self.img_height / self.tile_height) + self.n_tiles_width = math.ceil(self.img_width / self.tile_width) - def _get_tiles(self) -> Generator[Tuple[int, int, np.ndarray], None, None]: + def get_bright_tiles(self, planes: torch.Tensor) -> torch.Tensor: """ - Generator that yields tiles of the 2D image. + Takes a 3d z-stack. For each z it computes the mean/std of the corner + tile and uses that to get a in/out of brain threshold for each z. - Notes - ----- - The final tile in each dimension can have a smaller size than the - rest of the tiles if the tile shape does not exactly divide the - image shape. - """ - for y in range( - 0, self.img_height - self.tile_height, self.tile_height - ): - for x in range( - 0, self.img_width - self.tile_width, self.tile_width - ): - tile = self.img[ - x : x + self.tile_width, y : y + self.tile_height - ] - yield x, y, tile - - def mark_bright_tiles(self) -> None: - """ - Loop through tiles, and if the average value of a tile is - greater than the intensity threshold mark the tile as bright - in self.bright_tiles_mask. + Parameters + ---------- + planes : torch.Tensor + 3d z-stack. + + Returns + ------- + out_of_brain_thresholds : torch.Tensor + 3d z-stack whose planar shape is the number of tiles in a plane. + The returned data will be on the same torch device as the input + planes. """ - threshold = self.out_of_brain_threshold - if threshold == 0: - return + return _get_bright_tiles( + planes, + self.n_tiles_height, + self.n_tiles_width, + self.tile_height, + self.tile_width, + ) - for x, y, tile in self._get_tiles(): - if not is_low_average(tile, threshold): - mask_x = x // self.tile_width - mask_y = y // self.tile_height - self.bright_tiles_mask[mask_x, mask_y] = True + def get_tiled_buffer(self, depth: int, device: str): + return torch.zeros( + (depth, self.n_tiles_height, self.n_tiles_width), + dtype=torch.bool, + device=device, + ) -@njit -def is_low_average(tile: np.ndarray, threshold: float) -> bool: +@torch.jit.script +def _get_out_of_brain_threshold( + planes: torch.Tensor, tile_height: int, tile_width: int +) -> torch.Tensor: """ - Return `True` if the average value of *tile* is below *threshold*. + Takes a 3d z-stack. For each z it computes the mean/std of the corner tile + and uses that to get a in/out of brain threshold for each z-stack. + + Parameters + ---------- + planes : + 3d z-stack. + tile_height : + Height of each tile. + tile_width : + Width of each tile. + + Returns + ------- + out_of_brain_thresholds : + 1d z-stack. + """ + # get corner tile + corner_tiles = planes[:, 0:tile_height, 0:tile_width] + # convert from ZYX -> ZK, where K is the elements in the corner tile + corner_tiles = corner_tiles.reshape((planes.shape[0], -1)) + + # need to operate in float64, in case the values are large + corner64 = corner_tiles.type(torch.float64) + corner_intensity = torch.mean(corner64, dim=1).type(planes.dtype) + # for parity with past when we used np.std, which defaults to ddof=0 + corner_sd = torch.std(corner64, dim=1, correction=0).type(planes.dtype) + # add 1 to ensure not 0, as disables + out_of_brain_thresholds = corner_intensity + 2 * corner_sd + 1 + + return out_of_brain_thresholds + + +@torch.jit.script +def _get_bright_tiles( + planes: torch.Tensor, + n_tiles_height: int, + n_tiles_width: int, + tile_height: int, + tile_width: int, +) -> torch.Tensor: """ - avg = np.mean(tile) - return avg < threshold + Loop through the tiles of the plane for each plane. And if the average + value of a tile is greater than the intensity threshold of that plain, + mark the tile as bright. + """ + bright_tiles_mask = torch.zeros( + (planes.shape[0], n_tiles_height, n_tiles_width), + dtype=torch.bool, + device=planes.device, + ) + # if we don't have enough size for a single tile, it's all outside + if planes.shape[1] < tile_height or planes.shape[2] < tile_width: + return bright_tiles_mask + + # for each plane, the threshold + out_of_brain_thresholds = _get_out_of_brain_threshold( + planes, tile_height, tile_width + ) + # thresholds Z -> ZYX shape + thresholds = out_of_brain_thresholds.view(-1, 1, 1) + + # ZYX -> ZCYX required for function (C=1) + planes = planes.unsqueeze(1) + # get the average of each tile + tile_avg = F.avg_pool2d( + planes, + (tile_height, tile_width), + ceil_mode=False, # default is False, but to make sure + ) + # go back from ZCYX -> ZYX + tile_avg = tile_avg[:, 0, :, :] + + bright = tile_avg >= thresholds + # tile_avg and bright may be smaller than bright_tiles_mask because + # avg_pool2d first subtracts the kernel size before computing # tiles. + # So contain view to that size + bright_tiles_mask[:, : bright.shape[1], : bright.shape[2]][bright] = True + + return bright_tiles_mask diff --git a/cellfinder/core/detect/filters/setup_filters.py b/cellfinder/core/detect/filters/setup_filters.py index d68387fc..15db72ff 100644 --- a/cellfinder/core/detect/filters/setup_filters.py +++ b/cellfinder/core/detect/filters/setup_filters.py @@ -1,70 +1,427 @@ +""" +Container for all the settings used during 2d/3d filtering and cell detection. +""" + import math -from typing import Tuple +from dataclasses import dataclass +from functools import cached_property +from typing import Callable, Optional, Tuple, Type import numpy as np +from brainglobe_utils.general.system import get_num_processes -from cellfinder.core.detect.filters.volume.ball_filter import BallFilter -from cellfinder.core.detect.filters.volume.structure_detection import ( - CellDetector, +from cellfinder.core.tools.tools import ( + get_data_converter, + get_max_possible_int_value, ) -from cellfinder.core.tools.tools import get_max_possible_value - - -def get_ball_filter( - *, - plane: np.ndarray, - soma_diameter: int, - ball_xy_size: int, - ball_z_size: int, - ball_overlap_fraction: float = 0.6, -) -> BallFilter: - # thrsh_val is used to clip the data in plane to make sure - # a number is available to mark cells. soma_centre_val is the - # number used to mark cells. - max_value = get_max_possible_value(plane) - thrsh_val = max_value - 1 - soma_centre_val = max_value - - tile_width = soma_diameter * 2 - plane_height, plane_width = plane.shape - - ball_filter = BallFilter( - plane_width, - plane_height, - ball_xy_size, - ball_z_size, - overlap_fraction=ball_overlap_fraction, - tile_step_width=tile_width, - tile_step_height=tile_width, - threshold_value=thrsh_val, - soma_centre_value=soma_centre_val, - ) - return ball_filter - - -def get_cell_detector( - *, plane_shape: Tuple[int, int], ball_z_size: int, z_offset: int = 0 -) -> CellDetector: - plane_height, plane_width = plane_shape - start_z = z_offset + int(math.floor(ball_z_size / 2)) - return CellDetector(plane_width, plane_height, start_z=start_z) - - -def setup_tile_filtering(plane: np.ndarray) -> Tuple[int, int]: - """ - Setup values that are used to threshold the plane during 2D filtering. - - Returns - ------- - clipping_value : - Upper value used to clip planes before 2D filtering. This is chosen - to leave two numbers left that can later be used to mark bright points - during the 2D and 3D filtering stages. - threshold_value : - Value used to mark bright pixels after 2D filtering. - """ - max_value = get_max_possible_value(plane) - clipping_value = max_value - 2 - thrsh_val = max_value - 1 - - return clipping_value, thrsh_val + +MAX_TORCH_COMP_THREADS = 12 +# As seen in the benchmarks in the original PR, when running on CPU using +# more than ~12 cores it starts to result in slowdowns. So limit to this +# many cores when doing computational work (e.g. torch.functional.Conv2D). +# +# This prevents thread contention. + + +@dataclass +class DetectionSettings: + """ + Configuration class with all the parameters used during 2d and 3d filtering + and structure splitting. + """ + + plane_original_np_dtype: Type[np.number] = np.uint16 + """ + The numpy data type of the input data that will be passed to the filtering + pipeline. + + Throughout filtering at key stages, the data range is kept such + that we can convert the data back to this data type without having to + scale. I.e. the min/max of the data fits in this data type. + + Except for the cell detection stage, in that stage the data range can be + larger because the values are cell IDs and not intensity data anymore. + + During structure splitting, we do just 3d filtering/cell detection. This is + again the data type used as input to the filtering. + + Defaults to `uint16` + """ + + detection_dtype: Type[np.number] = np.uint64 + """ + The numpy data type that the cell detection code expects our filtered + data to be in. + + After filtering, where the voxels are intensity values, we pass the data + to cell detection where the voxels turn into cell IDs. So the data type + needs to be large enough to support the number of cells in the data. + + To get the data from the filtering data type to the detection data type + use `detection_data_converter_func`. + + Defaults to `uint64`. + """ + + plane_shape: Tuple[int, int] = (1, 1) + """ + The shape of each plane of the input data as (height, width) - i.e. + (axis 1, axis 2) in the z-stack where z is the first axis. + """ + + start_plane: int = 0 + """The index of first plane to process, in the input data (inclusive).""" + + end_plane: int = 1 + """ + The index of the last plane at which to stop processing the input data + (not inclusive). + """ + + voxel_sizes: Tuple[float, float, float] = (1.0, 1.0, 1.0) + """ + Tuple of voxel sizes in each dimension (z, y, x). We use this to convert + from `um` to pixel sizes. + """ + + soma_spread_factor: float = 1.4 + """Spread factor for soma size - how much it may stretch in the images.""" + + soma_diameter_um: float = 16 + """ + Diameter of a typical soma in um. Bright areas larger than this will be + split. + """ + + max_cluster_size_um3: float = 100_000 + """ + Maximum size of a cluster (bright area) that will be processed, in um. + Larger bright areas are skipped as artifacts. + """ + + ball_xy_size_um: float = 6 + """ + Diameter of the 3d spherical kernel filter in the x/y dimensions in um. + See `ball_xy_size` for size in voxels. + """ + + ball_z_size_um: float = 15 + """ + Diameter of the 3d spherical kernel filter in the z dimension in um. + See `ball_z_size` for size in voxels. + + `ball_z_size` also determines to the minimum number of planes that are + stacked before can filter the central plane of the stack. + """ + + ball_overlap_fraction: float = 0.6 + """ + Fraction of overlap between a bright area and the spherical kernel, + for the area to be considered a single ball. + """ + + log_sigma_size: float = 0.2 + """Size of the sigma for the 2d Gaussian filter.""" + + n_sds_above_mean_thresh: float = 10 + """ + Number of standard deviations above the mean intensity to use for a + threshold to define bright areas. Below it, it's not considered bright. + """ + + outlier_keep: bool = False + """Whether to keep outlier structures during detection.""" + + artifact_keep: bool = False + """Whether to keep artifact structures during detection.""" + + save_planes: bool = False + """ + Whether to save the 2d/3d filtered planes during after filtering. + + It is saved as tiffs of data type `plane_original_np_dtype`. + """ + + plane_directory: Optional[str] = None + """Directory path where to save the planes, if saving.""" + + batch_size: int = 1 + """ + The number of planes to process in each batch of the 2d/3d filters. + + For CPU, each plane in a batch is 2d filtered (the slowest filters) in its + own sub-process. But 3d filtering happens in a single thread. So larger + batches will use more processes but can speed up filtering until IO/3d + filters become the bottleneck. + + For CUDA, 2d and 3d filtering happens on the GPU and the larger the batch + size, the better the performance. Until it fills up the GPU memory - after + which it becomes slower. + + In all cases, higher batch size means more RAM used. + """ + + num_prefetch_batches: int = 2 + """ + The number of batches to load into memory. + + This many batches are loaded in memory so the next batch is ready to be + sent to the filters as soon as the previous batch is done. + + The higher the number the more RAM used, but it can also speed up + processing if IO becomes a limiting factor. + """ + + torch_device: str = "cpu" + """ + The device on which to run the 2d and/or 3d filtering. + + Either `"cpu"` or PyTorch's GPU device name, such as `"cuda"` or `"cuda:0"` + to run on the first GPU. + """ + + n_free_cpus: int = 2 + """ + Number of free CPU cores to keep available and not use during parallel + processing. Internally, more cores may actually be used by the system, + which we don't control. + """ + + n_splitting_iter: int = 10 + """ + During the structure splitting phase we iteratively shrink the bright areas + and re-filter with the 3d filter. This is the number of iterations to do. + + This is a maximum because we also stop if there are no more structures left + during any iteration. + """ + + def __getstate__(self): + d = self.__dict__.copy() + # when sending across processes, we need to be able to pickle. This + # property cannot be pickled (and doesn't need to be) + if "filter_data_converter_func" in d: + del d["filter_data_converter_func"] + return d + + @cached_property + def filter_data_converter_func(self) -> Callable[[np.ndarray], np.ndarray]: + """ + A callable that takes a numpy array of type + `plane_original_np_dtype` and converts it into the `filtering_dtype` + type. + + We use this to convert the input data into the data type used for + filtering. + """ + return get_data_converter( + self.plane_original_np_dtype, self.filtering_dtype + ) + + @cached_property + def filtering_dtype(self) -> Type[np.floating]: + """ + The numpy data type that the 2d/3d filters expect our data to be in. + Use `filter_data_converter_func` to convert to this type. + + The data will be used in the form of torch tensors, but it'll be this + data type. + + Currently, it's either float32 or float64. + """ + original_dtype = self.plane_original_np_dtype + original_max_int = get_max_possible_int_value(original_dtype) + + # does original data fit in float32 + if original_max_int <= get_max_possible_int_value(np.float32): + return np.float32 + # what about float64 + if original_max_int <= get_max_possible_int_value(np.float64): + return np.float64 + raise TypeError("Input array data type is too big for a float64") + + @cached_property + def clipping_value(self) -> int: + """ + The maximum value used to clip the input to, as well as the value to + which the filtered data is scaled to during filtering. + + This ensures the filtered data fits in the `plane_original_np_dtype`. + """ + return get_max_possible_int_value(self.plane_original_np_dtype) - 2 + + @cached_property + def threshold_value(self) -> int: + """ + The value used to set bright areas as indicating it's above a + brightness threshold, during 2d filtering. + """ + return get_max_possible_int_value(self.plane_original_np_dtype) - 1 + + @cached_property + def soma_centre_value(self) -> int: + """ + The value used to mark bright areas as the location of a soma center, + during 3d filtering. + """ + return get_max_possible_int_value(self.plane_original_np_dtype) + + @cached_property + def detection_soma_centre_value(self) -> int: + """ + The value used to mark bright areas as the location of a soma center, + during detection. See `detection_data_converter_func`. + """ + return get_max_possible_int_value(self.detection_dtype) + + @cached_property + def detection_data_converter_func( + self, + ) -> Callable[[np.ndarray], np.ndarray]: + """ + A callable that takes a numpy array of type + `filtering_dtype` and converts it into the `detection_dtype` + type. + + It takes the filtered data where somas are marked with the + `soma_centre_value` and returns a volume of the same size where the + somas are marked with `detection_soma_centre_value`. Other voxels are + zeroed. + + We use this to convert the output of the 3d filter into the data + passed to cell detection. + """ + + def convert_for_cell_detection(data: np.ndarray) -> np.ndarray: + detection_data = np.zeros_like(data, dtype=self.detection_dtype) + detection_data[data == self.soma_centre_value] = ( + self.detection_soma_centre_value + ) + return detection_data + + return convert_for_cell_detection + + @property + def tile_height(self) -> int: + """ + The height of each tile of the tiled input image, used during filtering + to mark individual tiles as inside/outside the brain. + """ + return self.soma_diameter * 2 + + @property + def tile_width(self) -> int: + """ + The width of each tile of the tiled input image, used during filtering + to mark individual tiles as inside/outside the brain. + """ + return self.soma_diameter * 2 + + @property + def plane_height(self) -> int: + """The height of each input plane of the z-stack.""" + return self.plane_shape[0] + + @property + def plane_width(self) -> int: + """The width of each input plane of the z-stack.""" + return self.plane_shape[1] + + @property + def n_planes(self) -> int: + """The number of planes in the z-stack.""" + return self.end_plane - self.start_plane + + @property + def n_processes(self) -> int: + """The maximum number of process we can use during detection.""" + n = get_num_processes(min_free_cpu_cores=self.n_free_cpus) + return max(n - 1, 1) + + @property + def n_torch_comp_threads(self) -> int: + """ + The maximum number of process we should use during filtering, + using pytorch. + + This is less than `n_processes` because we account for thread + contention. Specifically it's limited by `MAX_TORCH_COMP_THREADS`. + """ + # Reserve batch_size cores for batch multiprocess parallelization on + # CPU, 1 per plane. for GPU it doesn't matter either way because it + # doesn't use threads. Also reserve for data feeding thread and + # cell detection. Don't let it go below 4. + n = max(4, self.n_processes - self.batch_size - 2) + n = min(n, self.n_processes) + return min(n, MAX_TORCH_COMP_THREADS) + + @property + def in_plane_pixel_size(self) -> float: + """Returns the average in-plane (xy) um/pixel.""" + voxel_sizes = self.voxel_sizes + return (voxel_sizes[2] + voxel_sizes[1]) / 2 + + @cached_property + def soma_diameter(self) -> int: + """The `soma_diameter_um`, but in voxels.""" + return int(round(self.soma_diameter_um / self.in_plane_pixel_size)) + + @cached_property + def max_cluster_size(self) -> int: + """The `max_cluster_size_um3`, but in voxels.""" + voxel_sizes = self.voxel_sizes + voxel_volume = ( + float(voxel_sizes[2]) + * float(voxel_sizes[1]) + * float(voxel_sizes[0]) + ) + return int(round(self.max_cluster_size_um3 / voxel_volume)) + + @cached_property + def ball_xy_size(self) -> int: + """The `ball_xy_size_um`, but in voxels.""" + return int(round(self.ball_xy_size_um / self.in_plane_pixel_size)) + + @property + def z_pixel_size(self) -> float: + """Returns the um/pixel in the z direction.""" + return self.voxel_sizes[0] + + @cached_property + def ball_z_size(self) -> int: + """The `ball_z_size_um`, but in voxels.""" + ball_z_size = int(round(self.ball_z_size_um / self.z_pixel_size)) + + if not ball_z_size: + raise ValueError( + "Ball z size has been calculated to be 0 voxels." + " This may be due to large axial spacing of your data or the " + "ball_z_size_um parameter being too small. " + "Please check input parameters are correct. " + "Note that cellfinder requires high resolution data in all " + "dimensions, so that cells can be detected in multiple " + "image planes." + ) + return ball_z_size + + @property + def max_cell_volume(self) -> float: + """ + The maximum cell volume to consider as a single cell, in voxels. + + If we find a bright area larger than that, we will split it. + """ + radius = self.soma_spread_factor * self.soma_diameter / 2 + return (4 / 3) * math.pi * radius**3 + + @property + def plane_prefix(self) -> str: + """ + The prefix of the filename to use to save the 2d/3d filtered planes. + + To save plane `k`, do `plane_prefix.format(n=k)`. You can then add + an extension etc. + """ + n = max(4, int(math.ceil(math.log10(self.n_planes)))) + name = f"plane_{{n:0{n}d}}" + return name diff --git a/cellfinder/core/detect/filters/volume/ball_filter.py b/cellfinder/core/detect/filters/volume/ball_filter.py index c5f5f5b8..a2f1ef1c 100644 --- a/cellfinder/core/detect/filters/volume/ball_filter.py +++ b/cellfinder/core/detect/filters/volume/ball_filter.py @@ -1,29 +1,35 @@ +import math from functools import lru_cache +from typing import Optional import numpy as np -from numba import njit, objmode, prange -from numba.core import types -from numba.experimental import jitclass +import torch +import torch.nn.functional as F from cellfinder.core.tools.array_operations import bin_mean_3d from cellfinder.core.tools.geometry import make_sphere -DEBUG = False -uint32_3d_type = types.uint32[:, :, :] -bool_3d_type = types.bool_[:, :, :] -float_3d_type = types.float64[:, :, :] +class InvalidVolume(ValueError): + """ + Raised when the volume passed to BallFilter is too small or does not meet + requirements. + """ + + pass @lru_cache(maxsize=50) def get_kernel(ball_xy_size: int, ball_z_size: int) -> np.ndarray: - # Create a spherical kernel. - # - # This is done by: - # 1. Generating a binary sphere at a resolution *upscale_factor* larger - # than desired. - # 2. Downscaling the binary sphere to get a 'fuzzy' sphere at the - # original intended scale + """ + Create a spherical kernel. + + This is done by: + 1. Generating a binary sphere at a resolution *upscale_factor* larger + than desired. + 2. Downscaling the binary sphere to get a 'fuzzy' sphere at the + original intended scale + """ upscale_factor: int = 7 upscaled_kernel_shape = ( upscale_factor * ball_xy_size, @@ -42,11 +48,11 @@ def get_kernel(ball_xy_size: int, ball_z_size: int) -> np.ndarray: upscaled_ball_radius, upscaled_ball_centre_position, ) - sphere_kernel = sphere_kernel.astype(np.float64) + sphere_kernel = sphere_kernel.astype(np.float32) kernel = bin_mean_3d( sphere_kernel, - bin_height=upscale_factor, bin_width=upscale_factor, + bin_height=upscale_factor, bin_depth=upscale_factor, ) @@ -59,359 +65,351 @@ def get_kernel(ball_xy_size: int, ball_z_size: int) -> np.ndarray: return kernel -# volume indices/size is 64 bit for very large brains(!) -spec = [ - ("ball_xy_size", types.uint32), - ("ball_z_size", types.uint32), - ("tile_step_width", types.uint64), - ("tile_step_height", types.uint64), - ("THRESHOLD_VALUE", types.uint32), - ("SOMA_CENTRE_VALUE", types.uint32), - ("overlap_fraction", types.float64), - ("overlap_threshold", types.float64), - ("middle_z_idx", types.uint32), - ("_num_z_added", types.uint32), - ("kernel", float_3d_type), - ("volume", uint32_3d_type), - ("inside_brain_tiles", bool_3d_type), -] - - -@jitclass(spec=spec) class BallFilter: """ A 3D ball filter. - This runs a spherical kernel across the (x, y) dimensions + This runs a spherical kernel across the 2d planar dimensions of a *ball_z_size* stack of planes, and marks pixels in the middle - plane of the stack that have a high enough intensity within the - spherical kernel. + plane of the stack that have a high enough intensity over the + the spherical kernel. + + Parameters + ---------- + plane_height, plane_width : int + Height/width of the planes. + ball_xy_size : int + Diameter of the spherical kernel in the x/y dimensions. + ball_z_size : int + Diameter of the spherical kernel in the z dimension. + Equal to the number of planes stacked to filter + the central plane of the stack. + overlap_fraction : float + The fraction of pixels within the spherical kernel that + have to be over *threshold_value* for a pixel to be marked + as having a high intensity. + threshold_value : int + Value above which an individual pixel is considered to have + a high intensity. + soma_centre_value : int + Value used to mark pixels with a high enough intensity. + tile_height, tile_width : int + Width/height of individual tiles in the mask generated by + 2D filtering. + dtype : str + The data-type of the input planes and the type to use internally. + E.g. "float32". + batch_size: int + The number of planes that will be typically passed in a single batch to + `append`. This is only used to calculate `num_batches_before_ready`. + Defaults to 1. + torch_device: str + The device on which the data and processing occurs on. Can be e.g. + "cpu", "cuda" etc. Defaults to "cpu". Any data passed to the filter + must be on this device. Returned data will also be on this device. + use_mask : bool + Whether tiling masks will be used in `append`. If False, tile masks + won't be passed in and/or will be ignored. Defaults to True. + """ + + num_batches_before_ready: int """ + The number of batches of size `batch_size` passed to `append` + before `ready` would return True. + """ + + # the inside brain tiled mask, if tiles are used (use_mask is True) + inside_brain_tiles: Optional[torch.Tensor] = None def __init__( self, - plane_width: int, plane_height: int, + plane_width: int, ball_xy_size: int, ball_z_size: int, overlap_fraction: float, - tile_step_width: int, - tile_step_height: int, threshold_value: int, soma_centre_value: int, + tile_height: int, + tile_width: int, + dtype: str, + batch_size: int = 1, + torch_device: str = "cpu", + use_mask: bool = True, ): - """ - Parameters - ---------- - plane_width, plane_height : - Width/height of the planes. - ball_xy_size : - Diameter of the spherical kernel in the x/y dimensions. - ball_z_size : - Diameter of the spherical kernel in the z dimension. - Equal to the number of planes that stacked to filter - the central plane of the stack. - overlap_fraction : - The fraction of pixels within the spherical kernel that - have to be over *threshold_value* for a pixel to be marked - as having a high intensity. - tile_step_width, tile_step_height : - Width/height of individual tiles in the mask generated by - 2D filtering. - threshold_value : - Value above which an individual pixel is considered to have - a high intensity. - soma_centre_value : - Value used to mark pixels with a high enough intensity. - """ self.ball_xy_size = ball_xy_size self.ball_z_size = ball_z_size self.overlap_fraction = overlap_fraction - self.tile_step_width = tile_step_width - self.tile_step_height = tile_step_height + self.tile_step_height = tile_height + self.tile_step_width = tile_width + + d1 = plane_height + d2 = plane_width + ball_xy_size = self.ball_xy_size + if d1 < ball_xy_size or d2 < ball_xy_size: + raise InvalidVolume( + f"Invalid plane size {d1}x{d2}. Needs to be at least " + f"{ball_xy_size} in each dimension" + ) self.THRESHOLD_VALUE = threshold_value self.SOMA_CENTRE_VALUE = soma_centre_value - # getting kernel is not jitted - with objmode(kernel=float_3d_type): - kernel = get_kernel(ball_xy_size, ball_z_size) - self.kernel = kernel - - self.overlap_threshold = np.sum(self.overlap_fraction * self.kernel) + # kernel comes in as XYZ, change to ZYX so it aligns with data + kernel = np.moveaxis(get_kernel(ball_xy_size, self.ball_z_size), 2, 0) + self.overlap_threshold = np.sum(self.overlap_fraction * kernel) + self.kernel_xy_size = kernel.shape[-2:] + self.kernel_z_size = self.ball_z_size + + # convert to right type and pin for faster copying + kernel = torch.from_numpy(kernel).type(getattr(torch, dtype)) + if torch_device != "cpu": + # torch at one point threw a cuda memory error when splitting cells + # on cpu due to pinning. It's best to only pin on using cuda + kernel.pin_memory() + # add 2 dimensions at the start so we have 11ZYX We need this shape in + # the conv step + self.kernel = ( + kernel.unsqueeze(0) + .unsqueeze(0) + .to(device=torch_device, non_blocking=True) + ) - # Stores the current planes that are being filtered - # first axis is z for faster rotating the z-axis - self.volume = np.empty( - (ball_z_size, plane_width, plane_height), - dtype=np.uint32, + self.num_batches_before_ready = int( + math.ceil(self.ball_z_size / batch_size) + ) + # Stores the current planes that are being filtered. Start with no data + self.volume = torch.empty( + (0, plane_height, plane_width), + dtype=getattr(torch, dtype), ) # Index of the middle plane in the volume - self.middle_z_idx = int(np.floor(ball_z_size / 2)) - self._num_z_added = 0 + self.middle_z_idx = int(np.floor(self.ball_z_size / 2)) + if not use_mask: + return # first axis is z - self.inside_brain_tiles = np.empty( + n_vertical_tiles = int(np.ceil(plane_height / self.tile_step_height)) + n_horizontal_tiles = int(np.ceil(plane_width / self.tile_step_width)) + # Stores tile masks. We start with no data + self.inside_brain_tiles = torch.empty( ( - ball_z_size, - int(np.ceil(plane_width / tile_step_width)), - int(np.ceil(plane_height / tile_step_height)), + 0, + n_vertical_tiles, + n_horizontal_tiles, ), - dtype=np.bool_, + dtype=torch.bool, ) @property - def ready(self) -> bool: + def first_valid_plane(self) -> int: """ - Return `True` if enough planes have been appended to run the filter. + The index in `self.volume` (or the planes passed in) that will be the + first plane returned from `get_processed_planes`. + + E.g. if `ball_z_size` is 3, then this may return 1. Meaning the second + plane passed to `append` (index 1), will be the first returned plane + by `get_processed_planes`. """ - return self._num_z_added >= self.ball_z_size + return int(math.floor(self.ball_z_size / 2)) - def append(self, plane: np.ndarray, mask: np.ndarray) -> None: + @property + def remaining_planes(self) -> int: """ - Add a new 2D plane to the filter. + The number of planes in `self.volume` (or the planes passed in) that + will remain unprocessed after all the planes have been `walk`ed + and `get_processed_planes` called. + + E.g. if `ball_z_size` is 3, then this may return 1. Meaning the last + plane passed to `append`, will never be returned by + `get_processed_planes` because the filter "center" never overlapped + with it. """ - if DEBUG: - assert [e for e in plane.shape[:2]] == [ - e for e in self.volume.shape[1:] - ], 'plane shape mismatch, expected "{}", got "{}"'.format( - [e for e in self.volume.shape[1:]], - [e for e in plane.shape[:2]], - ) - assert [e for e in mask.shape[:2]] == [ - e for e in self.inside_brain_tiles.shape[1:] - ], 'mask shape mismatch, expected"{}", got {}"'.format( - [e for e in self.inside_brain_tiles.shape[1:]], - [e for e in mask.shape[:2]], - ) + return self.ball_z_size - self.first_valid_plane - 1 - if self.ready: - # Shift everything down by one to make way for the new plane - # this is faster than np.roll, especially with z-axis first - self.volume[:-1, :, :] = self.volume[1:, :, :] - self.inside_brain_tiles[:-1, :, :] = self.inside_brain_tiles[ - 1:, :, : - ] + @property + def ready(self) -> bool: + """ + Return whether enough planes have been appended to run the filter + using `walk`. + """ + return self.volume.shape[0] >= self.kernel_z_size + + def append( + self, planes: torch.Tensor, masks: Optional[torch.Tensor] = None + ) -> None: + """ + Add a new z-stack to the filter. - # index for *next* slice is num we added *so far* until max - idx = min(self._num_z_added, self.ball_z_size - 1) - self._num_z_added += 1 + Previous stacks passed to `append` are removed, except enough planes + at the top of the previous z-stack to provide padding so we can filter + starting from the first plane in `planes`. The first time we call + `append`, `first_valid_plane` is the first plane to actually be + filtered in the z-stack due to lack of padding. - # Add the new plane to the top of volume and inside_brain_tiles - self.volume[idx, :, :] = plane - self.inside_brain_tiles[idx, :, :] = mask + So make sure to call `walk`/`get_processed_planes` before calling + `append` again. - def get_middle_plane(self) -> np.ndarray: - """ - Get the plane in the middle of self.volume. + Parameters + ---------- + planes : torch.Tensor + The z-stack. There can be one or more planes in the stack, but it + must have 3 dimensions. Input data is not modified. + masks : torch.Tensor + A z-stack tile mask, indicating for each tile whether it's in or + outside the brain. If the latter it's excluded. + + If `use_mask` was True, this must be provided. If False, this + parameter will be ignored. + + Input data is not modified. """ - return self.volume[self.middle_z_idx, :, :].copy() - - def walk(self, parallel: bool = False) -> None: - # **don't** pass parallel as keyword arg - numba struggles with it - # Highly optimised because most time critical - ball_radius = self.ball_xy_size // 2 - # Get extents of image that are covered by tiles - tile_mask_covered_img_width = ( - self.inside_brain_tiles.shape[1] * self.tile_step_width - ) - tile_mask_covered_img_height = ( - self.inside_brain_tiles.shape[2] * self.tile_step_height - ) - # Get maximum offsets for the ball - max_width = tile_mask_covered_img_width - self.ball_xy_size - max_height = tile_mask_covered_img_height - self.ball_xy_size - - # we have to pass the raw volume so walk doesn't use its edits as it - # processes the volume. self.volume is the one edited in place - input_volume = self.volume.copy() - - if parallel: - _walk_parallel( - max_height, - max_width, - self.tile_step_width, - self.tile_step_height, - self.inside_brain_tiles, - input_volume, - self.volume, - self.kernel, - ball_radius, - self.middle_z_idx, - self.overlap_threshold, - self.THRESHOLD_VALUE, - self.SOMA_CENTRE_VALUE, + if self.volume.shape[0]: + if self.volume.shape[0] < self.kernel_z_size: + num_remaining_with_padding = 0 + else: + num_remaining = self.kernel_z_size - (self.middle_z_idx + 1) + num_remaining_with_padding = num_remaining + self.middle_z_idx + + self.volume = torch.cat( + [self.volume[-num_remaining_with_padding:, :, :], planes], + dim=0, ) + + if self.inside_brain_tiles is not None: + self.inside_brain_tiles = torch.cat( + [ + self.inside_brain_tiles[ + -num_remaining_with_padding:, :, : + ], + masks, + ], + dim=0, + ) else: - _walk_single( - max_height, - max_width, - self.tile_step_width, - self.tile_step_height, - self.inside_brain_tiles, - input_volume, - self.volume, - self.kernel, - ball_radius, - self.middle_z_idx, - self.overlap_threshold, - self.THRESHOLD_VALUE, - self.SOMA_CENTRE_VALUE, - ) + self.volume = planes.clone() + if self.inside_brain_tiles is not None: + self.inside_brain_tiles = masks.clone() + def get_processed_planes(self) -> np.ndarray: + """ + After passing enough planes to `append`, and after `walk`, this returns + a copy of the processed planes as a numpy z-stack. -@njit(cache=True) -def _cube_overlaps( - volume: np.ndarray, - x_start: int, - x_end: int, - y_start: int, - y_end: int, - overlap_threshold: float, - threshold_value: int, - kernel: np.ndarray, -) -> bool: # Highly optimised because most time critical - """ - For each pixel in cube in volume that is greater than THRESHOLD_VALUE, sum - up the corresponding pixels in *kernel*. If the total is less than - overlap_threshold, return False, otherwise return True. + It only starts returning planes corresponding to plane + `first_valid_plane` relative to the first planes passed to `append`. + E.g. if `ball_z_size` is 3 and `first_valid_plane` is 1, and you passed + 5 planes total to `append`, then this will have returned planes [1, 3]. - Halfway through scanning the z-planes, if the total overlap is - less than 0.4 * overlap_threshold, this will return False early - without scanning the second half of the z-planes. + Notice the last plane was not included, because we return only "middle" + planes - planes that can correspond to the center of a ball. + """ + if not self.ready: + raise TypeError("Not enough planes were appended") + + num_processed = self.volume.shape[0] - self.kernel_z_size + 1 + assert num_processed + middle = self.middle_z_idx + planes = ( + self.volume[middle : middle + num_processed, :, :] + .cpu() + .numpy() + .copy() + ) + return planes - Parameters - ---------- - volume : - 3D array. - x_start, x_end, y_start, y_end : - The start and end indices in volume that form the cube. End is - exclusive - overlap_threshold : - Threshold above which to return True. - threshold_value : - Value above which a pixel is marked as being part of a cell. - kernel : - 3D array, with the same shape as *cube* in the volume. - """ - current_overlap_value = 0.0 - - middle = np.floor(volume.shape[0] / 2) + 1 - halfway_overlap_thresh = ( - overlap_threshold * 0.4 - ) # FIXME: do not hard code value - - for z in range(volume.shape[0]): - # TODO: OPTIMISE: step from middle to outer boundaries to check - # more data first - # - # If halfway through the array, and the overlap value isn't more than - # 0.4 * the overlap threshold, return - if z == middle and current_overlap_value < halfway_overlap_thresh: - return False # DEBUG: optimisation attempt - - for y in range(y_start, y_end): - for x in range(x_start, x_end): - # includes self.SOMA_CENTRE_VALUE - if volume[z, x, y] >= threshold_value: - # x/y must be shifted in kernel because we x/y is relative - # to the full volume, so shift it to relative to the cube - current_overlap_value += kernel[ - x - x_start, y - y_start, z - ] - - return current_overlap_value > overlap_threshold - - -@njit -def _is_tile_to_check( - x: int, - y: int, - middle_z: int, - tile_step_width: int, - tile_step_height: int, - inside_brain_tiles: np.ndarray, -) -> bool: # Highly optimised because most time critical - """ - Check if the tile containing pixel (x, y) is a tile that needs checking. - """ - x_in_mask = x // tile_step_width # TEST: test bounds (-1 range) - y_in_mask = y // tile_step_height # TEST: test bounds (-1 range) - return inside_brain_tiles[middle_z, x_in_mask, y_in_mask] + def walk(self) -> None: + """ + Applies the filter to all the planes passed to `append`. + May only be called if `ready` was True. -def _walk_base( - max_height: int, - max_width: int, - tile_step_width: int, + You can get the processed planes from `get_processed_planes`. + """ + if not self.ready: + raise TypeError("Called walk before enough planes were appended") + + _walk( + self.kernel_z_size, + self.middle_z_idx, + self.tile_step_height, + self.tile_step_width, + self.overlap_threshold, + self.THRESHOLD_VALUE, + self.SOMA_CENTRE_VALUE, + self.kernel, + self.volume, + self.inside_brain_tiles, + ) + + +@torch.jit.script +def _walk( + kernel_z_size: int, + middle: int, tile_step_height: int, - inside_brain_tiles: np.ndarray, - input_volume: np.ndarray, - volume: np.ndarray, - kernel: np.ndarray, - ball_radius: int, - middle_z: int, + tile_step_width: int, overlap_threshold: float, threshold_value: int, soma_centre_value: int, -) -> None: - """ - Scan through *volume*, and mark pixels where there are enough surrounding - pixels with high enough intensity. + kernel: torch.Tensor, + volume: torch.Tensor, + inside_brain_tiles: Optional[torch.Tensor], +): + num_process = volume.shape[0] - kernel_z_size + 1 + height, width = volume.shape[1:] + num_z = kernel.shape[2] + + # threshold volume so it's zero/one. And add two dims at start so + # it's 11ZYX + volume_tresh = ( + (volume >= threshold_value) + .unsqueeze(0) + .unsqueeze(0) + .type(kernel.dtype) + ) - The surrounding area is defined by the *kernel*. + # we do a plane at a time, volume: i:i+num_z, for plane i+middle + for i in range(num_process): + # spherical kernel is symmetric so convolution=correlation. Use + # binary threshold mask over the kernel to sum the value of the + # kernel at voxels that are bright + overlap = F.conv3d( + volume_tresh[:, :, i : i + num_z, :, :], + kernel, + stride=1, + padding="valid", + )[0, 0, 0, :, :] + overlaps = overlap > overlap_threshold + + # only edit the volume that is valid - conv excludes edges so we + # only edit the plane parts returned by conv3d + height_valid, width_valid = overlaps.shape + height_offset = (height - height_valid) // 2 + width_offset = (width - width_valid) // 2 + sub_volume = volume[ + i + middle, + height_offset : height_offset + height_valid, + width_offset : width_offset + width_valid, + ] + + # do we use tile masks or just overlapping? + if inside_brain_tiles is not None: + # unfold tiles to cover the full voxel area each tile covers + inside = ( + inside_brain_tiles[i + middle, :, :] + .repeat_interleave(tile_step_height, dim=0) + .repeat_interleave(tile_step_width, dim=1) + ) + # again only process pixels in the valid area + inside = inside[ + height_offset : height_offset + height_valid, + width_offset : width_offset + width_valid, + ] - Parameters - ---------- - max_height, max_width : - Maximum offsets for the ball filter. - inside_brain_tiles : - 3d array containing information on whether a tile is - inside the brain or not. Tiles outside the brain are skipped. - input_volume : - 3D array containing the plane-filtered data passed to the function - before walking. volume is edited in place, so this is the original - volume to prevent the changes for some cubes affective other cubes - during a single walk call. - volume : - 3D array containing the plane-filtered data - edited in place. - kernel : - 3D array - ball_radius : - Radius of the ball in the xy plane. - soma_centre_value : - Value that is used to mark pixels in *volume*. - - Notes - ----- - Warning: modifies volume in place! - """ - for y in prange(max_height): - for x in prange(max_width): - ball_centre_x = x + ball_radius - ball_centre_y = y + ball_radius - if _is_tile_to_check( - ball_centre_x, - ball_centre_y, - middle_z, - tile_step_width, - tile_step_height, - inside_brain_tiles, - ): - if _cube_overlaps( - input_volume, - x, - x + kernel.shape[0], - y, - y + kernel.shape[1], - overlap_threshold, - threshold_value, - kernel, - ): - volume[middle_z, ball_centre_x, ball_centre_y] = ( - soma_centre_value - ) - - -_walk_parallel = njit(parallel=True)(_walk_base) -_walk_single = njit(parallel=False)(_walk_base) + # must have enough ball overlap to be bright/tile is in brain + sub_volume[torch.logical_and(overlaps, inside)] = soma_centre_value + + else: + # must have enough ball overlap to be bright + sub_volume[overlaps] = soma_centre_value diff --git a/cellfinder/core/detect/filters/volume/structure_detection.py b/cellfinder/core/detect/filters/volume/structure_detection.py index 536f00ad..da99633d 100644 --- a/cellfinder/core/detect/filters/volume/structure_detection.py +++ b/cellfinder/core/detect/filters/volume/structure_detection.py @@ -4,18 +4,20 @@ import numba.typed import numpy as np import numpy.typing as npt -from numba import njit, typed +from numba import njit, objmode, typed from numba.core import types from numba.experimental import jitclass from numba.types import DictType +from cellfinder.core.tools.tools import get_max_possible_int_value + T = TypeVar("T") # type used for the domain of the volume - the size of the vol vol_np_type = np.int64 vol_numba_type = types.int64 # type used for the structure id -sid_np_type = np.int64 -sid_numba_type = types.int64 +sid_np_type = np.uint64 +sid_numba_type = types.uint64 @dataclass @@ -26,14 +28,17 @@ class Point: @njit -def get_non_zero_dtype_min(values: np.ndarray) -> int: +def get_non_zero_dtype_min(values: np.ndarray) -> sid_numba_type: """ Get the minimum of non-zero entries in *values*. If all entries are zero, returns maximum storeable number in the values array. """ - min_val = np.iinfo(values.dtype).max + # we don't know how big the int is, so make it as large as possible (64) + with objmode(min_val="u8"): + min_val = get_max_possible_int_value(values.dtype) + for v in values: if v != 0 and v < min_val: min_val = v @@ -97,6 +102,7 @@ def _get_structure_centre(structure: types.ListType) -> np.ndarray: spec = [ ("z", vol_numba_type), ("next_structure_id", sid_numba_type), + ("soma_centre_value", sid_numba_type), # as large as possible ("shape", types.UniTuple(vol_numba_type, 2)), ("obsolete_ids", DictType(sid_numba_type, sid_numba_type)), ("coords_maps", DictType(sid_numba_type, list_of_points_type)), @@ -133,18 +139,25 @@ class CellDetector: points. """ - def __init__(self, width: int, height: int, start_z: int): + def __init__( + self, + height: int, + width: int, + start_z: int, + soma_centre_value: sid_numba_type, + ): """ Parameters ---------- - width, height + height, width: Shape of the planes input to self.process() start_z: The z-coordinate of the first processed plane. """ - self.shape = width, height + self.shape = height, width self.z = start_z self.next_structure_id = 1 + self.soma_centre_value = soma_centre_value # Mapping from obsolete IDs to the IDs that they have been # made obsolete by @@ -156,11 +169,18 @@ def __init__(self, width: int, height: int, start_z: int): key_type=sid_numba_type, value_type=list_of_points_type ) + def _set_soma(self, soma_centre_value: sid_numba_type): + # Due to https://github.com/numba/numba/issues/9576. For testing we try + # different data types. Because of that issue we cannot pass a uint64 + # soma_centre_value to constructor after we pass a uint32. This is the + # only way for now until numba fixes the issue + self.soma_centre_value = soma_centre_value + def process( self, plane: np.ndarray, previous_plane: Optional[np.ndarray] ) -> np.ndarray: """ - Process a new plane. + Process a new plane (should be in Y, X axis order). """ if plane.shape[:2] != self.shape: raise ValueError("plane does not have correct shape") @@ -185,21 +205,21 @@ def connect_four( ------- plane : Plane with pixels either set to zero (no structure) or labelled - with their structure ID. + with their structure ID. Plane is in Y, X axis order. """ - SOMA_CENTRE_VALUE = np.iinfo(plane.dtype).max - for y in range(plane.shape[1]): - for x in range(plane.shape[0]): - if plane[x, y] == SOMA_CENTRE_VALUE: + soma_centre_value = self.soma_centre_value + for y in range(plane.shape[0]): + for x in range(plane.shape[1]): + if plane[y, x] == soma_centre_value: # Labels of structures below, left and behind neighbour_ids = np.zeros(3, dtype=sid_np_type) # If in bounds look at neighbours - if x > 0: - neighbour_ids[0] = plane[x - 1, y] if y > 0: - neighbour_ids[1] = plane[x, y - 1] + neighbour_ids[0] = plane[y - 1, x] + if x > 0: + neighbour_ids[1] = plane[y, x - 1] if previous_plane is not None: - neighbour_ids[2] = previous_plane[x, y] + neighbour_ids[2] = previous_plane[y, x] if is_new_structure(neighbour_ids): neighbour_ids[0] = self.next_structure_id @@ -210,17 +230,20 @@ def connect_four( # structure in next iterations struct_id = 0 - plane[x, y] = struct_id + plane[y, x] = struct_id return plane def get_cell_centres(self) -> np.ndarray: + """ + Returns the 2D array of cell centers. It's (N, 3) with X, Y, Z columns. + """ return self.structures_to_cells() def get_structures(self) -> Dict[int, np.ndarray]: """ Gets the structures as a dict of structure IDs mapped to the 2D array - of structure points. + of structure points (points vs x, y, z columns). """ d = {} for sid, points in self.coords_maps.items(): @@ -228,7 +251,10 @@ def get_structures(self) -> Dict[int, np.ndarray]: # `item = np.array(points, dtype=vol_np_type)` so we need to create # array and then fill in the point item = np.empty((len(points), 3), dtype=vol_np_type) - d[sid] = item + # need to cast to int64, otherwise when dict is used we can get + # warnings as in numba issue #8829 b/c it assumes it's uint64. + # Python uses int(64) as the type + d[types.int64(sid)] = item for i, point in enumerate(points): item[i, :] = point @@ -239,33 +265,41 @@ def add_point( self, sid: int, point: Union[tuple, list, np.ndarray] ) -> None: """ - Add single 3d *point* to the structure with the given *sid*. + Add single 3d (x, y, z) *point* to the structure with the given *sid*. """ - if sid not in self.coords_maps: - self.coords_maps[sid] = typed.List.empty_list(tuple_point_type) + # cast in case user passes in int64 (default type for int in python) + # and numba complains + key = sid_numba_type(sid) + if key not in self.coords_maps: + self.coords_maps[key] = typed.List.empty_list(tuple_point_type) - self._add_point(sid, (int(point[0]), int(point[1]), int(point[2]))) + self._add_point(key, (int(point[0]), int(point[1]), int(point[2]))) def add_points(self, sid: int, points: np.ndarray): """ Adds ndarray of *points* to the structure with the given *sid*. - Each row is a 3d point. + Each row is a 3-column (x, y, z) point. """ - if sid not in self.coords_maps: - self.coords_maps[sid] = typed.List.empty_list(tuple_point_type) + # cast in case user passes in int64 (default type for int in python) + # and numba complains + key = sid_numba_type(sid) + if key not in self.coords_maps: + self.coords_maps[key] = typed.List.empty_list(tuple_point_type) - append = self.coords_maps[sid].append + append = self.coords_maps[key].append pts = np.round(points).astype(vol_np_type) for point in pts: append((point[0], point[1], point[2])) - def _add_point(self, sid: int, point: Tuple[int, int, int]) -> None: + def _add_point( + self, sid: sid_numba_type, point: Tuple[int, int, int] + ) -> None: # sid must exist self.coords_maps[sid].append(point) def add( self, x: int, y: int, z: int, neighbour_ids: npt.NDArray[sid_np_type] - ) -> int: + ) -> sid_numba_type: """ For the current coordinates takes all the neighbours and find the minimum structure including obsolete structures mapping to any of @@ -287,7 +321,9 @@ def add( self._add_point(updated_id, (int(x), int(y), int(z))) return updated_id - def sanitise_ids(self, neighbour_ids: npt.NDArray[sid_np_type]) -> int: + def sanitise_ids( + self, neighbour_ids: npt.NDArray[sid_np_type] + ) -> sid_numba_type: """ Get the smallest ID of all the structures that are connected to IDs in `neighbour_ids`. @@ -300,15 +336,17 @@ def sanitise_ids(self, neighbour_ids: npt.NDArray[sid_np_type]) -> int: """ for i, neighbour_id in enumerate(neighbour_ids): # walk up the chain of obsolescence - neighbour_id = int(traverse_dict(self.obsolete_ids, neighbour_id)) + neighbour_id = traverse_dict(self.obsolete_ids, neighbour_id) neighbour_ids[i] = neighbour_id # Get minimum of all non-obsolete IDs updated_id = get_non_zero_dtype_min(neighbour_ids) - return int(updated_id) + return updated_id def merge_structures( - self, updated_id: int, neighbour_ids: npt.NDArray[sid_np_type] + self, + updated_id: sid_numba_type, + neighbour_ids: npt.NDArray[sid_np_type], ) -> None: """ For all the neighbours, reassign all the points of neighbour to diff --git a/cellfinder/core/detect/filters/volume/structure_splitting.py b/cellfinder/core/detect/filters/volume/structure_splitting.py index 7480e8ec..e1d4a5f0 100644 --- a/cellfinder/core/detect/filters/volume/structure_splitting.py +++ b/cellfinder/core/detect/filters/volume/structure_splitting.py @@ -1,9 +1,14 @@ -from typing import List, Tuple +from typing import List, Tuple, Type import numpy as np +import torch from cellfinder.core import logger -from cellfinder.core.detect.filters.volume.ball_filter import BallFilter +from cellfinder.core.detect.filters.setup_filters import DetectionSettings +from cellfinder.core.detect.filters.volume.ball_filter import ( + BallFilter, + InvalidVolume, +) from cellfinder.core.detect.filters.volume.structure_detection import ( CellDetector, get_structure_centre, @@ -14,41 +19,62 @@ class StructureSplitException(Exception): pass -def get_shape(xs: np.ndarray, ys: np.ndarray, zs: np.ndarray) -> List[int]: +def get_shape( + xs: np.ndarray, ys: np.ndarray, zs: np.ndarray +) -> Tuple[int, int, int]: + """ + Takes a list of x, y, z coordinates and returns a volume size such that + all the points will fit into it. With axis order = x, y, z. + """ # +1 because difference. TEST: - shape = [int((dim.max() - dim.min()) + 1) for dim in (xs, ys, zs)] + shape = tuple(int((dim.max() - dim.min()) + 1) for dim in (xs, ys, zs)) return shape def coords_to_volume( - xs: np.ndarray, ys: np.ndarray, zs: np.ndarray, ball_radius: int = 1 -) -> np.ndarray: + xs: np.ndarray, + ys: np.ndarray, + zs: np.ndarray, + volume_shape: Tuple[int, int, int], + ball_radius: int, + dtype: Type[np.number], + threshold_value: int, +) -> torch.Tensor: + """ + Takes the series of x, y, z points along with the shape of the volume + that fully enclose them (also x, y, z order). It than expands the + shape by the ball diameter in each axis. Then, each point, shifted + by the radius internally is set to the threshold value. + + The volume is then transposed and returned in the Z, Y, X order. + """ + # it's faster doing the work in numpy and then returning as torch array, + # than doing the work in torch ball_diameter = ball_radius * 2 # Expanded to ensure the ball fits even at the border - expanded_shape = [ - dim_size + ball_diameter for dim_size in get_shape(xs, ys, zs) - ] - volume = np.zeros(expanded_shape, dtype=np.uint32) + expanded_shape = [dim_size + ball_diameter for dim_size in volume_shape] + # volume is now x, y, z order + volume = np.zeros(expanded_shape, dtype=dtype) x_min, y_min, z_min = xs.min(), ys.min(), zs.min() + # shift the points so any sphere centered on it would not have its + # radius expand beyond the volume relative_xs = np.array((xs - x_min + ball_radius), dtype=np.int64) relative_ys = np.array((ys - y_min + ball_radius), dtype=np.int64) relative_zs = np.array((zs - z_min + ball_radius), dtype=np.int64) - # OPTIMISE: vectorize + # set each point as the center with a value of threshold for rel_x, rel_y, rel_z in zip(relative_xs, relative_ys, relative_zs): - volume[rel_x, rel_y, rel_z] = np.iinfo(volume.dtype).max - 1 - return volume + volume[rel_x, rel_y, rel_z] = threshold_value + + volume = volume.swapaxes(0, 2) + return torch.from_numpy(volume) def ball_filter_imgs( - volume: np.ndarray, - threshold_value: int, - soma_centre_value: int, - ball_xy_size: int = 3, - ball_z_size: int = 3, -) -> Tuple[np.ndarray, np.ndarray]: + volume: torch.Tensor, settings: DetectionSettings +) -> np.ndarray: """ Apply ball filtering to a 3D volume and detect cell centres. @@ -56,105 +82,118 @@ def ball_filter_imgs( and the `CellDetector` class to detect cell centres. Args: - volume (np.ndarray): The 3D volume to be filtered. - threshold_value (int): The threshold value for ball filtering. - soma_centre_value (int): The value representing the soma centre. - ball_xy_size (int, optional): - The size of the ball filter in the XY plane. Defaults to 3. - ball_z_size (int, optional): - The size of the ball filter in the Z plane. Defaults to 3. + volume (torch.Tensor): The 3D volume to be filtered (Z, Y, X order). + settings (DetectionSettings): + The settings to use. Returns: - Tuple[np.ndarray, np.ndarray]: - A tuple containing the filtered volume and the cell centres. + The 2D array of cell centres (N, 3) - X, Y, Z order. """ - # OPTIMISE: reuse ball filter instance - - good_tiles_mask = np.ones((1, 1, volume.shape[2]), dtype=np.bool_) - - plane_width, plane_height = volume.shape[:2] - current_z = ball_z_size // 2 - - bf = BallFilter( - plane_width, - plane_height, - ball_xy_size, - ball_z_size, - overlap_fraction=0.8, - tile_step_width=plane_width, - tile_step_height=plane_height, - threshold_value=threshold_value, - soma_centre_value=soma_centre_value, + detection_convert = settings.detection_data_converter_func + batch_size = settings.batch_size + + # make sure volume is not less than kernel etc + try: + bf = BallFilter( + plane_height=settings.plane_height, + plane_width=settings.plane_width, + ball_xy_size=settings.ball_xy_size, + ball_z_size=settings.ball_z_size, + overlap_fraction=settings.ball_overlap_fraction, + threshold_value=settings.threshold_value, + soma_centre_value=settings.soma_centre_value, + tile_height=settings.tile_height, + tile_width=settings.tile_width, + dtype=settings.filtering_dtype.__name__, + batch_size=batch_size, + torch_device=settings.torch_device, + use_mask=False, # we don't need a mask here + ) + except InvalidVolume: + return np.empty((0, 3)) + + start_z = bf.first_valid_plane + cell_detector = CellDetector( + settings.plane_height, + settings.plane_width, + start_z=start_z, + soma_centre_value=settings.detection_soma_centre_value, ) - cell_detector = CellDetector(plane_width, plane_height, start_z=current_z) - # FIXME: hard coded type - ball_filtered_volume = np.zeros(volume.shape, dtype=np.uint32) previous_plane = None - for z in range(volume.shape[2]): - bf.append(volume[:, :, z].astype(np.uint32), good_tiles_mask[:, :, z]) + for z in range(0, volume.shape[0], batch_size): + bf.append(volume[z : z + batch_size, :, :]) + if bf.ready: bf.walk() - middle_plane = bf.get_middle_plane() - # first valid middle plane is the current_z, not z - ball_filtered_volume[:, :, current_z] = middle_plane[:] - current_z += 1 + middle_planes = bf.get_processed_planes() + n = middle_planes.shape[0] - # DEBUG: TEST: transpose - previous_plane = cell_detector.process( - middle_plane.copy(), previous_plane + # we edit volume, but only for planes already processed that won't + # be passed to the filter in this run + volume[start_z : start_z + n, :, :] = torch.from_numpy( + middle_planes ) - return ball_filtered_volume, cell_detector.get_cell_centres() + start_z += n + + # convert to type needed for detection + middle_planes = detection_convert(middle_planes) + for plane in middle_planes: + previous_plane = cell_detector.process(plane, previous_plane) + + return cell_detector.get_cell_centres() def iterative_ball_filter( - volume: np.ndarray, n_iter: int = 10 + volume: torch.Tensor, settings: DetectionSettings ) -> Tuple[List[int], List[np.ndarray]]: """ Apply iterative ball filtering to the given volume. The volume is eroded at each iteration, by subtracting 1 from the volume. Parameters: - volume (np.ndarray): The input volume. - n_iter (int): The number of iterations to perform. Default is 10. + volume (torch.Tensor): The input volume. It is edited inplace. + Of shape Z, Y, X. + settings (DetectionSettings): The settings to use. Returns: - Tuple[List[int], List[np.ndarray]]: A tuple containing two lists: - The structures found in each iteration. + tuple: A tuple containing two lists: + The number of structures found in each iteration. The cell centres found in each iteration. """ ns = [] centres = [] - threshold_value = np.iinfo(volume.dtype).max - 1 - soma_centre_value = np.iinfo(volume.dtype).max - - vol = volume.copy() # TODO: check if required - - for i in range(n_iter): - vol, cell_centres = ball_filter_imgs( - vol, threshold_value, soma_centre_value - ) - - # vol is unsigned, so can't let zeros underflow to max value - vol[:, :, :] = np.where(vol != 0, vol - 1, 0) + for i in range(settings.n_splitting_iter): + cell_centres = ball_filter_imgs(volume, settings) + volume.sub_(1) n_structures = len(cell_centres) ns.append(n_structures) centres.append(cell_centres) if n_structures == 0: break + return ns, centres def check_centre_in_cuboid(centre: np.ndarray, max_coords: np.ndarray) -> bool: """ - Checks whether a coordinate is in a cuboid - :param centre: x,y,z coordinate - :param max_coords: far corner of cuboid - :return: True if within cuboid, otherwise False + Checks whether a coordinate is in a cuboid. + + Parameters + ---------- + + centre : np.ndarray + x, y, z coordinate. + max_coords : np.ndarray + Far corner of cuboid. + + Returns + ------- + True if within cuboid, otherwise False. """ relative_coords = centre if (relative_coords > max_coords).all(): @@ -168,7 +207,7 @@ def check_centre_in_cuboid(centre: np.ndarray, max_coords: np.ndarray) -> bool: def split_cells( - cell_points: np.ndarray, outlier_keep: bool = False + cell_points: np.ndarray, settings: DetectionSettings ) -> np.ndarray: """ Split the given cell points into individual cell centres. @@ -177,28 +216,24 @@ def split_cells( cell_points (np.ndarray): Array of cell points with shape (N, 3), where N is the number of cell points and each point is represented by its x, y, and z coordinates. - outlier_keep (bool, optional): Flag indicating whether to keep outliers - during the splitting process. Defaults to False. + settings (DetectionSettings) : The settings to use for splitting. It is + modified inplace. Returns: np.ndarray: Array of absolute cell centres with shape (M, 3), where M is the number of individual cells and each centre is represented by its x, y, and z coordinates. """ + # these points are in x, y, z order columnwise, in absolute pixels orig_centre = get_structure_centre(cell_points) xs = cell_points[:, 0] ys = cell_points[:, 1] zs = cell_points[:, 2] - orig_corner = np.array( - [ - orig_centre[0] - (orig_centre[0] - xs.min()), - orig_centre[1] - (orig_centre[1] - ys.min()), - orig_centre[2] - (orig_centre[2] - zs.min()), - ] - ) - + # corner coordinates in absolute pixels + orig_corner = np.array([xs.min(), ys.min(), zs.min()]) + # volume center relative to corner relative_orig_centre = np.array( [ orig_centre[0] - orig_corner[0], @@ -207,22 +242,51 @@ def split_cells( ] ) + # total volume enclosing all points original_bounding_cuboid_shape = get_shape(xs, ys, zs) - ball_radius = 1 - vol = coords_to_volume(xs, ys, zs, ball_radius=ball_radius) + ball_radius = settings.ball_xy_size // 2 + # they should be the same dtype so as to not need a conversion before + # passing the input data with marked cells to the filters (we currently + # set both to float32) + assert settings.filtering_dtype == settings.plane_original_np_dtype + # volume will now be z, y, x order + vol = coords_to_volume( + xs, + ys, + zs, + volume_shape=original_bounding_cuboid_shape, + ball_radius=ball_radius, + dtype=settings.filtering_dtype, + threshold_value=settings.threshold_value, + ) + + # get an estimate of how much memory processing a single batch of original + # input planes takes. For this much smaller volume, our batch will be such + # that it uses at most that much memory + total_vol_size = ( + settings.plane_height * settings.plane_width * settings.batch_size + ) + batch_size = total_vol_size // (vol.shape[1] * vol.shape[2]) + batch_size = min(batch_size, vol.shape[0]) + + # update settings with our volume data + settings.plane_shape = vol.shape[1:] + settings.start_plane = 0 + settings.end_plane = vol.shape[0] + settings.batch_size = batch_size # centres is a list of arrays of centres (1 array of centres per ball run) - ns, centres = iterative_ball_filter(vol) + # in x, y, z order + ns, centres = iterative_ball_filter(vol, settings) ns.insert(0, 1) centres.insert(0, np.array([relative_orig_centre])) best_iteration = ns.index(max(ns)) - # TODO: put constraint on minimum centres distance ? relative_centres = centres[best_iteration] - if not outlier_keep: + if not settings.outlier_keep: # TODO: change to checking whether in original cluster shape original_max_coords = np.array(original_bounding_cuboid_shape) relative_centres = np.array( @@ -234,7 +298,7 @@ def split_cells( ) absolute_centres = np.empty((len(relative_centres), 3)) - # FIXME: extract functionality + # convert centers to absolute pixels absolute_centres[:, 0] = orig_corner[0] + relative_centres[:, 0] absolute_centres[:, 1] = orig_corner[1] + relative_centres[:, 1] absolute_centres[:, 2] = orig_corner[2] + relative_centres[:, 2] diff --git a/cellfinder/core/detect/filters/volume/volume_filter.py b/cellfinder/core/detect/filters/volume/volume_filter.py index 949f9f91..1cf432bf 100644 --- a/cellfinder/core/detect/filters/volume/volume_filter.py +++ b/cellfinder/core/detect/filters/volume/volume_filter.py @@ -1,151 +1,461 @@ -import math -import multiprocessing.pool +import multiprocessing as mp import os from functools import partial -from queue import Queue -from threading import Lock -from typing import Any, Callable, List, Optional, Tuple +from typing import Callable, List, Optional, Tuple import numpy as np +import torch from brainglobe_utils.cells.cells import Cell from tifffile import tifffile from tqdm import tqdm -from cellfinder.core import logger -from cellfinder.core.detect.filters.setup_filters import ( - get_ball_filter, - get_cell_detector, -) +from cellfinder.core import logger, types +from cellfinder.core.detect.filters.plane import TileProcessor +from cellfinder.core.detect.filters.setup_filters import DetectionSettings +from cellfinder.core.detect.filters.volume.ball_filter import BallFilter from cellfinder.core.detect.filters.volume.structure_detection import ( + CellDetector, get_structure_centre, ) from cellfinder.core.detect.filters.volume.structure_splitting import ( StructureSplitException, split_cells, ) +from cellfinder.core.tools.threading import ( + EOFSignal, + ProcessWithException, + ThreadWithException, +) +from cellfinder.core.tools.tools import inference_wrapper + + +@inference_wrapper +def _plane_filter( + process: ProcessWithException, + tile_processor: TileProcessor, + n_threads: int, + buffers: List[Tuple[torch.Tensor, torch.Tensor]], +): + """ + When running on cpu, we spin up a process for each plane in the batch. + This function runs in the process. + + For every new batch, main process sends a buffer token and plane index + to this function. We process that plane and let the main process know + we are done. + """ + # more than about 4 threads seems to slow down computation + torch.set_num_threads(min(n_threads, 4)) + + while True: + msg = process.get_msg_from_mainthread() + if msg == EOFSignal: + return + # with torch multiprocessing, tensors are shared in memory - so + # just update in place + token, i = msg + tensor, masks = buffers[token] + + plane, mask = tile_processor.get_tile_mask(tensor[i : i + 1, :, :]) + tensor[i : i + 1, :, :] = plane + masks[i : i + 1, :, :] = mask + + # tell the main thread we processed all the planes for this tensor + process.send_msg_to_mainthread(None) + + +class VolumeFilter: + """ + Filters and detects cells in the input data. + + This will take a 3d data array, filter each plane first with 2d filters + finding bright spots. Then it filters the stack with a ball filter to + find voxels that are potential cells. Then it runs cell detection on it + to actually identify the cells. + + Parameters + ---------- + settings : DetectionSettings + Settings object that contains all the configuration data. + """ + + def __init__(self, settings: DetectionSettings): + self.settings = settings + + self.ball_filter = BallFilter( + plane_height=settings.plane_height, + plane_width=settings.plane_width, + ball_xy_size=settings.ball_xy_size, + ball_z_size=settings.ball_z_size, + overlap_fraction=settings.ball_overlap_fraction, + threshold_value=settings.threshold_value, + soma_centre_value=settings.soma_centre_value, + tile_height=settings.tile_height, + tile_width=settings.tile_width, + dtype=settings.filtering_dtype.__name__, + batch_size=settings.batch_size, + torch_device=settings.torch_device, + use_mask=True, + ) + self.z = settings.start_plane + self.ball_filter.first_valid_plane -class VolumeFilter(object): - def __init__( - self, - *, - soma_diameter: float, - soma_size_spread_factor: float = 1.4, - setup_params: Tuple[np.ndarray, Any, int, int, float, Any], - n_planes: int, - n_locks_release: int, - save_planes: bool = False, - plane_directory: Optional[str] = None, - start_plane: int = 0, - max_cluster_size: int = 5000, - outlier_keep: bool = False, - artifact_keep: bool = True, - ): - self.soma_diameter = soma_diameter - self.soma_size_spread_factor = soma_size_spread_factor - self.n_planes = n_planes - self.z = start_plane - self.save_planes = save_planes - self.plane_directory = plane_directory - self.max_cluster_size = max_cluster_size - self.outlier_keep = outlier_keep - self.n_locks_release = n_locks_release - - self.artifact_keep = artifact_keep - - self.clipping_val = None - self.threshold_value = None - self.setup_params = setup_params - - self.previous_plane: Optional[np.ndarray] = None - - self.ball_filter = get_ball_filter( - plane=self.setup_params[0], - soma_diameter=self.setup_params[1], - ball_xy_size=self.setup_params[2], - ball_z_size=self.setup_params[3], - ball_overlap_fraction=self.setup_params[4], + self.cell_detector = CellDetector( + settings.plane_height, + settings.plane_width, + start_z=self.z, + soma_centre_value=settings.detection_soma_centre_value, ) - - self.cell_detector = get_cell_detector( - plane_shape=self.setup_params[0].shape, # type: ignore - ball_z_size=self.setup_params[3], - z_offset=self.setup_params[5], + # make sure we load enough data to filter. Otherwise, we won't be ready + # to filter and the data loading thread will wait for data to be + # processed before loading more data, but that will never happen + self.n_queue_buffer = max( + self.settings.num_prefetch_batches, + self.ball_filter.num_batches_before_ready, ) + def _get_filter_buffers( + self, cpu: bool, tile_processor: TileProcessor + ) -> List[Tuple[torch.Tensor, torch.Tensor]]: + """ + Generates buffers to use for data loading and filtering. + + It creates pinned tensors ahead of time for faster copying to gpu. + Pinned tensors are kept in RAM and are faster to copy to GPU because + they can't be paged. So loaded data is copied to the tensor and then + sent to the device. + + For CPU even though we don't pin, it's useful to create the buffers + ahead of time and reuse it so we can filter in sub-processes + (see `_plane_filter`). + For tile masks, we only create buffers for CPU. On CUDA, they are + generated every time new on the device. + """ + batch_size = self.settings.batch_size + torch_dtype = getattr(torch, self.settings.filtering_dtype.__name__) + + buffers = [] + for _ in range(self.n_queue_buffer): + # the tensor used for data loading + tensor = torch.empty( + (batch_size, *self.settings.plane_shape), + dtype=torch_dtype, + pin_memory=not cpu, + device="cpu", + ) + + # tile mask buffer - only for cpu + masks = None + if cpu: + masks = tile_processor.get_tiled_buffer( + batch_size, self.settings.torch_device + ) + + buffers.append((tensor, masks)) + + return buffers + + @inference_wrapper + def _feed_signal_batches( + self, + thread: ThreadWithException, + data: types.array, + processors: List[ProcessWithException], + buffers: List[Tuple[torch.Tensor, torch.Tensor]], + ) -> None: + """ + Runs in its own thread. It loads the input data planes, converts them + to torch tensors of the right data-type, and sends them to cuda or to + subprocesses for cpu to be filtered etc. + """ + batch_size = self.settings.batch_size + device = self.settings.torch_device + start_plane = self.settings.start_plane + end_plane = start_plane + self.settings.n_planes + data_converter = self.settings.filter_data_converter_func + cpu = self.settings.torch_device == "cpu" + # should only have 2d filter processors on the cpu + assert bool(processors) == cpu + + # seed the queue with tokens for the buffers + for token in range(len(buffers)): + thread.send_msg_to_thread(token) + + for z in range(start_plane, end_plane, batch_size): + # convert the data to the right type + np_data = data_converter(data[z : z + batch_size, :, :]) + # if we ran out of batches, we are done! + n = np_data.shape[0] + assert n + + # thread/underlying queues get first crack at msg. Unless we get + # eof, this will block until a buffer is returned from the main + # thread for reuse + token = thread.get_msg_from_mainthread() + if token is EOFSignal: + return + + # buffer is free, get it from token + tensor, masks = buffers[token] + + # for last batch, it can be smaller than normal so only set up to n + tensor[:n, :, :] = torch.from_numpy(np_data) + tensor = tensor[:n, :, :] + if not cpu: + # send to device - it won't block here because we pinned memory + tensor = tensor.to(device=device, non_blocking=True) + + # if used, send each plane in batch to processor + used_processors = [] + if cpu: + used_processors = processors[:n] + for i, process in enumerate(used_processors): + process.send_msg_to_thread((token, i)) + + # tell the main thread to wait for processors (if used) + msg = token, tensor, masks, used_processors, n + + if n < batch_size: + # on last batch, we are also done after this + thread.send_msg_to_mainthread(msg) + return + # send the data to the main thread + thread.send_msg_to_mainthread(msg) + def process( self, - async_result_queue: Queue, - locks: List[Lock], + tile_processor: TileProcessor, + signal_array, *, - callback: Callable[[int], None], + callback: Optional[Callable[[int], None]], ) -> None: - progress_bar = tqdm(total=self.n_planes, desc="Processing planes") - for z in range(self.n_planes): - # Get result from the queue. - # - # It is important to remove the result from the queue here - # to free up memory once this plane has been processed by - # the 3D filter here - logger.debug(f"🏐 Waiting for plane {z}") - result = async_result_queue.get() - # .get() blocks until the result is available - plane, mask = result.get() - logger.debug(f"🏐 Got plane {z}") - - self.ball_filter.append(plane, mask) - - if self.ball_filter.ready: - # Let the next 2D filter run - z_release = z + self.n_locks_release + 1 - if z_release < len(locks): - logger.debug(f"🔓 Releasing lock for plane {z_release}") - locks[z_release].release() - - self._run_filter() + """ + Takes the processor and the data and passes them through the filtering + and cell detection stages. + + If the callback is provided, we call it after every plane with the + current z index to update the status. It may be called from secondary + threads. + """ + progress_bar = tqdm( + total=self.settings.n_planes, desc="Processing planes" + ) + cpu = self.settings.torch_device == "cpu" + n_threads = self.settings.n_torch_comp_threads + + # we re-use these tensors for data loading, so we have a fixed number + # of planes in memory. The feeder thread will wait to load more data + # until a tensor is free to be reused. + # We have to keep the tensors in memory in main process while it's + # in used elsewhere + buffers = self._get_filter_buffers(cpu, tile_processor) + + # on cpu these processes will 2d filter each plane in the batch + plane_processes = [] + if cpu: + for _ in range(self.settings.batch_size): + process = ProcessWithException( + target=_plane_filter, + args=(tile_processor, n_threads, buffers), + pass_self=True, + ) + process.start() + plane_processes.append(process) + + # thread that reads and sends us data + feed_thread = ThreadWithException( + target=self._feed_signal_batches, + args=(signal_array, plane_processes, buffers), + pass_self=True, + ) + feed_thread.start() - callback(self.z) - self.z += 1 - progress_bar.update() + # thread that takes the 3d filtered data and does cell detection + cells_thread = ThreadWithException( + target=self._run_filter_thread, + args=(callback, progress_bar), + pass_self=True, + ) + cells_thread.start() + + try: + self._process(feed_thread, cells_thread, tile_processor, cpu) + finally: + # if we end, make sure to tell the threads to stop + feed_thread.notify_to_end_thread() + cells_thread.notify_to_end_thread() + for process in plane_processes: + process.notify_to_end_thread() + + # the notification above ensures this won't block forever + feed_thread.join() + cells_thread.join() + for process in plane_processes: + process.join() + + # in case these threads sent us an exception but we didn't yet read + # it, make sure to process them + feed_thread.clear_remaining() + cells_thread.clear_remaining() + for process in plane_processes: + process.clear_remaining() progress_bar.close() logger.debug("3D filter done") - def _run_filter(self) -> None: - logger.debug(f"🏐 Ball filtering plane {self.z}") - # filtering original images, the images should be large enough in x/y - # to benefit from parallelization. Note: don't pass arg as keyword arg - # because numba gets stuck (probably b/c class jit is new) - self.ball_filter.walk(True) - - middle_plane = self.ball_filter.get_middle_plane() - if self.save_planes: - self.save_plane(middle_plane) - - logger.debug(f"🏫 Detecting structures for plane {self.z}") - self.previous_plane = self.cell_detector.process( - middle_plane, self.previous_plane - ) + def _process( + self, + feed_thread: ThreadWithException, + cells_thread: ThreadWithException, + tile_processor: TileProcessor, + cpu: bool, + ) -> None: + """ + Processes the loaded data from feeder thread. If on cpu it is already + 2d filtered so just 3d filter. On cuda we need to do both 2d and 3d + filtering. Then, it sends the filtered data off to the detection thread + for cell detection. + """ + processing_tokens = [] + + while True: + # thread/underlying queues get first crack at msg. Unless we get + # eof, this will block until we get more loaded data until no more + # data or exception + msg = feed_thread.get_msg_from_thread() + # feeder thread exits at the end, causing a eof to be sent + if msg is EOFSignal: + break + token, tensor, masks, used_processors, n = msg + # this token is in use until we return it + processing_tokens.append(token) + + if cpu: + # we did 2d filtering in different process. Make sure all the + # planes are done filtering. Each msg from feeder thread has + # corresponding msg for each used processor (unless exception) + for process in used_processors: + process.get_msg_from_thread() + # batch size can change at the end so resize buffer + planes = tensor[:n, :, :] + masks = masks[:n, :, :] + else: + # we're not doing 2d filtering in different process + planes, masks = tile_processor.get_tile_mask(tensor) - logger.debug(f"🏫 Structures done for plane {self.z}") + self.ball_filter.append(planes, masks) + if self.ball_filter.ready: + self.ball_filter.walk() + middle_planes = self.ball_filter.get_processed_planes() + + # at this point we know input tensor can be reused - return + # it so feeder thread can load more data into it + for token in processing_tokens: + feed_thread.send_msg_to_thread(token) + processing_tokens.clear() + + # thread/underlying queues get first crack at msg. Unless + # we get eof, this will block until we get a token, + # indicating we can send more data. The cells thread has a + # fixed token supply, ensuring we don't send it too much + # data, in case detection takes longer than filtering + # Also, error messages incoming are at most # tokens behind + token = cells_thread.get_msg_from_thread() + if token is EOFSignal: + break + # send it more data and return the token + cells_thread.send_msg_to_thread((middle_planes, token)) + + @inference_wrapper + def _run_filter_thread( + self, thread: ThreadWithException, callback, progress_bar + ) -> None: + """ + Runs in its own thread and takes the filtered planes and passes them + through the cell detection system. Also saves the planes as needed. + """ + detector = self.cell_detector + original_dtype = self.settings.plane_original_np_dtype + detection_converter = self.settings.detection_data_converter_func + save_planes = self.settings.save_planes + previous_plane = None + bf = self.ball_filter + + # these many planes are not processed at start because 3d filter uses + # it as padding at the start of filter + progress_bar.update(bf.first_valid_plane) + + # main thread needs a token to send us planes - populate with some + for _ in range(self.n_queue_buffer): + thread.send_msg_to_mainthread(object()) + + while True: + # thread/underlying queues get first crack at msg. Unless we get + # eof, this will block until we get more data + msg = thread.get_msg_from_mainthread() + # requested that we return. This can mean the main thread finished + # sending data and it appended eof - so we get eof after all planes + if msg is EOFSignal: + # these many planes are not processed at the end because 3d + # filter uses it as padding at the end of the filter + progress_bar.update(bf.remaining_planes) + return + + # convert plane to the type needed by detection system + # we should not need scaling because throughout + # filtering we make sure result fits in this data type + middle_planes, token = msg + detection_middle_planes = detection_converter(middle_planes) + + logger.debug(f"🏫 Detecting structures for planes {self.z}+") + for plane, detection_plane in zip( + middle_planes, detection_middle_planes + ): + if save_planes: + self.save_plane(plane.astype(original_dtype)) + + previous_plane = detector.process( + detection_plane, previous_plane + ) + + if callback is not None: + callback(self.z) + self.z += 1 + progress_bar.update() + + # we must return the token, otherwise the main thread will run out + # and won't send more data to us + thread.send_msg_to_mainthread(token) + logger.debug(f"🏫 Structures done for planes {self.z}+") def save_plane(self, plane: np.ndarray) -> None: - if self.plane_directory is None: + """ + Saves the plane as an image according to the settings. + """ + if self.settings.plane_directory is None: raise ValueError( "plane_directory must be set to save planes to file" ) - plane_name = f"plane_{str(self.z).zfill(4)}.tif" - f_path = os.path.join(self.plane_directory, plane_name) - tifffile.imsave(f_path, plane.T) - - def get_results(self, worker_pool: multiprocessing.Pool) -> List[Cell]: + # self.z is zero based, we should save names as 1-based. + plane_name = self.settings.plane_prefix.format(n=self.z + 1) + ".tif" + f_path = os.path.join(self.settings.plane_directory, plane_name) + tifffile.imwrite(f_path, plane) + + def get_results(self, settings: DetectionSettings) -> List[Cell]: + """ + Returns the detected cells. + + After filtering, this parses the resulting cells and splits large + bright regions into individual cells. + """ logger.info("Splitting cell clusters and writing results") - max_cell_volume = sphere_volume( - self.soma_size_spread_factor * self.soma_diameter / 2 - ) + root_settings = self.settings + max_cell_volume = settings.max_cell_volume + # valid cells cells = [] + # regions that must be split into cells needs_split = [] structures = self.cell_detector.get_structures().items() logger.debug(f"Processing {len(structures)} found cells") @@ -158,7 +468,7 @@ def get_results(self, worker_pool: multiprocessing.Pool) -> List[Cell]: cell_centre = get_structure_centre(cell_points) cells.append(Cell(cell_centre.tolist(), Cell.UNKNOWN)) else: - if cell_volume < self.max_cluster_size: + if cell_volume < settings.max_cluster_size: needs_split.append((cell_id, cell_points)) else: cell_centre = get_structure_centre(cell_points) @@ -174,13 +484,22 @@ def get_results(self, worker_pool: multiprocessing.Pool) -> List[Cell]: total=len(needs_split), desc="Splitting cell clusters" ) - # we are not returning Cell instances from func because it'd be pickled - # by multiprocess which slows it down - func = partial(_split_cells, outlier_keep=self.outlier_keep) - for cell_centres in worker_pool.imap_unordered(func, needs_split): - for cell_centre in cell_centres: - cells.append(Cell(cell_centre.tolist(), Cell.UNKNOWN)) - progress_bar.update() + # the settings is pickled and re-created for each process, which is + # important because splitting can modify the settings, so we don't want + # parallel modifications for same object + f = partial(_split_cells, settings=settings) + ctx = mp.get_context("spawn") + # we can't use the context manager because of coverage issues: + # https://pytest-cov.readthedocs.io/en/latest/subprocess-support.html + pool = ctx.Pool(processes=root_settings.n_processes) + try: + for cell_centres in pool.imap_unordered(f, needs_split): + for cell_centre in cell_centres: + cells.append(Cell(cell_centre.tolist(), Cell.UNKNOWN)) + progress_bar.update() + finally: + pool.close() + pool.join() progress_bar.close() logger.debug( @@ -190,13 +509,15 @@ def get_results(self, worker_pool: multiprocessing.Pool) -> List[Cell]: return cells -def _split_cells(arg, outlier_keep): +@inference_wrapper +def _split_cells(arg, settings: DetectionSettings): + # runs in its own process for a bright region to be split. + # For splitting cells, we only run with one thread. Because the volume is + # likely small and using multiple threads would cost more in overhead than + # is worth. num threads can be set only at processes level. + torch.set_num_threads(1) cell_id, cell_points = arg try: - return split_cells(cell_points, outlier_keep=outlier_keep) + return split_cells(cell_points, settings=settings) except (ValueError, AssertionError) as err: raise StructureSplitException(f"Cell {cell_id}, error; {err}") - - -def sphere_volume(radius: float) -> float: - return (4 / 3) * math.pi * radius**3 diff --git a/cellfinder/core/main.py b/cellfinder/core/main.py index 5aad49f7..f9cc5c48 100644 --- a/cellfinder/core/main.py +++ b/cellfinder/core/main.py @@ -26,7 +26,7 @@ def main( ball_z_size: int = 15, ball_overlap_fraction: float = 0.6, log_sigma_size: float = 0.2, - n_sds_above_mean_thresh: int = 10, + n_sds_above_mean_thresh: float = 10, soma_spread_factor: float = 1.4, max_cluster_size: int = 100000, cube_width: int = 50, @@ -36,11 +36,13 @@ def main( skip_detection: bool = False, skip_classification: bool = False, detected_cells: List[Cell] = None, + classification_batch_size: Optional[int] = None, + classification_torch_device: str = "cpu", *, detect_callback: Optional[Callable[[int], None]] = None, classify_callback: Optional[Callable[[int], None]] = None, detect_finished_callback: Optional[Callable[[list], None]] = None, -) -> List: +) -> List[Cell]: """ Parameters ---------- @@ -74,6 +76,8 @@ def main( n_free_cpus, log_sigma_size, n_sds_above_mean_thresh, + batch_size=classification_batch_size, + torch_device=classification_torch_device, callback=detect_callback, ) diff --git a/cellfinder/core/tools/IO.py b/cellfinder/core/tools/IO.py new file mode 100644 index 00000000..109580bb --- /dev/null +++ b/cellfinder/core/tools/IO.py @@ -0,0 +1,45 @@ +import pooch + + +def fetch_pooch_directory( + registry: pooch.Pooch, + directory_name: str, + processor=None, + downloader=None, + progressbar=False, +): + """ + Fetches files from the Pooch registry that belong to a specific directory. + Parameters: + registry (pooch.Pooch): The Pooch registry object. + directory_name (str): + The remote relative path of the directory to fetch files from. + processor (callable, optional): + A function to process the fetched files. Defaults to None. + downloader (callable, optional): + A function to download the files. Defaults to None. + progressbar (bool, optional): + Whether to display a progress bar during the fetch. + Defaults to False. + Returns: + str: The local absolute path to the fetched directory. + """ + names = [] + for name in registry.registry_files: + if name.startswith(f"{directory_name}/"): + names.append(name) + + if not names: + raise FileExistsError( + f"Unable to find files in directory {directory_name}" + ) + + for name in names: + registry.fetch( + name, + processor=processor, + downloader=downloader, + progressbar=progressbar, + ) + + return str(registry.abspath / directory_name) diff --git a/cellfinder/core/tools/threading.py b/cellfinder/core/tools/threading.py new file mode 100644 index 00000000..5c3e54ff --- /dev/null +++ b/cellfinder/core/tools/threading.py @@ -0,0 +1,380 @@ +""" +Provides classes that can run a function in another thread or process and +allow passing data to and from the threads/processes. It also passes on any +exceptions that occur in the secondary thread/sub-process in the main thread +or when it exits. + +If using a sub-process and pytorch Tensors are sent from/to the main +process, pytorch will memory map the tensor so the same data is shared and +edits in the main process will be visible in the sub-process. However, there +are limitations (such so not re-sharing a tensor shared with us). See +https://pytorch.org/docs/stable/multiprocessing.html#sharing-cuda-tensors for +details. + +Typical example:: + + from cellfinder.core.tools.threading import ThreadWithException, \\ + EOFSignal, ProcessWithException + import torch + + + def worker(thread: ThreadWithException, power: float): + while True: + # if the main thread wants us to exit, it'll wake us up + msg = thread.get_msg_from_mainthread() + # we were asked to exit + if msg == EOFSignal: + return + + tensor_id, tensor, add = msg + # tensors are memory mapped for subprocess (and obv threads) so we + # can do the work inplace and result will be visible in main thread + tensor += add + tensor.pow_(power) + + # we should not share a tensor shared with us as per pytorch docs, + # just send the id back + thread.send_msg_to_mainthread(tensor_id) + + # we can also handle errors here, which will be re-raised in the main + # process + if tensor_id == 7: + raise ValueError("I fell asleep") + + + if __name__ == "__main__": + data = torch.rand((5, 10)) + thread = ThreadWithException( + target=worker, args=(2.5,), pass_self=True + ) + # use thread or sub-process + # thread = ProcessWithException( + # target=worker, args=(2.5,), pass_self=True + # ) + thread.start() + + try: + for i in range(10): + thread.send_msg_to_thread((i, data, i / 2)) + # if the thread raises an exception, get_msg_from_thread will + # re-raise it here + msg = thread.get_msg_from_thread() + if msg == EOFSignal: + # thread exited for whatever reason (not exception) + break + + print(f"Thread processed tensor {i}") + finally: + # whatever happens, make sure thread is told to finish so it + # doesn't get stuck + thread.notify_to_end_thread() + thread.join() + +When run, this prints:: + + Thread processed tensor 0 + Thread processed tensor 1 + Thread processed tensor 2 + Thread processed tensor 3 + Thread processed tensor 4 + Thread processed tensor 5 + Thread processed tensor 6 + Thread processed tensor 7 + Traceback (most recent call last): + File "threading.py", line 139, in user_func_runner + self.user_func(self, *self.args) + File "play.py", line 24, in worker + raise ValueError("I fell asleep") + ValueError: I fell asleep + + The above exception was the direct cause of the following exception: + + Traceback (most recent call last): + File "play.py", line 38, in + msg = thread.get_msg_from_thread() + File "threading.py", line 203, in get_msg_from_thread + raise ExecutionFailure( + ExecutionFailure: Reporting failure from other thread/process +""" + +from queue import Queue +from threading import Thread +from typing import Any, Callable, Optional + +import torch.multiprocessing as mp + + +class EOFSignal: + """ + This class object (not instance) is returned by the thread/process as an + indicator that someone exited or that you should exit. + """ + + pass + + +class ExecutionFailure(Exception): + """ + Exception class raised in the main thread when the function running in the + thread/process raises an exception. + + Get the original exception using the `__cause__` property of this + exception. + """ + + pass + + +class ExceptionWithQueueMixIn: + """ + A mix-in that can be used with a secondary thread or sub-process to + facilitate communication with them and the passing back of exceptions and + signals such as when the thread/sub-process exits or when the main + thread/process want the sub-process/thread to end. + + Communication happens bi-directionally via queues. These queues are not + limited in buffer size, so if there's potential for the queue to be backed + up because the main thread or sub-thread/process is unable to read and + process messages quickly enough, you must implement some kind of + back-pressure to prevent the queue from growing unlimitedly in size. + E.g. using a limited number of tokens that must be held to send a message + to prevent a sender from thrashing the queue. And these tokens can be + passed between main thread to sub-thread/process etc on each message and + when the message is done processing. + + The main process uses `get_msg_from_thread` to raise exceptions in the + main thread that occurred in the sub-thread/process. So that must be + called in the main thread if you want to know if the sub-thread/process + exited. + """ + + to_thread_queue: Queue + """ + Internal queue used to send messages to the thread from the main thread. + """ + + from_thread_queue: Queue + """ + Internal queue used to send messages from the thread to the main thread. + """ + + args: tuple = () + """ + The args provided by the caller that will be passed to the function running + in the thread/process when it starts. + """ + + # user_func_runner must end with an eof msg to the main thread. This tracks + # whether the main thread saw the eof. If it didn't, we know there are more + # msgs in the queue waiting to be read + _saw_eof: bool = False + + def __init__(self, target: Callable, pass_self: bool = False): + self.user_func = target + self.pass_self = pass_self + + def user_func_runner(self) -> None: + """ + The internal function that runs the target function provided by the + user. + """ + try: + if self.user_func is not None: + if self.pass_self: + self.user_func(self, *self.args) + else: + self.user_func(*self.args) + except BaseException as e: + self.from_thread_queue.put( + ("exception", e), block=True, timeout=None + ) + finally: + self.from_thread_queue.put(("eof", None), block=True, timeout=None) + + def send_msg_to_mainthread(self, value: Any) -> None: + """ + Sends the `value` to the main thread, from the sub-thread/process. + + The value must be pickleable if it's running in a sub-process. The main + thread then can read it using `get_msg_from_thread`. + + The queue is not limited in size, so if there's potential for the + queue to be backed up because the main thread is unable to read quickly + enough, you must implement some kind of back-pressure to prevent + the queue from growing unlimitedly in size. + """ + self.from_thread_queue.put(("user", value), block=True, timeout=None) + + def clear_remaining(self) -> None: + """ + Celled in the main-thread as part of cleanup when we expect the + secondary thread to have exited (e.g. we sent it a message telling it + to). + + It will drop any waiting messages sent by the secondary thread, but + more importantly, it will handle exceptions raised in the secondary + thread before it exited, that may not have yet been processed in the + main thread (e.g. we stopped listening to messages from the secondary + thread before we got an eof from it). + """ + while not self._saw_eof: + self.get_msg_from_thread() + + def get_msg_from_thread(self, timeout: Optional[float] = None) -> Any: + """ + Gets a message from the sub-thread/process sent to the main thread. + + This blocks forever until the message is sent by the sub-thread/process + and received by us. If `timeout` is not None, that's how long we block + here, in seconds, before raising an `Empty` Exception if no message was + received by then. + + If the return value is the `EOFSignal` object, it means the + sub-thread/process has or is about to exit. + + If the sub-thread/process has raised an exception, that exception is + caught and re-raised in the main thread when this method is called. + The exception raised is an `ExecutionFailure` and its `__cause__` + property is the original exception raised in the sub-thread/process. + + A typical pattern is:: + + >>> try: + ... msg = thread.get_msg_from_thread() + ... if msg == EOFSignal: + ... # thread exited + ... pass + ... else: + ... # do something with the msg + ... pass + ... except ExecutionFailure as e: + ... print(f"got exception {type(e.__cause__)}") + ... print(f"with message {e.__cause__.args[0]}") + """ + msg, value = self.from_thread_queue.get(block=True, timeout=timeout) + if msg == "eof": + self._saw_eof = True + return EOFSignal + if msg == "exception": + raise ExecutionFailure( + "Reporting failure from other thread/process" + ) from value + + return value + + def send_msg_to_thread(self, value: Any) -> None: + """ + Sends the `value` to the sub-thread/process, from the main thread. + + The value must be pickleable if it's sent to a sub-process. The thread + then can read it using `get_msg_from_mainthread`. + + The queue is not limited in size, so if there's potential for the + queue to be backed up because the thread is unable to read quickly + enough, you must implement some kind of back-pressure to prevent + the queue from growing unlimitedly in size. + """ + self.to_thread_queue.put(("user", value), block=True, timeout=None) + + def notify_to_end_thread(self) -> None: + """ + Sends a message to the sub-process/thread that the main process wants + them to end. The sub-process/thread sees it by receiving an `EOFSignal` + message from `get_msg_from_mainthread` and it should exit asap. + """ + self.to_thread_queue.put(("eof", None), block=True, timeout=None) + + def get_msg_from_mainthread(self, timeout: Optional[float] = None) -> Any: + """ + Gets a message from the main thread sent to the sub-thread/process. + + This blocks forever until the message is sent by the main thread + and received by us. If `timeout` is not None, that's how long we block + here, in seconds, before raising an `Empty` Exception if no message was + received by then. + + If the return value is the `EOFSignal` object, it means the + main thread has sent us an EOF message using `notify_to_end_thread` + because it wants us to exit. + + A typical pattern is:: + + >>> msg = thread.get_msg_from_mainthread() + ... if msg == EOFSignal: + ... # we should exit asap + ... return + ... # do something with the msg + """ + msg, value = self.to_thread_queue.get(block=True, timeout=timeout) + if msg == "eof": + return EOFSignal + + return value + + +class ThreadWithException(ExceptionWithQueueMixIn): + """ + Runs a target function in a secondary thread. + """ + + thread: Thread = None + """The thread running the function.""" + + def __init__(self, target, args=(), **kwargs): + super().__init__(target=target, **kwargs) + self.to_thread_queue = Queue(maxsize=0) + self.from_thread_queue = Queue(maxsize=0) + self.args = args + self.thread = Thread(target=self.user_func_runner) + + def start(self) -> None: + """Starts the thread that runs the target function.""" + self.thread.start() + + def join(self, timeout: Optional[float] = None) -> None: + """ + Waits and blocks until the thread exits. If timeout is given, + it's the duration to wait, in seconds, before returning. + + To know if it exited, you need to check `is_alive` of the `thread`. + """ + self.thread.join(timeout=timeout) + + +class ProcessWithException(ExceptionWithQueueMixIn): + """ + Runs a target function in a sub-process. + + Any data sent between the processes must be pickleable. + + We run the function using `torch.multiprocessing`. Any tensors sent between + the main process and sub-process is memory mapped so that it doesn't copy + the tensor. So any edits in the main process/sub-process is seen in the + other as well. See https://pytorch.org/docs/stable/multiprocessing.html + for more details on this. + """ + + process: mp.Process = None + """The sub-process running the function.""" + + def __init__(self, target, args=(), **kwargs): + super().__init__(target=target, **kwargs) + ctx = mp.get_context("spawn") + self.to_thread_queue = ctx.Queue(maxsize=0) + self.from_thread_queue = ctx.Queue(maxsize=0) + self.process = ctx.Process(target=self.user_func_runner) + + self.args = args + + def start(self) -> None: + """Starts the sub-process that runs the target function.""" + self.process.start() + + def join(self, timeout: Optional[float] = None) -> None: + """ + Waits and blocks until the process exits. If timeout is given, + it's the duration to wait, in seconds, before returning. + + To know if it exited, you need to check `is_alive` of the `process`. + """ + self.process.join(timeout=timeout) diff --git a/cellfinder/core/tools/tools.py b/cellfinder/core/tools/tools.py index f7a628d5..231515de 100644 --- a/cellfinder/core/tools/tools.py +++ b/cellfinder/core/tools/tools.py @@ -1,19 +1,141 @@ +from functools import wraps from random import getrandbits, uniform -from typing import Optional +from typing import Callable, Optional, Type import numpy as np +import torch from natsort import natsorted -def get_max_possible_value(obj_in: np.ndarray) -> int: +def inference_wrapper(func): """ - Returns the maximum allowed value for a numpy array of integer data type. + Decorator that makes the decorated function/method run with + `torch.inference_mode` set to True. + """ + + @wraps(func) + def inner_function(*args, **kwargs): + with torch.inference_mode(True): + return func(*args, **kwargs) + + return inner_function + + +def get_max_possible_int_value(dtype: Type[np.number]) -> int: + """ + Returns the maximum allowed integer for a numpy array of given type. + + If dtype is of integer type, it's the maximum value. If it's a floating + type, it's the maximum integer that can be accurately represented. + E.g. for float32, only integers up to 2**24 can be represented (due to + the number of bits representing the mantissa (significand). """ - dtype = obj_in.dtype if np.issubdtype(dtype, np.integer): return np.iinfo(dtype).max - else: - raise ValueError("obj_in must be a numpy array of integer data type.") + if np.issubdtype(dtype, np.floating): + mant = np.finfo(dtype).nmant + return 2**mant + raise ValueError("datatype must be of integer or floating data type") + + +def get_min_possible_int_value(dtype: Type[np.number]) -> int: + """ + Returns the minimum allowed integer for a numpy array of given type. + + If dtype is of integer type, it's the minimum value. If it's a floating + type, it's the minimum integer that can be accurately represented. + E.g. for float32, only integers up to -2**24 can be represented (due to + the number of bits representing the mantissa (significand). + """ + if np.issubdtype(dtype, np.integer): + return np.iinfo(dtype).min + if np.issubdtype(dtype, np.floating): + mant = np.finfo(dtype).nmant + # the sign bit is separate so we have the full mantissa for value + return -(2**mant) + raise ValueError("datatype must be of integer or floating data type") + + +def get_data_converter( + src_dtype: Type[np.number], dest_dtype: Type[np.floating] +) -> Callable[[np.ndarray], np.ndarray]: + """ + Returns a function that can be called to convert one data-type to another, + scaling the data down as needed. + + If the maximum value supported by the input data-type is smaller than that + supported by destination data-type, the data will be scaled by the ratio + of maximum integer representable by the `output / input` data-types. + If the max is equal or less, it's simply converted to the target type. + + Parameters + ---------- + src_dtype : np.dtype + The data-type of the input data. + dest_dtype : np.dtype + The data-type of the returned data. Currently, it must be a floating + type and `np.float32` or `np.float64`. + + Returns + ------- + callable: function + A function that takes a single input data parameter and returns + the converted data. + """ + if not np.issubdtype(dest_dtype, np.float32) and not np.issubdtype( + dest_dtype, np.float64 + ): + raise ValueError( + f"Destination dtype must be a float32 or float64, " + f"but it is {dest_dtype}" + ) + + in_min = get_min_possible_int_value(src_dtype) + in_max = get_max_possible_int_value(src_dtype) + out_min = get_min_possible_int_value(dest_dtype) + out_max = get_max_possible_int_value(dest_dtype) + in_abs_max = max(in_max, abs(in_min)) + out_abs_max = max(out_max, abs(out_min)) + + def unchanged(data: np.ndarray) -> np.ndarray: + return np.asarray(data) + + def float_to_float_scale_down(data: np.ndarray) -> np.ndarray: + return ((np.asarray(data) / in_abs_max) * out_abs_max).astype( + dest_dtype + ) + + def int_to_float_scale_down(data: np.ndarray) -> np.ndarray: + # data must fit in float64 + data = np.asarray(data).astype(np.float64) + return ((data / in_abs_max) * out_abs_max).astype(dest_dtype) + + def to_float_unscaled(data: np.ndarray) -> np.ndarray: + return np.asarray(data).astype(dest_dtype) + + if src_dtype == dest_dtype: + return unchanged + + # out can hold the largest in values - just convert to float + if out_min <= in_min < in_max <= out_max: + return to_float_unscaled + + # need to scale down before converting to float + if np.issubdtype(src_dtype, np.integer): + # if going to float32 and it didn't fit input must fit in 64-bit float + # so we can temp store it there to scale. If going to float64, same. + if in_max > get_max_possible_int_value( + np.float64 + ) or in_min < get_min_possible_int_value(np.float64): + raise ValueError( + f"The input datatype {src_dtype} cannot fit in a " + f"64-bit float" + ) + return int_to_float_scale_down + + # for float input, however big it is, we can always scale it down in the + # input data type before changing type + return float_to_float_scale_down def union(a, b): diff --git a/pyproject.toml b/pyproject.toml index af57ef29..feb57af9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,7 @@ dependencies = [ "scikit-image", "scikit-learn", "keras>=3.5.0", - "torch>=2.1.0", + "torch>=2.1.0,!=2.4.0", "tifffile", "tqdm", ] @@ -52,6 +52,7 @@ dev = [ "pytest-timeout", "pytest", "tox", + "pooch >= 1", ] napari = [ "brainglobe-napari-io", @@ -130,10 +131,12 @@ setenv = KERAS_BACKEND = torch passenv = NUMBA_DISABLE_JIT + PYTORCH_JIT CI GITHUB_ACTIONS DISPLAY XAUTHORITY NUMPY_EXPERIMENTAL_ARRAY_FUNCTION PYVISTA_OFF_SCREEN + BRAINGLOBE_TEST_DATA_DIR """ diff --git a/tests/core/conftest.py b/tests/core/conftest.py index d78e5cf5..eb4d1e8a 100644 --- a/tests/core/conftest.py +++ b/tests/core/conftest.py @@ -1,7 +1,9 @@ import os -from typing import Tuple +from pathlib import Path +from typing import List, Tuple import numpy as np +import pooch import pytest import torch.backends.mps from skimage.filters import gaussian @@ -27,6 +29,61 @@ def set_device_arm_macos_ci(): force_cpu() +def pytest_collection_modifyitems(session, config, items: List[pytest.Item]): + # this hook is called by pytest after test collection. Move the + # test_detection test to the end because if it's run in the middle we run + # into numba issue #9576 and the tests fail + # end_files are moved to the end, in the given order + end_files = [ + "test_connected_components_labelling.py", + "test_structure_detection.py", + ] + + items_new = [t for t in items if t.path.name not in end_files] + for name in end_files: + items_new.extend([t for t in items if t.path.name == name]) + + items[:] = items_new + + +@pytest.fixture +def test_data_registry(): + """ + Create a test data registry for BrainGlobe. + + Returns: + pooch.Pooch: The test data registry object. + + """ + registry = pooch.create( + path=pooch.os_cache("brainglobe_test_data"), + base_url="https://gin.g-node.org/BrainGlobe/test-data/raw/master/cellfinder/", + env="BRAINGLOBE_TEST_DATA_DIR", + ) + + registry.load_registry( + Path(__file__).parent.parent / "data" / "pooch_registry.txt" + ) + return registry + + +def mark_sphere( + data_zyx: np.ndarray, center_xyz, radius: int, fill_value: int +) -> None: + shape_zyx = data_zyx.shape + + z, y, x = np.mgrid[ + 0 : shape_zyx[0] : 1, 0 : shape_zyx[1] : 1, 0 : shape_zyx[2] : 1 + ] + dist = np.sqrt( + (x - center_xyz[0]) ** 2 + + (y - center_xyz[1]) ** 2 + + (z - center_xyz[2]) ** 2 + ) + # 100 seems to be the right size so std is not too small for filters + data_zyx[dist <= radius] = fill_value + + @pytest.fixture(scope="session") def no_free_cpus() -> int: """ @@ -85,3 +142,70 @@ def synthetic_bright_spots() -> Tuple[np.ndarray, np.ndarray]: background_array = np.zeros_like(signal_array) return signal_array, background_array + + +@pytest.fixture(scope="session") +def synthetic_single_spot() -> ( + Tuple[np.ndarray, np.ndarray, Tuple[int, int, int]] +): + """ + Creates a synthetic signal array with a single spherical spot + in a 3d numpy array to be used for cell detection testing. + + The max value is 100 and min is zero. The array is a floating type. + You must convert it to the right data type for your tests. + Also, `n_sds_above_mean_thresh` must be 1 or larger. + """ + shape_zyx = 20, 50, 50 + c_xyz = 25, 25, 10 + + signal_array = np.zeros(shape_zyx) + background_array = np.zeros_like(signal_array) + mark_sphere(signal_array, center_xyz=c_xyz, radius=2, fill_value=100) + + # 1 std should be larger, so it can be considered bright + assert np.mean(signal_array) + np.std(signal_array) > 1 + + return signal_array, background_array, c_xyz + + +@pytest.fixture(scope="session") +def synthetic_spot_clusters() -> ( + Tuple[np.ndarray, np.ndarray, List[Tuple[int, int, int]]] +): + """ + Creates a synthetic signal array with a 4 overlapping spherical spots + in a 3d numpy array to be used for cell cluster splitting testing. + + The max value is 100 and min is zero. The array is a floating type. + You must convert it to the right data type for your tests. + Also, `n_sds_above_mean_thresh` must be 1 or larger. + """ + shape_zyx = 20, 100, 100 + radius = 5 + s = 50 - radius * 4 + centers_xyz = [ + (s, 50, 10), + (s + 2 * radius - 1, 50, 10), + (s + 4 * radius - 2, 50, 10), + (s + 6 * radius - 3, 50, 10), + ] + + signal_array = np.zeros(shape_zyx) + background_array = np.zeros_like(signal_array) + + for center in centers_xyz: + mark_sphere( + signal_array, center_xyz=center, radius=radius, fill_value=100 + ) + + return signal_array, background_array, centers_xyz + + +@pytest.fixture(scope="session") +def repo_data_path() -> Path: + """ + The root path where the data used during test is stored + """ + # todo: use mod relative paths to find data instead of depending on cwd + return Path(__file__).parent.parent / "data" diff --git a/tests/core/test_integration/test_detection.py b/tests/core/test_integration/test_detection.py index fc7bf2f3..63a5fc41 100644 --- a/tests/core/test_integration/test_detection.py +++ b/tests/core/test_integration/test_detection.py @@ -4,9 +4,13 @@ import brainglobe_utils.IO.cells as cell_io import numpy as np import pytest +import torch +from brainglobe_utils.cells.cells import Cell from brainglobe_utils.general.system import get_num_processes from brainglobe_utils.IO.image.load import read_with_dask +from cellfinder.core.detect.detect import main as detect_main +from cellfinder.core.detect.filters.volume.ball_filter import InvalidVolume from cellfinder.core.main import main data_dir = os.path.join( @@ -133,9 +137,14 @@ def detect_finished_callback(points): classify_callback=classify_callback, detect_finished_callback=detect_finished_callback, n_free_cpus=no_free_cpus, + ball_z_size=15, ) - np.testing.assert_equal(planes_done, np.arange(len(signal_array))) + skipped_planes = int(round(15 / voxel_sizes[0])) - 1 + skip_start = skipped_planes // 2 + skip_end = skipped_planes - skip_start + n = len(signal_array) - skip_end + np.testing.assert_equal(planes_done, np.arange(skip_start, n)) np.testing.assert_equal(batches_classified, [0]) ncalls = len(points_found) @@ -144,12 +153,6 @@ def detect_finished_callback(points): assert npoints == 120, f"Expected 120 points, found {npoints}" -def test_floating_point_error(signal_array, background_array): - signal_array = signal_array.astype(float) - with pytest.raises(ValueError, match="signal_array must be integer"): - main(signal_array, background_array, voxel_sizes) - - def test_synthetic_data(synthetic_bright_spots, no_free_cpus): signal_array, background_array = synthetic_bright_spots detected = main( @@ -174,3 +177,141 @@ def test_data_dimension_error(ndim): with pytest.raises(ValueError, match="Input data must be 3D"): main(signal_array, background_array, voxel_sizes) + + +@pytest.mark.parametrize("device", ["cuda", "cpu"]) +@pytest.mark.parametrize( + "dtype", + [ + np.uint8, + np.uint16, + np.uint32, + np.int8, + np.int16, + np.int32, + np.float32, + np.float64, + ], +) +def test_signal_data_types(synthetic_single_spot, no_free_cpus, dtype, device): + + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("Cuda is not available") + + signal_array, background_array, center = synthetic_single_spot + signal_array = signal_array.astype(dtype) + # for signed ints, make some data negative + if np.issubdtype(dtype, np.signedinteger): + # min of signal_array is zero + assert np.isclose(0, np.min(signal_array)) + shift = (np.max(signal_array) - np.min(signal_array)) // 2 + signal_array = signal_array - shift + + background_array = background_array.astype(dtype) + detected = main( + signal_array, + background_array, + n_sds_above_mean_thresh=1.0, + voxel_sizes=voxel_sizes, + n_free_cpus=no_free_cpus, + skip_classification=True, + classification_torch_device=device, + ) + + assert len(detected) == 1 + assert detected[0] == Cell(center, Cell.UNKNOWN) + + +@pytest.mark.parametrize("use_scipy", [True, False]) +@pytest.mark.parametrize("device", ["cuda", "cpu"]) +def test_detection_scipy_torch( + synthetic_single_spot, no_free_cpus, use_scipy, device +): + + if device == "cuda" and not torch.cuda.is_available(): + pytest.xfail("Cuda is not available") + + signal_array, background_array, center = synthetic_single_spot + signal_array = signal_array.astype(np.float32) + + detected = detect_main( + signal_array, + n_sds_above_mean_thresh=1.0, + voxel_sizes=voxel_sizes, + n_free_cpus=no_free_cpus, + torch_device=device, + use_scipy=use_scipy, + ) + + assert len(detected) == 1 + assert detected[0] == Cell(center, Cell.UNKNOWN) + + +@pytest.mark.parametrize("device", ["cuda", "cpu"]) +def test_detection_cluster_splitting( + synthetic_spot_clusters, no_free_cpus, device +): + """ + Test cluster splitting for overlapping cells. + + Test filtering/detection on cpu and cuda. Because splitting is only on cpu + so make sure if detection is on cuda, splitting still works. + """ + + if device == "cuda" and not torch.cuda.is_available(): + pytest.xfail("Cuda is not available") + + signal_array, background_array, centers_xyz = synthetic_spot_clusters + signal_array = signal_array.astype(np.float32) + + detected = detect_main( + signal_array, + n_sds_above_mean_thresh=1.0, + voxel_sizes=voxel_sizes, + n_free_cpus=no_free_cpus, + torch_device=device, + ) + + assert len(detected) == len(centers_xyz) + for cell, center in zip(detected, centers_xyz): + p = [cell.x, cell.y, cell.z] + d = np.sqrt(np.sum(np.square(np.subtract(center, p)))) + assert d <= 3 + assert cell.type == Cell.UNKNOWN + + +def test_detection_cell_too_large(synthetic_spot_clusters, no_free_cpus): + """ + Test we detect one big artifact if the signal has a too large foreground + structure. + """ + # max_cell_volume is volume of soma * spread sphere. For values below + # radius is 7 pixels. So volume is ~1500 pixels + signal_array = np.zeros((15, 100, 100), dtype=np.float32) + # set volume larger than max volume to bright + signal_array[6 : 6 + 6, 40 : 40 + 26, 40 : 40 + 26] = 1000 + + detected = detect_main( + signal_array, + n_sds_above_mean_thresh=1.0, + voxel_sizes=(5, 2, 2), + soma_diameter=20, + n_free_cpus=no_free_cpus, + max_cluster_size=2000 * 5 * 2 * 2, + ) + + assert len(detected) == 1 + # not sure why it subtracts one to center, but probably rounding + assert detected[0] == Cell([39 + 13, 39 + 13, 5 + 3], Cell.ARTIFACT) + + +@pytest.mark.parametrize("y,x", [(100, 30), (30, 100)]) +def test_detection_plane_too_small(synthetic_spot_clusters, y, x): + # plane smaller than ball filter kernel should cause error + with pytest.raises(InvalidVolume): + detect_main( + np.zeros((5, y, x)), + n_sds_above_mean_thresh=1.0, + voxel_sizes=(1, 1, 1), + ball_xy_size=50, + ) diff --git a/tests/core/test_integration/test_detection_structure_splitting.py b/tests/core/test_integration/test_detection_structure_splitting.py index 31f04672..5fe74a77 100644 --- a/tests/core/test_integration/test_detection_structure_splitting.py +++ b/tests/core/test_integration/test_detection_structure_splitting.py @@ -6,38 +6,45 @@ real life data. """ -import os - import numpy as np import pytest from brainglobe_utils.IO.image.load import read_with_dask +from cellfinder.core.detect.filters.setup_filters import DetectionSettings from cellfinder.core.detect.filters.volume.structure_splitting import ( split_cells, ) from cellfinder.core.main import main -data_dir = os.path.join( - os.getcwd(), "tests", "data", "integration", "detection" -) -signal_data_path = os.path.join(data_dir, "structure_split_test", "signal") -background_data_path = os.path.join( - data_dir, "structure_split_test", "background" -) - voxel_sizes = [5, 2.31, 2.31] @pytest.fixture -def signal_array(): +def signal_array(repo_data_path): """A signal array that contains a structure that needs splitting.""" - return read_with_dask(signal_data_path) + return read_with_dask( + str( + repo_data_path + / "integration" + / "detection" + / "structure_split_test" + / "signal" + ) + ) @pytest.fixture -def background_array(): +def background_array(repo_data_path): """A background array that contains a structure that needs splitting.""" - return read_with_dask(background_data_path) + return read_with_dask( + str( + repo_data_path + / "integration" + / "detection" + / "structure_split_test" + / "background" + ) + ) def test_structure_splitting(signal_array, background_array): @@ -75,7 +82,16 @@ def test_underflow_issue_435(): bright_voxels[np.logical_or(inside1, inside2)] = True bright_indices = np.argwhere(bright_voxels) - centers = split_cells(bright_indices) + settings = DetectionSettings( + plane_shape=(100, 100), + plane_original_np_dtype=np.float32, + voxel_sizes=(1, 1, 1), + ball_xy_size_um=3, + ball_z_size_um=3, + ball_overlap_fraction=0.8, + soma_diameter_um=7, + ) + centers = split_cells(bright_indices, settings) # for some reason, same with pytorch, it's shifted by 1. Probably rounding expected = {(10, 11, 11), (20, 11, 11)} diff --git a/tests/core/test_unit/test_detect/test_detect.py b/tests/core/test_unit/test_detect/test_detect.py index b74d2ed9..0ad3291c 100644 --- a/tests/core/test_unit/test_detect/test_detect.py +++ b/tests/core/test_unit/test_detect/test_detect.py @@ -1,23 +1,123 @@ -import multiprocessing +from unittest.mock import MagicMock -from cellfinder.core.detect.detect import _map_with_locks +import numpy as np +import pytest +from pytest_mock.plugin import MockerFixture +from cellfinder.core.detect.detect import main -def add_one(a: int) -> int: - return a + 1 +@pytest.fixture +def mocked_main(mocker: MockerFixture): + from cellfinder.core.detect.filters.volume.volume_filter import ( + VolumeFilter, + ) -def test_map_with_locks(): - args = [1, 2, 3, 2, 10] + process = mocker.patch.object(VolumeFilter, "process", autospec=True) + get_results = mocker.patch.object( + VolumeFilter, "get_results", autospec=True + ) - with multiprocessing.Pool(2) as worker_pool: - result_queue, locks = _map_with_locks(add_one, args, worker_pool) + return process, get_results - async_results = [result_queue.get() for _ in range(len(args))] - assert len(async_results) == len(locks) == len(args) - for lock in locks: - lock.release() +def test_main_bad_signal_arg(mocked_main): + # should work + main(signal_array=np.empty((5, 50, 50))) - results = [res.get() for res in async_results] - assert results == [2, 3, 4, 3, 11] + with pytest.raises(ValueError): + main(signal_array=np.empty((1, 1, 1, 1))) + + with pytest.raises(ValueError): + main(signal_array=np.empty((1, 1))) + + with pytest.raises(TypeError): + main(signal_array=np.empty((5, 50, 50), dtype=np.str_)) + + with pytest.raises(TypeError): + main(signal_array=np.empty((5, 50, 50), dtype=np.uint64)) + + with pytest.raises(TypeError): + main(signal_array=np.empty((5, 50, 50), dtype=np.int64)) + + +@pytest.mark.parametrize( + "dtype", + [ + np.uint8, + np.uint16, + np.uint32, + np.int8, + np.int16, + np.int32, + np.float32, + np.float64, + ], +) +def test_main_good_signal_arg(mocked_main, dtype): + main(signal_array=np.empty((5, 50, 50))) + + +def test_main_bad_or_default_args(mocked_main): + process, get_results = mocked_main + main( + signal_array=np.empty((5, 8, 19), dtype=np.uint16), + end_plane=-1, + batch_size=None, + torch_device="cpu", + ) + process.assert_called() + get_results.assert_called() + + vol_filter, mp_tile_processor, signal_array = process.call_args.args + ( + _, + splitting_settings, + ) = get_results.call_args.args + settings = vol_filter.settings + + assert settings.plane_shape == (8, 19) + assert settings.plane_original_np_dtype == np.uint16 + # for uint16 input we should use float32 + assert settings.filtering_dtype == np.float32 + assert settings.detection_dtype == np.uint64 + assert settings.end_plane == 5 + assert settings.n_planes == 5 + assert settings.batch_size == 4 # cpu default is 4 + + assert splitting_settings.torch_device == "cpu" + + +def test_main_planes_size(mocked_main): + process, get_results = mocked_main + main(signal_array=np.empty((5, 8, 19)), end_plane=4, start_plane=1) + + process.assert_called() + vol_filter, mp_tile_processor, signal_array = process.call_args.args + settings = vol_filter.settings + + assert settings.plane_shape == (8, 19) + assert settings.end_plane == 4 + assert settings.n_planes == 3 + + +def test_main_splitting_cpu_cuda(mocker: MockerFixture): + # checks that even if main filtering runs on cuda, the structure splitting + # only runs on cpu + # patch anything that would do with cuda - in case there's no cuda + vol: MagicMock = mocker.patch( + "cellfinder.core.detect.detect.VolumeFilter", autospec=True + ) + mocker.patch("cellfinder.core.detect.detect.TileProcessor", autospec=True) + + main( + signal_array=np.empty((5, 8, 19)), batch_size=None, torch_device="cuda" + ) + + settings = vol.call_args.kwargs["settings"] + (splitting_settings,) = vol.return_value.get_results.call_args.args + + assert settings.torch_device == "cuda" + assert settings.batch_size == 1 # cuda default is 1 + + assert splitting_settings.torch_device == "cpu" diff --git a/tests/core/test_unit/test_detect/test_filters/test_plane_filters/test_classical_filters.py b/tests/core/test_unit/test_detect/test_filters/test_plane_filters/test_classical_filters.py new file mode 100644 index 00000000..181c912c --- /dev/null +++ b/tests/core/test_unit/test_detect/test_filters/test_plane_filters/test_classical_filters.py @@ -0,0 +1,245 @@ +import numpy as np +import pytest +import torch +from brainglobe_utils.IO.image.load import read_with_dask + +from cellfinder.core.detect.filters.plane import TileProcessor +from cellfinder.core.detect.filters.plane.classical_filter import PeakEnhancer +from cellfinder.core.detect.filters.plane.tile_walker import TileWalker +from cellfinder.core.detect.filters.setup_filters import DetectionSettings +from cellfinder.core.tools.IO import fetch_pooch_directory +from cellfinder.core.tools.tools import ( + get_max_possible_int_value, + inference_wrapper, +) + + +def load_pooch_dir(test_data_registry, path): + data_path = fetch_pooch_directory(test_data_registry, path) + return read_with_dask(data_path) + + +@pytest.mark.parametrize( + "signal,enhanced,soma_diameter", + [ + ("edge_cells_brain/signal", "edge_cells_brain/peak_enhanced", 16), + ("bright_brain/signal", "bright_brain/peak_enhanced", 30), + ], +) +@pytest.mark.parametrize( + "torch_device,use_scipy", [("cpu", False), ("cpu", True), ("cuda", False)] +) +@inference_wrapper +def test_2d_filtering_peak_enhance_parity( + signal, + enhanced, + soma_diameter, + torch_device, + use_scipy, + test_data_registry, +): + # test that the pure 2d plane filtering (median, gauss, laplacian) matches + # exactly. We use float64 in the test, like original code we compare to + # used + if torch_device == "cuda" and not torch.cuda.is_available(): + pytest.skip("Cuda is not available") + + # check input data size/type is as expected + data = np.asarray(load_pooch_dir(test_data_registry, signal)) + enhanced = np.asarray(load_pooch_dir(test_data_registry, enhanced)) + assert data.dtype == np.uint16 + assert enhanced.dtype == np.uint32 + assert data.shape == enhanced.shape + + # convert to working type and send to cpu/cuda, use float64, the type + # used originally + data = data.astype(np.float64) + data = torch.from_numpy(data) + data = data.to(torch_device) + + # first check that the raw filters produce the same output + clip = get_max_possible_int_value(np.uint32) + enhancer = PeakEnhancer( + torch_device=torch_device, + dtype=torch.float64, + clipping_value=clip, + laplace_gaussian_sigma=soma_diameter * 0.2, + use_scipy=use_scipy, + ) + enhanced_our = enhancer.enhance_peaks(data) + enhanced_our = enhanced_our.cpu().numpy().astype(np.uint32) + + assert enhanced_our.shape == enhanced.shape + # the number of pixels per plane that are different + different = np.sum( + np.sum(np.logical_not(np.isclose(enhanced_our, enhanced)), axis=2), + axis=1, + ) + assert np.all(different == 0) + + +@pytest.mark.parametrize( + "signal,filtered,tiles,soma_diameter,max_different", + [ + ( + "edge_cells_brain/signal", + "edge_cells_brain/2d_filter", + "edge_cells_brain/tiles", + 16, + 1, + ), + ( + "bright_brain/signal", + "bright_brain/2d_filter", + "bright_brain/tiles", + 30, + 2, + ), + ], +) +@pytest.mark.parametrize( + "torch_device,use_scipy", [("cpu", False), ("cpu", True), ("cuda", False)] +) +@inference_wrapper +def test_2d_filtering_parity( + signal, + filtered, + tiles, + soma_diameter, + max_different, + torch_device, + use_scipy, + test_data_registry, +): + # test that the pixels marked as bright after 2d plane filtering matches + # now we don't always use float64, but the bright pixels will still stay + # the same. Unlike test_2d_filtering_peak_enhance, we use float32 if the + # input data fits in it, like we do in the codebase. So we want to be sure + # that the number of bright pixels doesn't change much. Because running at + # full float64 is expensive + if torch_device == "cuda" and not torch.cuda.is_available(): + pytest.skip("Cuda is not available") + + # check input data size/type is as expected + data = np.asarray(load_pooch_dir(test_data_registry, signal)) + filtered = np.asarray(load_pooch_dir(test_data_registry, filtered)) + tiles = np.asarray(load_pooch_dir(test_data_registry, tiles)) + assert data.dtype == np.uint16 + assert filtered.dtype == np.uint16 + assert data.shape == filtered.shape + + settings = DetectionSettings(plane_original_np_dtype=np.uint16) + # convert to working type and send to cpu/cuda + data = torch.from_numpy(settings.filter_data_converter_func(data)) + data = data.to(torch_device) + + tile_processor = TileProcessor( + plane_shape=data[0, :, :].shape, + clipping_value=settings.clipping_value, + threshold_value=settings.threshold_value, + soma_diameter=soma_diameter, + log_sigma_size=0.2, + n_sds_above_mean_thresh=10, + torch_device=torch_device, + dtype=settings.filtering_dtype.__name__, + use_scipy=use_scipy, + ) + + # apply filter and get data back + filtered_our, tiles_our = tile_processor.get_tile_mask(data) + filtered_our = filtered_our.cpu().numpy().astype(np.uint16) + tiles_our = tiles_our.cpu().numpy() + + assert filtered_our.shape == filtered.shape + # we don't care about exact pixel values, only which pixels are marked + # bright and which aren't. Bright per plane + bright = np.sum( + np.sum(filtered == settings.threshold_value, axis=2), axis=1 + ) + bright_our = np.sum( + np.sum(filtered_our == settings.threshold_value, axis=2), axis=1 + ) + # the number of pixels different should be less than 2! + assert np.all(np.less(np.abs(bright - bright_our), max_different + 1)) + + # the in/out of brain tiles though should be identical + assert tiles_our.shape == tiles.shape + assert tiles_our.dtype == tiles.dtype + assert np.array_equal(tiles_our, tiles) + + +@pytest.mark.parametrize( + "plane_size", + [(1, 2), (2, 1), (2, 2), (2, 3), (3, 3), (2, 5), (22, 33), (200, 200)], +) +@inference_wrapper +def test_2d_filter_padding(plane_size): + # check that filter padding works correctly for different sized inputs - + # even if the input is smaller than filter sizes + settings = DetectionSettings(plane_original_np_dtype=np.uint16) + data = np.random.randint(0, 500, size=(1, *plane_size)) + data = data.astype(settings.filtering_dtype) + + tile_processor = TileProcessor( + plane_shape=plane_size, + clipping_value=settings.clipping_value, + threshold_value=settings.threshold_value, + soma_diameter=16, + log_sigma_size=0.2, + n_sds_above_mean_thresh=10, + torch_device="cpu", + dtype=settings.filtering_dtype.__name__, + use_scipy=False, + ) + + filtered, _ = tile_processor.get_tile_mask(torch.from_numpy(data)) + filtered = filtered.numpy() + assert filtered.shape == data.shape + + +@inference_wrapper +def test_even_filter_kernel(): + with pytest.raises(ValueError): + try: + n = PeakEnhancer.median_filter_size + PeakEnhancer.median_filter_size = 4 + PeakEnhancer( + "cpu", + torch.float32, + clipping_value=5, + laplace_gaussian_sigma=3.0, + use_scipy=False, + ) + finally: + PeakEnhancer.median_filter_size = n + + enhancer = PeakEnhancer( + "cpu", + torch.float32, + clipping_value=5, + laplace_gaussian_sigma=3.0, + use_scipy=False, + ) + + assert enhancer.gaussian_filter_size % 2 + + _, _, x, y = enhancer.lap_kernel.shape + assert x % 2, "Should be odd" + assert y % 2, "Should be odd" + assert x == y + + +@pytest.mark.parametrize( + "sizes", + [((1, 1), (1, 1)), ((1, 2), (1, 1)), ((2, 1), (1, 1)), ((22, 33), (3, 4))], +) +@inference_wrapper +def test_tile_walker_size(sizes, soma_diameter=5): + plane_size, tile_size = sizes + walker = TileWalker(plane_size, soma_diameter=soma_diameter) + assert walker.tile_height == 10 + assert walker.tile_width == 10 + + data = torch.rand((1, *plane_size), dtype=torch.float32) + tiles = walker.get_bright_tiles(data) + assert tiles.shape == (1, *tile_size) diff --git a/tests/core/test_unit/test_detect/test_filters/test_setup_filters.py b/tests/core/test_unit/test_detect/test_filters/test_setup_filters.py new file mode 100644 index 00000000..12c2ce1d --- /dev/null +++ b/tests/core/test_unit/test_detect/test_filters/test_setup_filters.py @@ -0,0 +1,131 @@ +import pickle + +import numpy as np +import pytest + +import cellfinder.core.tools.tools as tools +from cellfinder.core.detect.filters.setup_filters import DetectionSettings + + +@pytest.mark.parametrize( + "in_dtype,filter_dtype", + [ + (np.uint8, np.float32), + (np.uint16, np.float32), + (np.uint32, np.float64), + (np.int8, np.float32), + (np.int16, np.float32), + (np.int32, np.float64), + (np.float32, np.float32), + (np.float64, np.float64), + ], +) +@pytest.mark.parametrize( + "detect_dtype", + [ + np.uint8, + np.uint16, + np.uint32, + np.uint64, + np.int8, + np.int16, + np.int32, + np.int64, + np.float32, + np.float64, + ], +) +def test_good_input_dtype(in_dtype, filter_dtype, detect_dtype): + """ + These input and filter types doesn't require any conversion because the + filter type we use for the given input type is large enough to not need + scaling. So data should be identical after conversion to filtering type. + """ + settings = DetectionSettings( + plane_original_np_dtype=in_dtype, detection_dtype=detect_dtype + ) + converter = settings.filter_data_converter_func + detection_converter = settings.detection_data_converter_func + + assert settings.filtering_dtype == filter_dtype + assert settings.detection_dtype == detect_dtype + + # min input value can be converted to filter/detection type + src = np.full( + 5, tools.get_min_possible_int_value(in_dtype), dtype=in_dtype + ) + dest = converter(src) + assert np.array_equal(src, dest) + assert dest.dtype == filter_dtype + # it is safe to do this conversion for any data type because it only uses + # the soma_centre_value, and ignores everything else (e.g. outside range) + assert detection_converter(dest).dtype == detect_dtype + + # typical input value can be converted to filter/detection type + src = np.full(5, 3, dtype=in_dtype) + dest = converter(src) + assert np.array_equal(src, dest) + assert dest.dtype == filter_dtype + assert detection_converter(dest).dtype == detect_dtype + + # max input value can be converted to filter/detection type + src = np.full( + 5, tools.get_max_possible_int_value(in_dtype), dtype=in_dtype + ) + dest = converter(src) + assert np.array_equal(src, dest) + assert dest.dtype == filter_dtype + assert detection_converter(dest).dtype == detect_dtype + + # soma_centre_value can be converted to filter/detection type + src = np.full(5, settings.soma_centre_value, dtype=in_dtype) + dest = converter(src) + # data type is larger - so value is unchanged + assert np.array_equal(src, dest) + assert dest.dtype == filter_dtype + # for detect, we convert soma_centre_value to detection_soma_centre_value + detect = detection_converter(dest) + assert detect.dtype == detect_dtype + assert np.all(detect == settings.detection_soma_centre_value) + + +@pytest.mark.parametrize("in_dtype", [np.uint64, np.int64]) +def test_bad_input_dtype(in_dtype): + """ + For this input type, to be able to fit it into our largest filtering + type - float64 we'd need to scale the data. Although `converter` can do + it, we don't support it right now (maybe as an option?) + """ + settings = DetectionSettings(plane_original_np_dtype=in_dtype) + + with pytest.raises(TypeError): + # do assert to quiet linter complaints + assert settings.filter_data_converter_func + + with pytest.raises(TypeError): + assert settings.filtering_dtype + + # detection type should be available + assert settings.detection_dtype == np.uint64 + + +def test_pickle_settings(): + settings = DetectionSettings() + + # get some properties, both cached and not cached + assert settings.filter_data_converter_func is not None + assert settings.filtering_dtype is not None + assert settings.detection_dtype is not None + assert settings.threshold_value is not None + assert settings.plane_shape is not None + + # make sure pickle works + s = pickle.dumps(settings) + assert s + + +def test_bad_ball_z_size(): + settings = DetectionSettings(ball_z_size_um=0) + with pytest.raises(ValueError): + # do something with value to quiet linter + assert settings.ball_z_size diff --git a/tests/core/test_unit/test_detect/test_filters/test_volume_filters/test_ball_filter.py b/tests/core/test_unit/test_detect/test_filters/test_volume_filters/test_ball_filter.py new file mode 100644 index 00000000..2841af0b --- /dev/null +++ b/tests/core/test_unit/test_detect/test_filters/test_volume_filters/test_ball_filter.py @@ -0,0 +1,42 @@ +import pytest + +from cellfinder.core.detect.filters.volume.ball_filter import BallFilter + +bf_kwargs = { + "plane_height": 50, + "plane_width": 50, + "ball_xy_size": 3, + "ball_z_size": 3, + "overlap_fraction": 0.5, + "threshold_value": 1, + "soma_centre_value": 1, + "tile_height": 10, + "tile_width": 10, + "dtype": "float32", +} + + +def test_filter_not_ready(): + bf = BallFilter(**bf_kwargs) + assert not bf.ready + + with pytest.raises(TypeError): + bf.get_processed_planes() + + with pytest.raises(TypeError): + bf.walk() + + +@pytest.mark.parametrize( + "sizes", [(1, 0, 0), (2, 1, 0), (3, 1, 1), (4, 2, 1), (5, 2, 2), (6, 3, 2)] +) +def test_filter_unprocessed_planes(sizes): + kernel_size, start_offset, remaining = sizes + assert kernel_size == start_offset + 1 + remaining + + kwargs = bf_kwargs.copy() + kwargs["ball_z_size"] = kernel_size + bf = BallFilter(**kwargs) + + assert bf.first_valid_plane == start_offset + assert bf.remaining_planes == remaining diff --git a/tests/core/test_unit/test_detect/test_filters/test_volume_filters/test_connected_components_labelling.py b/tests/core/test_unit/test_detect/test_filters/test_volume_filters/test_connected_components_labelling.py index a8028ede..31fd4295 100644 --- a/tests/core/test_unit/test_detect/test_filters/test_volume_filters/test_connected_components_labelling.py +++ b/tests/core/test_unit/test_detect/test_filters/test_volume_filters/test_connected_components_labelling.py @@ -1,9 +1,12 @@ +from typing import Type + import numpy as np import pytest from cellfinder.core.detect.filters.volume.structure_detection import ( CellDetector, ) +from cellfinder.core.tools.tools import get_max_possible_int_value @pytest.mark.parametrize( @@ -21,7 +24,9 @@ (258, np.uint32), ], ) -def test_connect_four_limits(linear_size, datatype): +def test_connect_four_limits( + linear_size: int, datatype: Type[np.number] +) -> None: """ Test for `connect_four` with a rectangular plane (2-to-1 length ratio) containing a checkerboard of pixels marked as cells ("structures"). @@ -36,12 +41,15 @@ def test_connect_four_limits(linear_size, datatype): * there is exactly one structure with the maximum id... * ...and that structure is in the expected place (top-right pixel) """ - SOMA_CENTRE_VALUE = np.iinfo(datatype).max - checkerboard = np.zeros((linear_size * 2, linear_size), dtype=datatype) - for i in range(linear_size * 2): - for j in range(linear_size): - if (i + j) % 2 == 0: - checkerboard[i, j] = SOMA_CENTRE_VALUE + height = linear_size * 2 + width = linear_size + # use a very large value - similar to how it is normally used + soma_centre_value = get_max_possible_int_value(datatype) + + checkerboard = np.zeros((height, width), dtype=datatype) + i = np.arange(height)[:, np.newaxis] # rows + j = np.arange(width)[np.newaxis, :] # cols + checkerboard[(i + j) % 2 == 0] = soma_centre_value actual_nonzeros = np.count_nonzero(checkerboard) expected_nonzeros = linear_size**2 @@ -49,11 +57,14 @@ def test_connect_four_limits(linear_size, datatype): actual_nonzeros == expected_nonzeros ), "Checkerboard didn't have the expected number of non-zeros" - cell_detector = CellDetector(linear_size * 2, linear_size, 0) + cell_detector = CellDetector(height, width, 0, soma_centre_value) labelled_plane = cell_detector.connect_four(checkerboard, None) one_count = np.count_nonzero(labelled_plane == 1) assert one_count == 1, "There was not exactly one pixel with label 1." assert ( - labelled_plane[linear_size * 2 - 1, linear_size - 1] == actual_nonzeros + labelled_plane[height - 1, width - 1] == actual_nonzeros ), "The last labelled pixel did not have the maximum struct id." + assert np.all( + (labelled_plane != 0) == (checkerboard != 0) + ), "Structures should be exactly where centers were marked" diff --git a/tests/core/test_unit/test_detect/test_filters/test_volume_filters/test_structure_detection.py b/tests/core/test_unit/test_detect/test_filters/test_volume_filters/test_structure_detection.py index d1e1af7a..9623fedd 100644 --- a/tests/core/test_unit/test_detect/test_filters/test_volume_filters/test_structure_detection.py +++ b/tests/core/test_unit/test_detect/test_filters/test_volume_filters/test_structure_detection.py @@ -1,6 +1,9 @@ +from typing import Dict, List, Tuple, Type + import numpy as np import pytest +from cellfinder.core.detect.filters.setup_filters import DetectionSettings from cellfinder.core.detect.filters.volume.structure_detection import ( CellDetector, Point, @@ -9,7 +12,9 @@ ) -def coords_to_points(coords_arrays): +def coords_to_points( + coords_arrays: Dict[int, np.ndarray] +) -> Dict[int, List[Point]]: # Convert from arrays to dicts coords = {} for sid in coords_arrays: @@ -26,15 +31,17 @@ def coords_to_points(coords_arrays): (np.uint32, 2**32 - 1), (np.uint16, 2**16 - 1), (np.uint8, 2**8 - 1), + (np.float32, 2**23), # mantissa determine whole max int representable + (np.float64, 2**52), ], ) -def test_get_non_zero_dtype_min(dtype, expected): +def test_get_non_zero_dtype_min(dtype: Type[np.number], expected: int) -> None: assert get_non_zero_dtype_min(np.arange(10, dtype=dtype)) == 1 assert get_non_zero_dtype_min(np.zeros(10, dtype=dtype)) == expected @pytest.fixture() -def three_d_cross(): +def three_d_cross() -> np.ndarray: return np.array( [ [[0, 0, 0], [0, 1, 0], [0, 0, 0]], @@ -45,12 +52,12 @@ def three_d_cross(): @pytest.fixture() -def structure(three_d_cross) -> np.ndarray: +def structure(three_d_cross: np.ndarray) -> np.ndarray: coords = np.array(np.where(three_d_cross)).transpose() return coords -def test_get_structure_centre(structure): +def test_get_structure_centre(structure: np.ndarray) -> None: result_point = get_structure_centre(structure) assert (result_point[0], result_point[1], result_point[2]) == ( 1, @@ -65,7 +72,7 @@ def test_get_structure_centre(structure): # Each item in the test data contains: # -# - A list of indices to mark as structure pixels (ordering: [x, z, y]) +# - A list of indices to mark as structure pixels (ordering: [z, y, x]) # - A dict of expected structure coordinates test_data = [ ( @@ -75,12 +82,12 @@ def test_get_structure_centre(structure): ), ( # Two pixels connected in a single structure along x - [(0, 0, 0), (0, 1, 0)], + [(0, 0, 0), (0, 0, 1)], {1: [Point(0, 0, 0), Point(1, 0, 0)]}, ), ( # Two pixels connected in a single structure along y - [(0, 0, 0), (0, 0, 1)], + [(0, 0, 0), (0, 1, 0)], {1: [Point(0, 0, 0), Point(0, 1, 0)]}, ), ( @@ -90,13 +97,13 @@ def test_get_structure_centre(structure): ), ( # Four pixels all connected and spread across x-y-z - [(0, 0, 0), (1, 0, 0), (1, 1, 0), (1, 0, 1)], + [(0, 0, 0), (1, 0, 0), (1, 0, 1), (1, 1, 0)], {1: [Point(0, 0, 0), Point(0, 0, 1), Point(1, 0, 1), Point(0, 1, 1)]}, ), ( # three initially disconnected pixels that then get merged # by a fourth pixel - [(1, 1, 0), (0, 1, 1), (1, 0, 1), (1, 1, 1)], + [(0, 1, 1), (1, 0, 1), (1, 1, 0), (1, 1, 1)], { 1: [ Point(1, 1, 0), @@ -108,7 +115,7 @@ def test_get_structure_centre(structure): ), ( # Three pixels in x-y plane that require structure merging - [(1, 0, 0), (0, 1, 0), (1, 1, 0)], + [(0, 0, 1), (1, 0, 0), (1, 0, 1)], { 1: [ Point(1, 0, 0), @@ -119,7 +126,7 @@ def test_get_structure_centre(structure): ), ( # Two disconnected single-pixel structures - [(0, 0, 0), (0, 2, 0)], + [(0, 0, 0), (0, 0, 2)], {1: [Point(0, 0, 0)], 2: [Point(2, 0, 0)]}, ), ( @@ -130,16 +137,77 @@ def test_get_structure_centre(structure): ] -@pytest.mark.parametrize("dtype", [np.uint8, np.uint16, np.uint32, np.uint64]) +# Due to https://github.com/numba/numba/issues/9576 we need to run np.uint64 +# before smaller sizes +# we don't use floats in cell detection, but it works +@pytest.mark.parametrize( + "dtype", + [ + np.uint64, + np.int64, + np.uint8, + np.int8, + np.uint16, + np.int16, + np.uint32, + np.int32, + np.float32, + np.float64, + ], +) @pytest.mark.parametrize("pixels,expected_coords", test_data) -def test_detection(dtype, pixels, expected_coords): - data = np.zeros((depth, width, height)).astype(dtype) - detector = CellDetector(width, height, start_z=0) +@pytest.mark.parametrize( + "detect_dtype", + [ + np.uint64, + np.int64, + np.uint8, + np.uint16, + np.uint32, + np.int8, + np.int16, + np.int32, + np.float32, + np.float64, + ], +) +def test_detection( + dtype: Type[np.number], + pixels: List[Tuple[int, int, int]], + expected_coords: Dict[int, List[Point]], + detect_dtype: Type[np.number], +) -> None: + # original dtype is the dtype of the original data. filtering_dtype + # is the data type used during ball filtering. Currently, it can be at most + # float64. So original dtype cannot support 64-bit ints because it won't + # fit in float64. + # detection_dtype is the type that must be used during detection to fit a + # count of the number of cells present + settings = DetectionSettings( + plane_original_np_dtype=dtype, detection_dtype=detect_dtype + ) + + # should raise error for (u)int64 - too big for float64 so can't filter + if dtype in (np.uint64, np.int64): + with pytest.raises(TypeError): + filtering_dtype = settings.filtering_dtype + # do something with so linter doesn't complain + assert filtering_dtype + return + # pretend we got the intensity data from filtering + data = np.zeros((depth, height, width), dtype=settings.filtering_dtype) # This is the value used by BallFilter to mark pixels - max_poss_value = np.iinfo(dtype).max for pix in pixels: - data[pix] = max_poss_value + data[pix] = settings.soma_centre_value + + # convert intensity data to values expected by detector + data = settings.detection_data_converter_func(data) + + detector = CellDetector(height, width, 0, 0) + # similar to numba issue #9576 we can't pass to init a large value once + # a 32 bit type was used for detector. So pass it with custom method + detector._set_soma(settings.detection_soma_centre_value) previous_plane = None for plane in data: @@ -147,3 +215,28 @@ def test_detection(dtype, pixels, expected_coords): coords = detector.get_structures() assert coords_to_points(coords) == expected_coords + + +def test_add_point(): + detector = CellDetector(50, 50, 0, 0) + detector.add_point(0, (5, 5, 5)) + detector.add_point(0, (6, 5, 5)) + detector.add_point(1, (7, 5, 5)) + + +def test_add_points(): + detector = CellDetector(50, 50, 0, 0) + + points = np.array([(5, 5, 5), (6, 6, 6)], dtype=np.uint32) + points2 = np.array([(7, 5, 5), (8, 6, 6)], dtype=np.uint32) + points3 = np.array([(8, 5, 5), (8, 6, 6)], dtype=np.uint32) + detector.add_points(0, points) + detector.add_points(0, points2) + detector.add_points(1, points3) + + +def test_change_plane_size(): + # check that changing plane size errors out + detector = CellDetector(50, 50, 0, 5000) + with pytest.raises(ValueError): + detector.process(np.zeros((100, 50), dtype=np.uint32), None) diff --git a/tests/core/test_unit/test_detect/test_filters/test_volume_filters/test_volume_filter.py b/tests/core/test_unit/test_detect/test_filters/test_volume_filters/test_volume_filter.py new file mode 100644 index 00000000..2d1ff70f --- /dev/null +++ b/tests/core/test_unit/test_detect/test_filters/test_volume_filters/test_volume_filter.py @@ -0,0 +1,279 @@ +import numpy as np +import pytest +import torch +from brainglobe_utils.IO.cells import get_cells_xml +from brainglobe_utils.IO.image.load import read_with_dask +from pytest_mock.plugin import MockerFixture + +from cellfinder.core.detect.detect import main +from cellfinder.core.tools.IO import fetch_pooch_directory +from cellfinder.core.tools.threading import ExecutionFailure +from cellfinder.core.tools.tools import get_max_possible_int_value + +# even though we are testing volume filter as unit test, we are running through +# main and mocking VolumeFilter because that's the easiest way to instantiate +# it with proper args and run it + + +class ExceptionTest(Exception): + pass + + +def load_pooch_dir(test_data_registry, path): + data_path = fetch_pooch_directory(test_data_registry, path) + return read_with_dask(data_path) + + +def raise_exception(*args, **kwargs): + raise ExceptionTest("Bad times") + + +def run_main_assert_exception(): + # run volume filter - it should raise the ExceptionTest via + # ExecutionFailure from thread/process + try: + # must be on cpu b/c only on cpu do we do 2d filtering in subprocess + # lots of planes so it doesn't end naturally quickly + main( + signal_array=np.zeros((500, 500, 500), dtype=np.uint16), + torch_device="cpu", + ) + assert False, "should have raised exception" + except ExecutionFailure as e: + e2 = e.__cause__ + assert type(e2) is ExceptionTest and e2.args == ("Bad times",) + + +def test_2d_filter_process_exception(mocker: MockerFixture): + # check sub-process that does 2d filter. That exception ends things clean + mocker.patch( + "cellfinder.core.detect.filters.volume.volume_filter._plane_filter", + new=raise_exception, + ) + run_main_assert_exception() + + +def test_2d_filter_feeder_thread_exception(mocker: MockerFixture): + # check data feeder thread. That exception ends things clean + from cellfinder.core.detect.filters.volume.volume_filter import ( + VolumeFilter, + ) + + mocker.patch.object( + VolumeFilter, "_feed_signal_batches", new=raise_exception + ) + run_main_assert_exception() + + +def test_2d_filter_cell_detection_thread_exception(mocker: MockerFixture): + # check cell detection thread. That exception ends things clean + from cellfinder.core.detect.filters.volume.volume_filter import ( + VolumeFilter, + ) + + mocker.patch.object( + VolumeFilter, "_run_filter_thread", new=raise_exception + ) + run_main_assert_exception() + + +def test_3d_filter_main_thread_exception(mocker: MockerFixture): + # raises exception in the _process method in the main thread - after the + # subprocess and secondary threads were spun up. This makes sure that those + # subprocess and threads don't get stuck if main thread crashes + from cellfinder.core.detect.filters.volume.volume_filter import ( + VolumeFilter, + ) + + mocker.patch.object(VolumeFilter, "_process", new=raise_exception) + with pytest.raises(ExceptionTest): + main(signal_array=np.zeros((500, 500, 500)), torch_device="cpu") + + +@pytest.mark.parametrize("batch_size", [1, 2, 3, 4]) +def test_feeder_thread_batch(batch_size: int): + # checks various batch sizes to see if there are issues + # this also tests a batch size of 3 but 5 planes. So the feeder thread + # will feed us a batch of 3 and a batch of 2. It tests that filters can + # handle unequal batch sizes + planes = [] + + def callback(z): + planes.append(z) + + main( + signal_array=np.zeros((5, 50, 50)), + torch_device="cpu", + batch_size=batch_size, + callback=callback, + ) + + assert planes == list(range(1, 4)) + + +def test_not_enough_planes(): + # checks that even if there are not enough planes for volume filtering, it + # doesn't raise errors or gets stuck + planes = [] + + def callback(z): + planes.append(z) + + main( + signal_array=np.zeros((2, 50, 50)), + torch_device="cpu", + callback=callback, + ) + + assert not planes + + +def test_filtered_plane_range(mocker: MockerFixture): + # check that even if input data is negative, filtered data is non-negative + detector = mocker.patch( + "cellfinder.core.detect.filters.volume.volume_filter.CellDetector", + autospec=True, + ) + + # input data in range (-500, 500) + data = ((np.random.random((6, 50, 50)) - 0.5) * 1000).astype(np.float32) + data[1:3, 25:30, 25:30] = 5000 + main(signal_array=data) + + calls = detector.return_value.process.call_args_list + assert len(calls) + for call in calls: + plane, *_ = call.args + # should have either zero or soma value or both + assert len(np.unique(plane)) in (1, 2) + assert np.min(plane) >= 0 + + +def test_saving_filtered_planes(tmp_path): + # check that we can save filtered planes + path = tmp_path / "save_planes" + path.mkdir() + + main( + signal_array=np.zeros((6, 50, 50)), + save_planes=True, + plane_directory=str(path), + ) + + files = [p.name for p in path.iterdir() if p.is_file()] + # we're skipping first and last plane that isn't filtered due to kernel + assert len(files) == 4 + assert set(files) == { + "plane_0002.tif", + "plane_0003.tif", + "plane_0004.tif", + "plane_0005.tif", + } + + +def test_saving_filtered_planes_no_dir(): + # asked to save but didn't provide directory + with pytest.raises(ExecutionFailure) as exc_info: + main( + signal_array=np.zeros((6, 50, 50)), + save_planes=True, + plane_directory=None, + ) + assert type(exc_info.value.__cause__) is ValueError + + +@pytest.mark.parametrize( + "signal,filtered,cells,soma_diameter,voxel_sizes,cell_tol", + [ + ( + "edge_cells_brain/signal", + "edge_cells_brain/3d_filter", + "edge_cells_brain/detected_cells.xml", + 16, + (5, 2, 2), + 0, + ), + ( + "bright_brain/signal", + "bright_brain/3d_filter", + "bright_brain/detected_cells.xml", + 30, + (5.06, 4.5, 4.5), + 1, + ), + ], +) +@pytest.mark.parametrize( + "torch_device,use_scipy", [("cpu", False), ("cpu", True), ("cuda", False)] +) +def test_3d_filtering( + signal, + filtered, + cells, + soma_diameter, + voxel_sizes, + cell_tol, + torch_device, + use_scipy, + no_free_cpus, + tmp_path, + test_data_registry, +): + # test that the full 2d/3d matches the saved data + if torch_device == "cuda" and not torch.cuda.is_available(): + pytest.skip("Cuda is not available") + + # check input data size/type is as expected + data = np.asarray(load_pooch_dir(test_data_registry, signal)) + filtered = np.asarray(load_pooch_dir(test_data_registry, filtered)) + cells = get_cells_xml(test_data_registry.fetch(cells)) + assert data.dtype == np.uint16 + assert filtered.dtype == np.uint32 + assert data.shape == (filtered.shape[0] + 2, *filtered.shape[1:]) + + path = tmp_path / "3d_filter" + path.mkdir() + cells_our = main( + signal_array=data, + voxel_sizes=voxel_sizes, + soma_diameter=soma_diameter, + max_cluster_size=100000, + ball_xy_size=6, + ball_z_size=15, + ball_overlap_fraction=0.6, + soma_spread_factor=1.4, + n_free_cpus=no_free_cpus, + log_sigma_size=0.2, + n_sds_above_mean_thresh=10, + save_planes=True, + plane_directory=str(path), + batch_size=1, + ) + + filtered_our = np.asarray(read_with_dask(str(path))) + assert filtered_our.shape == filtered.shape + assert filtered_our.dtype == np.uint16 + # we need to rescale our data because the original data saved to uint32 + # (even though it fit in uint16), so rescale the max value the soma was + # saved as, to make comparison better + filtered_our = filtered_our.astype(np.uint32) + max16 = get_max_possible_int_value(np.uint16) + max32 = get_max_possible_int_value(np.uint32) + filtered_our[filtered_our == max16] = max32 + + # we only care about the soma value as only that is used in the next step + # in cell detection, so set everything else to zero + filtered_our[filtered_our != max16] = 0 + filtered[filtered_our != max32] = 0 + + # the number of pixels per plane that are different (marked bright/not) + diff = np.sum(np.sum(filtered_our != filtered, axis=2), axis=1) + # 100% same + assert np.all(diff == 0) + + # check that the resulting cells are the same. We expect them to be the + # same, cells at different pos count as different + cells_our = set(cells_our) + cells = set(cells) + diff = len(cells_our - cells) + len(cells - cells_our) + assert diff <= cell_tol diff --git a/tests/core/test_unit/test_tools/test_threading.py b/tests/core/test_unit/test_tools/test_threading.py new file mode 100644 index 00000000..0a341068 --- /dev/null +++ b/tests/core/test_unit/test_tools/test_threading.py @@ -0,0 +1,141 @@ +import pytest + +from cellfinder.core.tools.threading import ( + EOFSignal, + ExceptionWithQueueMixIn, + ExecutionFailure, + ProcessWithException, + ThreadWithException, +) + +cls_to_test = [ThreadWithException, ProcessWithException] + + +class ExceptionTest(Exception): + pass + + +def raise_exc(*args): + raise ExceptionTest("I'm a test") + + +def raise_exc_about_self(*args): + arg_msg = "No arg" + if args and args[-1] == "7": + arg_msg = "Got 7" + + if args and isinstance(args[0], ExceptionWithQueueMixIn): + raise ExceptionTest("Got self" + arg_msg) + raise ExceptionTest("No self" + arg_msg) + + +def do_nothing(*args): + pass + + +def send_back_msg(thread: ExceptionWithQueueMixIn): + # do this single op and exit + thread.send_msg_to_mainthread(("back", thread.get_msg_from_mainthread())) + + +def send_multiple_msgs(thread: ExceptionWithQueueMixIn): + thread.send_msg_to_mainthread("hello") + thread.send_msg_to_mainthread("to") + thread.send_msg_to_mainthread("you") + + +@pytest.mark.parametrize("cls", cls_to_test) +def test_reraise_exception_in_main_thread_from_thread(cls): + # exception in thread will show up in main thread + thread = cls(target=raise_exc, args=(1, "4")) + thread.start() + + with pytest.raises(ExecutionFailure) as exc_info: + thread.get_msg_from_thread() + assert type(exc_info.value.__cause__) is ExceptionTest + assert exc_info.value.__cause__.args[0] == "I'm a test" + # thread will have exited + thread.join() + + +@pytest.mark.parametrize("cls", cls_to_test) +def test_get_eof_in_main_thread_from_thread(cls): + # we should get eof when thread exits + thread = cls(target=do_nothing, args=(1, "4")) + thread.start() + + assert thread.get_msg_from_thread() is EOFSignal + # thread will have exited + thread.join() + + +@pytest.mark.parametrize("cls", cls_to_test) +def test_get_eof_in_thread_from_main_thread(cls): + # thread should get eof when we send it + thread = cls(target=send_back_msg, pass_self=True) + thread.start() + + thread.notify_to_end_thread() + assert thread.get_msg_from_thread() == ("back", EOFSignal) + # thread will have exited + thread.join() + + +@pytest.mark.parametrize("args", [(), (55, "7")]) +@pytest.mark.parametrize("pass_self", [(True, "Got self"), (False, "No self")]) +@pytest.mark.parametrize("cls", cls_to_test) +def test_pass_self_arg_to_func(cls, pass_self, args): + # check that passing self to the function works with/without args + # without passing self, there's no way for the thread to respond + # other than by the text of an error + thread = cls( + target=raise_exc_about_self, pass_self=pass_self[0], args=args + ) + thread.start() + + arg_msg = "No arg" + if args: + arg_msg = "Got 7" + + with pytest.raises(ExecutionFailure) as exc_info: + thread.get_msg_from_thread() + assert type(exc_info.value.__cause__) is ExceptionTest + assert exc_info.value.__cause__.args[0] == pass_self[1] + arg_msg + thread.join() + + +@pytest.mark.parametrize("cls", cls_to_test) +def test_send_to_and_recv_from_thread(cls): + # tests sending to the thread and receiving a message from the thread + thread = cls(target=send_back_msg, pass_self=True) + thread.start() + + msg = thread.get_msg_from_thread(thread.send_msg_to_thread("hello")) + assert msg == ("back", "hello") + thread.join() + + +@pytest.mark.parametrize("cls", cls_to_test) +def test_get_multiple_messages(cls): + # tests getting multiple msgs from thread + thread = cls(target=send_multiple_msgs, pass_self=True) + thread.start() + + assert thread.get_msg_from_thread() == "hello" + assert thread.get_msg_from_thread() == "to" + assert thread.get_msg_from_thread() == "you" + assert thread.get_msg_from_thread() == EOFSignal + thread.join() + + +@pytest.mark.parametrize("cls", cls_to_test) +def test_skip_until_eof(cls): + # tests skipping reading everything queued until we get eof + thread = cls(target=send_multiple_msgs, pass_self=True) + thread.start() + + thread.clear_remaining() + # it knows that there are no further messages because the thread sent an + # eof to main-thread, which is the last thing thread does before exiting + assert thread._saw_eof + thread.join() diff --git a/tests/core/test_unit/test_tools/test_tools_general.py b/tests/core/test_unit/test_tools/test_tools_general.py index d3609cb4..9a3c573e 100644 --- a/tests/core/test_unit/test_tools/test_tools_general.py +++ b/tests/core/test_unit/test_tools/test_tools_general.py @@ -1,5 +1,3 @@ -import random - import numpy as np import pytest @@ -18,16 +16,195 @@ ) -def test_get_max_possible_value(): - num = random.randint(0, 100) - assert 255 == tools.get_max_possible_value(np.array(num, dtype=np.uint8)) - assert 65535 == tools.get_max_possible_value( - np.array(num, dtype=np.uint16) - ) - with pytest.raises( - ValueError, match="must be a numpy array of integer data type" - ): - tools.get_max_possible_value(np.array(num, dtype=np.float32)) +def test_inference_wrapper(): + did_run = False + + @tools.inference_wrapper + def my_func(val, other=None): + import torch + + nonlocal did_run + + assert torch.is_inference_mode_enabled() + assert val == 1 + assert other == 5 + did_run = True + + my_func(1, other=5) + assert did_run + + +# for float, the values come from the mantissa and it's the largest int value +# representable without losing significant digits +@pytest.mark.parametrize( + "dtype,value", + [ + (np.uint8, 2**8 - 1), + (np.uint16, 2**16 - 1), + (np.uint32, 2**32 - 1), + (np.uint64, 2**64 - 1), + (np.int8, 2**7 - 1), + (np.int16, 2**15 - 1), + (np.int32, 2**31 - 1), + (np.int64, 2**63 - 1), + (np.float32, 2**23), + (np.float64, 2**52), + ], +) +def test_get_max_possible_int_value(dtype, value): + assert tools.get_max_possible_int_value(dtype) == value + + +def test_get_max_possible_int_value_bad_dtype(): + with pytest.raises(ValueError): + tools.get_max_possible_int_value(np.str_) + + +# for float, the values come from the mantissa and it's the largest int value +# representable without losing significant digits +@pytest.mark.parametrize( + "dtype,value", + [ + (np.uint8, 0), + (np.uint16, 0), + (np.uint32, 0), + (np.uint64, 0), + (np.int8, -(2**7)), + (np.int16, -(2**15)), + (np.int32, -(2**31)), + (np.int64, -(2**63)), + (np.float32, -(2**23)), + (np.float64, -(2**52)), + ], +) +def test_get_min_possible_int_value(dtype, value): + assert tools.get_min_possible_int_value(dtype) == value + + +def test_get_min_possible_int_value_bad_dtype(): + with pytest.raises(ValueError): + tools.get_min_possible_int_value(np.str_) + + +@pytest.mark.parametrize( + "src_dtype", + [ + np.uint8, + np.int8, + np.uint16, + np.int16, + np.uint32, + np.int32, + np.uint64, + np.int64, + np.float32, + np.float64, + ], +) +@pytest.mark.parametrize( + "dest_dtype", [np.uint8, np.uint16, np.uint32, np.uint64] +) +def test_get_data_converter_bad_dtype_target(src_dtype, dest_dtype): + with pytest.raises(ValueError): + tools.get_data_converter(src_dtype, dest_dtype) + + +@pytest.mark.parametrize( + "src_dtype,dest_dtype", + [ + (np.uint8, np.float32), + (np.int8, np.float32), + (np.uint16, np.float32), + (np.int16, np.float32), + (np.float32, np.float32), + (np.uint8, np.float64), + (np.int8, np.float64), + (np.uint16, np.float64), + (np.int16, np.float64), + (np.uint32, np.float64), + (np.int32, np.float64), + (np.float32, np.float64), + (np.float64, np.float64), + ], +) +def test_get_data_converter_no_scaling(src_dtype, dest_dtype): + # for these, the source is smaller than dest so no need to scale because + # it'll fit directly into the dest dtype + converter = tools.get_data_converter(src_dtype, dest_dtype) + + # min value + if np.issubdtype(src_dtype, np.integer): + src_min_val = np.iinfo(src_dtype).min + else: + assert np.issubdtype(src_dtype, np.floating) + src_min_val = np.finfo(src_dtype).min + src = np.full(5, src_min_val, dtype=src_dtype) + dest = converter(src) + assert np.array_equal(src, dest) + assert dest.dtype == dest_dtype + + # other value + src = np.full(5, 10, dtype=src_dtype) + dest = converter(src) + assert np.array_equal(src, dest) + assert dest.dtype == dest_dtype + + # max value + src_max_val = tools.get_max_possible_int_value(src_dtype) + src = np.full(3, src_max_val, dtype=src_dtype) + dest = converter(src) + assert np.array_equal(src, dest) + assert dest.dtype == dest_dtype + + +@pytest.mark.parametrize( + "src_dtype,dest_dtype,divisor", + [ + (np.uint32, np.float32, (2**32 - 1) / 2**23), + (np.int32, np.float32, (2**31) / 2**23), + (np.float64, np.float32, (2**52) / 2**23), + ], +) +def test_get_data_converter_with_scaling(src_dtype, dest_dtype, divisor): + # for these, the source is larger than dest type so we need to scale by max + # value of each type, so it'll fit into the dest dtype + converter = tools.get_data_converter(src_dtype, dest_dtype) + + # min value + src_min_val = tools.get_min_possible_int_value(src_dtype) + src = np.full(5, src_min_val, dtype=src_dtype) + dest = converter(src) + assert np.allclose(src / divisor, dest) + assert dest.dtype == dest_dtype + + # other value + src = np.full(5, 10, dtype=src_dtype) + dest = converter(src) + assert np.allclose(src / divisor, dest) + assert dest.dtype == dest_dtype + + # max value + src_max_val = tools.get_max_possible_int_value(src_dtype) + src = np.full(3, src_max_val, dtype=src_dtype) + dest = converter(src) + assert np.allclose(src / divisor, dest) + assert dest.dtype == dest_dtype + + +@pytest.mark.parametrize( + "src_dtype,dest_dtype", + [ + (np.uint64, np.float32), + (np.int64, np.float32), + (np.uint64, np.float64), + (np.int64, np.float64), + ], +) +def test_get_data_converter_with_bad_scaling(src_dtype, dest_dtype): + # for these, to scale we'd need to have a type that is at least 64 bits, + # so we can scale down. But float64 is too small + with pytest.raises(ValueError): + tools.get_data_converter(src_dtype, dest_dtype) def test_union(): diff --git a/tests/data/pooch_registry.txt b/tests/data/pooch_registry.txt new file mode 100644 index 00000000..63a04222 --- /dev/null +++ b/tests/data/pooch_registry.txt @@ -0,0 +1,129 @@ +bright_brain/2d_filter/signal0001.tif b224baa0c0a84dfbe180f6579c918ef3becabe73cae0b309ca1b55b448cf0ad1 +bright_brain/2d_filter/signal0002.tif 31d930f449131237ef606b506bc4bca20a1a3c44f0081ebd9d8a5437e5e471c8 +bright_brain/2d_filter/signal0003.tif 026bd60fc0d8823df6531f2e87266bc565a65ea64249c055a01c708e45e4345d +bright_brain/2d_filter/signal0004.tif 38180851830b0a560b3c8d1209abc6675de0975693d3e9ab1038b421db92e889 +bright_brain/2d_filter/signal0005.tif 3f4bbd3407121d89d1dc120bce3d444855d444c38e8e41949c7e76f3755fdafb +bright_brain/3d_filter/plane_0002.tif dd5a8f7799f7854ed8626a5b35cc7c9611ebf52c350b0474cbebe46b1c6f7433 +bright_brain/3d_filter/plane_0003.tif bf595ac9e760d27d1a7bf4da264a65d551f7d0918a3f2b9fba80bcef8a17a064 +bright_brain/3d_filter/plane_0004.tif 64ec4f9c131545753b9ddac667b2033602ec594a9fbb10e83299de91599067ac +bright_brain/detected_cells.xml 6f8e6a316576e4b38d3b5c04d58442288403b8a470011c16d3303c12fe60e7b3 +bright_brain/peak_enhanced/signal0001.tif 6a25266b695dfd8c3cf858bdb7db93d1ddaaf91cbcb8f82dd67b3243b1fce471 +bright_brain/peak_enhanced/signal0002.tif 34d5a8457987811101d921c2553ab8ba60a731b3d3caf81d1a6b5d1fbc42bb93 +bright_brain/peak_enhanced/signal0003.tif cca9ac6e038fe608f0cb57a9a0ef990290a9c1463b1ae358290af60427e7c11b +bright_brain/peak_enhanced/signal0004.tif 8e5a35a51e8120702eade5cd1bb9d409117f81ad22f95f4d54bc24d859ed25fc +bright_brain/peak_enhanced/signal0005.tif 9e749bc6e761799a63ff77094ff0a8c1d796e665991e1108b6f9c7442aff2004 +bright_brain/signal/signal0001.tif 3fe85dd8e154447b7296daa24f5e0b5d00310f58d40726939792bb8aa9f7f4fc +bright_brain/signal/signal0002.tif 9091e4c8ebcd5a4d1d7f3719983b33a5e804aff40b11714a66f16ecc385ffe04 +bright_brain/signal/signal0003.tif 08f6d28bc6ede525c5bc708405ae7227e0897c2808ae13c3526a52e9d75daec5 +bright_brain/signal/signal0004.tif 7927520de0ab2cbd255ff9a4721edbfa20a8135706ebd863ee9ddecc1c17eb29 +bright_brain/signal/signal0005.tif b908df40dc08899b1731ee6d0ffa7274f5faae88377594a1aba997a835d20dd9 +bright_brain/tiles/signal0001.tif bf3a5478b8a0a0e7d0625641340fa7a36ea95eceb77d6cad08ccd652e2565ada +bright_brain/tiles/signal0002.tif bf3a5478b8a0a0e7d0625641340fa7a36ea95eceb77d6cad08ccd652e2565ada +bright_brain/tiles/signal0003.tif bf3a5478b8a0a0e7d0625641340fa7a36ea95eceb77d6cad08ccd652e2565ada +bright_brain/tiles/signal0004.tif bf3a5478b8a0a0e7d0625641340fa7a36ea95eceb77d6cad08ccd652e2565ada +bright_brain/tiles/signal0005.tif bf3a5478b8a0a0e7d0625641340fa7a36ea95eceb77d6cad08ccd652e2565ada +cellfinder-test-data.zip b0ef53b1530e4fa3128fcc0a752d0751909eab129d701f384fc0ea5f138c5914 +edge_cells_brain/2d_filter/signal0000.tif 4803df36346838928c2cc9afa5041cdc7f57fcac21eda55e1c3bd2f70761bd8d +edge_cells_brain/2d_filter/signal0001.tif 39a008a23e3b91e10b0dd327a229550c85260cf57fd627a758877c668dd65cba +edge_cells_brain/2d_filter/signal0002.tif 501a706b007872353b9a638b952c4499e4b7e41a421092d33b99f3ad37432eff +edge_cells_brain/2d_filter/signal0003.tif 8c3f3ada6f6d7c108f778e8a0805a9d93dbd396d28b1e8a935a701b35e3797ef +edge_cells_brain/2d_filter/signal0004.tif 81d203c47e2e1752620cc80d79047683a5a3947973b840f4afef56ec77f5e61d +edge_cells_brain/2d_filter/signal0005.tif be0759a6601155d09f1698a67cb8b1c590a1deb51c9ece2d2fe52e9ff1eb2194 +edge_cells_brain/2d_filter/signal0006.tif 41acc77bbe3b4eb3379746939f659d6ec5b7b36f5cc9970427a9cc0cd59c2e25 +edge_cells_brain/2d_filter/signal0007.tif d8516e2a5c59b950536166525e9cefe2be81ed54527f92e9bf269431f471532a +edge_cells_brain/2d_filter/signal0008.tif d59a4ac8275882fbb8b1947a07ae45716616302e35cdeb731529d1a8f521c298 +edge_cells_brain/2d_filter/signal0009.tif 2715c53774793b395cd1ee340eb3438617414d769c5f1e293fd5698588125960 +edge_cells_brain/2d_filter/signal0010.tif 525b30be7f86ae957c8e2edae1ecc958c869d21979c8452455356020df112523 +edge_cells_brain/2d_filter/signal0011.tif 83a808ac1be98329a3767213f3a4396721a33aa13cef252bfa8906b496173101 +edge_cells_brain/2d_filter/signal0012.tif 110042dfedbc8f0189ff1e9a04a5e01ae1cc4d5a0d593dc4884bdd4a7c1dbdaa +edge_cells_brain/2d_filter/signal0013.tif c71930b8e6ed5d52b0c3699553faf75da728e343f766da174ee15057ff531e43 +edge_cells_brain/2d_filter/signal0014.tif d4b20626c940a35f5995be0b5ad36557634df3af947db10a24843ff325f94d1b +edge_cells_brain/2d_filter/signal0015.tif 1f6ed840d97320421cb9cb9e48487126e428cca4135036642292f9b6572d98b8 +edge_cells_brain/2d_filter/signal0016.tif 6e975a52d7eb282ecffa7e0ed90d6c814347b6be2fde77842f80da97d1ede01c +edge_cells_brain/2d_filter/signal0017.tif 4f3ef278746323db7cf585b4c40fead0925ce7a193af6a10ed84e330a9fd16a0 +edge_cells_brain/2d_filter/signal0018.tif f07233fe5134dff9f9935f3335b847aa1bd83e44ab379e45a48906fa4c1c7403 +edge_cells_brain/2d_filter/signal0019.tif c581dc2533e01f1e1c8d937c364f37c67af83ab3f85efb02eb9b1150a2e591dd +edge_cells_brain/2d_filter/signal0020.tif 25224c7965908fa3a6229f15ff0f79db4e76e4e50d50b3813830f844a11fa529 +edge_cells_brain/3d_filter/plane_0002.tif 26a15013be4b4120a045da6209f384dc9551a11b2726c9f702cbf6ef1046f3d5 +edge_cells_brain/3d_filter/plane_0003.tif 0fd764081ea03d6a24a9ff167c6f5c648612914c4f9451a8d5be246704d9dcce +edge_cells_brain/3d_filter/plane_0004.tif 6d09059ac4b4b67d02143223cb88b25925433a9bf9feea01826979cc697555a8 +edge_cells_brain/3d_filter/plane_0005.tif 3f71392ca5ee4c4f1164f8220ede6ffb1f6c15fa611e7ab258825929d9574601 +edge_cells_brain/3d_filter/plane_0006.tif 518320cd8b56a9f88adb647614672e8d02824125e89355facc49e6a98353bd8c +edge_cells_brain/3d_filter/plane_0007.tif e0184c6791393856a6bb2c4955891992db2f66d6c82ceaf0b1ba63511363a195 +edge_cells_brain/3d_filter/plane_0008.tif 05eca0b171879e6e837325d299b98d6cac7c2134cd456b85dbc7ba64e51d5ee7 +edge_cells_brain/3d_filter/plane_0009.tif 25c398550efc467b82e89ecca8e8188082c0dbfca3e630d47b4e7a8f9cf77141 +edge_cells_brain/3d_filter/plane_0010.tif 9b02d5a99544245c5367e5f54a0daafd3f55be43bb66220903c6b6766144070f +edge_cells_brain/3d_filter/plane_0011.tif 820b6064971a73ab5c346dda0710cf665ad87d9caf75a4d00440da56a0390e93 +edge_cells_brain/3d_filter/plane_0012.tif de1633ab4807bf61370d4ef7ff13d6cd73c8241030396013f977371a9291ade3 +edge_cells_brain/3d_filter/plane_0013.tif 819003634068868c17e6b9cca71a82607047e0e7aab01c4c61fa63df7355d7db +edge_cells_brain/3d_filter/plane_0014.tif deae555248b8b03c6b7b1cd2b02213f13255645e83737decf9a920bae30d1508 +edge_cells_brain/3d_filter/plane_0015.tif 063e8370dfdf80202a052d370d86357b6f07e263139f1f9cb115968d96ca1e76 +edge_cells_brain/3d_filter/plane_0016.tif dead984c00ef4d14582b125797d2743ce774eec46242e855a58342820e47063e +edge_cells_brain/3d_filter/plane_0017.tif d890497d0349c51379add2b188bded46ef528c56e016131da257fead119681ab +edge_cells_brain/3d_filter/plane_0018.tif d75f0a2e767dfa46f0d2cc3f63c367391eac47b3076427bce688b39305d20387 +edge_cells_brain/3d_filter/plane_0019.tif e7ee64795b3f07c4cff4f8a2700ba0c1055604ce50039fe0590c900c6ece43b8 +edge_cells_brain/3d_filter/plane_0020.tif 7f94a93aa45877be81732c5be4f4d1d67bb1d4059bb6de833d60c146be621efc +edge_cells_brain/detected_cells.xml 5519a25a60bfd0d15af86eab2d5edc3c249c0afd2c6d83a9a46838ffb6fe9543 +edge_cells_brain/peak_enhanced/signal0000.tif 11dc924b3f30def59f6ca0b31a428a8cb454c163f4e6417d620659028ac544d9 +edge_cells_brain/peak_enhanced/signal0001.tif ac3605624094838c229f2d4ca2af9df9df17ba742b11d7ca684c42ce8e49557e +edge_cells_brain/peak_enhanced/signal0002.tif a7759610b69fd0e9d023231488038482d8e810b9f4c096374ea267d2b7e2c5e7 +edge_cells_brain/peak_enhanced/signal0003.tif 25d712b49e7d97368c9f924876c23b132c0b0902f899ad5150c0b4bfe24d591b +edge_cells_brain/peak_enhanced/signal0004.tif e8607d850649ccd68d8bcd0af12c9cfb98f191bb5c092b81ed9b6c87837b986e +edge_cells_brain/peak_enhanced/signal0005.tif eff9cefc56c34f4307b9c94af90bcd6e41ef03d45855018eef9e40ab0bbdf116 +edge_cells_brain/peak_enhanced/signal0006.tif e5a058f759640b198069526772f972b9db0fb77d5502b5473a573d2f1a8fb746 +edge_cells_brain/peak_enhanced/signal0007.tif 44f7a1658c8c61f8e5b680ebb03d1f50779f0f22cf0066fb1b938764cd7353c2 +edge_cells_brain/peak_enhanced/signal0008.tif 5439cfeb6a6f17e6231bb532791a4d37c9fa00ea1c32b78184e8708938adf233 +edge_cells_brain/peak_enhanced/signal0009.tif 46c460f41e2ead3d694c9646ed4c21e6db57cae4fc3fd1fd4d411385877fdec7 +edge_cells_brain/peak_enhanced/signal0010.tif 0e74e27ef728f1c7b6db97172df8961917d9dba8c87cef045b7f2f523eec4da2 +edge_cells_brain/peak_enhanced/signal0011.tif 97495a61cfabec5df5a144305b394c306d56063307a911c7024d1b8e40089663 +edge_cells_brain/peak_enhanced/signal0012.tif 9809f14393cae19d7815c16370265e318ab13ea5188e075b73295a4968f071b5 +edge_cells_brain/peak_enhanced/signal0013.tif 136b1dd0b90be087fe87c1cf0caef02f136d54dc642aa91fc03a96944a113a0b +edge_cells_brain/peak_enhanced/signal0014.tif a74e9bed032952209bdffdd9c351dff8f0bb46b4272999dd14822c254e824f2c +edge_cells_brain/peak_enhanced/signal0015.tif eb675b9cabf241b2209473e62d84011349a997677a7e32b1d4a59b9f95229364 +edge_cells_brain/peak_enhanced/signal0016.tif 0df886c3fb7cc9a9dcd0c5803629ccd57125e5c9c574bca7445222cbb14d9be7 +edge_cells_brain/peak_enhanced/signal0017.tif 675125f17b0a4784aa5f56240abad9c155831f8fa3be0c6be13bd29cd1323a90 +edge_cells_brain/peak_enhanced/signal0018.tif be3df5edb2280eb69e05e94d7864d36635e929744667ae10e608f33b02f4363d +edge_cells_brain/peak_enhanced/signal0019.tif 90ea965719aa7ad0f4a6480ab1e860d58ddc1a6bd6441365f6edf6f1e76edfab +edge_cells_brain/peak_enhanced/signal0020.tif 7ba5bae6b3b8845cbbe0726da8d28138702e681167f6a99d050b063c4a271a23 +edge_cells_brain/signal/signal0000.tif 8214694de0e41dc6f9a456cc335cff60091129ebe700a1fe6f3a38ebbb02cc28 +edge_cells_brain/signal/signal0001.tif 6b7dec437f06c7a66f70d2c16dc8d5ec0f039a11b68a458e216574d281238d88 +edge_cells_brain/signal/signal0002.tif 4ebf6e8b96a970a64816210f9270fb19f5055cfe7fe6b2c3debd78485853c4b3 +edge_cells_brain/signal/signal0003.tif 920e5b95386d01766b94b357eb7e2906d2c6f1d27eee4e4765edb7455adcf91c +edge_cells_brain/signal/signal0004.tif 90b84ea67c92d086f993588649dd6570a2baa46de6ceaf5aa64213673c922bbd +edge_cells_brain/signal/signal0005.tif a6072c68f0636f4f7a34ada3ff3fad43105d2c7dbcf831aca6662a8b358a4d90 +edge_cells_brain/signal/signal0006.tif 87773e1f7bbbfad1324e5f83214e7339119b2e9b244a34d18f5f4211fd25978e +edge_cells_brain/signal/signal0007.tif 2a91b8ba0128314753542a7890438b0805b039c3b74de80f8d967e4aa45f31e3 +edge_cells_brain/signal/signal0008.tif e2883d175bebc48f1cf7a82690f709d379715e471ccab740205040a203c3e88e +edge_cells_brain/signal/signal0009.tif 854398ddc2de9f1f43256f935bb4bb2d3b36726ed615fedc045050adb618dda2 +edge_cells_brain/signal/signal0010.tif 8ccb209331784fe4aa8f43eb3174e68fee8abdc0d29acbb388ddf9b8236d8a31 +edge_cells_brain/signal/signal0011.tif 883722f089d5318881a2767ca4717857e52ccf0005e4535e09ed24de9ab9f73b +edge_cells_brain/signal/signal0012.tif 127d14bb8e1f7a721d23f02f05d949345d78a1471e41270c8ad0cf012454fb95 +edge_cells_brain/signal/signal0013.tif b4cd368115d17048c81934dcb310e43663cc43ea14535516e75496e4622cee1d +edge_cells_brain/signal/signal0014.tif 28c9d9dc3db609d48c69271f3755bcc5fbe3b870a8939b2c4685e7d8f3dfa8dd +edge_cells_brain/signal/signal0015.tif 2388045546ed504281a190eddbe1a362d0f4f5803cc4cd3bac92a8e10ee5c7d6 +edge_cells_brain/signal/signal0016.tif 96f834737b8be786db9aca7edbdbae04d11c8f6c4deea51616a8567f39d1d03b +edge_cells_brain/signal/signal0017.tif e79d979915c85598d1fc64e9d19c5a22f575a8c0ab7ae402f9da070027a19af9 +edge_cells_brain/signal/signal0018.tif af5d97c6365694fcff26bb04a37bab2c008b6d3ccacaaedf17ce4d003b289aec +edge_cells_brain/signal/signal0019.tif bb798c254f2627e26ec53b110ad572e0494eddb06ce79d48de7267cb6e629f89 +edge_cells_brain/signal/signal0020.tif edcc27ab82449b3d11202a8866096e885bacab2dd63471fdb8e142293a9a99e8 +edge_cells_brain/tiles/signal0000.tif ea8bd86e6c8ee62e7de84a1f6792ba7d2174d3dffe9d1b578fbbde9c71dacc8c +edge_cells_brain/tiles/signal0001.tif aa757c580996cd69264617ce3b2b7ac0a1173bac78fe5c5bdfbd7e61f633a764 +edge_cells_brain/tiles/signal0002.tif ea8bd86e6c8ee62e7de84a1f6792ba7d2174d3dffe9d1b578fbbde9c71dacc8c +edge_cells_brain/tiles/signal0003.tif aa757c580996cd69264617ce3b2b7ac0a1173bac78fe5c5bdfbd7e61f633a764 +edge_cells_brain/tiles/signal0004.tif e0aaa58adcf0f13402494ce34112f7b937bf9421a1e0064856f57946e0267a02 +edge_cells_brain/tiles/signal0005.tif e0aaa58adcf0f13402494ce34112f7b937bf9421a1e0064856f57946e0267a02 +edge_cells_brain/tiles/signal0006.tif e0aaa58adcf0f13402494ce34112f7b937bf9421a1e0064856f57946e0267a02 +edge_cells_brain/tiles/signal0007.tif f382330f73c2625e84bc36a3faee169a074ed8e59db836fcbbd236061f3888b4 +edge_cells_brain/tiles/signal0008.tif e0aaa58adcf0f13402494ce34112f7b937bf9421a1e0064856f57946e0267a02 +edge_cells_brain/tiles/signal0009.tif e0aaa58adcf0f13402494ce34112f7b937bf9421a1e0064856f57946e0267a02 +edge_cells_brain/tiles/signal0010.tif d3560fff32b42f8d7baa32b2746535137675ade1bb3e726d8ef890e321daf8cc +edge_cells_brain/tiles/signal0011.tif d3560fff32b42f8d7baa32b2746535137675ade1bb3e726d8ef890e321daf8cc +edge_cells_brain/tiles/signal0012.tif d3560fff32b42f8d7baa32b2746535137675ade1bb3e726d8ef890e321daf8cc +edge_cells_brain/tiles/signal0013.tif d3560fff32b42f8d7baa32b2746535137675ade1bb3e726d8ef890e321daf8cc +edge_cells_brain/tiles/signal0014.tif d3560fff32b42f8d7baa32b2746535137675ade1bb3e726d8ef890e321daf8cc +edge_cells_brain/tiles/signal0015.tif e0aaa58adcf0f13402494ce34112f7b937bf9421a1e0064856f57946e0267a02 +edge_cells_brain/tiles/signal0016.tif bc89f955fb25b3f536246912eaac757c7c3db4d9eddfdd5cb9b546800b782c2e +edge_cells_brain/tiles/signal0017.tif b0a9789d16458fcfc92166c4363fe48859311247939587acb1d04d2885a36476 +edge_cells_brain/tiles/signal0018.tif 5307038034e5d4ca98c33557bc5ed9ef5ccb12672934e5387ff7e1d236f5cac7 +edge_cells_brain/tiles/signal0019.tif 40c4474b4ec25a8c68ac20f6649e7296af058e1bd854585ef7d71facb11a37e0 +edge_cells_brain/tiles/signal0020.tif b0a9789d16458fcfc92166c4363fe48859311247939587acb1d04d2885a36476