This repository has been archived by the owner on Jan 3, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 16
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #121 from brainglobe/benchmarks
Merge benchmarking and numba work into main
- Loading branch information
Showing
15 changed files
with
859 additions
and
764 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,7 @@ on: | |
tags: | ||
- '*' | ||
pull_request: | ||
branches: [ '*' ] | ||
|
||
jobs: | ||
linting: | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
import numpy as np | ||
from pyinstrument import Profiler | ||
|
||
from cellfinder_core.detect.filters.plane import TileProcessor | ||
from cellfinder_core.detect.filters.setup_filters import setup_tile_filtering | ||
|
||
# Use random 16-bit integer data for signal plane | ||
shape = (10000, 10000) | ||
|
||
signal_array_plane = np.random.randint( | ||
low=0, high=65536, size=shape, dtype=np.uint16 | ||
) | ||
|
||
clipping_value, threshold_value = setup_tile_filtering(signal_array_plane) | ||
tile_processor = TileProcessor( | ||
clipping_value=clipping_value, | ||
threshold_value=threshold_value, | ||
soma_diameter=16, | ||
log_sigma_size=0.2, | ||
n_sds_above_mean_thresh=10, | ||
) | ||
if __name__ == "__main__": | ||
profiler = Profiler() | ||
profiler.start() | ||
plane, tiles = tile_processor.get_tile_mask(signal_array_plane) | ||
profiler.stop() | ||
profiler.print(show_all=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
import numpy as np | ||
from pyinstrument import Profiler | ||
|
||
from cellfinder_core.detect.filters.volume.volume_filter import VolumeFilter | ||
|
||
# Use random data for signal data | ||
ball_z_size = 3 | ||
|
||
|
||
def gen_signal_array(ny, nx): | ||
shape = (ball_z_size, ny, nx) | ||
return np.random.randint(low=0, high=65536, size=shape, dtype=np.uint16) | ||
|
||
|
||
signal_array = gen_signal_array(667, 510) | ||
|
||
soma_diameter = 8 | ||
setup_params = ( | ||
signal_array[0, :, :].T, | ||
soma_diameter, | ||
3, # ball_xy_size, | ||
ball_z_size, | ||
0.6, # ball_overlap_fraction, | ||
0, # start_plane, | ||
) | ||
|
||
mp_3d_filter = VolumeFilter( | ||
soma_diameter=soma_diameter, | ||
setup_params=setup_params, | ||
planes_paths_range=signal_array, | ||
) | ||
|
||
# Use random data for mask data | ||
mask = np.random.randint(low=0, high=2, size=(42, 32), dtype=np.uint8) | ||
|
||
# Fill up the 3D filter with planes | ||
for plane in signal_array: | ||
mp_3d_filter.ball_filter.append(plane, mask) | ||
if mp_3d_filter.ball_filter.ready: | ||
break | ||
|
||
# Run the 3D filter | ||
profiler = Profiler() | ||
profiler.start() | ||
for i in range(10): | ||
# Repeat same filter 10 times to increase runtime | ||
mp_3d_filter._run_filter() | ||
|
||
profiler.stop() | ||
profiler.print(show_all=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
130 changes: 130 additions & 0 deletions
130
src/cellfinder_core/detect/filters/plane/base_tile_filter.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
from typing import Dict | ||
|
||
import numpy as np | ||
from numba import jit | ||
|
||
|
||
def get_biggest_structure(sizes): | ||
result = 0 | ||
for val in sizes: | ||
if val > result: | ||
result = val | ||
return result | ||
|
||
|
||
class BaseTileFilter: | ||
def __init__(self, out_of_brain_intensity_threshold=100): | ||
""" | ||
:param int out_of_brain_intensity_threshold: Set to 0 to disable | ||
""" | ||
self.out_of_brain_intensity_threshold = ( | ||
out_of_brain_intensity_threshold | ||
) | ||
self.current_threshold = -1 | ||
self.keep = True | ||
self.size_analyser = SizeAnalyser() | ||
|
||
def set_tile(self, tile): | ||
raise NotImplementedError | ||
|
||
def get_tile(self): | ||
raise NotImplementedError | ||
|
||
def get_structures(self): | ||
struct_sizes = [] | ||
self.size_analyser.process(self._tile, self.current_threshold) | ||
struct_sizes = self.size_analyser.get_sizes() | ||
return get_biggest_structure(struct_sizes), struct_sizes.size() | ||
|
||
|
||
@jit | ||
def is_low_average(tile: np.ndarray, threshold: float) -> bool: | ||
""" | ||
Return `True` if the average value of *tile* is below *threshold*. | ||
""" | ||
avg = np.mean(tile) | ||
return avg < threshold | ||
|
||
|
||
class OutOfBrainTileFilter(BaseTileFilter): | ||
def set_tile(self, tile): | ||
self._tile = tile | ||
|
||
def get_tile(self): | ||
return self._tile | ||
|
||
|
||
class SizeAnalyser: | ||
obsolete_ids: Dict[int, int] = {} | ||
struct_sizes: Dict[int, int] = {} | ||
|
||
def process(self, tile, threshold): | ||
tile = tile.copy() | ||
self.clear_maps() | ||
|
||
last_structure_id = 1 | ||
|
||
for y in range(tile.shape[1]): | ||
for x in range(tile.shape[0]): | ||
# default struct_id to 0 so that it is not counted as | ||
# structure in next iterations | ||
id_west = id_north = struct_id = 0 | ||
if tile[x, y] >= threshold: | ||
# If in bounds look to neighbours | ||
if x > 0: | ||
id_west = tile[x - 1, y] | ||
if y > 0: | ||
id_north = tile[x, y - 1] | ||
|
||
id_west = self.sanitise_id(id_west) | ||
id_north = self.sanitise_id(id_north) | ||
|
||
if id_west != 0: | ||
if id_north != 0 and id_north != id_west: | ||
struct_id = self.merge_structures( | ||
id_west, id_north | ||
) | ||
else: | ||
struct_id = id_west | ||
elif id_north != 0: | ||
struct_id = id_north | ||
else: # no neighbours, create new structure | ||
struct_id = last_structure_id | ||
self.struct_sizes[last_structure_id] = 0 | ||
last_structure_id += 1 | ||
|
||
self.struct_sizes[struct_id] += 1 | ||
|
||
tile[x, y] = struct_id | ||
|
||
def get_sizes(self): | ||
for iterator_pair in self.struct_sizes: | ||
self.struct_sizes.push_back(iterator_pair.second) | ||
return self.struct_sizes | ||
|
||
def clear_maps(self): | ||
self.obsolete_ids.clear() | ||
self.struct_sizes.clear() | ||
|
||
def sanitise_id(self, s_id): | ||
while self.obsolete_ids.count( | ||
s_id | ||
): # walk up the chain of obsolescence | ||
s_id = self.obsolete_ids[s_id] | ||
return s_id | ||
|
||
# id1 and id2 must be valid struct IDs (>0)! | ||
def merge_structures(self, id1, id2): | ||
# ensure id1 is the smaller of the two values | ||
if id2 < id1: | ||
tmp_id = id1 # swap | ||
id1 = id2 | ||
id2 = tmp_id | ||
|
||
self.struct_sizes[id1] += self.struct_sizes[id2] | ||
self.struct_sizes.erase(id2) | ||
|
||
self.obsolete_ids[id2] = id1 | ||
|
||
return id1 |
Oops, something went wrong.