-
Notifications
You must be signed in to change notification settings - Fork 22.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Inconsistent recovery from CUDA OOMs #18853
Comments
And again with nightly, where all test cases fail: Thu Apr 4 10:20:07 2019
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 410.79 Driver Version: 410.79 CUDA Version: 10.0 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
|===============================+======================+======================|
| 0 Quadro GP100 Off | 00000000:81:00.0 Off | 0 |
| 32% 41C P0 35W / 235W | 0MiB / 16278MiB | 0% Default |
+-------------------------------+----------------------+----------------------+
| 1 Quadro GP100 Off | 00000000:82:00.0 Off | 0 |
| 28% 40C P0 34W / 235W | 0MiB / 16278MiB | 0% Default |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: GPU Memory |
| GPU PID Type Process name Usage |
|=============================================================================|
| No running processes found |
+-----------------------------------------------------------------------------+
=======================================================================
Activating pytorch nightly
=======================================================================
Collect env
------------------------------------------------------------
Collecting environment information...
PyTorch version: 1.0.0.dev20190404
Is debug build: No
CUDA used to build PyTorch: 10.0.130
OS: Ubuntu 18.04.1 LTS
GCC version: (Ubuntu 7.3.0-27ubuntu1~18.04) 7.3.0
CMake version: version 3.10.2
Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: 10.0.130
GPU models and configuration:
GPU 0: Quadro GP100
GPU 1: Quadro GP100
Nvidia driver version: 410.79
cuDNN version: Could not collect
Versions of relevant libraries:
[pip] numpy==1.16.2
[pip] torch==1.0.0.dev20190404
[conda] blas 1.0 mkl
[conda] mkl 2019.3 199
[conda] mkl_fft 1.0.10 py37ha843d7b_0
[conda] mkl_random 1.0.2 py37hd81dba3_0
[conda] pytorch-nightly 1.0.0.dev20190404 py3.7_cuda10.0.130_cudnn7.4.2_0 pytorch
Running mode=single
------------------------------------------------------------
Conda PREFIX: /private/home/roller/.conda/envs/nightly
Torch version: 1.0.0.dev20190404
CUDA version: 10.0.130
Using a single GPU
Step bs= 8192
Forward with bs = 8192
Backward with bs = 8192
FW/BW succeeded. Doubling BS
Step bs= 16384
Forward with bs = 16384
Backward with bs = 16384
FW/BW succeeded. Doubling BS
Step bs= 32768
Forward with bs = 32768
Backward with bs = 32768
FW/BW succeeded. Doubling BS
Step bs= 65536
Forward with bs = 65536
OOM #1! Running through a tiny batch to catch up worker
Forward with bs = 2
Backward with bs = 2
Traceback (most recent call last):
File "memtestcase.py", line 101, in run_trial
fwbw(model, bs)
File "memtestcase.py", line 63, in fwbw
yhat = model(X)
File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/modules/module.py", line 491, in __call__
result = self.forward(*input, **kwargs)
File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/modules/container.py", line 97, in forward
input = module(input)
File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/modules/module.py", line 491, in __call__
result = self.forward(*input, **kwargs)
File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/modules/linear.py", line 67, in forward
return F.linear(input, self.weight, self.bias)
File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/functional.py", line 1400, in linear
ret = torch.addmm(bias, input, weight.t())
RuntimeError: CUDA out of memory. Tried to allocate 1024.00 MiB (GPU 0; 15.90 GiB total capacity; 15.25 GiB already allocated; 9.56 MiB free; 1.59 MiB cached)
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "memtestcase.py", line 139, in <module>
main()
File "memtestcase.py", line 134, in main
run_trial(args)
File "memtestcase.py", line 113, in run_trial
fwbw(model, 2)
File "memtestcase.py", line 68, in fwbw
loss.backward()
File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/tensor.py", line 107, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/autograd/__init__.py", line 93, in backward
allow_unreachable=True) # allow_unreachable flag
RuntimeError: CUDA out of memory. Tried to allocate 64.00 MiB (GPU 0; 15.90 GiB total capacity; 15.25 GiB already allocated; 9.56 MiB free; 989.50 KiB cached)
Running mode=dp
------------------------------------------------------------
Conda PREFIX: /private/home/roller/.conda/envs/nightly
Torch version: 1.0.0.dev20190404
CUDA version: 10.0.130
Wrapping in DataParallel
Step bs= 8192
Forward with bs = 8192
Backward with bs = 8192
FW/BW succeeded. Doubling BS
Step bs= 16384
Forward with bs = 16384
Backward with bs = 16384
FW/BW succeeded. Doubling BS
Step bs= 32768
Forward with bs = 32768
Backward with bs = 32768
FW/BW succeeded. Doubling BS
Step bs= 65536
Forward with bs = 65536
Backward with bs = 65536
FW/BW succeeded. Doubling BS
Step bs= 131072
Forward with bs = 131072
OOM #1! Running through a tiny batch to catch up worker
Forward with bs = 2
Traceback (most recent call last):
File "memtestcase.py", line 101, in run_trial
fwbw(model, bs)
File "memtestcase.py", line 63, in fwbw
yhat = model(X)
File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/modules/module.py", line 491, in __call__
result = self.forward(*input, **kwargs)
File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 152, in forward
outputs = self.parallel_apply(replicas, inputs, kwargs)
File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 162, in parallel_apply
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 83, in parallel_apply
raise output
File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 59, in _worker
output = module(*input, **kwargs)
File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/modules/module.py", line 491, in __call__
result = self.forward(*input, **kwargs)
File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/modules/container.py", line 97, in forward
input = module(input)
File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/modules/module.py", line 491, in __call__
result = self.forward(*input, **kwargs)
File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/modules/activation.py", line 50, in forward
return F.threshold(input, self.threshold, self.value, self.inplace)
File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/functional.py", line 915, in threshold
result = _VF.threshold(input, threshold, value)
RuntimeError: CUDA out of memory. Tried to allocate 1024.00 MiB (GPU 0; 15.90 GiB total capacity; 14.25 GiB already allocated; 981.56 MiB free; 1.59 MiB cached)
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "memtestcase.py", line 139, in <module>
main()
File "memtestcase.py", line 134, in main
run_trial(args)
File "memtestcase.py", line 113, in run_trial
fwbw(model, 2)
File "memtestcase.py", line 63, in fwbw
yhat = model(X)
File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/modules/module.py", line 491, in __call__
result = self.forward(*input, **kwargs)
File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 151, in forward
replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 156, in replicate
return replicate(module, device_ids)
File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/parallel/replicate.py", line 281, in replicate
param_copies = _broadcast_coalesced_reshape(params, devices, detach)
File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/parallel/replicate.py", line 207, in _broadcast_coalesced_reshape
tensor_copies = Broadcast.apply(device_group, *tensor_group)
File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/parallel/_functions.py", line 21, in forward
outputs = comm.broadcast_coalesced(inputs, ctx.target_gpus)
File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/cuda/comm.py", line 39, in broadcast_coalesced
return torch._C._broadcast_coalesced(tensors, devices, buffer_size)
RuntimeError: CUDA out of memory. Tried to allocate 66.00 MiB (GPU 1; 15.90 GiB total capacity; 15.13 GiB already allocated; 57.56 MiB free; 29.67 MiB cached) (malloc at /opt/conda/conda-bld/pytorch-nightly_1554354360325/work/c10/cuda/CUDACachingAllocator.cpp:267)
frame #0: c10::Error::Error(c10::SourceLocation, std::string const&) + 0x45 (0x7efe8e6409c5 in /private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/lib/libc10.so)
frame #1: <unknown function> + 0x16971 (0x7efe869da971 in /private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/lib/libc10_cuda.so)
frame #2: <unknown function> + 0x17007 (0x7efe869db007 in /private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/lib/libc10_cuda.so)
frame #3: at::native::empty_cuda(c10::ArrayRef<long>, c10::TensorOptions const&) + 0x9e5 (0x7efe945b0935 in /private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/lib/libcaffe2_gpu.so)
frame #4: at::CUDAFloatType::empty(c10::ArrayRef<long>, c10::TensorOptions const&) const + 0x4d (0x7efe92d959bd in /private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/lib/libcaffe2_gpu.so)
frame #5: torch::autograd::VariableType::empty(c10::ArrayRef<long>, c10::TensorOptions const&) const + 0x1c1 (0x7efe87587ec1 in /private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/lib/libtorch.so.1)
frame #6: torch::cuda::broadcast(at::Tensor const&, c10::ArrayRef<long>) + 0x594 (0x7efeb89d3c44 in /private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #7: torch::cuda::broadcast_coalesced(c10::ArrayRef<at::Tensor>, c10::ArrayRef<long>, unsigned long) + 0x6f6 (0x7efeb89d47d6 in /private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #8: <unknown function> + 0x5b975e (0x7efeb89d875e in /private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #9: <unknown function> + 0x129586 (0x7efeb8548586 in /private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
<omitting python frames>
frame #18: THPFunction_apply(_object*, _object*) + 0x5e9 (0x7efeb87c1d69 in /private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #54: __libc_start_main + 0xe7 (0x7efec7330b97 in /lib/x86_64-linux-gnu/libc.so.6)
Running mode=ddp_single
------------------------------------------------------------
Conda PREFIX: /private/home/roller/.conda/envs/nightly
Torch version: 1.0.0.dev20190404
CUDA version: 10.0.130
Using a single GPU in distributed (equiv to 1 proc per gpu)
Step bs= 8192
Forward with bs = 8192
Backward with bs = 8192
FW/BW succeeded. Doubling BS
Step bs= 16384
Forward with bs = 16384
Backward with bs = 16384
FW/BW succeeded. Doubling BS
Step bs= 32768
Forward with bs = 32768
Backward with bs = 32768
FW/BW succeeded. Doubling BS
Step bs= 65536
Forward with bs = 65536
OOM #1! Running through a tiny batch to catch up worker
Forward with bs = 2
Backward with bs = 2
Traceback (most recent call last):
File "memtestcase.py", line 101, in run_trial
fwbw(model, bs)
File "memtestcase.py", line 63, in fwbw
yhat = model(X)
File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/modules/module.py", line 491, in __call__
result = self.forward(*input, **kwargs)
File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/modules/container.py", line 97, in forward
input = module(input)
File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/modules/module.py", line 491, in __call__
result = self.forward(*input, **kwargs)
File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/modules/linear.py", line 67, in forward
return F.linear(input, self.weight, self.bias)
File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/functional.py", line 1400, in linear
ret = torch.addmm(bias, input, weight.t())
RuntimeError: CUDA out of memory. Tried to allocate 1024.00 MiB (GPU 0; 15.90 GiB total capacity; 15.25 GiB already allocated; 9.56 MiB free; 1.59 MiB cached)
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "memtestcase.py", line 139, in <module>
main()
File "memtestcase.py", line 134, in main
run_trial(args)
File "memtestcase.py", line 113, in run_trial
fwbw(model, 2)
File "memtestcase.py", line 68, in fwbw
loss.backward()
File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/tensor.py", line 107, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/autograd/__init__.py", line 93, in backward
allow_unreachable=True) # allow_unreachable flag
RuntimeError: CUDA out of memory. Tried to allocate 64.00 MiB (GPU 0; 15.90 GiB total capacity; 15.25 GiB already allocated; 9.56 MiB free; 989.50 KiB cached)
Running mode=ddp_multi
------------------------------------------------------------
Conda PREFIX: /private/home/roller/.conda/envs/nightly
Torch version: 1.0.0.dev20190404
CUDA version: 10.0.130
Wrapping in DistributedDataParallel (equiv to 1 proc per node)
Step bs= 8192
Forward with bs = 8192
Traceback (most recent call last):
File "memtestcase.py", line 139, in <module>
main()
File "memtestcase.py", line 134, in main
run_trial(args)
File "memtestcase.py", line 107, in run_trial
raise rerr
File "memtestcase.py", line 101, in run_trial
fwbw(model, bs)
File "memtestcase.py", line 63, in fwbw
yhat = model(X)
File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/modules/module.py", line 491, in __call__
result = self.forward(*input, **kwargs)
File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/parallel/distributed.py", line 364, in forward
self._sync_params()
File "/private/home/roller/.conda/envs/nightly/lib/python3.7/site-packages/torch/nn/parallel/distributed.py", line 396, in _sync_params
param_data.set_(tensor)
RuntimeError: set_storage is not allowed on Tensor created from .data or .detach() |
This is quite serious and should be dealt with before the next PyTorch release. It appears to be a leak in the autograd engine -- note that if you remove the Basically, if you're lucky and the OOM happens during forward it's recoverable. If it happens during backward() it leaks memory and is effectively unrecoverable. Unclear if this is a regression or also a bug in older versions of PyTorch. I would suspect that some tasks live on in some of the Engine state. (Are the ready_queues empty after an error?) |
TODO: looks like I messed up the one-proc-per-gpu case. |
Updated scripts below:
#!/bin/bash
nvidia-smi
echo "Collect env"
echo "------------------------------------------------------------"
python collect_env.py
echo
for mode in single dp ddp_single ddp_multi
do
echo "Running mode=$mode"
echo "------------------------------------------------------------"
python -u memtestcase.py --mode=$mode 2>&1
echo
done
#!/usr/bin/env python
#SBATCH --gres=gpu:2
#SBATCH --job-name=distributed_example
#SBATCH --partition=dev
#SBATCH --nodes=1
#SBATCH --time=0:10:00
#SBATCH --ntasks-per-node=1
#SBATCH --mem=8G
#SBATCH --cpus-per-task=10
import os
import argparse
import torch
import torch.nn as nn
import torch.distributed as td
import torch.nn.parallel as tp
START_BS = 8 * 1024
# these don't matter, just constants meant to be a "big" model
INPUT_SIZE = 8192
HID_SIZE = 4096
LAYERS = 8
OUT_CLASSES = 4
def wrap_dp(model):
return tp.DataParallel(model)
def wrap_ddp_multi(model):
td.init_process_group(
backend='nccl',
init_method='tcp://localhost:61337',
rank=0,
world_size=1
)
model = tp.DistributedDataParallel(
model,
device_ids=None,
broadcast_buffers=False,
)
return model
def wrap_ddp_single(model):
td.init_process_group(
backend='nccl',
init_method='tcp://localhost:61337',
rank=0,
world_size=1
)
model = tp.DistributedDataParallel(
model,
device_ids=[0],
broadcast_buffers=False,
)
return model
def create_model(args):
model = nn.Sequential(
nn.Linear(INPUT_SIZE, HID_SIZE),
nn.ReLU(),
)
for i in range(LAYERS):
model.add_module('hidd' + str(i), nn.Linear(HID_SIZE, HID_SIZE))
model.add_module('relu' + str(i), nn.ReLU())
model.add_module('output', nn.Linear(HID_SIZE, OUT_CLASSES))
return model
def fwbw(model, bs):
print(' Forward with bs = {:-6d}'.format(bs))
X = torch.randn(bs, INPUT_SIZE).cuda()
torch.cuda.synchronize()
yhat = model(X)
torch.cuda.synchronize()
loss = yhat.sum()
torch.cuda.synchronize()
print(' Backward with bs = {:-6d}'.format(bs))
loss.backward()
torch.cuda.synchronize()
model.zero_grad()
torch.cuda.synchronize()
def run_trial(args):
print('Conda PREFIX:', os.environ['CONDA_PREFIX'])
print('Torch version:', torch.version.__version__)
print('CUDA version:', torch.version.cuda)
model = create_model(args).cuda()
if args.mode == 'dp':
print('Wrapping in DataParallel')
model = wrap_dp(model)
elif args.mode == 'ddp_multi':
print('Wrapping in DistributedDataParallel (equiv to 1 proc per node)')
model = wrap_ddp_multi(model)
elif args.mode == 'ddp_single':
print('Using a single GPU in distributed (equiv to 1 proc per gpu)')
torch.cuda.set_device(0)
model = wrap_ddp_single(model)
elif args.mode == 'single':
print('Using a single GPU')
pass
else:
raise ValueError('--mode wrong')
bs = args.bs
times_oomed = 0
while times_oomed < args.ooms:
# continuously double the batch size until we OOM
try:
print('Step bs=', bs)
fwbw(model, bs)
print('FW/BW succeeded. Doubling BS')
bs *= 2
except RuntimeError as rerr:
if 'memory' not in str(rerr):
# not the exception we wanted
raise rerr
# okay, we found the memory error! Now try to run a NOOP pass
# for DDP nodes. Production example here:
# https://github.com/pytorch/fairseq/blob/3658fa329b8cb987d951b2e38ec86c44b9e1fea5/fairseq/trainer.py#L361-L368
times_oomed += 1
print('OOM #{}! Running through a tiny batch to catch up worker'.format(times_oomed))
fwbw(model, 2)
print('Succeeded on the oom batch.')
# start the doubling procedure again
bs = args.bs
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
'--mode', default='ddp', choices=('dp', 'ddp_multi', 'ddp_single', 'single'),
help='DataParallel, DistributedDataParallel, or single gpu'
)
parser.add_argument(
'--ooms', default=1, type=int,
help='Number of times to OOM'
)
parser.add_argument(
'--bs', default=START_BS, type=int,
help='Initial batch size',
)
args = parser.parse_args()
run_trial(args)
print('Test passed.')
if __name__ == '__main__':
main() Log (1.0.1post2): devfair0237 parlai hybriddp » bash run.sh
Mon Apr 8 11:34:40 2019
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 410.79 Driver Version: 410.79 CUDA Version: 10.0 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
|===============================+======================+======================|
| 0 Quadro GP100 Off | 00000000:81:00.0 Off | 0 |
| 26% 38C P0 34W / 235W | 0MiB / 16278MiB | 0% Default |
+-------------------------------+----------------------+----------------------+
| 1 Quadro GP100 Off | 00000000:82:00.0 Off | 0 |
| 26% 38C P0 34W / 235W | 0MiB / 16278MiB | 0% Default |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: GPU Memory |
| GPU PID Type Process name Usage |
|=============================================================================|
| No running processes found |
+-----------------------------------------------------------------------------+
Collect env
------------------------------------------------------------
Collecting environment information...
PyTorch version: 1.0.1.post2
Is debug build: No
CUDA used to build PyTorch: 10.0.130
OS: Ubuntu 18.04.1 LTS
GCC version: (Ubuntu 7.3.0-27ubuntu1~18.04) 7.3.0
CMake version: version 3.10.2
Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: 10.0.130
GPU models and configuration:
GPU 0: Quadro GP100
GPU 1: Quadro GP100
Nvidia driver version: 410.79
cuDNN version: Could not collect
Versions of relevant libraries:
[pip] numpy==1.16.2
[pip] pytorch-pretrained-bert==0.6.1
[pip] torch==1.0.1.post2
[pip] torchtext==0.3.1
[conda] blas 1.0 mkl
[conda] mkl 2019.3 199
[conda] mkl_fft 1.0.10 py37ha843d7b_0
[conda] mkl_random 1.0.2 py37hd81dba3_0
[conda] pytorch 1.0.1 py3.7_cuda10.0.130_cudnn7.4.2_2 pytorch
[conda] pytorch-pretrained-bert 0.6.1 <pip>
[conda] torchtext 0.3.1 <pip>
Running mode=single
------------------------------------------------------------
Conda PREFIX: /private/home/roller/.conda/envs/retry-20190211
Torch version: 1.0.1.post2
CUDA version: 10.0.130
Using a single GPU
Step bs= 8192
Forward with bs = 8192
Backward with bs = 8192
FW/BW succeeded. Doubling BS
Step bs= 16384
Forward with bs = 16384
Backward with bs = 16384
FW/BW succeeded. Doubling BS
Step bs= 32768
Forward with bs = 32768
Backward with bs = 32768
FW/BW succeeded. Doubling BS
Step bs= 65536
Forward with bs = 65536
OOM #1! Running through a tiny batch to catch up worker
Forward with bs = 2
Backward with bs = 2
Traceback (most recent call last):
File "memtestcase.py", line 117, in run_trial
fwbw(model, bs)
File "memtestcase.py", line 78, in fwbw
yhat = model(X)
File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/nn/modules/module.py", line 489, in __call__
result = self.forward(*input, **kwargs)
File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/nn/modules/container.py", line 92, in forward
input = module(input)
File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/nn/modules/module.py", line 489, in __call__
result = self.forward(*input, **kwargs)
File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/nn/modules/linear.py", line 67, in forward
return F.linear(input, self.weight, self.bias)
File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/nn/functional.py", line 1352, in linear
ret = torch.addmm(torch.jit._unwrap_optional(bias), input, weight.t())
RuntimeError: CUDA out of memory. Tried to allocate 1024.00 MiB (GPU 0; 15.90 GiB total capacity; 15.25 GiB already allocated; 25.56 MiB free; 607.00 KiB cached)
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "memtestcase.py", line 155, in <module>
main()
File "memtestcase.py", line 150, in main
run_trial(args)
File "memtestcase.py", line 129, in run_trial
fwbw(model, 2)
File "memtestcase.py", line 83, in fwbw
loss.backward()
File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/tensor.py", line 102, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/autograd/__init__.py", line 90, in backward
allow_unreachable=True) # allow_unreachable flag
RuntimeError: CUDA out of memory. Tried to allocate 64.00 MiB (GPU 0; 15.90 GiB total capacity; 15.25 GiB already allocated; 25.56 MiB free; 989.50 KiB cached)
Running mode=dp
------------------------------------------------------------
Conda PREFIX: /private/home/roller/.conda/envs/retry-20190211
Torch version: 1.0.1.post2
CUDA version: 10.0.130
Wrapping in DataParallel
Step bs= 8192
Forward with bs = 8192
Backward with bs = 8192
FW/BW succeeded. Doubling BS
Step bs= 16384
Forward with bs = 16384
Backward with bs = 16384
FW/BW succeeded. Doubling BS
Step bs= 32768
Forward with bs = 32768
Backward with bs = 32768
FW/BW succeeded. Doubling BS
Step bs= 65536
Forward with bs = 65536
Backward with bs = 65536
FW/BW succeeded. Doubling BS
Step bs= 131072
Forward with bs = 131072
OOM #1! Running through a tiny batch to catch up worker
Forward with bs = 2
Traceback (most recent call last):
File "memtestcase.py", line 117, in run_trial
fwbw(model, bs)
File "memtestcase.py", line 78, in fwbw
yhat = model(X)
File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/nn/modules/module.py", line 489, in __call__
result = self.forward(*input, **kwargs)
File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 143, in forward
outputs = self.parallel_apply(replicas, inputs, kwargs)
File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 153, in parallel_apply
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 83, in parallel_apply
raise output
File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 59, in _worker
output = module(*input, **kwargs)
File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/nn/modules/module.py", line 489, in __call__
result = self.forward(*input, **kwargs)
File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/nn/modules/container.py", line 92, in forward
input = module(input)
File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/nn/modules/module.py", line 489, in __call__
result = self.forward(*input, **kwargs)
File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/nn/modules/activation.py", line 50, in forward
return F.threshold(input, self.threshold, self.value, self.inplace)
File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/nn/functional.py", line 840, in threshold
result = _VF.threshold(input, threshold, value)
RuntimeError: CUDA out of memory. Tried to allocate 1024.00 MiB (GPU 0; 15.90 GiB total capacity; 14.25 GiB already allocated; 997.56 MiB free; 607.00 KiB cached)
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "memtestcase.py", line 155, in <module>
main()
File "memtestcase.py", line 150, in main
run_trial(args)
File "memtestcase.py", line 129, in run_trial
fwbw(model, 2)
File "memtestcase.py", line 78, in fwbw
yhat = model(X)
File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/nn/modules/module.py", line 489, in __call__
result = self.forward(*input, **kwargs)
File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 142, in forward
replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 147, in replicate
return replicate(module, device_ids)
File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/nn/parallel/replicate.py", line 13, in replicate
param_copies = Broadcast.apply(devices, *params)
File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/nn/parallel/_functions.py", line 21, in forward
outputs = comm.broadcast_coalesced(inputs, ctx.target_gpus)
File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/cuda/comm.py", line 40, in broadcast_coalesced
return torch._C._broadcast_coalesced(tensors, devices, buffer_size)
RuntimeError: CUDA out of memory. Tried to allocate 64.12 MiB (GPU 1; 15.90 GiB total capacity; 15.19 GiB already allocated; 9.56 MiB free; 911.50 KiB cached) (malloc at /opt/conda/conda-bld/pytorch
_1549636813070/work/aten/src/THC/THCCachingAllocator.cpp:231)
frame #0: c10::Error::Error(c10::SourceLocation, std::string const&) + 0x45 (0x7fae1f0f1cf5 in /private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/lib/libc10.so)
frame #1: <unknown function> + 0x1239bc1 (0x7fae233d3bc1 in /private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/lib/libcaffe2_gpu.so)
frame #2: <unknown function> + 0x123a53a (0x7fae233d453a in /private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/lib/libcaffe2_gpu.so)
frame #3: at::native::empty_cuda(c10::ArrayRef<long>, at::TensorOptions const&) + 0x2d6 (0x7fae24a3edb6 in /private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/lib/libca
ffe2_gpu.so)
frame #4: at::CUDAFloatType::empty(c10::ArrayRef<long>, at::TensorOptions const&) const + 0x161 (0x7fae232f2311 in /private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/l
ib/libcaffe2_gpu.so)
frame #5: torch::autograd::VariableType::empty(c10::ArrayRef<long>, at::TensorOptions const&) const + 0x179 (0x7fae1832a209 in /private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-pack
ages/torch/lib/libtorch.so.1)
frame #6: torch::cuda::broadcast(at::Tensor const&, c10::ArrayRef<long>) + 0x545 (0x7fae467c3725 in /private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/lib/libtorch_pyt
hon.so)
frame #7: torch::cuda::broadcast_coalesced(c10::ArrayRef<at::Tensor>, c10::ArrayRef<long>, unsigned long) + 0x7e6 (0x7fae467c4396 in /private/home/roller/.conda/envs/retry-20190211/lib/python3.7/sit
e-packages/torch/lib/libtorch_python.so)
frame #8: <unknown function> + 0x4f2be6 (0x7fae467c8be6 in /private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #9: <unknown function> + 0x111af6 (0x7fae463e7af6 in /private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
<omitting python frames>
frame #18: THPFunction_apply(_object*, _object*) + 0x5a1 (0x7fae465e3061 in /private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #51: __libc_start_main + 0xe7 (0x7fae57b99b97 in /lib/x86_64-linux-gnu/libc.so.6)
Running mode=ddp_single
------------------------------------------------------------
Conda PREFIX: /private/home/roller/.conda/envs/retry-20190211
Torch version: 1.0.1.post2
CUDA version: 10.0.130
Using a single GPU in distributed (equiv to 1 proc per gpu)
Step bs= 8192
Forward with bs = 8192
Backward with bs = 8192
FW/BW succeeded. Doubling BS
Step bs= 16384
Forward with bs = 16384
Backward with bs = 16384
FW/BW succeeded. Doubling BS
Step bs= 32768
Forward with bs = 32768
Backward with bs = 32768
FW/BW succeeded. Doubling BS
Step bs= 65536
Forward with bs = 65536
OOM #1! Running through a tiny batch to catch up worker
Forward with bs = 2
Backward with bs = 2
Traceback (most recent call last):
File "memtestcase.py", line 117, in run_trial
fwbw(model, bs)
File "memtestcase.py", line 78, in fwbw
yhat = model(X)
File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/nn/modules/module.py", line 489, in __call__
result = self.forward(*input, **kwargs)
File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/nn/parallel/distributed.py", line 357, in forward
return self.module(*inputs[0], **kwargs[0])
File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/nn/modules/module.py", line 489, in __call__
result = self.forward(*input, **kwargs)
File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/nn/modules/container.py", line 92, in forward
input = module(input)
File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/nn/modules/module.py", line 489, in __call__
result = self.forward(*input, **kwargs)
File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/nn/modules/linear.py", line 67, in forward
return F.linear(input, self.weight, self.bias)
File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/nn/functional.py", line 1352, in linear
ret = torch.addmm(torch.jit._unwrap_optional(bias), input, weight.t())
RuntimeError: CUDA out of memory. Tried to allocate 1024.00 MiB (GPU 0; 15.90 GiB total capacity; 15.25 GiB already allocated; 19.56 MiB free; 607.00 KiB cached)
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "memtestcase.py", line 155, in <module>
main()
File "memtestcase.py", line 150, in main
run_trial(args)
File "memtestcase.py", line 129, in run_trial
fwbw(model, 2)
File "memtestcase.py", line 83, in fwbw
loss.backward()
File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/tensor.py", line 102, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/autograd/__init__.py", line 90, in backward
allow_unreachable=True) # allow_unreachable flag
RuntimeError: CUDA out of memory. Tried to allocate 64.00 MiB (GPU 0; 15.90 GiB total capacity; 15.25 GiB already allocated; 19.56 MiB free; 989.50 KiB cached)
Running mode=ddp_multi
------------------------------------------------------------
Conda PREFIX: /private/home/roller/.conda/envs/retry-20190211
Torch version: 1.0.1.post2
CUDA version: 10.0.130
Wrapping in DistributedDataParallel (equiv to 1 proc per node)
Step bs= 8192
Forward with bs = 8192
Backward with bs = 8192
FW/BW succeeded. Doubling BS
Step bs= 16384
Forward with bs = 16384
Backward with bs = 16384
FW/BW succeeded. Doubling BS
Step bs= 32768
Forward with bs = 32768
Backward with bs = 32768
FW/BW succeeded. Doubling BS
Step bs= 65536
Forward with bs = 65536
Backward with bs = 65536
FW/BW succeeded. Doubling BS
Step bs= 131072
Forward with bs = 131072
OOM #1! Running through a tiny batch to catch up worker
Forward with bs = 2
Traceback (most recent call last):
File "memtestcase.py", line 117, in run_trial
fwbw(model, bs)
File "memtestcase.py", line 78, in fwbw
yhat = model(X)
File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/nn/modules/module.py", line 489, in __call__
result = self.forward(*input, **kwargs)
File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/nn/parallel/distributed.py", line 358, in forward
outputs = self.parallel_apply(self._module_copies[:len(inputs)], inputs, kwargs)
File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/nn/parallel/distributed.py", line 365, in parallel_apply
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 83, in parallel_apply
raise output
File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 59, in _worker
output = module(*input, **kwargs)
File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/nn/modules/module.py", line 489, in __call__
result = self.forward(*input, **kwargs)
File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/nn/modules/container.py", line 92, in forward
input = module(input)
File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/nn/modules/module.py", line 489, in __call__
result = self.forward(*input, **kwargs)
File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/nn/modules/activation.py", line 50, in forward
return F.threshold(input, self.threshold, self.value, self.inplace)
File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/nn/functional.py", line 840, in threshold
result = _VF.threshold(input, threshold, value)
RuntimeError: CUDA out of memory. Tried to allocate 1024.00 MiB (GPU 0; 15.90 GiB total capacity; 14.25 GiB already allocated; 991.56 MiB free; 607.00 KiB cached)
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "memtestcase.py", line 155, in <module>
main()
File "memtestcase.py", line 150, in main
run_trial(args)
File "memtestcase.py", line 129, in run_trial
fwbw(model, 2)
File "memtestcase.py", line 78, in fwbw
yhat = model(X)
File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/nn/modules/module.py", line 489, in __call__
result = self.forward(*input, **kwargs)
File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/nn/parallel/distributed.py", line 355, in forward
self._sync_params()
File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/nn/parallel/distributed.py", line 384, in _sync_params
self.broadcast_bucket_size)
File "/private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/cuda/comm.py", line 40, in broadcast_coalesced
return torch._C._broadcast_coalesced(tensors, devices, buffer_size)
RuntimeError: CUDA out of memory. Tried to allocate 128.12 MiB (GPU 1; 15.90 GiB total capacity; 15.13 GiB already allocated; 89.56 MiB free; 992.00 KiB cached) (malloc at /opt/conda/conda-bld/pytorch_1549636813070/work/aten/src/THC/THCCachingAllocator.cpp:231)
frame #0: c10::Error::Error(c10::SourceLocation, std::string const&) + 0x45 (0x7f1e8e50acf5 in /private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/lib/libc10.so)
frame #1: <unknown function> + 0x1239bc1 (0x7f1e927ecbc1 in /private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/lib/libcaffe2_gpu.so)
frame #2: <unknown function> + 0x123a53a (0x7f1e927ed53a in /private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/lib/libcaffe2_gpu.so)
frame #3: at::native::empty_cuda(c10::ArrayRef<long>, at::TensorOptions const&) + 0x2d6 (0x7f1e93e57db6 in /private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/lib/libcaffe2_gpu.so)
frame #4: at::CUDAFloatType::empty(c10::ArrayRef<long>, at::TensorOptions const&) const + 0x161 (0x7f1e9270b311 in /private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/lib/libcaffe2_gpu.so)
frame #5: torch::autograd::VariableType::empty(c10::ArrayRef<long>, at::TensorOptions const&) const + 0x179 (0x7f1e87743209 in /private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/lib/libtorch.so.1)
frame #6: torch::cuda::broadcast(at::Tensor const&, c10::ArrayRef<long>) + 0x545 (0x7f1eb5bdc725 in /private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #7: torch::cuda::broadcast_coalesced(c10::ArrayRef<at::Tensor>, c10::ArrayRef<long>, unsigned long) + 0x7e6 (0x7f1eb5bdd396 in /private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #8: <unknown function> + 0x4f2be6 (0x7f1eb5be1be6 in /private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #9: <unknown function> + 0x111af6 (0x7f1eb5800af6 in /private/home/roller/.conda/envs/retry-20190211/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
<omitting python frames>
frame #43: __libc_start_main + 0xe7 (0x7f1ec6fb2b97 in /lib/x86_64-linux-gnu/libc.so.6) Log (nightly 2019-04-04):
|
Just ran with pytorch 1.1 given the release: Fri May 3 07:32:41 2019
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 410.79 Driver Version: 410.79 CUDA Version: 10.0 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
|===============================+======================+======================|
| 0 Quadro GP100 Off | 00000000:81:00.0 Off | 0 |
| 26% 34C P0 33W / 235W | 0MiB / 16278MiB | 0% Default |
+-------------------------------+----------------------+----------------------+
| 1 Quadro GP100 Off | 00000000:82:00.0 Off | 0 |
| 26% 34C P0 34W / 235W | 0MiB / 16278MiB | 0% Default |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: GPU Memory |
| GPU PID Type Process name Usage |
|=============================================================================|
| No running processes found |
+-----------------------------------------------------------------------------+
Collect env
------------------------------------------------------------
Collecting environment information...
PyTorch version: 1.1.0
Is debug build: No
CUDA used to build PyTorch: 10.0.130
OS: Ubuntu 18.04.1 LTS
GCC version: (Ubuntu 7.3.0-27ubuntu1~18.04) 7.3.0
CMake version: version 3.10.2
Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: 10.0.130
GPU models and configuration:
GPU 0: Quadro GP100
GPU 1: Quadro GP100
Nvidia driver version: 410.79
cuDNN version: Could not collect
Versions of relevant libraries:
[pip] numpy==1.16.3
[pip] torch==1.1.0
[pip] torchvision==0.2.2
[conda] blas 1.0 mkl
[conda] mkl 2019.3 199
[conda] mkl_fft 1.0.12 py37ha843d7b_0
[conda] mkl_random 1.0.2 py37hd81dba3_0
[conda] pytorch 1.1.0 py3.7_cuda10.0.130_cudnn7.5.1_0 pytorch
[conda] torchvision 0.2.2 py_3 pytorch
Running mode=single
------------------------------------------------------------
Conda PREFIX: /private/home/roller/.conda/envs/pytorch11
Torch version: 1.1.0
CUDA version: 10.0.130
Using a single GPU
Step bs= 8192
Forward with bs = 8192
Backward with bs = 8192
FW/BW succeeded. Doubling BS
Step bs= 16384
Forward with bs = 16384
Backward with bs = 16384
FW/BW succeeded. Doubling BS
Step bs= 32768
Forward with bs = 32768
Backward with bs = 32768
FW/BW succeeded. Doubling BS
Step bs= 65536
Forward with bs = 65536
OOM #1! Running through a tiny batch to catch up worker
Forward with bs = 2
Backward with bs = 2
Traceback (most recent call last):
File "memtestcase.py", line 117, in run_trial
fwbw(model, bs)
File "memtestcase.py", line 78, in fwbw
yhat = model(X)
File "/private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/nn/modules/module.py", line 493, in __call__
result = self.forward(*input, **kwargs)
File "/private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/nn/modules/container.py", line 92, in forward
input = module(input)
File "/private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/nn/modules/module.py", line 493, in __call__
result = self.forward(*input, **kwargs)
File "/private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/nn/modules/linear.py", line 92, in forward
return F.linear(input, self.weight, self.bias)
File "/private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/nn/functional.py", line 1406, in linear
ret = torch.addmm(bias, input, weight.t())
RuntimeError: CUDA out of memory. Tried to allocate 1024.00 MiB (GPU 0; 15.90 GiB total capacity; 15.25 GiB already allocated; 1.56 MiB free; 1.59 MiB cached)
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "memtestcase.py", line 155, in <module>
main()
File "memtestcase.py", line 150, in main
run_trial(args)
File "memtestcase.py", line 129, in run_trial
fwbw(model, 2)
File "memtestcase.py", line 83, in fwbw
loss.backward()
File "/private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/tensor.py", line 107, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File "/private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/autograd/__init__.py", line 93, in backward
allow_unreachable=True) # allow_unreachable flag
RuntimeError: CUDA out of memory. Tried to allocate 64.00 MiB (GPU 0; 15.90 GiB total capacity; 15.25 GiB already allocated; 1.56 MiB free; 989.50 KiB cached)
Running mode=dp
------------------------------------------------------------
Conda PREFIX: /private/home/roller/.conda/envs/pytorch11
Torch version: 1.1.0
CUDA version: 10.0.130
Wrapping in DataParallel
Step bs= 8192
Forward with bs = 8192
Backward with bs = 8192
FW/BW succeeded. Doubling BS
Step bs= 16384
Forward with bs = 16384
Backward with bs = 16384
FW/BW succeeded. Doubling BS
Step bs= 32768
Forward with bs = 32768
Backward with bs = 32768
FW/BW succeeded. Doubling BS
Step bs= 65536
Forward with bs = 65536
Backward with bs = 65536
FW/BW succeeded. Doubling BS
Step bs= 131072
Forward with bs = 131072
OOM #1! Running through a tiny batch to catch up worker
Forward with bs = 2
Traceback (most recent call last):
File "memtestcase.py", line 117, in run_trial
fwbw(model, bs)
File "memtestcase.py", line 78, in fwbw
yhat = model(X)
File "/private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/nn/modules/module.py", line 493, in __call__
result = self.forward(*input, **kwargs)
File "/private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 152, in forward
outputs = self.parallel_apply(replicas, inputs, kwargs)
File "/private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 162, in parallel_apply
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
File "/private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 83, in parallel_apply
raise output
File "/private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 59, in _worker
output = module(*input, **kwargs)
File "/private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/nn/modules/module.py", line 493, in __call__
result = self.forward(*input, **kwargs)
File "/private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/nn/modules/container.py", line 92, in forward
input = module(input)
File "/private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/nn/modules/module.py", line 493, in __call__
result = self.forward(*input, **kwargs)
File "/private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/nn/modules/activation.py", line 99, in forward
return F.relu(input, inplace=self.inplace)
File "/private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/nn/functional.py", line 943, in relu
result = torch.relu(input)
RuntimeError: CUDA out of memory. Tried to allocate 1024.00 MiB (GPU 0; 15.90 GiB total capacity; 14.25 GiB already allocated; 973.56 MiB free; 1.59 MiB cached)
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "memtestcase.py", line 155, in <module>
main()
File "memtestcase.py", line 150, in main
run_trial(args)
File "memtestcase.py", line 129, in run_trial
fwbw(model, 2)
File "memtestcase.py", line 78, in fwbw
yhat = model(X)
File "/private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/nn/modules/module.py", line 493, in __call__
result = self.forward(*input, **kwargs)
File "/private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 151, in forward
replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
File "/private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 156, in replicate
return replicate(module, device_ids)
File "/private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/nn/parallel/replicate.py", line 97, in replicate
param_copies = _broadcast_coalesced_reshape(params, devices, detach)
File "/private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/nn/parallel/replicate.py", line 80, in _broadcast_coalesced_reshape
tensor_copies = Broadcast.apply(devices, *tensors)
File "/private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/nn/parallel/_functions.py", line 21, in forward
outputs = comm.broadcast_coalesced(inputs, ctx.target_gpus)
File "/private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/cuda/comm.py", line 39, in broadcast_coalesced
return torch._C._broadcast_coalesced(tensors, devices, buffer_size)
RuntimeError: CUDA out of memory. Tried to allocate 66.00 MiB (GPU 1; 15.90 GiB total capacity; 15.13 GiB already allocated; 49.56 MiB free; 29.67 MiB cached) (malloc at /opt/conda/conda-bld/pytorch_1556653114079/work/c10/cuda/CUDACachingAllocator.cpp:267)
frame #0: c10::Error::Error(c10::SourceLocation, std::string const&) + 0x45 (0x7f07289d3dc5 in /private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/lib/libc10.so)
frame #1: <unknown function> + 0x16ca7 (0x7f0720c9cca7 in /private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/lib/libc10_cuda.so)
frame #2: <unknown function> + 0x17347 (0x7f0720c9d347 in /private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/lib/libc10_cuda.so)
frame #3: at::native::empty_cuda(c10::ArrayRef<long>, c10::TensorOptions const&) + 0x274 (0x7f072e850f14 in /private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/lib/libcaffe2_gpu.so)
frame #4: at::CUDAType::empty(c10::ArrayRef<long>, c10::TensorOptions const&) const + 0x19b (0x7f072cfa04bb in /private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/lib/libcaffe2_gpu.so)
frame #5: torch::autograd::VariableType::empty(c10::ArrayRef<long>, c10::TensorOptions const&) const + 0x268 (0x7f07218debf8 in /private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/lib/libtorch.so.1)
frame #6: at::native::to(at::Tensor const&, c10::Device, c10::ScalarType, bool, bool) + 0x687 (0x7f07293a1457 in /private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/lib/libcaffe2.so)
frame #7: at::TypeDefault::to(at::Tensor const&, c10::Device, c10::ScalarType, bool, bool) const + 0x1b (0x7f07296217eb in /private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/lib/libcaffe2.so)
frame #8: torch::autograd::VariableType::to(at::Tensor const&, c10::Device, c10::ScalarType, bool, bool) const + 0x312 (0x7f07217ca5c2 in /private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/lib/libtorch.so.1)
frame #9: torch::cuda::broadcast(at::Tensor const&, c10::ArrayRef<long>) + 0x100 (0x7f0721d1f1d0 in /private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/lib/libtorch.so.1)
frame #10: torch::cuda::broadcast_coalesced(c10::ArrayRef<at::Tensor>, c10::ArrayRef<long>, unsigned long) + 0x439 (0x7f0721d1f829 in /private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/lib/libtorch.so.1)
frame #11: <unknown function> + 0x5a912e (0x7f0757e5912e in /private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #12: <unknown function> + 0x12d07a (0x7f07579dd07a in /private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
<omitting python frames>
frame #21: THPFunction_apply(_object*, _object*) + 0x691 (0x7f0757c5f891 in /private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #57: __libc_start_main + 0xe7 (0x7f07667bdb97 in /lib/x86_64-linux-gnu/libc.so.6)
Running mode=ddp_single
------------------------------------------------------------
Conda PREFIX: /private/home/roller/.conda/envs/pytorch11
Torch version: 1.1.0
CUDA version: 10.0.130
Using a single GPU in distributed (equiv to 1 proc per gpu)
Step bs= 8192
Forward with bs = 8192
Backward with bs = 8192
FW/BW succeeded. Doubling BS
Step bs= 16384
Forward with bs = 16384
Backward with bs = 16384
FW/BW succeeded. Doubling BS
Step bs= 32768
Forward with bs = 32768
Backward with bs = 32768
FW/BW succeeded. Doubling BS
Step bs= 65536
Forward with bs = 65536
OOM #1! Running through a tiny batch to catch up worker
Forward with bs = 2
Backward with bs = 2
Succeeded on the oom batch.
Test passed.
Running mode=ddp_multi
------------------------------------------------------------
Conda PREFIX: /private/home/roller/.conda/envs/pytorch11
Torch version: 1.1.0
CUDA version: 10.0.130
Wrapping in DistributedDataParallel (equiv to 1 proc per node)
Step bs= 8192
Forward with bs = 8192
Backward with bs = 8192
FW/BW succeeded. Doubling BS
Step bs= 16384
Forward with bs = 16384
Backward with bs = 16384
FW/BW succeeded. Doubling BS
Step bs= 32768
Forward with bs = 32768
Backward with bs = 32768
FW/BW succeeded. Doubling BS
Step bs= 65536
Forward with bs = 65536
Backward with bs = 65536
FW/BW succeeded. Doubling BS
Step bs= 131072
Forward with bs = 131072
OOM #1! Running through a tiny batch to catch up worker
Forward with bs = 2
Traceback (most recent call last):
File "memtestcase.py", line 117, in run_trial
fwbw(model, bs)
File "memtestcase.py", line 78, in fwbw
yhat = model(X)
File "/private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/nn/modules/module.py", line 493, in __call__
result = self.forward(*input, **kwargs)
File "/private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/nn/parallel/distributed.py", line 378, in forward
outputs = self.parallel_apply(self._module_copies[:len(inputs)], inputs, kwargs)
File "/private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/nn/parallel/distributed.py", line 399, in parallel_apply
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
File "/private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 83, in parallel_apply
raise output
File "/private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 59, in _worker
output = module(*input, **kwargs)
File "/private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/nn/modules/module.py", line 493, in __call__
result = self.forward(*input, **kwargs)
File "/private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/nn/modules/container.py", line 92, in forward
input = module(input)
File "/private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/nn/modules/module.py", line 493, in __call__
result = self.forward(*input, **kwargs)
File "/private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/nn/modules/activation.py", line 99, in forward
return F.relu(input, inplace=self.inplace)
File "/private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/nn/functional.py", line 943, in relu
result = torch.relu(input)
RuntimeError: CUDA out of memory. Tried to allocate 1024.00 MiB (GPU 0; 15.90 GiB total capacity; 14.88 GiB already allocated; 193.56 MiB free; 137.39 MiB cached)
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "memtestcase.py", line 155, in <module>
main()
File "memtestcase.py", line 150, in main
run_trial(args)
File "memtestcase.py", line 129, in run_trial
fwbw(model, 2)
File "memtestcase.py", line 78, in fwbw
yhat = model(X)
File "/private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/nn/modules/module.py", line 493, in __call__
result = self.forward(*input, **kwargs)
File "/private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/nn/parallel/distributed.py", line 372, in forward
self._sync_params()
File "/private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/nn/parallel/distributed.py", line 420, in _sync_params
self.broadcast_bucket_size)
File "/private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/cuda/comm.py", line 39, in broadcast_coalesced
return torch._C._broadcast_coalesced(tensors, devices, buffer_size)
RuntimeError: CUDA out of memory. Tried to allocate 258.00 MiB (GPU 0; 15.90 GiB total capacity; 14.88 GiB already allocated; 193.56 MiB free; 137.33 MiB cached) (malloc at /opt/conda/conda-bld/pytorch_1556653114079/work/c10/cuda/CUDACachingAllocator.cpp:267)
frame #0: c10::Error::Error(c10::SourceLocation, std::string const&) + 0x45 (0x7f82d47c1dc5 in /private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/lib/libc10.so)
frame #1: <unknown function> + 0x16ca7 (0x7f82cca8aca7 in /private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/lib/libc10_cuda.so)
frame #2: <unknown function> + 0x17347 (0x7f82cca8b347 in /private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/lib/libc10_cuda.so)
frame #3: THCStorage_resize + 0x96 (0x7f82d8eac706 in /private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/lib/libcaffe2_gpu.so)
frame #4: THCTensor_resizeNd + 0x4d8 (0x7f82d8ec2e68 in /private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/lib/libcaffe2_gpu.so)
frame #5: THCudaTensor_catArray + 0x4f1 (0x7f82d917abe1 in /private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/lib/libcaffe2_gpu.so)
frame #6: at::CUDAType::_th_cat(c10::ArrayRef<at::Tensor>, long) const + 0x968 (0x7f82d8d910a8 in /private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/lib/libcaffe2_gpu.so)
frame #7: at::native::cat(c10::ArrayRef<at::Tensor>, long) + 0xa4 (0x7f82d51ba2b4 in /private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/lib/libcaffe2.so)
frame #8: at::TypeDefault::cat(c10::ArrayRef<at::Tensor>, long) const + 0x4f (0x7f82d54715bf in /private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/lib/libcaffe2.so)
frame #9: torch::autograd::VariableType::cat(c10::ArrayRef<at::Tensor>, long) const + 0x7f1 (0x7f82cd588ac1 in /private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/lib/libtorch.so.1)
frame #10: <unknown function> + 0xbd0c86 (0x7f82cda72c86 in /private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/lib/libtorch.so.1)
frame #11: torch::cuda::broadcast_coalesced(c10::ArrayRef<at::Tensor>, c10::ArrayRef<long>, unsigned long) + 0x412 (0x7f82cdb0d802 in /private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/lib/libtorch.so.1)
frame #12: <unknown function> + 0x5a912e (0x7f8303c4712e in /private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #13: <unknown function> + 0x12d07a (0x7f83037cb07a in /private/home/roller/.conda/envs/pytorch11/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
<omitting python frames>
frame #47: __libc_start_main + 0xe7 (0x7f83125a6b97 in /lib/x86_64-linux-gnu/libc.so.6) |
@albanD @ezyang @colesbury do you have any thoughts on the autograd memory leak issue? |
This may not be relevant. We recently have the CUDA OOM problem. We were able to resolve the issue by removing some "weak_script" decorators (#20563). |
Super interesting. I’m not using weak script anywhere (and neither is the test case) |
Any updates? |
Checkout #21344, it is not exactly the same issue, but I suspect the OOM error also has something to do with DDP hooks. |
removing 1.1 milestone -- it didn't make it and doesn't seem blocking for 1.2. |
Hey @stephenroller, I am not sure if this relates to the memory leak problem above, but the OOM recovery code you linked above might hit desync problem. Say we have 2 DDP processes (p0, p1), and each organizes its backward reduction into 2 buckets (b0,b1). If the OOM occurs on p0 after reducing b0, p1 would try to reduce b1, but p0 would retry both b0 and b1. It then leads to allreducing two different buckets (b0 in p0 and b1 in p1) together, and hence the desync. To avoid this problem, the trainer will need to destroy and reconstruct ProcessGroup and DDP objects on all processes. |
Would it be possible for you to write a basic code snippet on how to implement such a routine? |
I was able to reproduce the failure without autograd by modifying the script slightly. I removed @torch.no_grad()
def fwbw(model, bs):
print(' Forward with bs = {:-6d}'.format(bs))
X = torch.randn(bs, INPUT_SIZE, device='cuda')
torch.cuda.synchronize()
yhat = model(X)
torch.cuda.synchronize()
loss = yhat.sum()
torch.cuda.synchronize() Without autograd, this uses much less cuda memory for the same batch size. So, for the oom batch to fail to allocate, we need to increase the batch size proportional to the failed batch size. If I use |
This appears to be the same underlying issue as #27600; the memory for the tensors isn't freed until the exception object has gone out of scope. I was able to get try:
fwbw(model, bs)
bs *= 2
oom = False
except RuntimeError as rerr:
if 'memory' not in str(rerr):
# not the exception we wanted
raise rerr
oom = True
if oom:
# do recovery batch
@mrshenli could this be related to the desync issue you mentioned above? |
Oh wow, that’s a great workaround @peterbell10, thanks for documenting! |
cc @pritamdamania87 that introduced this error message in #27940 Do you know why it is triggered here? |
@pritamdamania87, @mrshenli any thoughts as to why the graph task might be expiring here? |
I ran the test again with a newer build of pytorch and |
@peterbell10 Yes, this looks resolved to our satisfaction! Perhaps the only last thing is to maybe think about a good place to add some documentation about this. |
@ezyang We're also experiencing this issue (DDP doesn't sync well once a OOM recovery has happened, because some variables of autograd graph never had backward). Should we do sth similar as in https://github.com/facebookresearch/ParlAI/blob/e9b95d8be32c7b08f487423441de68d05709a615/parlai/core/torch_generator_agent.py ? That re-passes dummy patch to trigger backward sync Is there a way to manually tell DDP that some variables should marked "as if backward pass happened on them"? Then we could just call |
Not sure, haven't dug into the guts of DDP enough to know if there's a way to sync using manually set gradients. The "standard" way is it hooks into backward(), which makes that difficult |
It looks like DDP.join might serve the purpose? Please ping on this thread if you try that and it works. |
Or maybe DDP._match_all_reduce_for_bwd_pass |
Note that moving the OOM recovery logic outside of the import traceback
...
try:
fwbw(model, bs)
bs *= 2
except RuntimeError as rerr:
if 'memory' not in str(rerr):
# not the exception we wanted
raise rerr
# Got OOM.
traceback.clear_frames(rerr.__traceback__) # free locals
# do recovery batch Note, one thing to be careful about (and this is no matter whether you handle the OOM recovering inside or outside the In that case, it is tricky to free the references to any Torch tensors on the exception stack. Just calling tb = rerr.__traceback__
while tb:
try:
tb.tb_frame.clear()
except RuntimeError:
pass # can happen if this still executing, e.g. the current frame
else:
# Using this code triggers that the ref actually goes out of scope, otherwise it does not!
# https://github.com/python/cpython/issues/113939
tb.tb_frame.f_locals # noqa
tb = tb.tb_next |
lower batch size to 10. pytorch/pytorch#18853
lower batch size to 10. pytorch/pytorch#18853
🐛 Bug
Editorial note: Make sure that you are not holding on to tensors via an exception object (which contains the stack trace and will retain tensors). Do not bring the exception object into scope, or do error recovery outside of the catch block.
Catching a RuntimeError on a CUDA OOM should allow one to gracefully recover, for example by lowering the batchsize. This is particularly important when using DistributedDataParallel, where workers must sync on backward, and so it's important that we be able to perform a "dummy batch" after an OOM in order to stay in sync with other workers.
Observed behavior during a CUDA out of memory event is inconsistent across nondistributed/dataparallel/distributeddataparallel. Expected behavior is that all modes should be able to recover easily.
To Reproduce
Test case and logs available here:
https://gist.github.com/stephenroller/bd2cd644e7c117c1ec8192639ecf30b6
Steps to reproduce the behavior:
memtestcase.py
andrun.sh
run.sh
. Observe which test cases pass.Logs from several environments on the fair cluster:
Expected behavior
All test cases pass. At the very least, test cases should produce consistent results across all
--mode
s.Environment
Here's the environment from the third log (@stephenroller's) (pytorch 1.0.1.post2, cuda 10.0.130):
cc @ezyang @gchanan @zou3519 @ssnl @albanD @gqchen @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @aazzolini @xush6528 @osalpekar
The text was updated successfully, but these errors were encountered: