Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inconsistent recovery from CUDA OOMs #18853

Closed
stephenroller opened this issue Apr 4, 2019 · 25 comments
Closed

Inconsistent recovery from CUDA OOMs #18853

stephenroller opened this issue Apr 4, 2019 · 25 comments
Assignees
Labels
has workaround high priority module: autograd Related to torch.autograd, and the autograd engine in general module: memory usage PyTorch is using more memory than it should, or it is leaking memory oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@stephenroller
Copy link
Contributor

stephenroller commented Apr 4, 2019

🐛 Bug

Editorial note: Make sure that you are not holding on to tensors via an exception object (which contains the stack trace and will retain tensors). Do not bring the exception object into scope, or do error recovery outside of the catch block.

Catching a RuntimeError on a CUDA OOM should allow one to gracefully recover, for example by lowering the batchsize. This is particularly important when using DistributedDataParallel, where workers must sync on backward, and so it's important that we be able to perform a "dummy batch" after an OOM in order to stay in sync with other workers.

Observed behavior during a CUDA out of memory event is inconsistent across nondistributed/dataparallel/distributeddataparallel. Expected behavior is that all modes should be able to recover easily.

To Reproduce

Test case and logs available here:
https://gist.github.com/stephenroller/bd2cd644e7c117c1ec8192639ecf30b6

Steps to reproduce the behavior:

  1. Download memtestcase.py and run.sh
  2. Run run.sh. Observe which test cases pass.

Logs from several environments on the fair cluster:

  1. pytorch 1.0.0 (@klshuster's env) (pytorch 1.0.0, cuda 9.0.176): https://gist.github.com/stephenroller/bd2cd644e7c117c1ec8192639ecf30b6#file-kurtlog-pytorch-1-0-0-cuda-9-0-176
  2. fairseqenv (@myleott's) (pytorch 1.0.0.dev20190211, cuda 10.0.130): https://gist.github.com/stephenroller/bd2cd644e7c117c1ec8192639ecf30b6#file-fairseqenv-pytorch-1-0-0-dev20190211-cuda-10-0-130
  3. pytorch stable env (@stephenroller's): https://gist.github.com/stephenroller/bd2cd644e7c117c1ec8192639ecf30b6#file-stable-env-pytorch-1-0-1-post2-cuda-10-0-130

Expected behavior

All test cases pass. At the very least, test cases should produce consistent results across all --modes.

Environment

Here's the environment from the third log (@stephenroller's) (pytorch 1.0.1.post2, cuda 10.0.130):

Collecting environment information...
PyTorch version: 1.0.1.post2
Is debug build: No
CUDA used to build PyTorch: 10.0.130

OS: Ubuntu 18.04.1 LTS
GCC version: (Ubuntu 7.3.0-27ubuntu1~18.04) 7.3.0
CMake version: version 3.10.2

Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: 10.0.130
GPU models and configuration:
GPU 0: Quadro GP100
GPU 1: Quadro GP100

Nvidia driver version: 410.79
cuDNN version: Could not collect

Versions of relevant libraries:
[pip] numpy==1.16.2
[pip] pytorch-pretrained-bert==0.6.1
[pip] torch==1.0.1.post2
[pip] torchtext==0.3.1
[conda] blas                      1.0                         mkl
[conda] mkl                       2019.3                      199
[conda] mkl_fft                   1.0.10           py37ha843d7b_0
[conda] mkl_random                1.0.2            py37hd81dba3_0
[conda] pytorch                   1.0.1           py3.7_cuda10.0.130_cudnn7.4.2_2    pytorch
[conda] pytorch-pretrained-bert   0.6.1                     <pip>
[conda] torchtext                 0.3.1                     <pip>

cc @ezyang @gchanan @zou3519 @ssnl @albanD @gqchen @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @aazzolini @xush6528 @osalpekar

@stephenroller
Copy link
Contributor Author

And again with nightly, where all test cases fail:

Thu Apr  4 10:20:07 2019       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 410.79       Driver Version: 410.79       CUDA Version: 10.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|===============================+======================+======================|
|   0  Quadro GP100        Off  | 00000000:81:00.0 Off |                    0 |
| 32%   41C    P0    35W / 235W |      0MiB / 16278MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
|   1  Quadro GP100        Off  | 00000000:82:00.0 Off |                    0 |
| 28%   40C    P0    34W / 235W |      0MiB / 16278MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+

=======================================================================
Activating pytorch nightly
=======================================================================

Collect env
------------------------------------------------------------
Collecting environment information...
PyTorch version: 1.0.0.dev20190404
Is debug build: No
CUDA used to build PyTorch: 10.0.130

OS: Ubuntu 18.04.1 LTS
GCC version: (Ubuntu 7.3.0-27ubuntu1~18.04) 7.3.0
CMake version: version 3.10.2

Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: 10.0.130
GPU models and configuration: 
GPU 0: Quadro GP100
GPU 1: Quadro GP100

Nvidia driver version: 410.79
cuDNN version: Could not collect

Versions of relevant libraries:
[pip] numpy==1.16.2
[pip] torch==1.0.0.dev20190404
[conda] blas                      1.0                         mkl  
[conda] mkl                       2019.3                      199  
[conda] mkl_fft                   1.0.10           py37ha843d7b_0  
[conda] mkl_random                1.0.2            py37hd81dba3_0  
[conda] pytorch-nightly           1.0.0.dev20190404 py3.7_cuda10.0.130_cudnn7.4.2_0    pytorch

Running mode=single
------------------------------------------------------------
Conda PREFIX: /private/home/roller/.conda/envs/nightly
Torch version: 1.0.0.dev20190404
CUDA version: 10.0.130
Using a single GPU
Step bs= 8192
  Forward with bs  =   8192
  Backward with bs =   8192
FW/BW succeeded. Doubling BS
Step bs= 16384
  Forward with bs  =  16384
  Backward with bs =  16384
FW/BW succeeded. Doubling BS
Step bs= 32768
  Forward with bs  =  32768
  Backward with bs =  32768
FW/BW succeeded. Doubling BS
Step bs= 65536
  Forward with bs  =  65536
OOM #1! Running through a tiny batch to catch up worker
  Forward with bs  =      2
  Backward with bs =      2
Traceback (most recent call last):
  File "memtestcase.py", line 101, in run_trial
    fwbw(model, bs)
  File "memtestcase.py", line 63, in fwbw
    yhat = model(X)
  File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/modules/module.py", line 491, in __call__
    result = self.forward(*input, **kwargs)
  File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/modules/container.py", line 97, in forward
    input = module(input)
  File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/modules/module.py", line 491, in __call__
    result = self.forward(*input, **kwargs)
  File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/modules/linear.py", line 67, in forward
    return F.linear(input, self.weight, self.bias)
  File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/functional.py", line 1400, in linear
    ret = torch.addmm(bias, input, weight.t())
RuntimeError: CUDA out of memory. Tried to allocate 1024.00 MiB (GPU 0; 15.90 GiB total capacity; 15.25 GiB already allocated; 9.56 MiB free; 1.59 MiB cached)

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "memtestcase.py", line 139, in <module>
    main()
  File "memtestcase.py", line 134, in main
    run_trial(args)
  File "memtestcase.py", line 113, in run_trial
    fwbw(model, 2)
  File "memtestcase.py", line 68, in fwbw
    loss.backward()
  File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/tensor.py", line 107, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/autograd/__init__.py", line 93, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: CUDA out of memory. Tried to allocate 64.00 MiB (GPU 0; 15.90 GiB total capacity; 15.25 GiB already allocated; 9.56 MiB free; 989.50 KiB cached)

Running mode=dp
------------------------------------------------------------
Conda PREFIX: /private/home/roller/.conda/envs/nightly
Torch version: 1.0.0.dev20190404
CUDA version: 10.0.130
Wrapping in DataParallel
Step bs= 8192
  Forward with bs  =   8192
  Backward with bs =   8192
FW/BW succeeded. Doubling BS
Step bs= 16384
  Forward with bs  =  16384
  Backward with bs =  16384
FW/BW succeeded. Doubling BS
Step bs= 32768
  Forward with bs  =  32768
  Backward with bs =  32768
FW/BW succeeded. Doubling BS
Step bs= 65536
  Forward with bs  =  65536
  Backward with bs =  65536
FW/BW succeeded. Doubling BS
Step bs= 131072
  Forward with bs  = 131072
OOM #1! Running through a tiny batch to catch up worker
  Forward with bs  =      2
Traceback (most recent call last):
  File "memtestcase.py", line 101, in run_trial
    fwbw(model, bs)
  File "memtestcase.py", line 63, in fwbw
    yhat = model(X)
  File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/modules/module.py", line 491, in __call__
    result = self.forward(*input, **kwargs)
  File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 152, in forward
    outputs = self.parallel_apply(replicas, inputs, kwargs)
  File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 162, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
  File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 83, in parallel_apply
    raise output
  File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 59, in _worker
    output = module(*input, **kwargs)
  File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/modules/module.py", line 491, in __call__
    result = self.forward(*input, **kwargs)
  File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/modules/container.py", line 97, in forward
    input = module(input)
  File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/modules/module.py", line 491, in __call__
    result = self.forward(*input, **kwargs)
  File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/modules/activation.py", line 50, in forward
    return F.threshold(input, self.threshold, self.value, self.inplace)
  File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/functional.py", line 915, in threshold
    result = _VF.threshold(input, threshold, value)
RuntimeError: CUDA out of memory. Tried to allocate 1024.00 MiB (GPU 0; 15.90 GiB total capacity; 14.25 GiB already allocated; 981.56 MiB free; 1.59 MiB cached)

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "memtestcase.py", line 139, in <module>
    main()
  File "memtestcase.py", line 134, in main
    run_trial(args)
  File "memtestcase.py", line 113, in run_trial
    fwbw(model, 2)
  File "memtestcase.py", line 63, in fwbw
    yhat = model(X)
  File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/modules/module.py", line 491, in __call__
    result = self.forward(*input, **kwargs)
  File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 151, in forward
    replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
  File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 156, in replicate
    return replicate(module, device_ids)
  File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/parallel/replicate.py", line 281, in replicate
    param_copies = _broadcast_coalesced_reshape(params, devices, detach)
  File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/parallel/replicate.py", line 207, in _broadcast_coalesced_reshape
    tensor_copies = Broadcast.apply(device_group, *tensor_group)
  File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/parallel/_functions.py", line 21, in forward
    outputs = comm.broadcast_coalesced(inputs, ctx.target_gpus)
  File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/cuda/comm.py", line 39, in broadcast_coalesced
    return torch._C._broadcast_coalesced(tensors, devices, buffer_size)
RuntimeError: CUDA out of memory. Tried to allocate 66.00 MiB (GPU 1; 15.90 GiB total capacity; 15.13 GiB already allocated; 57.56 MiB free; 29.67 MiB cached) (malloc at /opt/conda/conda-bld/pytorch-nightly_1554354360325/work/c10/cuda/CUDACachingAllocator.cpp:267)
frame #0: c10::Error::Error(c10::SourceLocation, std::string const&) + 0x45 (0x7efe8e6409c5 in /private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/lib/libc10.so)
frame #1: <unknown function> + 0x16971 (0x7efe869da971 in /private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/lib/libc10_cuda.so)
frame #2: <unknown function> + 0x17007 (0x7efe869db007 in /private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/lib/libc10_cuda.so)
frame #3: at::native::empty_cuda(c10::ArrayRef<long>, c10::TensorOptions const&) + 0x9e5 (0x7efe945b0935 in /private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/lib/libcaffe2_gpu.so)
frame #4: at::CUDAFloatType::empty(c10::ArrayRef<long>, c10::TensorOptions const&) const + 0x4d (0x7efe92d959bd in /private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/lib/libcaffe2_gpu.so)
frame #5: torch::autograd::VariableType::empty(c10::ArrayRef<long>, c10::TensorOptions const&) const + 0x1c1 (0x7efe87587ec1 in /private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/lib/libtorch.so.1)
frame #6: torch::cuda::broadcast(at::Tensor const&, c10::ArrayRef<long>) + 0x594 (0x7efeb89d3c44 in /private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #7: torch::cuda::broadcast_coalesced(c10::ArrayRef<at::Tensor>, c10::ArrayRef<long>, unsigned long) + 0x6f6 (0x7efeb89d47d6 in /private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #8: <unknown function> + 0x5b975e (0x7efeb89d875e in /private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #9: <unknown function> + 0x129586 (0x7efeb8548586 in /private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
<omitting python frames>
frame #18: THPFunction_apply(_object*, _object*) + 0x5e9 (0x7efeb87c1d69 in /private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #54: __libc_start_main + 0xe7 (0x7efec7330b97 in /lib/x86_64-linux-gnu/libc.so.6)


Running mode=ddp_single
------------------------------------------------------------
Conda PREFIX: /private/home/roller/.conda/envs/nightly
Torch version: 1.0.0.dev20190404
CUDA version: 10.0.130
Using a single GPU in distributed (equiv to 1 proc per gpu)
Step bs= 8192
  Forward with bs  =   8192
  Backward with bs =   8192
FW/BW succeeded. Doubling BS
Step bs= 16384
  Forward with bs  =  16384
  Backward with bs =  16384
FW/BW succeeded. Doubling BS
Step bs= 32768
  Forward with bs  =  32768
  Backward with bs =  32768
FW/BW succeeded. Doubling BS
Step bs= 65536
  Forward with bs  =  65536
OOM #1! Running through a tiny batch to catch up worker
  Forward with bs  =      2
  Backward with bs =      2
Traceback (most recent call last):
  File "memtestcase.py", line 101, in run_trial
    fwbw(model, bs)
  File "memtestcase.py", line 63, in fwbw
    yhat = model(X)
  File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/modules/module.py", line 491, in __call__
    result = self.forward(*input, **kwargs)
  File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/modules/container.py", line 97, in forward
    input = module(input)
  File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/modules/module.py", line 491, in __call__
    result = self.forward(*input, **kwargs)
  File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/modules/linear.py", line 67, in forward
    return F.linear(input, self.weight, self.bias)
  File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/functional.py", line 1400, in linear
    ret = torch.addmm(bias, input, weight.t())
RuntimeError: CUDA out of memory. Tried to allocate 1024.00 MiB (GPU 0; 15.90 GiB total capacity; 15.25 GiB already allocated; 9.56 MiB free; 1.59 MiB cached)

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "memtestcase.py", line 139, in <module>
    main()
  File "memtestcase.py", line 134, in main
    run_trial(args)
  File "memtestcase.py", line 113, in run_trial
    fwbw(model, 2)
  File "memtestcase.py", line 68, in fwbw
    loss.backward()
  File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/tensor.py", line 107, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/autograd/__init__.py", line 93, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: CUDA out of memory. Tried to allocate 64.00 MiB (GPU 0; 15.90 GiB total capacity; 15.25 GiB already allocated; 9.56 MiB free; 989.50 KiB cached)

Running mode=ddp_multi
------------------------------------------------------------
Conda PREFIX: /private/home/roller/.conda/envs/nightly
Torch version: 1.0.0.dev20190404
CUDA version: 10.0.130
Wrapping in DistributedDataParallel (equiv to 1 proc per node)
Step bs= 8192
  Forward with bs  =   8192
Traceback (most recent call last):
  File "memtestcase.py", line 139, in <module>
    main()
  File "memtestcase.py", line 134, in main
    run_trial(args)
  File "memtestcase.py", line 107, in run_trial
    raise rerr
  File "memtestcase.py", line 101, in run_trial
    fwbw(model, bs)
  File "memtestcase.py", line 63, in fwbw
    yhat = model(X)
  File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/modules/module.py", line 491, in __call__
    result = self.forward(*input, **kwargs)
  File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/parallel/distributed.py", line 364, in forward
    self._sync_params()
  File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/parallel/distributed.py", line 396, in _sync_params
    param_data.set_(tensor)
RuntimeError: set_storage is not allowed on Tensor created from .data or .detach()

@colesbury
Copy link
Member

This is quite serious and should be dealt with before the next PyTorch release. It appears to be a leak in the autograd engine -- note that if you remove the loss.backward() line in @stephenroller's test case all the tests pass (except for ddp_multi which has a different error that also should be fixed).

Basically, if you're lucky and the OOM happens during forward it's recoverable. If it happens during backward() it leaks memory and is effectively unrecoverable. Unclear if this is a regression or also a bug in older versions of PyTorch.

I would suspect that some tasks live on in some of the Engine state. (Are the ready_queues empty after an error?)

@fmassa fmassa added high priority module: autograd Related to torch.autograd, and the autograd engine in general module: memory usage PyTorch is using more memory than it should, or it is leaking memory triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Apr 4, 2019
@fmassa fmassa added this to the 1.1 milestone Apr 4, 2019
@stephenroller
Copy link
Contributor Author

TODO: looks like I messed up the one-proc-per-gpu case.

@stephenroller
Copy link
Contributor Author

Updated scripts below:

run.sh unchanged, but pasted here:

#!/bin/bash

nvidia-smi

echo "Collect env"
echo "------------------------------------------------------------"
python collect_env.py
echo

for mode in single dp ddp_single ddp_multi
do
    echo "Running mode=$mode"
    echo "------------------------------------------------------------"
    python -u memtestcase.py --mode=$mode 2>&1
    echo
done

memtestcase.py slightly changed:

#!/usr/bin/env python

#SBATCH --gres=gpu:2
#SBATCH --job-name=distributed_example
#SBATCH --partition=dev
#SBATCH --nodes=1
#SBATCH --time=0:10:00
#SBATCH --ntasks-per-node=1
#SBATCH --mem=8G
#SBATCH --cpus-per-task=10

import os
import argparse
import torch
import torch.nn as nn
import torch.distributed as td
import torch.nn.parallel as tp

START_BS = 8 * 1024

# these don't matter, just constants meant to be a "big" model
INPUT_SIZE = 8192
HID_SIZE = 4096
LAYERS = 8
OUT_CLASSES = 4


def wrap_dp(model):
    return tp.DataParallel(model)


def wrap_ddp_multi(model):
    td.init_process_group(
        backend='nccl',
        init_method='tcp://localhost:61337',
        rank=0,
        world_size=1
    )
    model = tp.DistributedDataParallel(
        model,
        device_ids=None,
        broadcast_buffers=False,
    )
    return model


def wrap_ddp_single(model):
    td.init_process_group(
        backend='nccl',
        init_method='tcp://localhost:61337',
        rank=0,
        world_size=1
    )
    model = tp.DistributedDataParallel(
        model,
        device_ids=[0],
        broadcast_buffers=False,
    )
    return model


def create_model(args):
    model = nn.Sequential(
        nn.Linear(INPUT_SIZE, HID_SIZE),
        nn.ReLU(),
    )
    for i in range(LAYERS):
        model.add_module('hidd' + str(i), nn.Linear(HID_SIZE, HID_SIZE))
        model.add_module('relu' + str(i), nn.ReLU())
    model.add_module('output', nn.Linear(HID_SIZE, OUT_CLASSES))
    return model


def fwbw(model, bs):
    print('  Forward with bs  = {:-6d}'.format(bs))
    X = torch.randn(bs, INPUT_SIZE).cuda()
    torch.cuda.synchronize()
    yhat = model(X)
    torch.cuda.synchronize()
    loss = yhat.sum()
    torch.cuda.synchronize()
    print('  Backward with bs = {:-6d}'.format(bs))
    loss.backward()
    torch.cuda.synchronize()
    model.zero_grad()
    torch.cuda.synchronize()


def run_trial(args):
    print('Conda PREFIX:', os.environ['CONDA_PREFIX'])
    print('Torch version:', torch.version.__version__)
    print('CUDA version:', torch.version.cuda)

    model = create_model(args).cuda()
    if args.mode == 'dp':
        print('Wrapping in DataParallel')
        model = wrap_dp(model)
    elif args.mode == 'ddp_multi':
        print('Wrapping in DistributedDataParallel (equiv to 1 proc per node)')
        model = wrap_ddp_multi(model)
    elif args.mode == 'ddp_single':
        print('Using a single GPU in distributed (equiv to 1 proc per gpu)')
        torch.cuda.set_device(0)
        model = wrap_ddp_single(model)
    elif args.mode == 'single':
        print('Using a single GPU')
        pass
    else:
        raise ValueError('--mode wrong')

    bs = args.bs
    times_oomed = 0
    while times_oomed < args.ooms:
        # continuously double the batch size until we OOM
        try:
            print('Step bs=', bs)
            fwbw(model, bs)
            print('FW/BW succeeded. Doubling BS')
            bs *= 2
        except RuntimeError as rerr:
            if 'memory' not in str(rerr):
                # not the exception we wanted
                raise rerr
            # okay, we found the memory error! Now try to run a NOOP pass
            # for DDP nodes. Production example here:
            # https://github.com/pytorch/fairseq/blob/3658fa329b8cb987d951b2e38ec86c44b9e1fea5/fairseq/trainer.py#L361-L368
            times_oomed += 1
            print('OOM #{}! Running through a tiny batch to catch up worker'.format(times_oomed))
            fwbw(model, 2)
            print('Succeeded on the oom batch.')
            # start the doubling procedure again
            bs = args.bs


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--mode', default='ddp', choices=('dp', 'ddp_multi', 'ddp_single', 'single'),
        help='DataParallel, DistributedDataParallel, or single gpu'
    )
    parser.add_argument(
        '--ooms', default=1, type=int,
        help='Number of times to OOM'
    )
    parser.add_argument(
        '--bs', default=START_BS, type=int,
        help='Initial batch size',
    )
    args = parser.parse_args()
    run_trial(args)
    print('Test passed.')


if __name__ == '__main__':
    main()

Log (1.0.1post2):

devfair0237 parlai hybriddp » bash run.sh
Mon Apr  8 11:34:40 2019
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 410.79       Driver Version: 410.79       CUDA Version: 10.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|===============================+======================+======================|
|   0  Quadro GP100        Off  | 00000000:81:00.0 Off |                    0 |
| 26%   38C    P0    34W / 235W |      0MiB / 16278MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
|   1  Quadro GP100        Off  | 00000000:82:00.0 Off |                    0 |
| 26%   38C    P0    34W / 235W |      0MiB / 16278MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+
Collect env
------------------------------------------------------------
Collecting environment information...
PyTorch version: 1.0.1.post2
Is debug build: No
CUDA used to build PyTorch: 10.0.130

OS: Ubuntu 18.04.1 LTS
GCC version: (Ubuntu 7.3.0-27ubuntu1~18.04) 7.3.0
CMake version: version 3.10.2

Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: 10.0.130
GPU models and configuration:
GPU 0: Quadro GP100
GPU 1: Quadro GP100

Nvidia driver version: 410.79
cuDNN version: Could not collect

Versions of relevant libraries:
[pip] numpy==1.16.2
[pip] pytorch-pretrained-bert==0.6.1
[pip] torch==1.0.1.post2
[pip] torchtext==0.3.1
[conda] blas                      1.0                         mkl
[conda] mkl                       2019.3                      199
[conda] mkl_fft                   1.0.10           py37ha843d7b_0
[conda] mkl_random                1.0.2            py37hd81dba3_0
[conda] pytorch                   1.0.1           py3.7_cuda10.0.130_cudnn7.4.2_2    pytorch
[conda] pytorch-pretrained-bert   0.6.1                     <pip>
[conda] torchtext                 0.3.1                     <pip>

Running mode=single
------------------------------------------------------------
Conda PREFIX: /private/home/roller/.conda/envs/retry-20190211
Torch version: 1.0.1.post2
CUDA version: 10.0.130
Using a single GPU
Step bs= 8192
  Forward with bs  =   8192
  Backward with bs =   8192
FW/BW succeeded. Doubling BS
Step bs= 16384
  Forward with bs  =  16384
  Backward with bs =  16384
FW/BW succeeded. Doubling BS
Step bs= 32768
  Forward with bs  =  32768
  Backward with bs =  32768
FW/BW succeeded. Doubling BS
Step bs= 65536
  Forward with bs  =  65536
OOM #1! Running through a tiny batch to catch up worker
  Forward with bs  =      2
  Backward with bs =      2
Traceback (most recent call last):
  File "memtestcase.py", line 117, in run_trial
    fwbw(model, bs)
  File "memtestcase.py", line 78, in fwbw
    yhat = model(X)
  File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/nn/modules/module.py", line 489, in __call__
    result = self.forward(*input, **kwargs)
  File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/nn/modules/container.py", line 92, in forward
    input = module(input)
  File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/nn/modules/module.py", line 489, in __call__
    result = self.forward(*input, **kwargs)
  File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/nn/modules/linear.py", line 67, in forward
    return F.linear(input, self.weight, self.bias)
  File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/nn/functional.py", line 1352, in linear
    ret = torch.addmm(torch.jit._unwrap_optional(bias), input, weight.t())
RuntimeError: CUDA out of memory. Tried to allocate 1024.00 MiB (GPU 0; 15.90 GiB total capacity; 15.25 GiB already allocated; 25.56 MiB free; 607.00 KiB cached)

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "memtestcase.py", line 155, in <module>
    main()
  File "memtestcase.py", line 150, in main
    run_trial(args)
  File "memtestcase.py", line 129, in run_trial
    fwbw(model, 2)
  File "memtestcase.py", line 83, in fwbw
    loss.backward()
  File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/tensor.py", line 102, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/autograd/__init__.py", line 90, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: CUDA out of memory. Tried to allocate 64.00 MiB (GPU 0; 15.90 GiB total capacity; 15.25 GiB already allocated; 25.56 MiB free; 989.50 KiB cached)

Running mode=dp
------------------------------------------------------------
Conda PREFIX: /private/home/roller/.conda/envs/retry-20190211
Torch version: 1.0.1.post2
CUDA version: 10.0.130
Wrapping in DataParallel
Step bs= 8192
  Forward with bs  =   8192
  Backward with bs =   8192
FW/BW succeeded. Doubling BS
Step bs= 16384
  Forward with bs  =  16384
  Backward with bs =  16384
FW/BW succeeded. Doubling BS
Step bs= 32768
  Forward with bs  =  32768
  Backward with bs =  32768
FW/BW succeeded. Doubling BS
Step bs= 65536
  Forward with bs  =  65536
  Backward with bs =  65536
FW/BW succeeded. Doubling BS
Step bs= 131072
  Forward with bs  = 131072
OOM #1! Running through a tiny batch to catch up worker
  Forward with bs  =      2
Traceback (most recent call last):
  File "memtestcase.py", line 117, in run_trial
    fwbw(model, bs)
  File "memtestcase.py", line 78, in fwbw
    yhat = model(X)
  File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/nn/modules/module.py", line 489, in __call__
    result = self.forward(*input, **kwargs)
  File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 143, in forward
    outputs = self.parallel_apply(replicas, inputs, kwargs)
  File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 153, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
  File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 83, in parallel_apply
    raise output
  File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 59, in _worker
    output = module(*input, **kwargs)
  File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/nn/modules/module.py", line 489, in __call__
    result = self.forward(*input, **kwargs)
  File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/nn/modules/container.py", line 92, in forward
    input = module(input)
  File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/nn/modules/module.py", line 489, in __call__
    result = self.forward(*input, **kwargs)
  File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/nn/modules/activation.py", line 50, in forward
    return F.threshold(input, self.threshold, self.value, self.inplace)
  File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/nn/functional.py", line 840, in threshold
    result = _VF.threshold(input, threshold, value)
RuntimeError: CUDA out of memory. Tried to allocate 1024.00 MiB (GPU 0; 15.90 GiB total capacity; 14.25 GiB already allocated; 997.56 MiB free; 607.00 KiB cached)

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "memtestcase.py", line 155, in <module>
    main()
  File "memtestcase.py", line 150, in main
    run_trial(args)
  File "memtestcase.py", line 129, in run_trial
    fwbw(model, 2)
  File "memtestcase.py", line 78, in fwbw
    yhat = model(X)
  File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/nn/modules/module.py", line 489, in __call__
    result = self.forward(*input, **kwargs)
  File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 142, in forward
    replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
  File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 147, in replicate
    return replicate(module, device_ids)
  File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/nn/parallel/replicate.py", line 13, in replicate
    param_copies = Broadcast.apply(devices, *params)
  File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/nn/parallel/_functions.py", line 21, in forward
    outputs = comm.broadcast_coalesced(inputs, ctx.target_gpus)
  File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/cuda/comm.py", line 40, in broadcast_coalesced
    return torch._C._broadcast_coalesced(tensors, devices, buffer_size)
RuntimeError: CUDA out of memory. Tried to allocate 64.12 MiB (GPU 1; 15.90 GiB total capacity; 15.19 GiB already allocated; 9.56 MiB free; 911.50 KiB cached) (malloc at /opt/conda/conda-bld/pytorch
_1549636813070/work/aten/src/THC/THCCachingAllocator.cpp:231)
frame #0: c10::Error::Error(c10::SourceLocation, std::string const&) + 0x45 (0x7fae1f0f1cf5 in /private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/lib/libc10.so)
frame #1: <unknown function> + 0x1239bc1 (0x7fae233d3bc1 in /private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/lib/libcaffe2_gpu.so)
frame #2: <unknown function> + 0x123a53a (0x7fae233d453a in /private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/lib/libcaffe2_gpu.so)
frame #3: at::native::empty_cuda(c10::ArrayRef<long>, at::TensorOptions const&) + 0x2d6 (0x7fae24a3edb6 in /private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/lib/libca
ffe2_gpu.so)
frame #4: at::CUDAFloatType::empty(c10::ArrayRef<long>, at::TensorOptions const&) const + 0x161 (0x7fae232f2311 in /private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/l
ib/libcaffe2_gpu.so)
frame #5: torch::autograd::VariableType::empty(c10::ArrayRef<long>, at::TensorOptions const&) const + 0x179 (0x7fae1832a209 in /private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-pack
ages/torch/lib/libtorch.so.1)
frame #6: torch::cuda::broadcast(at::Tensor const&, c10::ArrayRef<long>) + 0x545 (0x7fae467c3725 in /private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/lib/libtorch_pyt
hon.so)
frame #7: torch::cuda::broadcast_coalesced(c10::ArrayRef<at::Tensor>, c10::ArrayRef<long>, unsigned long) + 0x7e6 (0x7fae467c4396 in /private/home/roller/.conda/envs/retry-20190211/lib/python3.7/sit
e-packages/torch/lib/libtorch_python.so)
frame #8: <unknown function> + 0x4f2be6 (0x7fae467c8be6 in /private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #9: <unknown function> + 0x111af6 (0x7fae463e7af6 in /private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
<omitting python frames>
frame #18: THPFunction_apply(_object*, _object*) + 0x5a1 (0x7fae465e3061 in /private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #51: __libc_start_main + 0xe7 (0x7fae57b99b97 in /lib/x86_64-linux-gnu/libc.so.6)


Running mode=ddp_single
------------------------------------------------------------
Conda PREFIX: /private/home/roller/.conda/envs/retry-20190211
Torch version: 1.0.1.post2
CUDA version: 10.0.130
Using a single GPU in distributed (equiv to 1 proc per gpu)
Step bs= 8192
  Forward with bs  =   8192
  Backward with bs =   8192
FW/BW succeeded. Doubling BS
Step bs= 16384
  Forward with bs  =  16384
  Backward with bs =  16384
FW/BW succeeded. Doubling BS
Step bs= 32768
  Forward with bs  =  32768
  Backward with bs =  32768
FW/BW succeeded. Doubling BS
Step bs= 65536
  Forward with bs  =  65536
OOM #1! Running through a tiny batch to catch up worker
  Forward with bs  =      2
  Backward with bs =      2
Traceback (most recent call last):
  File "memtestcase.py", line 117, in run_trial
    fwbw(model, bs)
  File "memtestcase.py", line 78, in fwbw
    yhat = model(X)
  File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/nn/modules/module.py", line 489, in __call__
    result = self.forward(*input, **kwargs)
  File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/nn/parallel/distributed.py", line 357, in forward
    return self.module(*inputs[0], **kwargs[0])
  File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/nn/modules/module.py", line 489, in __call__
    result = self.forward(*input, **kwargs)
  File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/nn/modules/container.py", line 92, in forward
    input = module(input)
  File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/nn/modules/module.py", line 489, in __call__
    result = self.forward(*input, **kwargs)
  File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/nn/modules/linear.py", line 67, in forward
    return F.linear(input, self.weight, self.bias)
  File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/nn/functional.py", line 1352, in linear
    ret = torch.addmm(torch.jit._unwrap_optional(bias), input, weight.t())
RuntimeError: CUDA out of memory. Tried to allocate 1024.00 MiB (GPU 0; 15.90 GiB total capacity; 15.25 GiB already allocated; 19.56 MiB free; 607.00 KiB cached)

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "memtestcase.py", line 155, in <module>
    main()
  File "memtestcase.py", line 150, in main
    run_trial(args)
  File "memtestcase.py", line 129, in run_trial
    fwbw(model, 2)
  File "memtestcase.py", line 83, in fwbw
    loss.backward()
  File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/tensor.py", line 102, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/autograd/__init__.py", line 90, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: CUDA out of memory. Tried to allocate 64.00 MiB (GPU 0; 15.90 GiB total capacity; 15.25 GiB already allocated; 19.56 MiB free; 989.50 KiB cached)

Running mode=ddp_multi
------------------------------------------------------------
Conda PREFIX: /private/home/roller/.conda/envs/retry-20190211
Torch version: 1.0.1.post2
CUDA version: 10.0.130
Wrapping in DistributedDataParallel (equiv to 1 proc per node)
Step bs= 8192
  Forward with bs  =   8192
  Backward with bs =   8192
FW/BW succeeded. Doubling BS
Step bs= 16384
  Forward with bs  =  16384
  Backward with bs =  16384
FW/BW succeeded. Doubling BS
Step bs= 32768
  Forward with bs  =  32768
  Backward with bs =  32768
FW/BW succeeded. Doubling BS
Step bs= 65536
  Forward with bs  =  65536
  Backward with bs =  65536
FW/BW succeeded. Doubling BS
Step bs= 131072
  Forward with bs  = 131072
OOM #1! Running through a tiny batch to catch up worker
  Forward with bs  =      2
Traceback (most recent call last):
  File "memtestcase.py", line 117, in run_trial
    fwbw(model, bs)
  File "memtestcase.py", line 78, in fwbw
    yhat = model(X)
  File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/nn/modules/module.py", line 489, in __call__
    result = self.forward(*input, **kwargs)
  File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/nn/parallel/distributed.py", line 358, in forward
    outputs = self.parallel_apply(self._module_copies[:len(inputs)], inputs, kwargs)
  File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/nn/parallel/distributed.py", line 365, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
  File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 83, in parallel_apply
    raise output
  File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 59, in _worker
    output = module(*input, **kwargs)
  File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/nn/modules/module.py", line 489, in __call__
    result = self.forward(*input, **kwargs)
  File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/nn/modules/container.py", line 92, in forward
    input = module(input)
  File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/nn/modules/module.py", line 489, in __call__
    result = self.forward(*input, **kwargs)
  File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/nn/modules/activation.py", line 50, in forward
    return F.threshold(input, self.threshold, self.value, self.inplace)
  File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/nn/functional.py", line 840, in threshold
    result = _VF.threshold(input, threshold, value)
RuntimeError: CUDA out of memory. Tried to allocate 1024.00 MiB (GPU 0; 15.90 GiB total capacity; 14.25 GiB already allocated; 991.56 MiB free; 607.00 KiB cached)

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "memtestcase.py", line 155, in <module>
    main()
  File "memtestcase.py", line 150, in main
    run_trial(args)
  File "memtestcase.py", line 129, in run_trial
    fwbw(model, 2)
  File "memtestcase.py", line 78, in fwbw
    yhat = model(X)
  File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/nn/modules/module.py", line 489, in __call__
    result = self.forward(*input, **kwargs)
  File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/nn/parallel/distributed.py", line 355, in forward
    self._sync_params()
  File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/nn/parallel/distributed.py", line 384, in _sync_params
    self.broadcast_bucket_size)
  File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/cuda/comm.py", line 40, in broadcast_coalesced
    return torch._C._broadcast_coalesced(tensors, devices, buffer_size)
RuntimeError: CUDA out of memory. Tried to allocate 128.12 MiB (GPU 1; 15.90 GiB total capacity; 15.13 GiB already allocated; 89.56 MiB free; 992.00 KiB cached) (malloc at /opt/conda/conda-bld/pytorch_1549636813070/work/aten/src/THC/THCCachingAllocator.cpp:231)
frame #0: c10::Error::Error(c10::SourceLocation, std::string const&) + 0x45 (0x7f1e8e50acf5 in /private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/lib/libc10.so)
frame #1: <unknown function> + 0x1239bc1 (0x7f1e927ecbc1 in /private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/lib/libcaffe2_gpu.so)
frame #2: <unknown function> + 0x123a53a (0x7f1e927ed53a in /private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/lib/libcaffe2_gpu.so)
frame #3: at::native::empty_cuda(c10::ArrayRef<long>, at::TensorOptions const&) + 0x2d6 (0x7f1e93e57db6 in /private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/lib/libcaffe2_gpu.so)
frame #4: at::CUDAFloatType::empty(c10::ArrayRef<long>, at::TensorOptions const&) const + 0x161 (0x7f1e9270b311 in /private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/lib/libcaffe2_gpu.so)
frame #5: torch::autograd::VariableType::empty(c10::ArrayRef<long>, at::TensorOptions const&) const + 0x179 (0x7f1e87743209 in /private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/lib/libtorch.so.1)
frame #6: torch::cuda::broadcast(at::Tensor const&, c10::ArrayRef<long>) + 0x545 (0x7f1eb5bdc725 in /private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #7: torch::cuda::broadcast_coalesced(c10::ArrayRef<at::Tensor>, c10::ArrayRef<long>, unsigned long) + 0x7e6 (0x7f1eb5bdd396 in /private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #8: <unknown function> + 0x4f2be6 (0x7f1eb5be1be6 in /private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #9: <unknown function> + 0x111af6 (0x7f1eb5800af6 in /private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
<omitting python frames>
frame #43: __libc_start_main + 0xe7 (0x7f1ec6fb2b97 in /lib/x86_64-linux-gnu/libc.so.6)

Log (nightly 2019-04-04):

Mon Apr  8 11:40:28 2019       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 410.79       Driver Version: 410.79       CUDA Version: 10.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|===============================+======================+======================|
|   0  Quadro GP100        Off  | 00000000:81:00.0 Off |                    0 |
| 27%   39C    P0    34W / 235W |      0MiB / 16278MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
|   1  Quadro GP100        Off  | 00000000:82:00.0 Off |                    0 |
| 26%   39C    P0    34W / 235W |      0MiB / 16278MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+
Collect env
------------------------------------------------------------
Collecting environment information...
PyTorch version: 1.0.0.dev20190404
Is debug build: No
CUDA used to build PyTorch: 10.0.130

OS: Ubuntu 18.04.1 LTS
GCC version: (Ubuntu 7.3.0-27ubuntu1~18.04) 7.3.0
CMake version: version 3.10.2

Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: 10.0.130
GPU models and configuration: 
GPU 0: Quadro GP100
GPU 1: Quadro GP100

Nvidia driver version: 410.79
cuDNN version: Could not collect

Versions of relevant libraries:
[pip] numpy==1.16.2
[pip] torch==1.0.0.dev20190404
[conda] blas                      1.0                         mkl  
[conda] mkl                       2019.3                      199  
[conda] mkl_fft                   1.0.10           py37ha843d7b_0  
[conda] mkl_random                1.0.2            py37hd81dba3_0  
[conda] pytorch-nightly           1.0.0.dev20190404 py3.7_cuda10.0.130_cudnn7.4.2_0    pytorch

Running mode=single
------------------------------------------------------------
Conda PREFIX: /private/home/roller/.conda/envs/nightly
Torch version: 1.0.0.dev20190404
CUDA version: 10.0.130
Using a single GPU
Step bs= 8192
  Forward with bs  =   8192
  Backward with bs =   8192
FW/BW succeeded. Doubling BS
Step bs= 16384
  Forward with bs  =  16384
  Backward with bs =  16384
FW/BW succeeded. Doubling BS
Step bs= 32768
  Forward with bs  =  32768
  Backward with bs =  32768
FW/BW succeeded. Doubling BS
Step bs= 65536
  Forward with bs  =  65536
OOM #1! Running through a tiny batch to catch up worker
  Forward with bs  =      2
  Backward with bs =      2
Traceback (most recent call last):
  File "memtestcase.py", line 117, in run_trial
    fwbw(model, bs)
  File "memtestcase.py", line 78, in fwbw
    yhat = model(X)
  File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/modules/module.py", line 491, in __call__
    result = self.forward(*input, **kwargs)
  File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/modules/container.py", line 97, in forward
    input = module(input)
  File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/modules/module.py", line 491, in __call__
    result = self.forward(*input, **kwargs)
  File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/modules/linear.py", line 67, in forward
    return F.linear(input, self.weight, self.bias)
  File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/functional.py", line 1400, in linear
    ret = torch.addmm(bias, input, weight.t())
RuntimeError: CUDA out of memory. Tried to allocate 1024.00 MiB (GPU 0; 15.90 GiB total capacity; 15.25 GiB already allocated; 9.56 MiB free; 1.59 MiB cached)

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "memtestcase.py", line 155, in <module>
    main()
  File "memtestcase.py", line 150, in main
    run_trial(args)
  File "memtestcase.py", line 129, in run_trial
    fwbw(model, 2)
  File "memtestcase.py", line 83, in fwbw
    loss.backward()
  File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/tensor.py", line 107, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/autograd/__init__.py", line 93, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: CUDA out of memory. Tried to allocate 64.00 MiB (GPU 0; 15.90 GiB total capacity; 15.25 GiB already allocated; 9.56 MiB free; 989.50 KiB cached)

Running mode=dp
------------------------------------------------------------
Conda PREFIX: /private/home/roller/.conda/envs/nightly
Torch version: 1.0.0.dev20190404
CUDA version: 10.0.130
Wrapping in DataParallel
Step bs= 8192
  Forward with bs  =   8192
  Backward with bs =   8192
FW/BW succeeded. Doubling BS
Step bs= 16384
  Forward with bs  =  16384
  Backward with bs =  16384
FW/BW succeeded. Doubling BS
Step bs= 32768
  Forward with bs  =  32768
  Backward with bs =  32768
FW/BW succeeded. Doubling BS
Step bs= 65536
  Forward with bs  =  65536
  Backward with bs =  65536
FW/BW succeeded. Doubling BS
Step bs= 131072
  Forward with bs  = 131072
OOM #1! Running through a tiny batch to catch up worker
  Forward with bs  =      2
Traceback (most recent call last):
  File "memtestcase.py", line 117, in run_trial
    fwbw(model, bs)
  File "memtestcase.py", line 78, in fwbw
    yhat = model(X)
  File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/modules/module.py", line 491, in __call__
    result = self.forward(*input, **kwargs)
  File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 152, in forward
    outputs = self.parallel_apply(replicas, inputs, kwargs)
  File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 162, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
  File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 83, in parallel_apply
    raise output
  File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 59, in _worker
    output = module(*input, **kwargs)
  File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/modules/module.py", line 491, in __call__
    result = self.forward(*input, **kwargs)
  File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/modules/container.py", line 97, in forward
    input = module(input)
  File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/modules/module.py", line 491, in __call__
    result = self.forward(*input, **kwargs)
  File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/modules/activation.py", line 50, in forward
    return F.threshold(input, self.threshold, self.value, self.inplace)
  File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/functional.py", line 915, in threshold
    result = _VF.threshold(input, threshold, value)
RuntimeError: CUDA out of memory. Tried to allocate 1024.00 MiB (GPU 0; 15.90 GiB total capacity; 14.25 GiB already allocated; 981.56 MiB free; 1.59 MiB cached)

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "memtestcase.py", line 155, in <module>
    main()
  File "memtestcase.py", line 150, in main
    run_trial(args)
  File "memtestcase.py", line 129, in run_trial
    fwbw(model, 2)
  File "memtestcase.py", line 78, in fwbw
    yhat = model(X)
  File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/modules/module.py", line 491, in __call__
    result = self.forward(*input, **kwargs)
  File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 151, in forward
    replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
  File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 156, in replicate
    return replicate(module, device_ids)
  File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/parallel/replicate.py", line 281, in replicate
    param_copies = _broadcast_coalesced_reshape(params, devices, detach)
  File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/parallel/replicate.py", line 207, in _broadcast_coalesced_reshape
    tensor_copies = Broadcast.apply(device_group, *tensor_group)
  File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/parallel/_functions.py", line 21, in forward
    outputs = comm.broadcast_coalesced(inputs, ctx.target_gpus)
  File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/cuda/comm.py", line 39, in broadcast_coalesced
    return torch._C._broadcast_coalesced(tensors, devices, buffer_size)
RuntimeError: CUDA out of memory. Tried to allocate 66.00 MiB (GPU 1; 15.90 GiB total capacity; 15.13 GiB already allocated; 57.56 MiB free; 29.67 MiB cached) (malloc at /opt/conda/conda-bld/pytorch-nightly_1554354360325/work/c10/cuda/CUDACachingAllocator.cpp:267)
frame #0: c10::Error::Error(c10::SourceLocation, std::string const&) + 0x45 (0x7fd0230309c5 in /private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/lib/libc10.so)
frame #1: <unknown function> + 0x16971 (0x7fd01b3ca971 in /private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/lib/libc10_cuda.so)
frame #2: <unknown function> + 0x17007 (0x7fd01b3cb007 in /private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/lib/libc10_cuda.so)
frame #3: at::native::empty_cuda(c10::ArrayRef<long>, c10::TensorOptions const&) + 0x9e5 (0x7fd028fa0935 in /private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/lib/libcaffe2_gpu.so)
frame #4: at::CUDAFloatType::empty(c10::ArrayRef<long>, c10::TensorOptions const&) const + 0x4d (0x7fd0277859bd in /private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/lib/libcaffe2_gpu.so)
frame #5: torch::autograd::VariableType::empty(c10::ArrayRef<long>, c10::TensorOptions const&) const + 0x1c1 (0x7fd01bf77ec1 in /private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/lib/libtorch.so.1)
frame #6: torch::cuda::broadcast(at::Tensor const&, c10::ArrayRef<long>) + 0x594 (0x7fd04d3c3c44 in /private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #7: torch::cuda::broadcast_coalesced(c10::ArrayRef<at::Tensor>, c10::ArrayRef<long>, unsigned long) + 0x6f6 (0x7fd04d3c47d6 in /private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #8: <unknown function> + 0x5b975e (0x7fd04d3c875e in /private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #9: <unknown function> + 0x129586 (0x7fd04cf38586 in /private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
<omitting python frames>
frame #18: THPFunction_apply(_object*, _object*) + 0x5e9 (0x7fd04d1b1d69 in /private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #54: __libc_start_main + 0xe7 (0x7fd05bd1fb97 in /lib/x86_64-linux-gnu/libc.so.6)


Running mode=ddp_single
------------------------------------------------------------
Conda PREFIX: /private/home/roller/.conda/envs/nightly
Torch version: 1.0.0.dev20190404
CUDA version: 10.0.130
Using a single GPU in distributed (equiv to 1 proc per gpu)
Step bs= 8192
  Forward with bs  =   8192
  Backward with bs =   8192
FW/BW succeeded. Doubling BS
Step bs= 16384
  Forward with bs  =  16384
  Backward with bs =  16384
FW/BW succeeded. Doubling BS
Step bs= 32768
  Forward with bs  =  32768
  Backward with bs =  32768
FW/BW succeeded. Doubling BS
Step bs= 65536
  Forward with bs  =  65536
OOM #1! Running through a tiny batch to catch up worker
  Forward with bs  =      2
  Backward with bs =      2
Traceback (most recent call last):
  File "memtestcase.py", line 117, in run_trial
    fwbw(model, bs)
  File "memtestcase.py", line 78, in fwbw
    yhat = model(X)
  File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/modules/module.py", line 491, in __call__
    result = self.forward(*input, **kwargs)
  File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/parallel/distributed.py", line 366, in forward
    return self.module(*inputs[0], **kwargs[0])
  File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/modules/module.py", line 491, in __call__
    result = self.forward(*input, **kwargs)
  File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/modules/container.py", line 97, in forward
    input = module(input)
  File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/modules/module.py", line 491, in __call__
    result = self.forward(*input, **kwargs)
  File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/modules/linear.py", line 67, in forward
    return F.linear(input, self.weight, self.bias)
  File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/functional.py", line 1400, in linear
    ret = torch.addmm(bias, input, weight.t())
RuntimeError: CUDA out of memory. Tried to allocate 1024.00 MiB (GPU 0; 15.90 GiB total capacity; 15.25 GiB already allocated; 3.56 MiB free; 1.59 MiB cached)

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "memtestcase.py", line 155, in <module>
    main()
  File "memtestcase.py", line 150, in main
    run_trial(args)
  File "memtestcase.py", line 129, in run_trial
    fwbw(model, 2)
  File "memtestcase.py", line 83, in fwbw
    loss.backward()
  File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/tensor.py", line 107, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/autograd/__init__.py", line 93, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: CUDA out of memory. Tried to allocate 64.00 MiB (GPU 0; 15.90 GiB total capacity; 15.25 GiB already allocated; 3.56 MiB free; 989.50 KiB cached)

Running mode=ddp_multi
------------------------------------------------------------
Conda PREFIX: /private/home/roller/.conda/envs/nightly
Torch version: 1.0.0.dev20190404
CUDA version: 10.0.130
Wrapping in DistributedDataParallel (equiv to 1 proc per node)
Step bs= 8192
  Forward with bs  =   8192
Traceback (most recent call last):
  File "memtestcase.py", line 155, in <module>
    main()
  File "memtestcase.py", line 150, in main
    run_trial(args)
  File "memtestcase.py", line 123, in run_trial
    raise rerr
  File "memtestcase.py", line 117, in run_trial
    fwbw(model, bs)
  File "memtestcase.py", line 78, in fwbw
    yhat = model(X)
  File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/modules/module.py", line 491, in __call__
    result = self.forward(*input, **kwargs)
  File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/parallel/distributed.py", line 364, in forward
    self._sync_params()
  File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/parallel/distributed.py", line 396, in _sync_params
    param_data.set_(tensor)
RuntimeError: set_storage is not allowed on Tensor created from .data or .detach()

@stephenroller
Copy link
Contributor Author

Just ran with pytorch 1.1 given the release:

Fri May  3 07:32:41 2019       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 410.79       Driver Version: 410.79       CUDA Version: 10.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|===============================+======================+======================|
|   0  Quadro GP100        Off  | 00000000:81:00.0 Off |                    0 |
| 26%   34C    P0    33W / 235W |      0MiB / 16278MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
|   1  Quadro GP100        Off  | 00000000:82:00.0 Off |                    0 |
| 26%   34C    P0    34W / 235W |      0MiB / 16278MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+
Collect env
------------------------------------------------------------
Collecting environment information...
PyTorch version: 1.1.0
Is debug build: No
CUDA used to build PyTorch: 10.0.130

OS: Ubuntu 18.04.1 LTS
GCC version: (Ubuntu 7.3.0-27ubuntu1~18.04) 7.3.0
CMake version: version 3.10.2

Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: 10.0.130
GPU models and configuration: 
GPU 0: Quadro GP100
GPU 1: Quadro GP100

Nvidia driver version: 410.79
cuDNN version: Could not collect

Versions of relevant libraries:
[pip] numpy==1.16.3
[pip] torch==1.1.0
[pip] torchvision==0.2.2
[conda] blas                      1.0                         mkl  
[conda] mkl                       2019.3                      199  
[conda] mkl_fft                   1.0.12           py37ha843d7b_0  
[conda] mkl_random                1.0.2            py37hd81dba3_0  
[conda] pytorch                   1.1.0           py3.7_cuda10.0.130_cudnn7.5.1_0    pytorch
[conda] torchvision               0.2.2                      py_3    pytorch

Running mode=single
------------------------------------------------------------
Conda PREFIX: /private/home/roller/.conda/envs/pytorch11
Torch version: 1.1.0
CUDA version: 10.0.130
Using a single GPU
Step bs= 8192
  Forward with bs  =   8192
  Backward with bs =   8192
FW/BW succeeded. Doubling BS
Step bs= 16384
  Forward with bs  =  16384
  Backward with bs =  16384
FW/BW succeeded. Doubling BS
Step bs= 32768
  Forward with bs  =  32768
  Backward with bs =  32768
FW/BW succeeded. Doubling BS
Step bs= 65536
  Forward with bs  =  65536
OOM #1! Running through a tiny batch to catch up worker
  Forward with bs  =      2
  Backward with bs =      2
Traceback (most recent call last):
  File "memtestcase.py", line 117, in run_trial
    fwbw(model, bs)
  File "memtestcase.py", line 78, in fwbw
    yhat = model(X)
  File "/private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/nn/modules/module.py", line 493, in __call__
    result = self.forward(*input, **kwargs)
  File "/private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/nn/modules/container.py", line 92, in forward
    input = module(input)
  File "/private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/nn/modules/module.py", line 493, in __call__
    result = self.forward(*input, **kwargs)
  File "/private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/nn/modules/linear.py", line 92, in forward
    return F.linear(input, self.weight, self.bias)
  File "/private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/nn/functional.py", line 1406, in linear
    ret = torch.addmm(bias, input, weight.t())
RuntimeError: CUDA out of memory. Tried to allocate 1024.00 MiB (GPU 0; 15.90 GiB total capacity; 15.25 GiB already allocated; 1.56 MiB free; 1.59 MiB cached)

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "memtestcase.py", line 155, in <module>
    main()
  File "memtestcase.py", line 150, in main
    run_trial(args)
  File "memtestcase.py", line 129, in run_trial
    fwbw(model, 2)
  File "memtestcase.py", line 83, in fwbw
    loss.backward()
  File "/private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/tensor.py", line 107, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/autograd/__init__.py", line 93, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: CUDA out of memory. Tried to allocate 64.00 MiB (GPU 0; 15.90 GiB total capacity; 15.25 GiB already allocated; 1.56 MiB free; 989.50 KiB cached)

Running mode=dp
------------------------------------------------------------
Conda PREFIX: /private/home/roller/.conda/envs/pytorch11
Torch version: 1.1.0
CUDA version: 10.0.130
Wrapping in DataParallel
Step bs= 8192
  Forward with bs  =   8192
  Backward with bs =   8192
FW/BW succeeded. Doubling BS
Step bs= 16384
  Forward with bs  =  16384
  Backward with bs =  16384
FW/BW succeeded. Doubling BS
Step bs= 32768
  Forward with bs  =  32768
  Backward with bs =  32768
FW/BW succeeded. Doubling BS
Step bs= 65536
  Forward with bs  =  65536
  Backward with bs =  65536
FW/BW succeeded. Doubling BS
Step bs= 131072
  Forward with bs  = 131072
OOM #1! Running through a tiny batch to catch up worker
  Forward with bs  =      2
Traceback (most recent call last):
  File "memtestcase.py", line 117, in run_trial
    fwbw(model, bs)
  File "memtestcase.py", line 78, in fwbw
    yhat = model(X)
  File "/private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/nn/modules/module.py", line 493, in __call__
    result = self.forward(*input, **kwargs)
  File "/private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 152, in forward
    outputs = self.parallel_apply(replicas, inputs, kwargs)
  File "/private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 162, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
  File "/private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 83, in parallel_apply
    raise output
  File "/private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 59, in _worker
    output = module(*input, **kwargs)
  File "/private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/nn/modules/module.py", line 493, in __call__
    result = self.forward(*input, **kwargs)
  File "/private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/nn/modules/container.py", line 92, in forward
    input = module(input)
  File "/private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/nn/modules/module.py", line 493, in __call__
    result = self.forward(*input, **kwargs)
  File "/private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/nn/modules/activation.py", line 99, in forward
    return F.relu(input, inplace=self.inplace)
  File "/private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/nn/functional.py", line 943, in relu
    result = torch.relu(input)
RuntimeError: CUDA out of memory. Tried to allocate 1024.00 MiB (GPU 0; 15.90 GiB total capacity; 14.25 GiB already allocated; 973.56 MiB free; 1.59 MiB cached)

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "memtestcase.py", line 155, in <module>
    main()
  File "memtestcase.py", line 150, in main
    run_trial(args)
  File "memtestcase.py", line 129, in run_trial
    fwbw(model, 2)
  File "memtestcase.py", line 78, in fwbw
    yhat = model(X)
  File "/private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/nn/modules/module.py", line 493, in __call__
    result = self.forward(*input, **kwargs)
  File "/private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 151, in forward
    replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
  File "/private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 156, in replicate
    return replicate(module, device_ids)
  File "/private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/nn/parallel/replicate.py", line 97, in replicate
    param_copies = _broadcast_coalesced_reshape(params, devices, detach)
  File "/private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/nn/parallel/replicate.py", line 80, in _broadcast_coalesced_reshape
    tensor_copies = Broadcast.apply(devices, *tensors)
  File "/private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/nn/parallel/_functions.py", line 21, in forward
    outputs = comm.broadcast_coalesced(inputs, ctx.target_gpus)
  File "/private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/cuda/comm.py", line 39, in broadcast_coalesced
    return torch._C._broadcast_coalesced(tensors, devices, buffer_size)
RuntimeError: CUDA out of memory. Tried to allocate 66.00 MiB (GPU 1; 15.90 GiB total capacity; 15.13 GiB already allocated; 49.56 MiB free; 29.67 MiB cached) (malloc at /opt/conda/conda-bld/pytorch_1556653114079/work/c10/cuda/CUDACachingAllocator.cpp:267)
frame #0: c10::Error::Error(c10::SourceLocation, std::string const&) + 0x45 (0x7f07289d3dc5 in /private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/lib/libc10.so)
frame #1: <unknown function> + 0x16ca7 (0x7f0720c9cca7 in /private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/lib/libc10_cuda.so)
frame #2: <unknown function> + 0x17347 (0x7f0720c9d347 in /private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/lib/libc10_cuda.so)
frame #3: at::native::empty_cuda(c10::ArrayRef<long>, c10::TensorOptions const&) + 0x274 (0x7f072e850f14 in /private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/lib/libcaffe2_gpu.so)
frame #4: at::CUDAType::empty(c10::ArrayRef<long>, c10::TensorOptions const&) const + 0x19b (0x7f072cfa04bb in /private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/lib/libcaffe2_gpu.so)
frame #5: torch::autograd::VariableType::empty(c10::ArrayRef<long>, c10::TensorOptions const&) const + 0x268 (0x7f07218debf8 in /private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/lib/libtorch.so.1)
frame #6: at::native::to(at::Tensor const&, c10::Device, c10::ScalarType, bool, bool) + 0x687 (0x7f07293a1457 in /private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/lib/libcaffe2.so)
frame #7: at::TypeDefault::to(at::Tensor const&, c10::Device, c10::ScalarType, bool, bool) const + 0x1b (0x7f07296217eb in /private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/lib/libcaffe2.so)
frame #8: torch::autograd::VariableType::to(at::Tensor const&, c10::Device, c10::ScalarType, bool, bool) const + 0x312 (0x7f07217ca5c2 in /private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/lib/libtorch.so.1)
frame #9: torch::cuda::broadcast(at::Tensor const&, c10::ArrayRef<long>) + 0x100 (0x7f0721d1f1d0 in /private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/lib/libtorch.so.1)
frame #10: torch::cuda::broadcast_coalesced(c10::ArrayRef<at::Tensor>, c10::ArrayRef<long>, unsigned long) + 0x439 (0x7f0721d1f829 in /private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/lib/libtorch.so.1)
frame #11: <unknown function> + 0x5a912e (0x7f0757e5912e in /private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #12: <unknown function> + 0x12d07a (0x7f07579dd07a in /private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
<omitting python frames>
frame #21: THPFunction_apply(_object*, _object*) + 0x691 (0x7f0757c5f891 in /private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #57: __libc_start_main + 0xe7 (0x7f07667bdb97 in /lib/x86_64-linux-gnu/libc.so.6)


Running mode=ddp_single
------------------------------------------------------------
Conda PREFIX: /private/home/roller/.conda/envs/pytorch11
Torch version: 1.1.0
CUDA version: 10.0.130
Using a single GPU in distributed (equiv to 1 proc per gpu)
Step bs= 8192
  Forward with bs  =   8192
  Backward with bs =   8192
FW/BW succeeded. Doubling BS
Step bs= 16384
  Forward with bs  =  16384
  Backward with bs =  16384
FW/BW succeeded. Doubling BS
Step bs= 32768
  Forward with bs  =  32768
  Backward with bs =  32768
FW/BW succeeded. Doubling BS
Step bs= 65536
  Forward with bs  =  65536
OOM #1! Running through a tiny batch to catch up worker
  Forward with bs  =      2
  Backward with bs =      2
Succeeded on the oom batch.
Test passed.

Running mode=ddp_multi
------------------------------------------------------------
Conda PREFIX: /private/home/roller/.conda/envs/pytorch11
Torch version: 1.1.0
CUDA version: 10.0.130
Wrapping in DistributedDataParallel (equiv to 1 proc per node)
Step bs= 8192
  Forward with bs  =   8192
  Backward with bs =   8192
FW/BW succeeded. Doubling BS
Step bs= 16384
  Forward with bs  =  16384
  Backward with bs =  16384
FW/BW succeeded. Doubling BS
Step bs= 32768
  Forward with bs  =  32768
  Backward with bs =  32768
FW/BW succeeded. Doubling BS
Step bs= 65536
  Forward with bs  =  65536
  Backward with bs =  65536
FW/BW succeeded. Doubling BS
Step bs= 131072
  Forward with bs  = 131072
OOM #1! Running through a tiny batch to catch up worker
  Forward with bs  =      2
Traceback (most recent call last):
  File "memtestcase.py", line 117, in run_trial
    fwbw(model, bs)
  File "memtestcase.py", line 78, in fwbw
    yhat = model(X)
  File "/private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/nn/modules/module.py", line 493, in __call__
    result = self.forward(*input, **kwargs)
  File "/private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/nn/parallel/distributed.py", line 378, in forward
    outputs = self.parallel_apply(self._module_copies[:len(inputs)], inputs, kwargs)
  File "/private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/nn/parallel/distributed.py", line 399, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
  File "/private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 83, in parallel_apply
    raise output
  File "/private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 59, in _worker
    output = module(*input, **kwargs)
  File "/private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/nn/modules/module.py", line 493, in __call__
    result = self.forward(*input, **kwargs)
  File "/private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/nn/modules/container.py", line 92, in forward
    input = module(input)
  File "/private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/nn/modules/module.py", line 493, in __call__
    result = self.forward(*input, **kwargs)
  File "/private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/nn/modules/activation.py", line 99, in forward
    return F.relu(input, inplace=self.inplace)
  File "/private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/nn/functional.py", line 943, in relu
    result = torch.relu(input)
RuntimeError: CUDA out of memory. Tried to allocate 1024.00 MiB (GPU 0; 15.90 GiB total capacity; 14.88 GiB already allocated; 193.56 MiB free; 137.39 MiB cached)

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "memtestcase.py", line 155, in <module>
    main()
  File "memtestcase.py", line 150, in main
    run_trial(args)
  File "memtestcase.py", line 129, in run_trial
    fwbw(model, 2)
  File "memtestcase.py", line 78, in fwbw
    yhat = model(X)
  File "/private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/nn/modules/module.py", line 493, in __call__
    result = self.forward(*input, **kwargs)
  File "/private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/nn/parallel/distributed.py", line 372, in forward
    self._sync_params()
  File "/private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/nn/parallel/distributed.py", line 420, in _sync_params
    self.broadcast_bucket_size)
  File "/private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/cuda/comm.py", line 39, in broadcast_coalesced
    return torch._C._broadcast_coalesced(tensors, devices, buffer_size)
RuntimeError: CUDA out of memory. Tried to allocate 258.00 MiB (GPU 0; 15.90 GiB total capacity; 14.88 GiB already allocated; 193.56 MiB free; 137.33 MiB cached) (malloc at /opt/conda/conda-bld/pytorch_1556653114079/work/c10/cuda/CUDACachingAllocator.cpp:267)
frame #0: c10::Error::Error(c10::SourceLocation, std::string const&) + 0x45 (0x7f82d47c1dc5 in /private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/lib/libc10.so)
frame #1: <unknown function> + 0x16ca7 (0x7f82cca8aca7 in /private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/lib/libc10_cuda.so)
frame #2: <unknown function> + 0x17347 (0x7f82cca8b347 in /private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/lib/libc10_cuda.so)
frame #3: THCStorage_resize + 0x96 (0x7f82d8eac706 in /private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/lib/libcaffe2_gpu.so)
frame #4: THCTensor_resizeNd + 0x4d8 (0x7f82d8ec2e68 in /private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/lib/libcaffe2_gpu.so)
frame #5: THCudaTensor_catArray + 0x4f1 (0x7f82d917abe1 in /private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/lib/libcaffe2_gpu.so)
frame #6: at::CUDAType::_th_cat(c10::ArrayRef<at::Tensor>, long) const + 0x968 (0x7f82d8d910a8 in /private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/lib/libcaffe2_gpu.so)
frame #7: at::native::cat(c10::ArrayRef<at::Tensor>, long) + 0xa4 (0x7f82d51ba2b4 in /private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/lib/libcaffe2.so)
frame #8: at::TypeDefault::cat(c10::ArrayRef<at::Tensor>, long) const + 0x4f (0x7f82d54715bf in /private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/lib/libcaffe2.so)
frame #9: torch::autograd::VariableType::cat(c10::ArrayRef<at::Tensor>, long) const + 0x7f1 (0x7f82cd588ac1 in /private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/lib/libtorch.so.1)
frame #10: <unknown function> + 0xbd0c86 (0x7f82cda72c86 in /private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/lib/libtorch.so.1)
frame #11: torch::cuda::broadcast_coalesced(c10::ArrayRef<at::Tensor>, c10::ArrayRef<long>, unsigned long) + 0x412 (0x7f82cdb0d802 in /private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/lib/libtorch.so.1)
frame #12: <unknown function> + 0x5a912e (0x7f8303c4712e in /private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #13: <unknown function> + 0x12d07a (0x7f83037cb07a in /private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
<omitting python frames>
frame #47: __libc_start_main + 0xe7 (0x7f83125a6b97 in /lib/x86_64-linux-gnu/libc.so.6)

@mrshenli mrshenli added the oncall: distributed Add this issue/PR to distributed oncall triage queue label May 10, 2019
@mrshenli
Copy link
Contributor

@albanD @ezyang @colesbury do you have any thoughts on the autograd memory leak issue?

@zhangguanheng66
Copy link
Contributor

This may not be relevant. We recently have the CUDA OOM problem. We were able to resolve the issue by removing some "weak_script" decorators (#20563).

@stephenroller
Copy link
Contributor Author

Super interesting. I’m not using weak script anywhere (and neither is the test case)

@stephenroller
Copy link
Contributor Author

Any updates?

@mrshenli
Copy link
Contributor

mrshenli commented Jun 6, 2019

Checkout #21344, it is not exactly the same issue, but I suspect the OOM error also has something to do with DDP hooks.

@gchanan gchanan removed this from the 1.1 milestone Jul 16, 2019
@gchanan
Copy link
Contributor

gchanan commented Jul 16, 2019

removing 1.1 milestone -- it didn't make it and doesn't seem blocking for 1.2.

@mrshenli
Copy link
Contributor

Hey @stephenroller, I am not sure if this relates to the memory leak problem above, but the OOM recovery code you linked above might hit desync problem. Say we have 2 DDP processes (p0, p1), and each organizes its backward reduction into 2 buckets (b0,b1). If the OOM occurs on p0 after reducing b0, p1 would try to reduce b1, but p0 would retry both b0 and b1. It then leads to allreducing two different buckets (b0 in p0 and b1 in p1) together, and hence the desync. To avoid this problem, the trainer will need to destroy and reconstruct ProcessGroup and DDP objects on all processes.

@Ali2500
Copy link

Ali2500 commented Aug 24, 2019

Hey @stephenroller, I am not sure if this relates to the memory leak problem above, but the OOM recovery code you linked above might hit desync problem. Say we have 2 DDP processes (p0, p1), and each organizes its backward reduction into 2 buckets (b0,b1). If the OOM occurs on p0 after reducing b0, p1 would try to reduce b1, but p0 would retry both b0 and b1. It then leads to allreducing two different buckets (b0 in p0 and b1 in p1) together, and hence the desync. To avoid this problem, the trainer will need to destroy and reconstruct ProcessGroup and DDP objects on all processes.

Would it be possible for you to write a basic code snippet on how to implement such a routine?

@rgommers rgommers added the quansight-nack High-prio issues that have been reviewed by Quansight and are judged to be not actionable. label Jan 26, 2020
@peterbell10
Copy link
Collaborator

note that if you remove the loss.backward() line in @stephenroller's test case all the tests pass

I was able to reproduce the failure without autograd by modifying the script slightly. I removed backward from fwbw and used torch.no_grad to disable autograd:

@torch.no_grad()
def fwbw(model, bs):
    print('  Forward with bs  = {:-6d}'.format(bs))
    X = torch.randn(bs, INPUT_SIZE, device='cuda')
    torch.cuda.synchronize()
    yhat = model(X)
    torch.cuda.synchronize()
    loss = yhat.sum()
    torch.cuda.synchronize()

Without autograd, this uses much less cuda memory for the same batch size. So, for the oom batch to fail to allocate, we need to increase the batch size proportional to the failed batch size.

If I use fwbw(model, bs//4), I get the OOM batch failing consistently for mode=single but succeeding for mode=dp. If I use fwbw(model, bs//2), it fails for all modes. That seems odd to me since we just successfully ran at that batch size. This happens even if I call gc.collect() before the OOM batch.

@rgommers rgommers removed the quansight-nack High-prio issues that have been reviewed by Quansight and are judged to be not actionable. label Jan 31, 2020
@peterbell10 peterbell10 self-assigned this Feb 1, 2020
@peterbell10
Copy link
Collaborator

This appears to be the same underlying issue as #27600; the memory for the tensors isn't freed until the exception object has gone out of scope.

I was able to get single, dp and ddp_single to work in the script by moving the oom batch outside of the except clause:

try:
    fwbw(model, bs)
    bs *= 2
    oom = False
except RuntimeError as rerr:
    if 'memory' not in str(rerr):
        # not the exception we wanted
        raise rerr
    oom = True

if oom:
    # do recovery batch

ddp_multi still fails for me but now with a different error:

terminate called after throwing an instance of 'c10::Error'
  what():  graph_task INTERNAL ASSERT FAILED at ../torch/csrc/autograd/engine.cpp:157, please report a bug to PyTorch. GraphTask is no longer valid! (getReentrantDepth at ../torch/csrc/autograd/engine.cpp:157)

@mrshenli could this be related to the desync issue you mentioned above?

@stephenroller
Copy link
Contributor Author

Oh wow, that’s a great workaround @peterbell10, thanks for documenting!

@albanD
Copy link
Collaborator

albanD commented Feb 9, 2020

cc @pritamdamania87 that introduced this error message in #27940 Do you know why it is triggered here?

@peterbell10
Copy link
Collaborator

@pritamdamania87, @mrshenli any thoughts as to why the graph task might be expiring here?

@peterbell10
Copy link
Collaborator

I ran the test again with a newer build of pytorch and ddp_multi succeeds as well now. Running through the history, 05d18ff was the first commit to succeed. It looks like the TORCH_CHECK that was failing has been removed completely. So, I guess this can be closed now?

cc @pritamdamania87

@ezyang
Copy link
Contributor

ezyang commented Mar 9, 2020

@peterbell10 Yes, this looks resolved to our satisfaction! Perhaps the only last thing is to maybe think about a good place to add some documentation about this.

@vadimkantorov
Copy link
Contributor

vadimkantorov commented Sep 24, 2020

@ezyang We're also experiencing this issue (DDP doesn't sync well once a OOM recovery has happened, because some variables of autograd graph never had backward). Should we do sth similar as in https://github.com/facebookresearch/ParlAI/blob/e9b95d8be32c7b08f487423441de68d05709a615/parlai/core/torch_generator_agent.py ? That re-passes dummy patch to trigger backward sync

Is there a way to manually tell DDP that some variables should marked "as if backward pass happened on them"? Then we could just call zero_grad at OOM recovery and trigger some all the autograd hooks / dummy backward without hack with passing a dummy batch

@stephenroller
Copy link
Contributor Author

Not sure, haven't dug into the guts of DDP enough to know if there's a way to sync using manually set gradients. The "standard" way is it hooks into backward(), which makes that difficult

@stephenroller
Copy link
Contributor Author

It looks like DDP.join might serve the purpose? Please ping on this thread if you try that and it works.

@stephenroller
Copy link
Contributor Author

Or maybe DDP._match_all_reduce_for_bwd_pass

@albertz
Copy link
Contributor

albertz commented Jan 12, 2024

I was able to get single, dp and ddp_single to work in the script by moving the oom batch outside of the except clause:

try:
    fwbw(model, bs)
    bs *= 2
    oom = False
except RuntimeError as rerr:
    if 'memory' not in str(rerr):
        # not the exception we wanted
        raise rerr
    oom = True

if oom:
    # do recovery batch

Note that moving the OOM recovery logic outside of the except clause is not strictly necessary. The problem is that the exception (rerr here) also covers the traceback (rerr.__traceback__) which covers references to all the stack frames which covers references to all the local variables. But there is the function traceback.clear_frames which can remove those references. I.e. you could do this instead:

import traceback

...

try:
    fwbw(model, bs)
    bs *= 2
except RuntimeError as rerr:
    if 'memory' not in str(rerr):
        # not the exception we wanted
        raise rerr
    # Got OOM.
    traceback.clear_frames(rerr.__traceback__)  # free locals
    # do recovery batch

Note, one thing to be careful about (and this is no matter whether you handle the OOM recovering inside or outside the except clause): Whenever you (or sth else, e.g. your debugger) accesses tb_frame.f_locals for any of the frames of the traceback, you might hit this CPython bug: python/cpython#113939

In that case, it is tricky to free the references to any Torch tensors on the exception stack. Just calling traceback.clear_frames(rerr.__traceback__) is not enough. There is a tricky workaround to free the locals then: access f_locals again after the frame objects have been cleared, which potentially cleans up a previous copy of the locals (due to python/cpython#113939). I.e. sth like this:

tb = rerr.__traceback__
while tb:
    try:
        tb.tb_frame.clear()
    except RuntimeError:
        pass  # can happen if this still executing, e.g. the current frame
    else:
        # Using this code triggers that the ref actually goes out of scope, otherwise it does not!
        # https://github.com/python/cpython/issues/113939
        tb.tb_frame.f_locals  # noqa
    tb = tb.tb_next

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
has workaround high priority module: autograd Related to torch.autograd, and the autograd engine in general module: memory usage PyTorch is using more memory than it should, or it is leaking memory oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.