Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move detection (2d/3d filtering, structure splitting) to PyTorch #440

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .codecov.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ flags:
numba:
paths:
- cellfinder/core/detect/filters/plane/tile_walker.py
- cellfinder/core/detect/filters/plane/classical_filter.py
- cellfinder/core/detect/filters/plane/plane_filter.py
- cellfinder/core/detect/filters/volume/ball_filter.py
- cellfinder/core/detect/filters/volume/structure_detection.py
carryforward: true
30 changes: 26 additions & 4 deletions .github/workflows/test_and_deploy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,14 @@ jobs:
test:
needs: [linting, manifest]
name: Run package tests
timeout-minutes: 60
timeout-minutes: 120
runs-on: ${{ matrix.os }}
env:
KERAS_BACKEND: torch
CELLFINDER_TEST_DEVICE: cpu
# pooch cache dir
BRAINGLOBE_TEST_DATA_DIR: "~/.pooch_cache"

strategy:
matrix:
# Run all supported Python versions on linux
Expand All @@ -53,6 +56,14 @@ jobs:
python-version: "3.12"

steps:
- uses: actions/checkout@v4
- name: Cache pooch data
uses: actions/cache@v4
with:
path: "~/.pooch_cache"
# hash on conftest in case url changes
key: ${{ runner.os }}-${{ matrix.python-version }}-${{ hashFiles('**/pooch_registry.txt') }}
# Cache the tensorflow model so we don't have to remake it every time
- name: Cache brainglobe directory
uses: actions/cache@v3
with:
Expand All @@ -75,19 +86,30 @@ jobs:
test_numba_disabled:
needs: [linting, manifest]
name: Run tests with numba disabled
timeout-minutes: 60
timeout-minutes: 120
runs-on: ubuntu-latest
env:
NUMBA_DISABLE_JIT: "1"
NUMBA_DISABLE_JIT: "1"
PYTORCH_JIT: "0"
# pooch cache dir
BRAINGLOBE_TEST_DATA_DIR: "~/.pooch_cache"

steps:
- uses: actions/checkout@v4
- name: Cache brainglobe directory
uses: actions/cache@v3
with:
path: | # ensure we don't cache any interrupted atlas download and extraction, if e.g. we cancel the workflow manually
~/.brainglobe
!~/.brainglobe/atlas.tar.gz
key: brainglobe

- name: Cache pooch data
uses: actions/cache@v4
with:
path: "~/.pooch_cache"
key: ${{ runner.os }}-3.10-${{ hashFiles('**/pooch_registry.txt') }}

# Setup pyqt libraries
- name: Setup qtpy libraries
uses: tlambert03/setup-qt-libs@v1
Expand All @@ -105,7 +127,7 @@ jobs:
test_brainmapper_cli:
needs: [linting, manifest]
name: Run brainmapper tests to check for breakages
timeout-minutes: 60
timeout-minutes: 120
runs-on: ubuntu-latest
env:
KERAS_BACKEND: torch
Expand Down
86 changes: 86 additions & 0 deletions benchmarks/benchmark_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from pathlib import Path

import pooch
import torch
from torch.profiler import ProfilerActivity, profile
from torch.utils.benchmark import Compare, Timer

from cellfinder.core.tools.IO import fetch_pooch_directory


def get_test_data_path(path):
"""
Create a test data registry for BrainGlobe.

Returns:
pooch.Pooch: The test data registry object.

"""
registry = pooch.create(
path=pooch.os_cache("brainglobe_test_data"),
base_url="https://gin.g-node.org/BrainGlobe/test-data/raw/master/cellfinder/",
env="BRAINGLOBE_TEST_DATA_DIR",
)

registry.load_registry(
Path(__file__).parent.parent / "tests" / "data" / "pooch_registry.txt"
)

return fetch_pooch_directory(registry, path)


def time_filters(repeat, run, run_args, label):
timer = Timer(
stmt="run(*args)",
globals={"run": run, "args": run_args},
label=label,
num_threads=4,
description="", # must be not None due to pytorch bug
)
return timer.timeit(number=repeat)


def compare_results(*results):
# prints the results of all the timed tests
compare = Compare(results)
compare.trim_significant_figures()
compare.colorize()
compare.print()


def profile_cpu(repeat, run, run_args):
with profile(
activities=[ProfilerActivity.CPU],
record_shapes=True,
profile_memory=True,
with_stack=True,
with_modules=True,
) as prof:
for _ in range(repeat):
run(*run_args)

print(
prof.key_averages(group_by_stack_n=1).table(
sort_by="self_cpu_time_total", row_limit=20
)
)


def profile_cuda(repeat, run, run_args):
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
record_shapes=True,
profile_memory=True,
with_stack=True,
with_modules=True,
) as prof:
for _ in range(repeat):
run(*run_args)
# make sure it's fully done filtering
torch.cuda.synchronize("cuda")

print(
prof.key_averages(group_by_stack_n=1).table(
sort_by="self_cuda_time_total", row_limit=20
)
)
144 changes: 124 additions & 20 deletions benchmarks/filter_2d.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,131 @@
import os
import sys

sys.path.append(os.path.dirname(__file__))

import numpy as np
from pyinstrument import Profiler
import torch
from benchmark_tools import (
compare_results,
get_test_data_path,
profile_cpu,
profile_cuda,
time_filters,
)
from brainglobe_utils.IO.image.load import read_with_dask

from cellfinder.core.detect.filters.plane import TileProcessor
from cellfinder.core.detect.filters.setup_filters import setup_tile_filtering
from cellfinder.core.detect.filters.setup_filters import DetectionSettings

# Use random 16-bit integer data for signal plane
shape = (10000, 10000)

signal_array_plane = np.random.randint(
low=0, high=65536, size=shape, dtype=np.uint16
)
def setup_filter(
signal_path,
batch_size=1,
num_z=None,
torch_device="cpu",
dtype=np.uint16,
use_scipy=False,
):
signal_array = read_with_dask(signal_path)
num_z = num_z or len(signal_array)
signal_array = np.asarray(signal_array[:num_z]).astype(dtype)
shape = signal_array.shape

settings = DetectionSettings(
plane_original_np_dtype=dtype,
plane_shape=shape[1:],
torch_device=torch_device,
voxel_sizes=(5.06, 4.5, 4.5),
soma_diameter_um=30,
ball_xy_size_um=6,
ball_z_size_um=15,
)
signal_array = settings.filter_data_converter_func(signal_array)
signal_array = torch.from_numpy(signal_array).to(torch_device)

tile_processor = TileProcessor(
plane_shape=shape[1:],
clipping_value=settings.clipping_value,
threshold_value=settings.threshold_value,
soma_diameter=settings.soma_diameter,
log_sigma_size=settings.log_sigma_size,
n_sds_above_mean_thresh=settings.n_sds_above_mean_thresh,
torch_device=torch_device,
dtype=settings.filtering_dtype.__name__,
use_scipy=use_scipy,
)

return tile_processor, signal_array, batch_size


def run_filter(tile_processor, signal_array, batch_size):
for i in range(0, len(signal_array), batch_size):
tile_processor.get_tile_mask(signal_array[i : i + batch_size])


clipping_value, threshold_value = setup_tile_filtering(signal_array_plane)
tile_processor = TileProcessor(
clipping_value=clipping_value,
threshold_value=threshold_value,
soma_diameter=16,
log_sigma_size=0.2,
n_sds_above_mean_thresh=10,
)
if __name__ == "__main__":
profiler = Profiler()
profiler.start()
plane, tiles = tile_processor.get_tile_mask(signal_array_plane)
profiler.stop()
profiler.print(show_all=True)
with torch.inference_mode(True):
n = 5
batch_size = 2
signal_path = get_test_data_path("bright_brain/signal")

compare_results(
time_filters(
n,
run_filter,
setup_filter(
signal_path,
batch_size=batch_size,
torch_device="cpu",
use_scipy=False,
),
"cpu-no_scipy",
),
time_filters(
n,
run_filter,
setup_filter(
signal_path,
batch_size=batch_size,
torch_device="cpu",
use_scipy=True,
),
"cpu-scipy",
),
time_filters(
n,
run_filter,
setup_filter(
signal_path, batch_size=batch_size, torch_device="cuda"
),
"cuda",
),
)

profile_cpu(
n,
run_filter,
setup_filter(
signal_path,
batch_size=batch_size,
torch_device="cpu",
use_scipy=False,
),
)
profile_cpu(
n,
run_filter,
setup_filter(
signal_path,
batch_size=batch_size,
torch_device="cpu",
use_scipy=True,
),
)
profile_cuda(
n,
run_filter,
setup_filter(
signal_path, batch_size=batch_size, torch_device="cuda"
),
)
Loading
Loading