Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

Memory Usage is higher than other Pytorch implementation? #182

Open
XudongWang12Sigma opened this issue Nov 20, 2018 · 23 comments
Open

Memory Usage is higher than other Pytorch implementation? #182

XudongWang12Sigma opened this issue Nov 20, 2018 · 23 comments
Labels
question Further information is requested

Comments

@XudongWang12Sigma
Copy link

XudongWang12Sigma commented Nov 20, 2018

❓ Questions and Help

I try to run Faster-RCNN with resnet-50 as backbone on COCO dataset, and it seems like that the memory usage for your implementation is: 9.4 G for 2 ims/gpu(scales=800). And for the implementation of: https://github.com/jwyang/faster-rcnn.pytorch, faster-rcnn with resnet-50 only need to use around 6G. Your memory usage is also larger than official Detectron(7.2G with FPN). I am not sure whether I made some mistakes? or your implementation really occupy more memory? Thanks

@fmassa
Copy link
Contributor

fmassa commented Nov 20, 2018

Hi,

Here are the memory requirements that we have for Faster R-CNN R-50 on COCO.
For FPN, we use 4.4GB of training memory.

One thing to keep in mind is that if the meaning of IMS_PER_BATCH in the config is different from the Detectron implementation, in the sense that we use the global mini-batch size (over all GPUs), while Detectron uses the mini-batch size for a single GPU, see in here for more explanation.

Maybe that answers your question?

@fmassa fmassa added the question Further information is requested label Nov 20, 2018
@jario-jin
Copy link
Contributor

jario-jin commented Nov 21, 2018

My environment:
ubuntu 16.04 x64
Driver Version: 390.59
cuda: 9.1
cudnn: 7.0.5
pytorch: torch-1.0.0a0+5d0ef34-py2.7.egg-info
maskrcnn-benchmark: newest
2 GPUs

Training parameters:
use:
e2e_faster_rcnn_R_50_FPN_1x.yaml
modify:
SOLVER:
BASE_LR: 0.005 # 0.0025 * Num-of-GPUs (e.g. 8 GPUs 0.02)
WEIGHT_DECAY: 0.0001
STEPS: (6000, 8000) # (480000, 640000) / Num-of-GPUs (e.g. 8 GPUs (60000, 80000))
MAX_ITER: 9000 # 720000 / Num-of-GPUs (e.g. 8 GPUs 90000)
IMS_PER_BATCH: 4 # 2 * Num-of-GPUs (e.g. 8 GPUs 16)
TEST:
IMS_PER_BATCH: 4
other parameters are as defaults
training on coco dataset

GPU Memory Usage:
| 0 GeForce GTX 108... Off | 00000000:01:00.0 On | N/A |
| 50% 69C P2 185W / 250W | 8045MiB / 11177MiB | 100% Default |
+-------------------------------+----------------------+----------------------+
| 1 TITAN X (Pascal) Off | 00000000:71:00.0 Off | N/A |
| 68% 85C P2 136W / 250W | 8243MiB / 12196MiB | 95% Default |
+-------------------------------+----------------------+----------------------+

really high..

@fmassa
Copy link
Contributor

fmassa commented Nov 21, 2018

Oh, you should not be looking at the output from nvidia-smi, but instead to the value that is logged in the logger.
What actually happens is that PyTorch has a caching memory allocator for CUDA.
Which means that, once you create a tensor and destroys it, instead of giving the memory back to the cuda driver we instead keep it allocated. This way, if you request a similar tensor afterwards, we can directly return this cached memory.

The reason why we do it is because a cudaFree implies a synchronization point, which is very expensive. And the PyTorch design is to not pre-allocate all the memory beforehand, but let the user allocate it eagerly.

Let me know if you have more questions

@jario-jin
Copy link
Contributor

Oh, you should not be looking at the output from nvidia-smi, but instead to the value that is logged in the logger.
What actually happens is that PyTorch has a caching memory allocator for CUDA.
Which means that, once you create a tensor and destroys it, instead of giving the memory back to the cuda driver we instead keep it allocated. This way, if you request a similar tensor afterwards, we can directly return this cached memory.

The reason why we do it is because a cudaFree implies a synchronization point, which is very expensive. And the PyTorch design is to not pre-allocate all the memory beforehand, but let the user allocate it eagerly.

Let me know if you have more questions

Thanks for your explanation. The max mem in log file is about 3.5Gb.

@fmassa
Copy link
Contributor

fmassa commented Nov 21, 2018

Cool, let me know if something else is not clear. @XudongWang12Sigma does that answer your question as well?

@XudongWang12Sigma
Copy link
Author

Hi, thank you for your reply. In fact, when I was training Faster-RCNN using this repo https://github.com/jwyang/faster-rcnn.pytorch, I also use nvidia-smi to track the memory usage, but the memory usage appeared in terminal is about 2.5 G lower than your repo. I double checked my codes, I used 8 GPUs and default setting and COCO datasets, so, I am not sure whether I can do something to lower the memory usage, or lower to the level of Jwyang's repo in nividia-smi? Because I need to train some datasets with larger image size, it will be out of memory sometimes. Thanks

@fmassa
Copy link
Contributor

fmassa commented Nov 21, 2018

Hi,

One current place where the memory usage can be made more efficient is in the box IoU computation, see #18

A current workaround is to compute the iou on the CPU. Let me know if that addresses the issue for you, but I'll look in improving the memory usage of that function

@fmassa
Copy link
Contributor

fmassa commented Nov 21, 2018

Also, One thing to do is to print the memory used by the other repo by calling torch.cuda.max_memory_allocated(), this will give better metrics for the real memory usage

@hellock
Copy link

hellock commented Nov 27, 2018

@fmassa The actual memory required may be larger than torch.cuda.max_memory_allocated(), and it seems to be closer to torch.cuda.max_memory_cached(). (maybe between the two values)
If the GPU memory is 12G and max allocated memory is 5G when batch size is 16, then it is very likely to run out of memory when increasing the batch size to 32.

Here are some memory usage data captured from experiments.

max allocated max cached nvidia-smi
3.6 5.1 6.8
4.2 5.8 8.3

@fmassa
Copy link
Contributor

fmassa commented Nov 27, 2018

Hi @hellock,

Indeed, the memory allocated by the CUDA driver (which can go up to 1GB or more, and happens when you initialize cuda) is not counted in torch.cuda.max_memory_allocated(), which might explain the OOM that you experienced. But note that torch.cuda.max_memory_cached() doesn't count the CUDA memory driver either.

I still think though that the torch.cuda.max_memory_allocated() should be closer to the real amount of memory required by the algorithm.

Simple example

As a basic example (run in a new interpreter):

import torch
for i in range(10):
    a = torch.rand(1000, 1000, i + 1, device='cuda:0')
    print(torch.cuda.max_memory_allocated(), torch.cuda.max_memory_cached())

will print

4.7001953125 4.875
12.4501953125 12.625
20.0751953125 24.125
27.7001953125 39.5
35.3251953125 58.625
42.9501953125 81.625
50.5751953125 108.375
58.2001953125 139.0
65.8251953125 173.375
73.4501953125 211.625

while (again in a new interpreter)

import torch
for i in range(10):
    a = torch.rand(1000, 1000, i + 1, device='cuda:0')
    print(torch.cuda.max_memory_allocated(), torch.cuda.max_memory_cached())
    del a  # now a is not in scope anymore

gives

4.7001953125 4.875
8.5751953125 12.625
12.3251953125 24.125
16.2001953125 39.5
19.9501953125 58.625
23.8251953125 81.625
27.5751953125 108.375
31.4501953125 139.0
35.2001953125 173.375
39.0751953125 211.625

which indicates that what we might indeed want to log is torch.cuda.max_memory_allocated().

In the first example, because a is still alive before torch.rand returns a new tensor, it requires the memory from the previous tensor to still be allocated. The second example deletes the tensor right away, so that we indeed only track the total memory used.

In both cases though, we see that max_memory_cached() is not representative of what is being actually required by the system to run, as for example the last run creates a float32 tensor of size 1000 * 1000 * 10 elements, which corresponds to 38.15 MB.

Let me know what you think

@hellock
Copy link

hellock commented Nov 28, 2018

@fmassa Thanks for your reply. I agree that max_memory_allocated is what we want theoretically. In some simple cases, it can exactly reflect the memory usage. However, it is smaller than the real memory required when training models, even if we take the CUDA driver memory into account, which may be confusing for some users.
A more accurate indicator would be helpful, though may not be applicable now, to estimate the available memory and guide the design and hyper-parameter tuning.

@jario-jin
Copy link
Contributor

jario-jin commented Dec 25, 2018

When I am training on VisDrone Dataset (http://www.aiskyeye.com/). Maybe there are many objects on this Dataset.
I get Out-of-Memory error with max mem: 4885, and actually, I have 11G GPU memory.
ERROR as follows:
2018-12-25 15:25:13,756 maskrcnn_benchmark.trainer INFO: eta: 14:41:59 iter: 1280 loss: 1.3188 (1.3482) loss_box_reg: 0.2650 (0.2763) loss_rpn_box_reg: 0.3080 (0.2983) loss_classifier: 0.5226 (0.5861) time: 0.2926 (0.2961) loss_objectness: 0.1509 (0.1874) data: 0.0632 (0.0685) lr: 0.005000 max mem: 4885
2018-12-25 15:25:19,666 maskrcnn_benchmark.trainer INFO: eta: 14:41:52 iter: 1300 loss: 1.1438 (1.3447) loss_box_reg: 0.2482 (0.2759) loss_rpn_box_reg: 0.2550 (0.2976) loss_classifier: 0.4935 (0.5849) time: 0.2967 (0.2961) loss_objectness: 0.1087 (0.1864) data: 0.0646 (0.0686) lr: 0.005000 max mem: 4885
Traceback (most recent call last):
File "tools/train_net.py", line 171, in
main()
File "tools/train_net.py", line 164, in main
model = train(cfg, args.local_rank, args.distributed)
File "tools/train_net.py", line 73, in train
arguments,
File "/home/jario/spire-net-1812/maskrcnn-benchmark/maskrcnn_benchmark/engine/trainer.py", line 66, in do_train
loss_dict = model(images, targets)
File "/usr/local/lib/python2.7/dist-packages/torch/nn/modules/module.py", line 489, in call
result = self.forward(*input, **kwargs)
File "/usr/local/lib/python2.7/dist-packages/torch/nn/parallel/distributed.py", line 357, in forward
return self.module(*inputs[0], **kwargs[0])
File "/usr/local/lib/python2.7/dist-packages/torch/nn/modules/module.py", line 489, in call
result = self.forward(*input, **kwargs)
File "/home/jario/spire-net-1812/maskrcnn-benchmark/maskrcnn_benchmark/modeling/detector/generalized_rcnn.py", line 50, in forward
proposals, proposal_losses = self.rpn(images, features, targets)
File "/usr/local/lib/python2.7/dist-packages/torch/nn/modules/module.py", line 489, in call
result = self.forward(*input, **kwargs)
File "/home/jario/spire-net-1812/maskrcnn-benchmark/maskrcnn_benchmark/modeling/rpn/rpn.py", line 100, in forward
return self._forward_train(anchors, objectness, rpn_box_regression, targets)
File "/home/jario/spire-net-1812/maskrcnn-benchmark/maskrcnn_benchmark/modeling/rpn/rpn.py", line 119, in _forward_train
anchors, objectness, rpn_box_regression, targets
File "/home/jario/spire-net-1812/maskrcnn-benchmark/maskrcnn_benchmark/modeling/rpn/loss.py", line 91, in call
labels, regression_targets = self.prepare_targets(anchors, targets)
File "/home/jario/spire-net-1812/maskrcnn-benchmark/maskrcnn_benchmark/modeling/rpn/loss.py", line 55, in prepare_targets
anchors_per_image, targets_per_image
File "/home/jario/spire-net-1812/maskrcnn-benchmark/maskrcnn_benchmark/modeling/rpn/loss.py", line 37, in match_targets_to_anchors
match_quality_matrix = boxlist_iou(target, anchor)
File "/home/jario/spire-net-1812/maskrcnn-benchmark/maskrcnn_benchmark/structures/boxlist_ops.py", line 84, in boxlist_iou
wh = (rb - lt + TO_REMOVE).clamp(min=0) # [N,M,2]
RuntimeError: CUDA out of memory. Tried to allocate 1.76 GiB (GPU 0; 10.92 GiB total capacity; 5.17 GiB already allocated; 776.81 MiB free; 106.69 MiB cached)

MY ENVIRONMENT as follows:
PyTorch version: 1.0.0
Is debug build: No
CUDA used to build PyTorch: 10.0.130

OS: Ubuntu 18.04.1 LTS
GCC version: (Ubuntu 7.3.0-27ubuntu1~18.04) 7.3.0
CMake version: Could not collect

Python version: 2.7
Is CUDA available: Yes
CUDA runtime version: Could not collect
GPU models and configuration:
GPU 0: GeForce GTX 1080 Ti
GPU 1: TITAN X (Pascal)

Nvidia driver version: 410.48
cuDNN version: Probably one of the following:
/usr/local/cuda-10.0/targets/x86_64-linux/lib/libcudnn.so
/usr/local/cuda-10.0/targets/x86_64-linux/lib/libcudnn.so.7
/usr/local/cuda-10.0/targets/x86_64-linux/lib/libcudnn.so.7.4.1
/usr/local/cuda-10.0/targets/x86_64-linux/lib/libcudnn_static.a

Versions of relevant libraries:
[pip] numpy (1.15.4)
[pip] torch (1.0.0)
[pip] torchvision (0.2.1)
[conda] Could not collect
Pillow (5.3.0)
2018-12-25 15:18:49,131 maskrcnn_benchmark INFO: Loaded configuration file /home/jario/spire-net-1812/exps/visdrone_baseline_s4/e2e_faster_rcnn_R_50_FPN_1x.yaml
2018-12-25 15:18:49,131 maskrcnn_benchmark INFO:
MODEL:
META_ARCHITECTURE: "GeneralizedRCNN"
WEIGHT: "catalog://ModelDir/R-50.pkl"
BACKBONE:
CONV_BODY: "R-50-FPN"
OUT_CHANNELS: 256
USE_LIGHT_HEAD: False
RPN:
USE_FPN: True
ANCHOR_STRIDE: (4, 8, 16, 32, 64)
PRE_NMS_TOP_N_TRAIN: 2000
PRE_NMS_TOP_N_TEST: 1000
POST_NMS_TOP_N_TEST: 1000
FPN_POST_NMS_TOP_N_TEST: 1000
ROI_HEADS:
USE_FPN: True
DETECTIONS_PER_IMG: 500
ROI_BOX_HEAD:
POOLER_RESOLUTION: 7
POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125)
POOLER_SAMPLING_RATIO: 2
FEATURE_EXTRACTOR: "FPN2MLPFeatureExtractor"
PREDICTOR: "FPNPredictor"
DATASETS:
TRAIN: ("BB180913_vis_drone_train",) # ("coco_2014_train", "coco_2014_valminusminival")
TEST: ("BB180913_vis_drone_val",) # ("coco_2014_minival",)
DATALOADER:
SIZE_DIVISIBILITY: 32
SOLVER:
BASE_LR: 0.005 # 0.0025 * Num-of-GPUs (e.g. 8 GPUs 0.02)
WEIGHT_DECAY: 0.0001
STEPS: (120000, 160000) # (480000, 640000) / Num-of-GPUs (e.g. 8 GPUs (60000, 80000))
MAX_ITER: 180000 # 720000 / Num-of-GPUs (e.g. 8 GPUs 90000)
IMS_PER_BATCH: 2 # 2 * Num-of-GPUs (e.g. 8 GPUs 16)
TEST:
IMS_PER_BATCH: 2

@fmassa
Copy link
Contributor

fmassa commented Dec 25, 2018

@jario-jin There is probably an image in your dataset which contains a lot of objects.
I'd recommend checking the answer in #18 (comment) to avoid the OOM error. The easiest would be to run the box_iou computation on the CPU.

@qianyizhang
Copy link
Contributor

any progress on making a more accurate memory indicator?
perhaps we should raise a feature request in the NVIDIA community forum/github about the CUDA driver memory usage?

@fmassa
Copy link
Contributor

fmassa commented Jan 28, 2019

@qianyizhang we have recently discovered that, for newer generations of GPUs, the allocations via cudaMalloc seems to be rounding up to blocks of 2MB (which could potentially be split by cuda afterwards). This was not being taken into account by PyTorch caching allocator, so whenever we allocated 1048576 + 1 bytes (1 MB + 1 byte), cudaMalloc would allocate 2MB.

Here is an example (which should allocate 10 GB):

import torch
for _ in range(10000):
    a.append(torch.empty(1048576 + 1, dtype=torch.uint8, device='cuda'))

but instead we would have an error as follows:

RuntimeError: CUDA out of memory. Tried to allocate 1.12 MiB (GPU 0; 15.90 GiB total capacity; 8.55 GiB already allocated; 1.56 MiB free; 0 bytes cached)

I believe there is a patch being worked on PyTorch side which should reduce this gap, and it is currently being tested to see if there are no adverse effects.

@fmassa
Copy link
Contributor

fmassa commented Feb 14, 2019

Here is a PR that improves the memory reporting in PyTorch. I should allow fitting larger batch sizes with maskrcnn-benchmark pytorch/pytorch#17120

@XudongWang12Sigma
Copy link
Author

Thanks! Will try

@soumith
Copy link
Member

soumith commented Mar 13, 2019

just fyi, the PR is now merged, and is part of pytorch-nightly

@yxchng
Copy link

yxchng commented Apr 9, 2019

@fmassa this post seems to be talking about why the logged memory is more accurate, which is pointless. Because what nvidia-smi reflects is the actual memory needed to run the model (whether the code gets the out of memory error depends on this). So I am wondering if there is a way to make the model runs with less nvidia-smi memory? Right now, I am going out of memory with batch size of 1 on a 2080 ti, which is ridiculous? Or is it normal that the model in this repo is so memory intensive?

@fmassa
Copy link
Contributor

fmassa commented Apr 9, 2019

@yxchng the model is generally not memory intensive, the values reported by nvidia-smi are not representative of the total amount of memory required by the model.

If you are running out of memory, maybe you have a lot (> 50) ground-truth boxes in your image and the IoU computation is a bottleneck? Check #18

@yxchng
Copy link

yxchng commented Apr 9, 2019

@fmassa I am using the COCO 2017 dataset with config e2e_mask_rcnn_X_101_32x8d_FPN_1x.yaml to test. Do you think it should goes out of memory with batch size 1 on 11gb of memory? What should be the expected memory consumption and for what batch size? The documentation on model zoo seems to suggest that it only uses 7gb of memory for batch size of 8? However, I couldn't get that result and it is not even close. Do you have any idea why?

@fmassa
Copy link
Contributor

fmassa commented Apr 9, 2019

@yxchng is this during training or during testing? Also, the memory reported is per GPU, so if there is 1 image per GPU, then you should consider 7GB per image.
Also, you need to change the configs to use a smaller global batch size if using a single GPU, and adapt the learning rates, learning rate schedules etc accordingly, following the instructions in the README

@ShihuaiXu
Copy link

@fmassa
When I test my segementation model, I have to set the batchsize to 1, because one image use 7GB memory. It is normal?

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

8 participants