Skip to content
This repository has been archived by the owner on Jan 3, 2024. It is now read-only.

Commit

Permalink
Merge pull request #121 from brainglobe/benchmarks
Browse files Browse the repository at this point in the history
Merge benchmarking and numba work into main
  • Loading branch information
dstansby authored Mar 31, 2023
2 parents 5b25c8e + 3c0bfb2 commit c246be9
Show file tree
Hide file tree
Showing 15 changed files with 859 additions and 764 deletions.
1 change: 1 addition & 0 deletions .github/workflows/test_and_deploy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ on:
tags:
- '*'
pull_request:
branches: [ '*' ]

jobs:
linting:
Expand Down
27 changes: 27 additions & 0 deletions benchmarks/filter_2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import numpy as np
from pyinstrument import Profiler

from cellfinder_core.detect.filters.plane import TileProcessor
from cellfinder_core.detect.filters.setup_filters import setup_tile_filtering

# Use random 16-bit integer data for signal plane
shape = (10000, 10000)

signal_array_plane = np.random.randint(
low=0, high=65536, size=shape, dtype=np.uint16
)

clipping_value, threshold_value = setup_tile_filtering(signal_array_plane)
tile_processor = TileProcessor(
clipping_value=clipping_value,
threshold_value=threshold_value,
soma_diameter=16,
log_sigma_size=0.2,
n_sds_above_mean_thresh=10,
)
if __name__ == "__main__":
profiler = Profiler()
profiler.start()
plane, tiles = tile_processor.get_tile_mask(signal_array_plane)
profiler.stop()
profiler.print(show_all=True)
50 changes: 50 additions & 0 deletions benchmarks/filter_3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import numpy as np
from pyinstrument import Profiler

from cellfinder_core.detect.filters.volume.volume_filter import VolumeFilter

# Use random data for signal data
ball_z_size = 3


def gen_signal_array(ny, nx):
shape = (ball_z_size, ny, nx)
return np.random.randint(low=0, high=65536, size=shape, dtype=np.uint16)


signal_array = gen_signal_array(667, 510)

soma_diameter = 8
setup_params = (
signal_array[0, :, :].T,
soma_diameter,
3, # ball_xy_size,
ball_z_size,
0.6, # ball_overlap_fraction,
0, # start_plane,
)

mp_3d_filter = VolumeFilter(
soma_diameter=soma_diameter,
setup_params=setup_params,
planes_paths_range=signal_array,
)

# Use random data for mask data
mask = np.random.randint(low=0, high=2, size=(42, 32), dtype=np.uint8)

# Fill up the 3D filter with planes
for plane in signal_array:
mp_3d_filter.ball_filter.append(plane, mask)
if mp_3d_filter.ball_filter.ready:
break

# Run the 3D filter
profiler = Profiler()
profiler.start()
for i in range(10):
# Repeat same filter 10 times to increase runtime
mp_3d_filter._run_filter()

profiler.stop()
profiler.print(show_all=True)
9 changes: 6 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ dependencies = [
"fancylog>=0.0.7",
"imlib>=0.0.26",
"natsort",
"numba",
"numpy",
"scikit-image",
"scikit-learn",
Expand All @@ -44,6 +45,7 @@ dev = [
"black",
"gitpython",
"pre-commit",
"pyinstrument",
"pytest",
"pytest-cov",
"pytest-timeout",
Expand All @@ -59,7 +61,6 @@ requires = [
"setuptools>=45",
"wheel",
"setuptools_scm[toml]>=6.2",
"cython",
]
build-backend = 'setuptools.build_meta'

Expand All @@ -80,13 +81,15 @@ filterwarnings = [
"error",
# Raised by tensorflow; should be removed when tensorflow 2.12.0 is released. Fix is:
# https://github.com/tensorflow/tensorflow/commit/b23c5750c9f35a87872793eef7c56e74ec55d4a7
"ignore:`np.bool8` is a deprecated alias for `np.bool_`"
"ignore:`np.bool8` is a deprecated alias for `np.bool_`",
# See https://github.com/numba/numba/issues/8676
"ignore:.*:numba.core.errors.NumbaTypeSafetyWarning"
]
markers = [
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
"serial",
]
log_level = "DEBUG"
log_level = "WARNING"

[tool.black]
target-version = ['py38', 'py39', 'py310']
Expand Down
32 changes: 0 additions & 32 deletions setup.py

This file was deleted.

130 changes: 130 additions & 0 deletions src/cellfinder_core/detect/filters/plane/base_tile_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
from typing import Dict

import numpy as np
from numba import jit


def get_biggest_structure(sizes):
result = 0
for val in sizes:
if val > result:
result = val
return result


class BaseTileFilter:
def __init__(self, out_of_brain_intensity_threshold=100):
"""
:param int out_of_brain_intensity_threshold: Set to 0 to disable
"""
self.out_of_brain_intensity_threshold = (
out_of_brain_intensity_threshold
)
self.current_threshold = -1
self.keep = True
self.size_analyser = SizeAnalyser()

def set_tile(self, tile):
raise NotImplementedError

def get_tile(self):
raise NotImplementedError

def get_structures(self):
struct_sizes = []
self.size_analyser.process(self._tile, self.current_threshold)
struct_sizes = self.size_analyser.get_sizes()
return get_biggest_structure(struct_sizes), struct_sizes.size()


@jit
def is_low_average(tile: np.ndarray, threshold: float) -> bool:
"""
Return `True` if the average value of *tile* is below *threshold*.
"""
avg = np.mean(tile)
return avg < threshold


class OutOfBrainTileFilter(BaseTileFilter):
def set_tile(self, tile):
self._tile = tile

def get_tile(self):
return self._tile


class SizeAnalyser:
obsolete_ids: Dict[int, int] = {}
struct_sizes: Dict[int, int] = {}

def process(self, tile, threshold):
tile = tile.copy()
self.clear_maps()

last_structure_id = 1

for y in range(tile.shape[1]):
for x in range(tile.shape[0]):
# default struct_id to 0 so that it is not counted as
# structure in next iterations
id_west = id_north = struct_id = 0
if tile[x, y] >= threshold:
# If in bounds look to neighbours
if x > 0:
id_west = tile[x - 1, y]
if y > 0:
id_north = tile[x, y - 1]

id_west = self.sanitise_id(id_west)
id_north = self.sanitise_id(id_north)

if id_west != 0:
if id_north != 0 and id_north != id_west:
struct_id = self.merge_structures(
id_west, id_north
)
else:
struct_id = id_west
elif id_north != 0:
struct_id = id_north
else: # no neighbours, create new structure
struct_id = last_structure_id
self.struct_sizes[last_structure_id] = 0
last_structure_id += 1

self.struct_sizes[struct_id] += 1

tile[x, y] = struct_id

def get_sizes(self):
for iterator_pair in self.struct_sizes:
self.struct_sizes.push_back(iterator_pair.second)
return self.struct_sizes

def clear_maps(self):
self.obsolete_ids.clear()
self.struct_sizes.clear()

def sanitise_id(self, s_id):
while self.obsolete_ids.count(
s_id
): # walk up the chain of obsolescence
s_id = self.obsolete_ids[s_id]
return s_id

# id1 and id2 must be valid struct IDs (>0)!
def merge_structures(self, id1, id2):
# ensure id1 is the smaller of the two values
if id2 < id1:
tmp_id = id1 # swap
id1 = id2
id2 = tmp_id

self.struct_sizes[id1] += self.struct_sizes[id2]
self.struct_sizes.erase(id2)

self.obsolete_ids[id2] = id1

return id1
Loading

0 comments on commit c246be9

Please sign in to comment.