-
Notifications
You must be signed in to change notification settings - Fork 2.5k
[Feature request] Make IoU computation more memory efficient #18
Comments
Hi, There is a difference in how we interpret the IMS_PER_BATCH parameter in our codebase: in Detectron, it's per GPU. In our implementation, it's a global batch size, which gets divided over the number of GPUs that you are using. So in your case, your are probably training with a batch size of 16, on a single GPU. So to fix your memory issues, you'll need to adapt the IMS_,PER_BATCH, as well as the number of iterations / or schedule / learning rate according to detectron rules. The reason why we changed the meaning of IMS_PER_BATCH compared to Detectron was indeed to simplify experimentation, as all those parameters I mentioned are fixed given a global batch size. But they need to be adjusted if you change the global batch size, which was the case before whenever you changed the number of GPUs. Let me know if this is clear |
But that makes me think that it is a good idea to add a note about this in the README. Would you be willing to do it? |
@steve-goley I've improved the README in #35 with more details on how to perform experiments with single GPU. |
Thanks @fmassa for the followup. I'll keep investigating this. Thanks for the clarification on the IMS_PER_BATCH parameter, I was also confused about it. I believe that I changed that to 1 but still ran into the memory error. I'm able to train for awhile with stable memory usage only to run into an OOM error hundreds or thousands of iterations in. I'm trying/am going to try a couple of workarounds:
iou = inter.cpu() / (area1[:, None].cpu() + area2.cpu() - inter.cpu())
return iou.cuda() which allowed me to get to 4000+ iterations. However, I eventually errored out here
If the other changes don't work I'll try moving more of the boxlist_opst to the CPU. My problems set has some cases of extremely dense GT boxes, >500 in one image. My hypothesis is that this is causing the issue. Does that make sense? |
Our implementation performs bounding box assignment on the GPU, so having an extremely large number of GT boxes might be one of the reasons. I've did some quick computations, for a batch size of 1, with default parameters for the FPN, there are 242991 anchors. So, in those cases, I think there are a few options:
I'd start with the first solution, so calling Let me know what you think |
@fmassa Thanks for your diligence! It sounds like that is indeed my issue. I can't say that I completely understand your second alternative there. I'll see what speed hit I take from moving it to the CPU, or perhaps do so conditionally. If it's too drastic then I will batch it up. |
Let me know if you have issues implementing the batched up implementation. About point 2, I was mentioning writing a dedicated CUDA kernel for computing the IoU matrix. I'll think about implementing it |
@fmassa I did some brief debugging and found that about 200 GT boxes used about 1.5GB of GPU-RAM, roughly in line with your calculations. I'm conditionally using the CPU now for that block if M*N is quite large (>20000000). Looking at the code, there might be a more memory efficient Python implementation as well, save an MxNx2 allocation by inplace operations. This would increase the threshold but still have its limits. Sorry, I was slow on the kernel uptake. I thought you meant casting it as a constitutional kernel, which seemed odd. A more memory efficient kernel would be great for my current use case, e.g. overhead imagery. There are other workarounds (cropping) so it likely shouldn't be the highest item on your list. I'm using 800x800 images, but junk yards and parking lots can pack in a lot of GT targets. Feel free to close the issue and maybe open it as an enhancement? |
I've changed the title of the issue, let's keep it open. I think we can use some in-place operations there, and it will bring some savings, but I'm not sure by how much. |
IOU matrix is very often extremely sparse, especially if you immediately remove bbox matches with IOU less than predefined threshold (which might be 0.05 or 0.3 or something else). Is it a good idea to add iou computation result to be a sparse matrix (or at least add it as an option)? |
There is currently limited support for sparse matrices in PyTorch, so it might not be ideal for now. But once we have better support for sparse reductions in PyTorch ( |
hi @fmassa |
You can do device = bbox1.device
bbox1 = bbox1.cpu()
...
iou = iou.to(device) |
thanks @fmassa. follow ur instruction, i change the code as below:
but when i run the experiment, i meet this problem
i try to set the do u have any suggestions to solve it? |
I don't think it's a good choice to do like this. I calculate iou in four ways: numpy version, torch version, cython version and gpu version. Indeed, gpu version is fastest. But it costs a lot of memory. Numpy version is close to the torch version but is much slower than cython version. So I suggest using the cython version( can refer to detectron.pytorch) |
You can also use the |
Thanks. But I'm not familiar with torch.jit.script. Directly wrapping the function is OK? |
Try something like this instead. Note that you'll need to unwrap the
|
@fmassa Why do torch.jit.script save memory? Why is it not used in the master code when it seems like a very good improvement? Is there any downside? |
@yxchng no downsides. It's not in master because it makes things slightly less unreadable. It saves memory because it doesn't materialize the intermediate results into large tensors. |
I was running into the same issue for both this repo and detectron2. I ended up solving it with chunking. Here is some code that I modified:
The original code can be found at https://github.com/facebookresearch/detectron2/blob/master/detectron2/structures/boxes.py#L235. It's very similar to maskrcnn_benchmark, and can be adapted to it. I broke it into chunks of size 20. Now at least I can train on my custom dataset with a lot of instances per image. |
@ethanweber Interesting solution. Did you get a chance to compare it against the |
I didn't compare it with the |
❓ Questions and Help
I'm experiencing high GPU memory usage. I made my own COCO dataset and started training with 2 separate models: e2e_faster_rcnn_R_50_FPN_1x.yaml, e2e_faster_rcnn_R_101_FPN_1x.yaml, e2e_faster_rcnn_X_101_32x8d_FPN_1x.yaml. I changed the number of GPUs to 1 and ran the single GPU training command.
I get well into the training, 100s or 1000s of iterations, then I receive the CUDA OOM message. The reported mem usage is around 7GB though nvidia-smi reports about 9.7GB for the ResNeXt model.
I'm running on a 1080Ti with 11GB of RAM, so it should be able to handle this amount of memory. It seems as though there are periodic peaks in the memory usage.
The error message for R_50_FPN looks like this:
Note, I've trained with this dataset on Detectron.pytorch. Any suggestions?
Based on where the error occurs, is it possible that one of my images contains too many targets (potentially hundreds) and the iou calculation blows up?
Steve
The text was updated successfully, but these errors were encountered: