Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move detection (2d/3d filtering, structure splitting) to PyTorch #440

Merged
merged 18 commits into from
Oct 31, 2024

Conversation

matham
Copy link
Contributor

@matham matham commented Jun 13, 2024

Description

What is this PR

  • Bug fix
  • Addition of a new feature
  • Other

Why is this PR needed?

When I was testing out cell detection, it was too slow for our usecase where ideally we play with detection parameters until all cells are detected and then run it through classification. I was used to Imaris where cell detection took like a minute (probably simple thresholding?) So I was looking to speed it up so detection was faster.

What does this PR do?

This PR moves the 2d and 3d filtering used in detection and structure splitting to PyTorch. The benefits is that it's a lot faster on GPU, and a bit faster on CPU. While these are significant changes, I tested with a couple of brains, including one that is very bright and noisy, and achieved parity with main. Those tests are included in the PR - there is generally very thorough testing with close to 100% coverage.

While there are a lot of changes, there is a few core sources of change that propagate; unifying options that were hardcoded/in various places, translating the filtering code to pytorch, and changing the threading/sub-processing and data loading to be more suited for pytorch. If you follow these core concerns, the changes should make sense.

I realize it's a giant PR, but I'll help with review in whatever way you need - I couldn't split it up because it all seems to goe together, especially to get the proper performance from pytorch.

Performance

First some performance numbers. I measured this on 1555x3222x3848 brain. On CPU and GPU. Each row is a computer identified by the number of CPU cores (e.g. 72) and RTX6000 is the GPU on that computer.

The columns are the current main branch (before this fix #435), CPU means using pytorch on CPU (batch size of 4), CPU+scipy means scipy (see later) for the 2d filtering and pytorch CPU for 3d filtering (batch=4), and CUDA means pytorch on the GPU, with different batch sizes.

Time is in seconds and includes just detection, not structure splitting.

Computer Main branch CPU CPU+scipy* CUDA - CUDA -
Time Time Time Batch Time Batch Time
72 (RTX6000) 963 723 582 1 271 8 243
36 (RTX3090) 660 949 662 1 387 6 274
16 (RTX2080) 1087 1350 682 1 410 6 302

* means the default configuration out of the box with these changes.

I also looked at the RAM and it used about the same for main as for pytorch ~7GB for batch size of 4 and main.

These are the computer specs for the three desktop systems I tested on:

Ubuntu
Processor	Intel(R) Xeon(R) Gold 6154 CPU @ 3.00GHz   2.99 GHz  (2 processors)
Installed RAM   384 GB (383 GB usable)
Core: 72

Windows
Processor	Intel(R) Core(TM) i9-9980XE CPU @ 3.00GHz   3.00 GHz
Installed RAM	128 GB (128 GB usable)
Core: 36

Windows
Processor	Intel(R) Core(TM) i7-7820X CPU @ 3.60GHz   3.60 GHz
Installed RAM	128 GB (128 GB usable)
Core: 16

References

None.

How has this PR been tested?

There's thorough testing of the new code (close to 100% coverage) - you have to run it with NUMBA_DISABLE_JIT=1 PYTORCH_JIT=0 pytest though. I also tested it on two brains, which is included in the pooch test files. Testing the image output of the 2d and 3d filters and detected cells. One brain is quite bright, which makes it a good test.

Is this a breaking change?

In terms of the API, it adds new parameters to main, with default values that works out of the box. I have not added support for passing parameters through CLI, ideally that would happen in a separate PR. The default values for parameters used should remain the same, they were simply extracted into a unified settings class.

In terms of the quality of the detection, there's some difference as explained below. But PyTorch needs the data as floats. We use float32 if the input data fits, otherwise float64. If we use float64, there's 100% parity of output with main. But float64 is much slower. Float32 is not complete parity, but, for our 1555x3222x3848 very noisy brain with 252k detected cells, only 7.2k were not an exact match. For clean data I expect it to be the same or close enough to be negligible.

Does this PR require an update to the documentation?

We need to decide on how/which to pass new (and previously hardcoded parameters) through CLI and the docs should be updated for that. Including how to properly install PyTorch for GPU support (mostly just linking to https://pytorch.org/get-started/locally/).

Checklist:

  • The code has been tested locally
  • Tests have been added to cover all new functionality (unit & integration)
  • The documentation has been updated to reflect any changes
  • The code has been formatted with pre-commit

Changes

Here are the changes, at a higher level

Settings

  • To translate/test in pytorch, I had to keep track of input parameters. But many were hardcoded as default parameters in sub-functions, or encoded locally as data types used. So I extracted them and unified them in setup_filters.py. That files has extensive config options and docs.
  • One example is how some of the ball filtering parameters in main for splitting, originate from the default values of the splitting functions. Because the same parameter defaulted to different values for original filtering and structure splitting.
  • I did not change brainmapper or anything above main - when running from CLI I manually changed the default parameters values in main. So new parameters cannot be passed in from CLI. Part of it is that I'm not quite sure what should be passable from CLI and how they would be named. E.g. classification probably has a batch_size, and now detection also has it so we need multiple.

Data types

  • pytorch needs the data to be as floats during filtering, and we prefer float32 because it's faster than float64. Since we need to convert anyway, the settings can now handle input data of any data type (as long as it fits in upto float64), it will get it into a float of the smallest size it can fit for filtering, and then back to the original data type size, if saving filtered planes, and finally into the data type used by the cell detector code (uint64 by default). In main it was always converted to uint32 and float64 for some parts of the filtering. So settings can now handle all input data sizes.
  • When using only float64, we have 100% parity with main in terms of what the 2d filter outputs for each pixel. I.e. the values after the laplace filter are the same. But, float32 is significantly faster. So we use that if the data fits into it otherwise we use float64 (e.g for uint32 data). Then we have parity of the output bright areas, but not 100%. I.e. after the laplace stage we use the filtered values to mark some pixels as bright based on their value above stdev. For the brains I tested (3222x3848 planes), the number of bright pixels per plane were off by sometimes one but at most two pixels after thresholding. IMO this tradeoff is worth it. And we test for this parity in the pytests.

Translated to pytorch

  • As part of translating I made the (in/out of brain) tile creation and filters proper classes, where it wasn't, so you don't re-create the class on every plane.
  • The classical 2d filters, tiling, and 3d ball filter was all translated to pytorch.
  • Input data now follows this path: input data -> convert to float by finding the smallest float type that input data will fit (save the max value of the original type - data will be kept below this value) -> get it into pytorch -> pass it through 2d filters and get in/out-of-brain tiling masks -> pass it through 3d filters -> if selected, save the result as images in the original data type -> back to numpy/numba - convert the data to detection data type (uint64 by default, but any data type that can fit the # of potential cells will do) -> detect cell structures, like before.
    • Extensively tested that all input data types work as long as it's float64 or less.
    • For cell detection similarly we test all data types since it's agnostic to type. As long as the previous step tells us the (large) value used to mark voxels as cells. But by default we just use uint64, but I made sure it supports all data type, including floats. Because the data only needs a max "cell value" and room for cell IDs for each detected cell. So any data type can work. There's no downside to uint64, but left support for smaller sizes as there's no real cost.

CPU/GPU/SciPy

  • The pytorch pipeline above can run fully either on on GPU (CUDA) or CPU. However, on some CPUs, pytorch 2d filtering is significantly slower than scipy 2d filtering as documented in the issue and below (filtering on CPU slower in pytorch than scipy pytorch/pytorch#126115). So we support a mix of scipy for parts of 2d filtering and torch for the rest (on CPU only). There's no cost for numpy/pytorch to share the same data. This is selected by default for CPU (which is also the default) as that's the fastest on CPU for now.
  • The test brains show there's parity between scipy and torch filtering for float64, and within tolerance for float32. Both are verified in the added pytests.
  • Below I compare the cell detection for all these configuration and it's very similar.

Threading

  • Previously, data loading and 2d filtering was split among process, with everything else running in the main process. This doesn't quite work for pytorch, especially when running on GPU because you have to be able to upload to the GPU everything (data/instructions) to do and get out of the way until the batch is done. Also, on CPU, pytorch will automatically multhithread certain operations (e.g. convolution).
  • Instead we now have the following pattern (there is a new module at tools/threading.py to support all this seamlessly):
    • There's a data feeder thread that loads the data, converts the data to the right type and into pytorch, if using GPU it also uploads the data to the GPU asynchronously, and sends it on by queue to the next step:
    • On GPU:
      • The main thread gets the torch data reference from queue.
      • Passes it to the 2d and 3d filtering on the device. All the intermediate filtering instructions happens asynchronously (we just tell the GPU to do it, but it doesn't block there).
      • We read out the filtered data, at this point it blocks the main thread until data is available. We add the filtered data to a queue for the cell detection thread.
    • On CPU:
      • We spun up as many processes as planes in a batch to parallelize the batch. By queue, from the original data feeder thread we send each process a ref to the input data and the plane to 2d filter. pytorch memory shares tensors across processes so no data copies is needed here.
      • Each process does the 2d filtering for its assigned plane in the batch. Either using pytorch only or using scipy for parts if selected (see above/later).
      • When done the sub-process, by queue, tells the main thread it's done. And finally when all planes in the batch is done, the main thread does the 3d filtering directly.
      • After 3d filtering it sends the data on by queue to the cell detection thread.
    • Cell detection. This thread takes the filtered data, converts it to the data type used for cell detection and does detection. It also saves the planes to disk if selected.
  • When it's all done we do structure splitting and converting into Cell, like before in sub-processes.

Tests

  • There's thorough testing of all the code. As well as end-to-end and components using two brains in the pooch data.
  • Benchmarks was updated to use these pooch brains as well and to profile using pytorch appropriately.

Results comparison

Here is how each of the code paths does on cell detection.

First, pytorch on CPU vs main (that includes the fix from #435). The differences are due to using float32 instead of float64, as explained above.

main_vs_torch_cpu

This compares pytorch on CPU vs pytorch on CPU + scipy as explained above. It's quite similar and has to do with scipy using float64 internally in various places while we use float32 for these uint16 data.

torch_cpu_vs_torch_cpu_scipy

pytorch on CPU vs pytorch on GPU - the detected cells were identical so no graph.

Todo

Some of this can hopefully be address in another PR.

Docs and parameters

  • Understand the default parameter values added to main, e.g. split_ball_xy_size, why were these specific values originally hardcoded? Should they be in um like the other params? I left it in voxels and not um unlike the other parameters because that's how they were hardcoded.
  • Decide which parameters should be passable from brainmapper CLI and the appropriate names.
  • Update the docs with a more intuitive explanation for what each of these parameters do - right now it's often vague.
  • All these parameter docs should probably only be listed in a single place? They are listed in setup_filters.py, in main both in root and just detect, and they also need to be shown in napari. So surely we don't want to copy paste the docs 4 times. Also, even in napari they are vague and not everything is even offered as an option.
  • Clearly document how to install pytorch with GPU support (really just follow https://pytorch.org/get-started/locally/).
  • Document about batch sizes and threading. Related to the performance tables below and above.
  • About using float32 or float64 for filtering. By default it'll use the smallest that fits the input data. But perhaps we want to allow the user to specify this if they want to pay the performance cost with float64 (I'm not sure why they would, though).
    • We should perhaps also document how we handle the data i.e. if it fits in float32 we use that, otherwise float64 so people understand what data types we support.
    • And maybe tell them to reduce their data and convert it to float32 if they have uint32 data but don't want to pay the float64 performance cost. Or maybe that's pointless.
    • It's not clear to me how common uint32 or above data is. Presumably if you have uint32 data you don't want to lose that bit accuracy? Or maybe it doesn't matter for cell detection so we could allow them to scale the data to float32 at the loss of resolution? I'm not sure how to handle all this from CLI. And maybe that can be left for another time if someone has that data.
    • Right now we don't downscale data, but we could (we'd need to know the max value of the dataset, or go by max possible value of type?).
    • Do we need to discuss about using scipy/pytorch only (for 2d filtering)? In case people want to use CPU (+scipy) and CUDA (no scipy) but the results may then be slightly different. So users may need to know. Or perhaps it doesn't matter for cleaner data.
  • CPU batch size defaults to 4 - most people should have at least 5 independent processes capability. GPU defaults to one because it depends on plane size and memory, which is hard to predict. This should be a tunable option. But I think this is a good default!?

Tests

  • There's very thorough testings. But it's a bit slow. Decide whether to mark some tests as slow and skip them in PRs/merges and run only nightly? But if someone messes with the filtering code then they should run all tests!? Or is the wait ok to run everything.
  • Many tests test CPU, CPU+scipy, CUDA. CI may not have GPU, so figure out how to test it long term. Otherwise it may only be tested when someone runs stuff locally and we won't discover GPU issues until then. Is GPU runner already available on GH? Would that be a paid service?
  • Because tests and benchmarks are not proper packages, we can't reuse code. So I had to do some path trickery in the benchmarks. Maybe benchmarks should be under cellfinder?

External issues

Some minutiae

Pytorch does its own multithreading when running on CPU. It supports intra-threading (# of threads used for e.g. convolution). For 2d filtering, increasing the number of threads per plane above like 4 doesn't help. Which means, for 2d filtering, we still benefit from multiproces parallelization. So, on CPU, a batch is split up among parallel processes.

This leads to use being able to tune the following parameters:

  • On GPU,
    • B: the batch size. Larger batches utilizes the GPU more, but too large won't fit or will slow down (for shared memory configuration).
  • On CPU,
    • B: The number of sub-processes to use for 2d filtering. E.g. if it's 4, we do 2d filtering for 4 planes at once, each in its sub-process.
    • T2d: For 2d filtering, each plane in the batch is processed in its own sub-process. We can set the number of threads in this sub-process.
    • T3d: For 3d filtering and other pytorch stuff in the main process we can set the number of threads.

Based on the following tables, other tuning, and pytorch docs I arrived at the following values.

  • B: Only B is settable as an input parameter. Be default, if run on CPU B is 4, on cuda it's 1.
  • T2d: it's hardcoded to at most 4.
  • T3d: it's set to 12, minus threads used for data loading/cell detection.

Following tables show the duration in seconds for cell detection of the 1555x3222x3848 brain. The computer is ID by number of cores. There has been changes to code so the numbers are not directly comparable to the table above.

This tests the number of pytorch threads during 2d filtering in pure pytorch on CPU. C36 means that it was limited to 36, but potentially less depending on the machine number of threads. 4 for T2d was the best option.

Computer T2d=C36 T2d=3 T2d=5
Time Time Time
72
36 1021 1169 1024
16 1398

This tests the number of pytorch threads during 3d volume (main process) filtering on CPU. C36 means that it was limited to 36, but potentially less depending on the machine number of threads.

C12 for T3d was what I ended up with.

Computer T3d=C36 T3d=4
Time Time
72 724 641
36 1035 988
16 1303 1323

This tests the number of pytorch threads during 3d volume (main process) filtering, as well the batch size on CPU. For this, we used scipy for the 2d filtering. C36(C12) means that it was limited to 36(12), but potentially less depending on the machine number of threads. B is the batch size.

B defaults to 4 on CPU now.

Computer T3d=C12;B=4 T3d=C36;B=4 T3d=4;B=8 T3d=C12;B=8 T3d=C36;B=8 T3d=C36;B=20
Time Time Time Time Time Time
72 583 614 328 326 410
36 655 738 579 434 437 411
16 743 743 656 657 653

This compares B on GPU. We default to 1 on CUDA, because we can't know how many planes will fit in GPU memory.

Computer CUDA CUDA CUDA
Time Batch Time Batch Time Batch
72 (RTX3090) 238 1 218 4 216 8
36 (GTX1060) 583 1
16 (RTX2080) 317 1 312 2 305 5
72 (RTX6000) 240 1 203 4 176 42

matham and others added 3 commits June 12, 2024 15:47
Add initial cuda detection.

Convert plane filters to torch.

Add batch sizes.

Add batching.

Fix threshold values for float vs uint32.

Save alternate way of tiling.

Move all filters to pytorch.

Don't raise exception, return instead.

Conv filter on the fastest dim.

Turn ON infrence mode.

Refactor out the detection settings.

Handle input data conversions.

Add Wrappers for threads.

Add support for single plane batches.

Add dtype for detection.

To pass soma value without knowing dtype, set largest dtype.

Switch splitting to torch.

Return the filtered plains...

Remove multiprocessing.

Use correct z axis.

Use as large a batch as possible for splitting.

Add back multiprocessing for splitting.

Ensure volume is valid.

Make tiling optional.

Cleanup and docs.

Limit cores to most prevent contension.

Fix division by zero.

We only need one version of ball filter now.

Parallelize again 2d filtering for CPU.

Add kornia as dependency.

Use better multiprocessing for cpu plane filters.

Allow using scipy for plane filter.

Pass buffers ahead to process to prevent torch queue buffer issues.

Queue must come from same ctx as process.

Reduce max cores.

Don't pin memory on cpu.

Fix tests and add more.

Add more tests and fixes.

More tests and include int inputs.

Fix coverage multiprocessing.

More tests.

Add more tests.

More docs/tests.

Use modules for 2d filters so we can use reflect padding and add 2d filter tests.

With correct thresholds, detection dtype can be input dtype size.

Add test for tiles generated during 2d filtering.

Add testing for 3d filtering.

Clean up filtering to detection conversion.

Brainmapper passes in str not float.

Fix numba cast warning.

Pad 2d filter data enough to not require padding for each filter.

Add threading test/docs.

We must process a plane at a time for parity with previous algo.

Add test comparing generated cells.

Ensure full parity with scipy for 2dfiltering.

Fix numba warning.

Don't count values at threshold in planes - brings 2d filtering to scipy parity.

Include 3d filter top/bottom padded planes in progress bar.

Move more into jit scipt.

Get test data from pooch and use it in benchmarks.

Add test for splitting underflow.
)

# input data in range (-500, 500)
data = ((np.random.random((6, 50, 50)) - 0.5) * 1000).astype(np.float32)

Check notice

Code scanning / SonarCloud

numpy.random.Generator should be preferred to numpy.random.RandomState Low test

Use a "numpy.random.Generator" here instead of this legacy function. See more on SonarCloud
# check that filter padding works correctly for different sized inputs -
# even if the input is smaller than filter sizes
settings = DetectionSettings(plane_original_np_dtype=np.uint16)
data = np.random.randint(0, 500, size=(1, *plane_size))

Check notice

Code scanning / SonarCloud

numpy.random.Generator should be preferred to numpy.random.RandomState Low test

Use a "numpy.random.Generator" here instead of this legacy function. See more on SonarCloud
@alessandrofelder alessandrofelder self-requested a review June 13, 2024 09:13
@alessandrofelder
Copy link
Member

Thanks a lot @matham - I'll have an in-depth look next week!

@matham
Copy link
Contributor Author

matham commented Jun 19, 2024

The reason for the tests failing is that it times out because it takes too long (60 min). As mentioned in the todos, there are a lot of tests and I can't imagine they are fast on CI. So perhaps we will need to figure out what to do about it - perhaps marking the slower tests optional and only running them nightly or some other approach!?

Also, getting the test data from g-node is quite slow. Surprisingly so given it's supposed to host datasets. I assumed it'd be faster a bit after upload once it had a chance to cache or whatever, and especially on github, but it's just as slow as locally I think. Although there's no clear indication how long the download takes as pytest eats all the logs. On my computer it was well under 1MB/s (1MBits/s?). But if the tests weren't so slow themselves, it wouldn't be an issue due to the cache. But still it's a bit concerning!?

@alessandrofelder
Copy link
Member

Yea, not sure what to do about this, but good point - I'll have a think 🤔 suggestions welcome!
I have started reviewing (looking good!) but will likely need a bit more time, sorry.
Lots of things going on at the moment :)

@matham
Copy link
Contributor Author

matham commented Jun 19, 2024

No worries at all!

I responded about g-node here: #439.

About skipping slow tests. I think we can classify slow (and not as essential) and normal tests using a mark. And then pytest lets you exclude these tests by mark. And then by default it'd run all tests. But on CI we disable these tests, except for nightly runs, which runs once a day. And we can increase the timeout time for them (to 6 hours...).

Github also has larger runners which may (will probably) be faster. But they are not free...

@adamltyson
Copy link
Member

I'd like to make sure we run all tests if we can. A test we skip isn't really a test at all.

I mentioned on the other issue about reducing test data sizes and caching to speed things up, but can we also parallelise the tests across runners?

I'd like to avoid using larger runners if we can. It may be cheap now, but may not scale well.

@matham
Copy link
Contributor Author

matham commented Jul 3, 2024

It now uses a smaller brain for the bright brain. With cache it seems to take around 12 min and without cache upto 30min. Perhaps increasing to timeout to 2 hours to account for this rate situation when the cache changes is a good idea.

About parallelization across runners, I'm not sure how to do it in pytest easily. There's pytest-xdist, but it only works across cores on a single machine. But perhaps 12min is fast enough?

@adamltyson
Copy link
Member

It now uses a smaller brain for the bright brain. With cache it seems to take around 12 min and without cache up to 30min.

I think that's ok. If there's anyway we can speed it up, then that's always welcome, but I think 12 mins is doable.

Perhaps increasing to timeout to 2 hours to account for this rate situation when the cache changes is a good idea.

I think that's a good idea.

Copy link
Member

@alessandrofelder alessandrofelder left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@matham THANK YOU for this (and thank you for your patience waiting on this review)

The documentation and the tests are particularly nice, and the code is a lot more organised now!

I have two very minor comments that need consideration, and tiny suggestions for the docs. Mostly a lot of me asking questions to double-check my understanding.

Then I think this is ready to go (we might wait a little bit to ensure that cellfinder 1.3.x is stable enough before we release this though - just in case we need to fallback... I doubt it though)

cellfinder/core/tools/threading.py Outdated Show resolved Hide resolved
cellfinder/core/tools/threading.py Show resolved Hide resolved
thread.get_msg_from_thread()
assert type(exc_info.value.__cause__) is ExceptionTest
assert exc_info.value.__cause__.args[0] == pass_self[1] + arg_msg
thread.join()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my understanding, the calls at the ends of tests to thread.join are to ensure thread has terminated before we move to the next test?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct. Just so that we don't leave resources hanging - even though it'd eventually terminate either way.

Also, my ulterior motive is that it's another test. Because if the code is wrong and the thread never exits it may hang here. While we wouldn't be able to tell what is wrong, at least we'd know something is wrong. Although hopefully of course that should never happen...

thread = cls(target=send_multiple_msgs, pass_self=True)
thread.start()

thread.clear_remaining()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Presumably we could assert something here to check that the thread is actually out of messages?

tests/core/test_integration/test_detection.py Outdated Show resolved Hide resolved
Comment on lines 10 to 13
filtered_planes: torch.Tensor,
clipping_value: float,
flip: bool,
upscale: bool,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we document these?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am also wondering whether this signature would be clearer?

Suggested change
filtered_planes: torch.Tensor,
clipping_value: float,
flip: bool,
upscale: bool,
filtered_planes: torch.Tensor,
max_value: float = 1.0,
flip: bool,

The fact that we scale to [0,1] by default would be clearer from the function signature, and there would be one fewer argument.

cellfinder/core/detect/filters/plane/classical_filter.py Outdated Show resolved Hide resolved
# We go from ZCYX -> ZCYX, C=1 to C=9 with C containing the elements around
# each Z,X,Y voxel over which we compute the median
# Zero padding is ok here
filtered_planes = F.conv2d(filtered_planes, med_kernel, padding="same")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docs for conv2d flag that this can be non-deterministic - are we happy to take that risk?

I am leaning towards "yes" because presumably the non-deterministic errors would be in the machine-precision range, but might be worth discussing/digging into deeper?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good observation. I think if there's a section in user facing docs that documents how to make stuff deterministic (seed etc) then we should add a line with all the stuff listed in pytorch docs on how to make it deterministic and about potential performance issues. Or it may be good to add such a section in the docs.

But I think by default it should be ok to have the non-determinism. Especially on GPU where I think there would be significant performance penalties. The reason is that overall we have large-ish blobs and we want to know, is the overall intensity/size of the blob large enough to be a cell (binary yes/no). We don't quite care about the exact intensity values of the blob, just its overall statistics. Especially when there are lots of cells with varying intensity present. At least in the case I'm thinking off. So slight variation in computation due to parallelism should not effect the result much.

The place I can see making a difference is for a very noise image where cells and background are at similar intensity and cells are tiny - then small intensity differences may result in different labeling. Assuming this non-determinism is that significant.

What do you think?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep I'll make a note to link to the pytorch docs for cellfinder users. Happy with this explanation otherwise.

cellfinder/core/detect/filters/plane/plane_filter.py Outdated Show resolved Hide resolved
corner_intensity = torch.mean(corner64, dim=1).type(planes.dtype)
# for parity with past when we used np.std, which defaults to ddof=0
corner_sd = torch.std(corner64, dim=1, correction=0).type(planes.dtype)
# add 1 to ensure not 0, as disables
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does it disable? Zero-division error?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment preceded me. Maybe @adamltyson knows.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Setting this threshold to 0 prevents the tile from being processed, so we add 1 here. TBH not sure if we need this though.

@matham
Copy link
Contributor Author

matham commented Aug 13, 2024

Thank you for taking the time reviewing the giant PR 🙂 The other requested code changes I'll do ASAP.

@matham
Copy link
Contributor Author

matham commented Aug 16, 2024

The reason why it fails on Windows on the CI is due to pytorch/pytorch#131958.

It's an issue with pytorch 2.4.0. I think we may need to pin pytorch (max) versions so this kind of thing doesn't happen. And bump pytorch versions periodically (e.g. when new python versions are added yearly, e.g. 3.13) after testing. This should be necessary both for detection and classification parts as both use pytorch.

Although hopefully such high critical bugs should be rare in pytorch...

@adamltyson
Copy link
Member

I'm hesitant to routinely pin max versions. Across the whole BrainGlobe ecosystem (and other tools likely to be installed at the same time like napari plugins), pinning versions causes lots of problems down the line.

Of course we should avoid pytorch 2.4.0 for now. I've raised #451 to track updating this once the pytorch fix filters through to the latest release.

@alessandrofelder
Copy link
Member

alessandrofelder commented Aug 23, 2024

I think I am happy with this PR now.

Looks like pytorch is on it about fixing the bug that forces us to pin a max version.
I suggest we wait for that version to be fully released, pin to latest pytorch and then merge and release this.

I'll make sure the user docs are up-to-date in the meantime.
(Dev docs updates may take a bit longer)


# todo: first z should account for middle plane not being start plane
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this TODO be addressed as part of this PR, @matham ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC this is taken care of by the settings' start_plane and we can just remove this comment?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct, it was already fixed but I left it there to make sure to test it. The tests already check that now.

with torch.inference_mode(True):
return func(*args, **kwargs)

return inner_function
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(You can tell I am writing docs and noticing more things as I go along 😂 )
If I read the pytorch docs correctly, we could use torch.inference_mode as a decorator directly instead of redefining it here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but I remember I had issues using it as a decorator, where it didn't seem to work. It's possible something else was at fault though.

@imagesc-bot
Copy link

This pull request has been mentioned on Image.sc Forum. There might be relevant details there:

https://forum.image.sc/t/python-programs-to-detect-fluorescent-cells-in-3d-images-of-mouse-brain/98418/15

@alessandrofelder
Copy link
Member

Hey @matham I am tempted to merge this and have written issues #464 and brainglobe/brainglobe.github.io#254 for outstanding items. Let me know if you have any objections.

I am at a conference next week, but hope to continue drafting user and dev docs week of 28th October

@matham
Copy link
Contributor Author

matham commented Oct 18, 2024

I don't have any objections.

I am mainly worried about people trying to use it without the full API being in place. But I guess people who use main without a release can be expected to have to figure it out. Because there's also some risk of main diverging if we don't merge. So I'm ok with it either way!

@alessandrofelder alessandrofelder merged commit 4691987 into brainglobe:main Oct 31, 2024
16 of 17 checks passed
@matham matham deleted the pytorch branch October 31, 2024 21:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Structure splitting unsigned underflow
4 participants