-
Notifications
You must be signed in to change notification settings - Fork 39
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Move detection (2d/3d filtering, structure splitting) to PyTorch #440
Conversation
Add initial cuda detection. Convert plane filters to torch. Add batch sizes. Add batching. Fix threshold values for float vs uint32. Save alternate way of tiling. Move all filters to pytorch. Don't raise exception, return instead. Conv filter on the fastest dim. Turn ON infrence mode. Refactor out the detection settings. Handle input data conversions. Add Wrappers for threads. Add support for single plane batches. Add dtype for detection. To pass soma value without knowing dtype, set largest dtype. Switch splitting to torch. Return the filtered plains... Remove multiprocessing. Use correct z axis. Use as large a batch as possible for splitting. Add back multiprocessing for splitting. Ensure volume is valid. Make tiling optional. Cleanup and docs. Limit cores to most prevent contension. Fix division by zero. We only need one version of ball filter now. Parallelize again 2d filtering for CPU. Add kornia as dependency. Use better multiprocessing for cpu plane filters. Allow using scipy for plane filter. Pass buffers ahead to process to prevent torch queue buffer issues. Queue must come from same ctx as process. Reduce max cores. Don't pin memory on cpu. Fix tests and add more. Add more tests and fixes. More tests and include int inputs. Fix coverage multiprocessing. More tests. Add more tests. More docs/tests. Use modules for 2d filters so we can use reflect padding and add 2d filter tests. With correct thresholds, detection dtype can be input dtype size. Add test for tiles generated during 2d filtering. Add testing for 3d filtering. Clean up filtering to detection conversion. Brainmapper passes in str not float. Fix numba cast warning. Pad 2d filter data enough to not require padding for each filter. Add threading test/docs. We must process a plane at a time for parity with previous algo. Add test comparing generated cells. Ensure full parity with scipy for 2dfiltering. Fix numba warning. Don't count values at threshold in planes - brings 2d filtering to scipy parity. Include 3d filter top/bottom padded planes in progress bar. Move more into jit scipt. Get test data from pooch and use it in benchmarks. Add test for splitting underflow.
for more information, see https://pre-commit.ci
) | ||
|
||
# input data in range (-500, 500) | ||
data = ((np.random.random((6, 50, 50)) - 0.5) * 1000).astype(np.float32) |
Check notice
Code scanning / SonarCloud
numpy.random.Generator should be preferred to numpy.random.RandomState Low test
# check that filter padding works correctly for different sized inputs - | ||
# even if the input is smaller than filter sizes | ||
settings = DetectionSettings(plane_original_np_dtype=np.uint16) | ||
data = np.random.randint(0, 500, size=(1, *plane_size)) |
Check notice
Code scanning / SonarCloud
numpy.random.Generator should be preferred to numpy.random.RandomState Low test
Thanks a lot @matham - I'll have an in-depth look next week! |
The reason for the tests failing is that it times out because it takes too long (60 min). As mentioned in the todos, there are a lot of tests and I can't imagine they are fast on CI. So perhaps we will need to figure out what to do about it - perhaps marking the slower tests optional and only running them nightly or some other approach!? Also, getting the test data from g-node is quite slow. Surprisingly so given it's supposed to host datasets. I assumed it'd be faster a bit after upload once it had a chance to cache or whatever, and especially on github, but it's just as slow as locally I think. Although there's no clear indication how long the download takes as pytest eats all the logs. On my computer it was well under 1MB/s (1MBits/s?). But if the tests weren't so slow themselves, it wouldn't be an issue due to the cache. But still it's a bit concerning!? |
Yea, not sure what to do about this, but good point - I'll have a think 🤔 suggestions welcome! |
No worries at all! I responded about g-node here: #439. About skipping slow tests. I think we can classify slow (and not as essential) and normal tests using a mark. And then pytest lets you exclude these tests by mark. And then by default it'd run all tests. But on CI we disable these tests, except for nightly runs, which runs once a day. And we can increase the timeout time for them (to 6 hours...). Github also has larger runners which may (will probably) be faster. But they are not free... |
I'd like to make sure we run all tests if we can. A test we skip isn't really a test at all. I mentioned on the other issue about reducing test data sizes and caching to speed things up, but can we also parallelise the tests across runners? I'd like to avoid using larger runners if we can. It may be cheap now, but may not scale well. |
It now uses a smaller brain for the bright brain. With cache it seems to take around 12 min and without cache upto 30min. Perhaps increasing to timeout to 2 hours to account for this rate situation when the cache changes is a good idea. About parallelization across runners, I'm not sure how to do it in pytest easily. There's pytest-xdist, but it only works across cores on a single machine. But perhaps 12min is fast enough? |
I think that's ok. If there's anyway we can speed it up, then that's always welcome, but I think 12 mins is doable.
I think that's a good idea. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@matham THANK YOU for this (and thank you for your patience waiting on this review)
The documentation and the tests are particularly nice, and the code is a lot more organised now!
I have two very minor comments that need consideration, and tiny suggestions for the docs. Mostly a lot of me asking questions to double-check my understanding.
Then I think this is ready to go (we might wait a little bit to ensure that cellfinder 1.3.x is stable enough before we release this though - just in case we need to fallback... I doubt it though)
thread.get_msg_from_thread() | ||
assert type(exc_info.value.__cause__) is ExceptionTest | ||
assert exc_info.value.__cause__.args[0] == pass_self[1] + arg_msg | ||
thread.join() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For my understanding, the calls at the ends of tests to thread.join
are to ensure thread
has terminated before we move to the next test?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct. Just so that we don't leave resources hanging - even though it'd eventually terminate either way.
Also, my ulterior motive is that it's another test. Because if the code is wrong and the thread never exits it may hang here. While we wouldn't be able to tell what is wrong, at least we'd know something is wrong. Although hopefully of course that should never happen...
thread = cls(target=send_multiple_msgs, pass_self=True) | ||
thread.start() | ||
|
||
thread.clear_remaining() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Presumably we could assert something here to check that the thread is actually out of messages?
filtered_planes: torch.Tensor, | ||
clipping_value: float, | ||
flip: bool, | ||
upscale: bool, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we document these?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am also wondering whether this signature would be clearer?
filtered_planes: torch.Tensor, | |
clipping_value: float, | |
flip: bool, | |
upscale: bool, | |
filtered_planes: torch.Tensor, | |
max_value: float = 1.0, | |
flip: bool, |
The fact that we scale to [0,1] by default would be clearer from the function signature, and there would be one fewer argument.
# We go from ZCYX -> ZCYX, C=1 to C=9 with C containing the elements around | ||
# each Z,X,Y voxel over which we compute the median | ||
# Zero padding is ok here | ||
filtered_planes = F.conv2d(filtered_planes, med_kernel, padding="same") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The docs for conv2d
flag that this can be non-deterministic - are we happy to take that risk?
I am leaning towards "yes" because presumably the non-deterministic errors would be in the machine-precision range, but might be worth discussing/digging into deeper?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good observation. I think if there's a section in user facing docs that documents how to make stuff deterministic (seed etc) then we should add a line with all the stuff listed in pytorch docs on how to make it deterministic and about potential performance issues. Or it may be good to add such a section in the docs.
But I think by default it should be ok to have the non-determinism. Especially on GPU where I think there would be significant performance penalties. The reason is that overall we have large-ish blobs and we want to know, is the overall intensity/size of the blob large enough to be a cell (binary yes/no). We don't quite care about the exact intensity values of the blob, just its overall statistics. Especially when there are lots of cells with varying intensity present. At least in the case I'm thinking off. So slight variation in computation due to parallelism should not effect the result much.
The place I can see making a difference is for a very noise image where cells and background are at similar intensity and cells are tiny - then small intensity differences may result in different labeling. Assuming this non-determinism is that significant.
What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep I'll make a note to link to the pytorch
docs for cellfinder users. Happy with this explanation otherwise.
corner_intensity = torch.mean(corner64, dim=1).type(planes.dtype) | ||
# for parity with past when we used np.std, which defaults to ddof=0 | ||
corner_sd = torch.std(corner64, dim=1, correction=0).type(planes.dtype) | ||
# add 1 to ensure not 0, as disables |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does it disable? Zero-division error?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment preceded me. Maybe @adamltyson knows.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Setting this threshold to 0 prevents the tile from being processed, so we add 1 here. TBH not sure if we need this though.
Co-authored-by: Alessandro Felder <[email protected]>
Co-authored-by: Alessandro Felder <[email protected]>
for more information, see https://pre-commit.ci
Thank you for taking the time reviewing the giant PR 🙂 The other requested code changes I'll do ASAP. |
The reason why it fails on Windows on the CI is due to pytorch/pytorch#131958. It's an issue with pytorch 2.4.0. I think we may need to pin pytorch (max) versions so this kind of thing doesn't happen. And bump pytorch versions periodically (e.g. when new python versions are added yearly, e.g. 3.13) after testing. This should be necessary both for detection and classification parts as both use pytorch. Although hopefully such high critical bugs should be rare in pytorch... |
I'm hesitant to routinely pin max versions. Across the whole BrainGlobe ecosystem (and other tools likely to be installed at the same time like napari plugins), pinning versions causes lots of problems down the line. Of course we should avoid pytorch 2.4.0 for now. I've raised #451 to track updating this once the pytorch fix filters through to the latest release. |
I think I am happy with this PR now. Looks like I'll make sure the user docs are up-to-date in the meantime. |
|
||
# todo: first z should account for middle plane not being start plane |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this TODO be addressed as part of this PR, @matham ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC this is taken care of by the settings' start_plane
and we can just remove this comment?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct, it was already fixed but I left it there to make sure to test it. The tests already check that now.
with torch.inference_mode(True): | ||
return func(*args, **kwargs) | ||
|
||
return inner_function |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(You can tell I am writing docs and noticing more things as I go along 😂 )
If I read the pytorch
docs correctly, we could use torch.inference_mode
as a decorator directly instead of redefining it here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, but I remember I had issues using it as a decorator, where it didn't seem to work. It's possible something else was at fault though.
This pull request has been mentioned on Image.sc Forum. There might be relevant details there: |
Hey @matham I am tempted to merge this and have written issues #464 and brainglobe/brainglobe.github.io#254 for outstanding items. Let me know if you have any objections. I am at a conference next week, but hope to continue drafting user and dev docs week of 28th October |
I don't have any objections. I am mainly worried about people trying to use it without the full API being in place. But I guess people who use |
Description
What is this PR
Why is this PR needed?
When I was testing out cell detection, it was too slow for our usecase where ideally we play with detection parameters until all cells are detected and then run it through classification. I was used to Imaris where cell detection took like a minute (probably simple thresholding?) So I was looking to speed it up so detection was faster.
What does this PR do?
This PR moves the 2d and 3d filtering used in detection and structure splitting to PyTorch. The benefits is that it's a lot faster on GPU, and a bit faster on CPU. While these are significant changes, I tested with a couple of brains, including one that is very bright and noisy, and achieved parity with
main
. Those tests are included in the PR - there is generally very thorough testing with close to 100% coverage.While there are a lot of changes, there is a few core sources of change that propagate; unifying options that were hardcoded/in various places, translating the filtering code to pytorch, and changing the threading/sub-processing and data loading to be more suited for pytorch. If you follow these core concerns, the changes should make sense.
I realize it's a giant PR, but I'll help with review in whatever way you need - I couldn't split it up because it all seems to goe together, especially to get the proper performance from pytorch.
Performance
First some performance numbers. I measured this on
1555x3222x3848
brain. On CPU and GPU. Each row is a computer identified by the number of CPU cores (e.g.72
) andRTX6000
is the GPU on that computer.The columns are the current
main
branch (before this fix #435),CPU
means using pytorch on CPU (batch size of 4),CPU+scipy
means scipy (see later) for the 2d filtering and pytorch CPU for 3d filtering (batch=4), andCUDA
means pytorch on the GPU, with different batch sizes.Time is in seconds and includes just detection, not structure splitting.
*
means the default configuration out of the box with these changes.I also looked at the RAM and it used about the same for main as for pytorch ~7GB for batch size of 4 and main.
These are the computer specs for the three desktop systems I tested on:
References
None.
How has this PR been tested?
There's thorough testing of the new code (close to 100% coverage) - you have to run it with
NUMBA_DISABLE_JIT=1 PYTORCH_JIT=0 pytest
though. I also tested it on two brains, which is included in the pooch test files. Testing the image output of the 2d and 3d filters and detected cells. One brain is quite bright, which makes it a good test.Is this a breaking change?
In terms of the API, it adds new parameters to
main
, with default values that works out of the box. I have not added support for passing parameters through CLI, ideally that would happen in a separate PR. The default values for parameters used should remain the same, they were simply extracted into a unified settings class.In terms of the quality of the detection, there's some difference as explained below. But PyTorch needs the data as floats. We use float32 if the input data fits, otherwise float64. If we use float64, there's 100% parity of output with
main
. But float64 is much slower. Float32 is not complete parity, but, for our1555x3222x3848
very noisy brain with 252k detected cells, only 7.2k were not an exact match. For clean data I expect it to be the same or close enough to be negligible.Does this PR require an update to the documentation?
We need to decide on how/which to pass new (and previously hardcoded parameters) through CLI and the docs should be updated for that. Including how to properly install PyTorch for GPU support (mostly just linking to https://pytorch.org/get-started/locally/).
Checklist:
Changes
Here are the changes, at a higher level
Settings
setup_filters.py
. That files has extensive config options and docs.main
for splitting, originate from the default values of the splitting functions. Because the same parameter defaulted to different values for original filtering and structure splitting.main
. So new parameters cannot be passed in from CLI. Part of it is that I'm not quite sure what should be passable from CLI and how they would be named. E.g. classification probably has a batch_size, and now detection also has it so we need multiple.Data types
main
it was always converted to uint32 and float64 for some parts of the filtering. So settings can now handle all input data sizes.one
but at mosttwo
pixels after thresholding. IMO this tradeoff is worth it. And we test for this parity in the pytests.Translated to pytorch
CPU/GPU/SciPy
Threading
tools/threading.py
to support all this seamlessly):Cell
, like before in sub-processes.Tests
Results comparison
Here is how each of the code paths does on cell detection.
First, pytorch on CPU vs main (that includes the fix from #435). The differences are due to using float32 instead of float64, as explained above.
This compares pytorch on CPU vs pytorch on CPU + scipy as explained above. It's quite similar and has to do with scipy using float64 internally in various places while we use float32 for these
uint16
data.pytorch on CPU vs pytorch on GPU - the detected cells were identical so no graph.
Todo
Some of this can hopefully be address in another PR.
Docs and parameters
split_ball_xy_size
, why were these specific values originally hardcoded? Should they be inum
like the other params? I left it in voxels and notum
unlike the other parameters because that's how they were hardcoded.setup_filters.py
, inmain
both in root and just detect, and they also need to be shown in napari. So surely we don't want to copy paste the docs 4 times. Also, even in napari they are vague and not everything is even offered as an option.float32
orfloat64
for filtering. By default it'll use the smallest that fits the input data. But perhaps we want to allow the user to specify this if they want to pay the performance cost with float64 (I'm not sure why they would, though).Tests
External issues
_set_soma
).use_scipy
): filtering on CPU slower in pytorch than scipy pytorch/pytorch#126115. I suspect it's only an issue with older computers (i.e. more than 2/ years old...) as pytorch CPU may not be optimized as much for it.Some minutiae
Pytorch does its own multithreading when running on CPU. It supports intra-threading (# of threads used for e.g. convolution). For 2d filtering, increasing the number of threads per plane above like 4 doesn't help. Which means, for 2d filtering, we still benefit from multiproces parallelization. So, on CPU, a batch is split up among parallel processes.
This leads to use being able to tune the following parameters:
B
: the batch size. Larger batches utilizes the GPU more, but too large won't fit or will slow down (for shared memory configuration).B
: The number of sub-processes to use for 2d filtering. E.g. if it's 4, we do 2d filtering for 4 planes at once, each in its sub-process.T2d
: For 2d filtering, each plane in the batch is processed in its own sub-process. We can set the number of threads in this sub-process.T3d
: For 3d filtering and other pytorch stuff in the main process we can set the number of threads.Based on the following tables, other tuning, and pytorch docs I arrived at the following values.
B
: OnlyB
is settable as an input parameter. Be default, if run on CPUB
is 4, on cuda it's 1.T2d
: it's hardcoded to at most 4.T3d
: it's set to12
, minus threads used for data loading/cell detection.Following tables show the duration in seconds for cell detection of the
1555x3222x3848
brain. The computer is ID by number of cores. There has been changes to code so the numbers are not directly comparable to the table above.This tests the number of pytorch threads during 2d filtering in pure pytorch on CPU. C36 means that it was limited to 36, but potentially less depending on the machine number of threads.
4
forT2d
was the best option.This tests the number of pytorch threads during 3d volume (main process) filtering on CPU. C36 means that it was limited to 36, but potentially less depending on the machine number of threads.
C12
forT3d
was what I ended up with.This tests the number of pytorch threads during 3d volume (main process) filtering, as well the batch size on CPU. For this, we used scipy for the 2d filtering. C36(C12) means that it was limited to 36(12), but potentially less depending on the machine number of threads.
B
is the batch size.B
defaults to 4 on CPU now.This compares
B
on GPU. We default to 1 on CUDA, because we can't know how many planes will fit in GPU memory.