-
Notifications
You must be signed in to change notification settings - Fork 511
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Cuda OOM error in gather_topk_anchors #1499
Comments
Currently, lowering BS seems to be the only quick solution. As you correctly pointed out OOM happens when one image has many GT boxes which due to batched procedure of anchors assignment causes the some unnecessary padding to other samples in batch. The solution could be to disable this batched assignment and do it on per-sample basis. This would be somewhat slower but should fix the issue with OOM. I imagine this may be a loss argument that one may enable if needed or try/catch inside loss and fallback to per-sample processing in case OOM happens. I'm not sure I can give you an estimate when we may get some resources to work on this improvement. If someone wants to contribute - we would be happy to guide here. |
Would a quick and dirty solution be to do some try/catch thing which allows you to skip such batches? i realise this would mean that you would essentially never train on images with lots of GTs, but that might be acceptable. Alternatively you could pre-filter your dataset to remove such images, but then you would need to somehow know up-front the number of GTs at which this problem starts occurring, which isn't obvious. |
Starting from 3.4.0 release we now have this feature: #1582. |
💡 Your Question
I am trying to train a yolonas model and sometimes get out of memory errors randomly mid way through training at the line
is_in_topk = torch.nn.functional.one_hot(topk_idxs, num_anchors).sum(dim=-2).type_as(metrics)
in the function
gather_topk_anchors
This seems to happen when a batch just happens to contain a very large number of ground-truth objects. I can avoid it by lowering the batch size, but doing so means that I'm not taking full advantage of my gpu memory and lose a fair bit of performance.
Any ideas how I can alleviate this?
Versions
No response
The text was updated successfully, but these errors were encountered: