Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torchvision classification train.py script fails with DistributedDataParallel and --apex #1119

Closed
andravin opened this issue Jul 15, 2019 · 8 comments

Comments

@andravin
Copy link

andravin commented Jul 15, 2019

Environment:

torch 1.1, CUDA 10.0, cuDNN 7.5, torchvision 0.3, apex built from master (commit 574fe2449cbe6ae4c8af53c6ecb1b5fc13877234)

Summary:

The torchvision references train.py script fails when used with DistributedDataParallel and --apex. The error indicates that "the parallel wrappers should only be applied to the model(s) AFTER the model(s) have been returned from amp.initialize"

Commandline:

OMP_NUM_THREADS=1
WORKERS=16
CUDA_PATH=/usr/local/cuda

COMMENT="benchmark-apex"
MODEL=mobilenet_v2

DATADIR=/data/imagenet

NAME=${MODEL}${COMMENT}
LOGDIR=/data/log/${NAME}
MODELDIR=/data/trained-models/${NAME}
mkdir -p $LOGDIR

python3 -m torch.distributed.launch --use_env --nproc_per_node=8\
        train.py --model $MODEL\
        --epochs 300\
        --lr 0.045\
        --wd 0.00004 --lr-step-size 1 --lr-gamma 0.98\
        --apex\
        --data-path=${DATADIR}\
        --output-dir=${MODELDIR}\
        --batch-size=32\
        --print-freq=100\
        --workers=$WORKERS |& tee -a $LOGDIR/train-${NAME}.log

Output:

$ source train.bash                                                                                                                                                                                                          
| distributed init (rank 5): env://
| distributed init (rank 4): env://
| distributed init (rank 1): env://
| distributed init (rank 6): env://
| distributed init (rank 7): env://
| distributed init (rank 2): env://
| distributed init (rank 3): env://
| distributed init (rank 0): env://
Namespace(apex=True, apex_opt_level='O1', batch_size=32, cache_dataset=False, data_path='/data/imagenet', device='cuda', dist_backend='nccl', dist_url='env://', distributed=True, epochs=300, gpu=0, lr=0.045, lr_gamma=0.98, lr_step_size=1, model='mobilenet_v2', momentum=0.9, 
output_dir='/data/trained-models/mobilenet_v2benchmark-apex', pretrained=False, print_freq=100, rank=0, resume='', start_epoch=0, sync_bn=False, test_only=False, weight_decay=4e-05, workers=16, world_size=8)
Loading data
Loading training data
Took 4.644477128982544
Loading validation data
Creating data loaders
Creating model
Traceback (most recent call last):
  File "train.py", line 305, in <module>
Traceback (most recent call last):
  File "train.py", line 305, in <module>
    main(args)
  File "train.py", line 193, in main
    opt_level=args.apex_opt_level
  File "/home/ubuntu/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/apex/amp/frontend.py", line 357, in initialize
    main(args)
  File "train.py", line 193, in main
    opt_level=args.apex_opt_level
  File "/home/ubuntu/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/apex/amp/frontend.py", line 357, in initialize
    return _initialize(models, optimizers, _amp_state.opt_properties, num_losses, cast_model_outputs)
  File "/home/ubuntu/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/apex/amp/_initialize.py", line 174, in _initialize
    return _initialize(models, optimizers, _amp_state.opt_properties, num_losses, cast_model_outputs)
  File "/home/ubuntu/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/apex/amp/_initialize.py", line 174, in _initialize
    check_models(models)
  File "/home/ubuntu/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/apex/amp/_initialize.py", line 72, in check_models
    "Parallel wrappers should only be applied to the model(s) AFTER \n"
    check_models(models)
  File "/home/ubuntu/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/apex/amp/_initialize.py", line 72, in check_models
RuntimeError: Incoming model is an instance of torch.nn.parallel.DistributedDataParallel. Parallel wrappers should only be applied to the model(s) AFTER
the model(s) have been returned from amp.initialize.
    "Parallel wrappers should only be applied to the model(s) AFTER \n"
RuntimeError: Incoming model is an instance of torch.nn.parallel.DistributedDataParallel. Parallel wrappers should only be applied to the model(s) AFTER
the model(s) have been returned from amp.initialize.
Traceback (most recent call last):
  File "train.py", line 305, in <module>
    main(args)
  File "train.py", line 193, in main
    opt_level=args.apex_opt_level
  File "/home/ubuntu/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/apex/amp/frontend.py", line 357, in initialize
    return _initialize(models, optimizers, _amp_state.opt_properties, num_losses, cast_model_outputs)
  File "/home/ubuntu/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/apex/amp/_initialize.py", line 174, in _initialize
    check_models(models)
  File "/home/ubuntu/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/apex/amp/_initialize.py", line 72, in check_models
    "Parallel wrappers should only be applied to the model(s) AFTER \n"
[...]

Fix:

Simply moving the torch.nn.parallel.DistributedDataParallelcall down a few lines in the script so that it happens after the amp.initialize call seems to fix the issue, though I have not yet tested it thoroughly with different combinations of commandline arguments.

@andravin andravin changed the title torchvision classification train.py and fails with DistributedDataParallel and --apex torchvision classification train.py script fails with DistributedDataParallel and --apex Jul 15, 2019
@andravin
Copy link
Author

Here is the relevant section of the train.py script:

if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
model_without_ddp = model.module
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(
model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
if args.apex:
model, optimizer = amp.initialize(model, optimizer,
opt_level=args.apex_opt_level
)

@fmassa
Copy link
Member

fmassa commented Jul 15, 2019

@vinhngx given that you have originally added support for APEX with mixed precision, can you have a look?

I have no experience with APEX and I'm busy with other things now, so I wouldn't have a chance to look at this.

@vinhngx
Copy link
Contributor

vinhngx commented Jul 16, 2019

Glad someone spotted this :)
I can confirm that this is a right step towards the correct solution. According to https://github.com/NVIDIA/apex/tree/master/examples/imagenet

To use DDP with apex.amp, the only gotcha is that

model, optimizer = amp.initialize(model, optimizer, flags...)

must precede

model = DDP(model)

In addition:

With both Apex DDP and Torch DDP, you must also call torch.cuda.set_device(args.local_rank) within each process prior to initializing your model or any other tensors. 

Therefore, we also need to add torch.cuda.set_device(args.gpu) to apex initialization, and move apex initialization code to after utils.init_distributed_mode(args).

    if args.apex:
        if sys.version_info < (3, 0):
            raise RuntimeError("Apex currently only supports Python 3. Aborting.")
        if amp is None:
            raise RuntimeError("Failed to import apex. Please install apex from https://www.github.com/nvidia/apex "
                               "to enable mixed-precision training.")
        if args.distributed:
            torch.cuda.set_device(args.gpu)

I've created a PR here: #1124

@andravin
Copy link
Author

@vinhngx torch.cuda.set_device(args.gpu) appears to be already called in init_distributed_mode.

@vinhngx
Copy link
Contributor

vinhngx commented Jul 17, 2019

@vinhngx torch.cuda.set_device(args.gpu) appears to be already called in init_distributed_mode.

indeed it is. Thanks for pointing this out. So we don't need to do it in the Apex initialization code.

@fmassa
Copy link
Member

fmassa commented Jul 19, 2019

@vinhngx One question, is this a change in behavior in APEX, or was it a bug since the beginning?

@vinhngx
Copy link
Contributor

vinhngx commented Jul 23, 2019

@fmassa I suppose it is a bug in the beginning

@fmassa
Copy link
Member

fmassa commented Jul 23, 2019

Fixed in #1124

@fmassa fmassa closed this as completed Jul 23, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants