Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update DDP for torch.distributed.run with gloo backend #3680

Merged
merged 35 commits into from
Jun 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
007902e
Update DDP for `torch.distributed.run`
glenn-jocher Jun 18, 2021
9bcb4ad
Add LOCAL_RANK
glenn-jocher Jun 18, 2021
b32bae0
remove opt.local_rank
glenn-jocher Jun 18, 2021
b467501
backend="gloo|nccl"
glenn-jocher Jun 18, 2021
c886538
print
glenn-jocher Jun 18, 2021
5d847dc
print
glenn-jocher Jun 18, 2021
26d0ecf
debug
glenn-jocher Jun 18, 2021
832ba4c
debug
glenn-jocher Jun 18, 2021
9a1bb01
os.getenv
glenn-jocher Jun 18, 2021
0e912df
gloo
glenn-jocher Jun 18, 2021
5f5e428
gloo
glenn-jocher Jun 18, 2021
e8493c6
gloo
glenn-jocher Jun 18, 2021
fb342fc
cleanup
glenn-jocher Jun 18, 2021
382ce4f
fix getenv
glenn-jocher Jun 18, 2021
b09b415
cleanup
glenn-jocher Jun 18, 2021
9c4ac05
cleanup destroy
glenn-jocher Jun 18, 2021
8ae9ea1
try nccl
glenn-jocher Jun 18, 2021
a18f933
merge master
glenn-jocher Jun 19, 2021
2435775
return opt
glenn-jocher Jun 19, 2021
56a4ab4
add --local_rank
glenn-jocher Jun 19, 2021
c4d839b
add timeout
glenn-jocher Jun 19, 2021
0584e7e
add init_method
glenn-jocher Jun 19, 2021
d917341
gloo
glenn-jocher Jun 19, 2021
6a1cc64
move destroy
glenn-jocher Jun 19, 2021
3581c76
move destroy
glenn-jocher Jun 19, 2021
5f5d122
move print(opt) under if RANK
glenn-jocher Jun 19, 2021
5451fc2
destroy only RANK 0
glenn-jocher Jun 19, 2021
9aa229e
move destroy inside train()
glenn-jocher Jun 19, 2021
94363ce
restore destroy outside train()
glenn-jocher Jun 19, 2021
9647379
update print(opt)
glenn-jocher Jun 19, 2021
cb8395d
merge master
glenn-jocher Jun 19, 2021
96686fd
cleanup
glenn-jocher Jun 19, 2021
446c610
nccl
glenn-jocher Jun 19, 2021
49bb0b7
gloo with 60 second timeout
glenn-jocher Jun 19, 2021
b5decde
update namespace printing
glenn-jocher Jun 19, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

from models.experimental import attempt_load
from utils.datasets import LoadStreams, LoadImages
from utils.general import check_img_size, check_requirements, check_imshow, non_max_suppression, apply_classifier, \
scale_coords, xyxy2xywh, strip_optimizer, set_logging, increment_path, save_one_box
from utils.general import check_img_size, check_requirements, check_imshow, colorstr, non_max_suppression, \
apply_classifier, scale_coords, xyxy2xywh, strip_optimizer, set_logging, increment_path, save_one_box
from utils.plots import colors, plot_one_box
from utils.torch_utils import select_device, load_classifier, time_synchronized

Expand Down Expand Up @@ -202,7 +202,7 @@ def parse_opt():


def main(opt):
print(opt)
print(colorstr('detect: ') + ', '.join(f'{k}={v}' for k, v in vars(opt).items()))
check_requirements(exclude=('tensorboard', 'thop'))
detect(**vars(opt))

Expand Down
2 changes: 1 addition & 1 deletion models/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,8 @@ def parse_opt():


def main(opt):
print(opt)
set_logging()
print(colorstr('export: ') + ', '.join(f'{k}={v}' for k, v in vars(opt).items()))
export(**vars(opt))


Expand Down
4 changes: 2 additions & 2 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ def test(data,
device = next(model.parameters()).device # get model device

else: # called directly
set_logging()
device = select_device(device, batch_size=batch_size)

# Directories
Expand Down Expand Up @@ -323,7 +322,8 @@ def parse_opt():


def main(opt):
print(opt)
set_logging()
print(colorstr('test: ') + ', '.join(f'{k}={v}' for k, v in vars(opt).items()))
check_requirements(exclude=('tensorboard', 'thop'))

if opt.task in ('train', 'val', 'test'): # run normally
Expand Down
95 changes: 46 additions & 49 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,17 @@
from utils.wandb_logging.wandb_utils import WandbLogger, check_wandb_resume

logger = logging.getLogger(__name__)
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
RANK = int(os.getenv('RANK', -1))
WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))


def train(hyp, # path/to/hyp.yaml or hyp dictionary
opt,
device,
):
save_dir, epochs, batch_size, total_batch_size, weights, rank, single_cls = \
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank, \
opt.single_cls
save_dir, epochs, batch_size, total_batch_size, weights, single_cls = \
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.single_cls

# Directories
wdir = save_dir / 'weights'
Expand All @@ -69,13 +71,13 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
# Configure
plots = not opt.evolve # create plots
cuda = device.type != 'cpu'
init_seeds(2 + rank)
init_seeds(2 + RANK)
with open(opt.data) as f:
data_dict = yaml.safe_load(f) # data dict

# Loggers
loggers = {'wandb': None, 'tb': None} # loggers dict
if rank in [-1, 0]:
if RANK in [-1, 0]:
# TensorBoard
if not opt.evolve:
prefix = colorstr('tensorboard: ')
Expand All @@ -99,7 +101,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
# Model
pretrained = weights.endswith('.pt')
if pretrained:
with torch_distributed_zero_first(rank):
with torch_distributed_zero_first(RANK):
weights = attempt_download(weights) # download if not found locally
ckpt = torch.load(weights, map_location=device) # load checkpoint
model = Model(opt.cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
Expand All @@ -110,7 +112,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
logger.info('Transferred %g/%g items from %s' % (len(state_dict), len(model.state_dict()), weights)) # report
else:
model = Model(opt.cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
with torch_distributed_zero_first(rank):
with torch_distributed_zero_first(RANK):
check_dataset(data_dict) # check
train_path = data_dict['train']
test_path = data_dict['val']
Expand Down Expand Up @@ -158,7 +160,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
# plot_lr_scheduler(optimizer, scheduler, epochs)

# EMA
ema = ModelEMA(model) if rank in [-1, 0] else None
ema = ModelEMA(model) if RANK in [-1, 0] else None

# Resume
start_epoch, best_fitness = 0, 0.0
Expand Down Expand Up @@ -194,28 +196,28 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
imgsz, imgsz_test = [check_img_size(x, gs) for x in opt.img_size] # verify imgsz are gs-multiples

# DP mode
if cuda and rank == -1 and torch.cuda.device_count() > 1:
if cuda and RANK == -1 and torch.cuda.device_count() > 1:
model = torch.nn.DataParallel(model)

# SyncBatchNorm
if opt.sync_bn and cuda and rank != -1:
if opt.sync_bn and cuda and RANK != -1:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
logger.info('Using SyncBatchNorm()')

# Trainloader
dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, single_cls,
hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=rank,
world_size=opt.world_size, workers=opt.workers,
hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=RANK,
workers=opt.workers,
image_weights=opt.image_weights, quad=opt.quad, prefix=colorstr('train: '))
mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
nb = len(dataloader) # number of batches
assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, opt.data, nc - 1)

# Process 0
if rank in [-1, 0]:
if RANK in [-1, 0]:
testloader = create_dataloader(test_path, imgsz_test, batch_size * 2, gs, single_cls,
hyp=hyp, cache=opt.cache_images and not opt.notest, rect=True, rank=-1,
world_size=opt.world_size, workers=opt.workers,
workers=opt.workers,
pad=0.5, prefix=colorstr('val: '))[0]

if not opt.resume:
Expand All @@ -234,8 +236,8 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
model.half().float() # pre-reduce anchor precision

# DDP mode
if cuda and rank != -1:
model = DDP(model, device_ids=[opt.local_rank], output_device=opt.local_rank,
if cuda and RANK != -1:
model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK,
# nn.MultiheadAttention incompatibility with DDP https://github.com/pytorch/pytorch/issues/26698
find_unused_parameters=any(isinstance(layer, nn.MultiheadAttention) for layer in model.modules()))

Expand Down Expand Up @@ -269,27 +271,27 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
# Update image weights (optional)
if opt.image_weights:
# Generate indices
if rank in [-1, 0]:
if RANK in [-1, 0]:
cw = model.class_weights.cpu().numpy() * (1 - maps) ** 2 / nc # class weights
iw = labels_to_image_weights(dataset.labels, nc=nc, class_weights=cw) # image weights
dataset.indices = random.choices(range(dataset.n), weights=iw, k=dataset.n) # rand weighted idx
# Broadcast if DDP
if rank != -1:
indices = (torch.tensor(dataset.indices) if rank == 0 else torch.zeros(dataset.n)).int()
if RANK != -1:
indices = (torch.tensor(dataset.indices) if RANK == 0 else torch.zeros(dataset.n)).int()
dist.broadcast(indices, 0)
if rank != 0:
if RANK != 0:
dataset.indices = indices.cpu().numpy()

# Update mosaic border
# b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs)
# dataset.mosaic_border = [b - imgsz, -b] # height, width borders

mloss = torch.zeros(4, device=device) # mean losses
if rank != -1:
if RANK != -1:
dataloader.sampler.set_epoch(epoch)
pbar = enumerate(dataloader)
logger.info(('\n' + '%10s' * 8) % ('Epoch', 'gpu_mem', 'box', 'obj', 'cls', 'total', 'labels', 'img_size'))
if rank in [-1, 0]:
if RANK in [-1, 0]:
pbar = tqdm(pbar, total=nb) # progress bar
optimizer.zero_grad()
for i, (imgs, targets, paths, _) in pbar: # batch -------------------------------------------------------------
Expand Down Expand Up @@ -319,8 +321,8 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
with amp.autocast(enabled=cuda):
pred = model(imgs) # forward
loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
if rank != -1:
loss *= opt.world_size # gradient averaged between devices in DDP mode
if RANK != -1:
loss *= WORLD_SIZE # gradient averaged between devices in DDP mode
if opt.quad:
loss *= 4.

Expand All @@ -336,7 +338,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
ema.update(model)

# Print
if rank in [-1, 0]:
if RANK in [-1, 0]:
mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
mem = '%.3gG' % (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0) # (GB)
s = ('%10s' * 2 + '%10.4g' * 6) % (
Expand All @@ -362,7 +364,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
scheduler.step()

# DDP process 0 or single-GPU
if rank in [-1, 0]:
if RANK in [-1, 0]:
# mAP
ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride', 'class_weights'])
final_epoch = epoch + 1 == epochs
Expand Down Expand Up @@ -424,7 +426,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary

# end epoch ----------------------------------------------------------------------------------------------------
# end training -----------------------------------------------------------------------------------------------------
if rank in [-1, 0]:
if RANK in [-1, 0]:
logger.info(f'{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.\n')
if plots:
plot_results(save_dir=save_dir) # save as results.png
Expand Down Expand Up @@ -457,8 +459,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
name='run_' + wandb_logger.wandb_run.id + '_model',
aliases=['latest', 'best', 'stripped'])
wandb_logger.finish_run()
else:
dist.destroy_process_group()

torch.cuda.empty_cache()
return results

Expand Down Expand Up @@ -486,7 +487,6 @@ def parse_opt():
parser.add_argument('--single-cls', action='store_true', help='train multi-class data as single-class')
parser.add_argument('--adam', action='store_true', help='use torch.optim.Adam() optimizer')
parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
parser.add_argument('--workers', type=int, default=8, help='maximum number of dataloader workers')
parser.add_argument('--project', default='runs/train', help='save to project/name')
parser.add_argument('--entity', default=None, help='W&B entity')
Expand All @@ -499,18 +499,15 @@ def parse_opt():
parser.add_argument('--bbox_interval', type=int, default=-1, help='Set bounding-box image logging interval for W&B')
parser.add_argument('--save_period', type=int, default=-1, help='Log model after every "save_period" epoch')
parser.add_argument('--artifact_alias', type=str, default="latest", help='version of dataset artifact to be used')
parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
opt = parser.parse_args()

# Set DDP variables
opt.world_size = int(getattr(os.environ, 'WORLD_SIZE', 1))
opt.global_rank = int(getattr(os.environ, 'RANK', -1))
return opt


def main(opt):
print(opt)
set_logging(opt.global_rank)
if opt.global_rank in [-1, 0]:
set_logging(RANK)
if RANK in [-1, 0]:
print(colorstr('train: ') + ', '.join(f'{k}={v}' for k, v in vars(opt).items()))
check_git_status()
check_requirements(exclude=['thop'])

Expand All @@ -519,11 +516,9 @@ def main(opt):
if opt.resume and not wandb_run: # resume an interrupted run
ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # specified or most recent path
assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist'
apriori = opt.global_rank, opt.local_rank
with open(Path(ckpt).parent.parent / 'opt.yaml') as f:
opt = argparse.Namespace(**yaml.safe_load(f)) # replace
opt.cfg, opt.weights, opt.resume, opt.batch_size, opt.global_rank, opt.local_rank = \
'', ckpt, True, opt.total_batch_size, *apriori # reinstate
opt.cfg, opt.weights, opt.resume, opt.batch_size = '', ckpt, True, opt.total_batch_size # reinstate
logger.info('Resuming training from %s' % ckpt)
else:
# opt.hyp = opt.hyp or ('hyp.finetune.yaml' if opt.weights else 'hyp.scratch.yaml')
Expand All @@ -536,19 +531,21 @@ def main(opt):
# DDP mode
opt.total_batch_size = opt.batch_size
device = select_device(opt.device, batch_size=opt.batch_size)
if opt.local_rank != -1:
assert torch.cuda.device_count() > opt.local_rank
torch.cuda.set_device(opt.local_rank)
device = torch.device('cuda', opt.local_rank)
dist.init_process_group(backend='nccl', init_method='env://') # distributed backend
assert opt.batch_size % opt.world_size == 0, '--batch-size must be multiple of CUDA device count'
if LOCAL_RANK != -1:
from datetime import timedelta
assert torch.cuda.device_count() > LOCAL_RANK, 'too few GPUS for DDP command'
torch.cuda.set_device(LOCAL_RANK)
device = torch.device('cuda', LOCAL_RANK)
dist.init_process_group(backend="gloo", timeout=timedelta(seconds=60))
assert opt.batch_size % WORLD_SIZE == 0, '--batch-size must be multiple of CUDA device count'
assert not opt.image_weights, '--image-weights argument is not compatible with DDP training'
opt.batch_size = opt.total_batch_size // opt.world_size
opt.batch_size = opt.total_batch_size // WORLD_SIZE

# Train
logger.info(opt)
if not opt.evolve:
train(opt.hyp, opt, device)
if WORLD_SIZE > 1 and RANK == 0:
_ = [print('Destroying process group... ', end=''), dist.destroy_process_group(), print('Done.')]

# Evolve hyperparameters (optional)
else:
Expand Down Expand Up @@ -584,7 +581,7 @@ def main(opt):

with open(opt.hyp) as f:
hyp = yaml.safe_load(f) # load hyps dict
assert opt.local_rank == -1, 'DDP mode not implemented for --evolve'
assert LOCAL_RANK == -1, 'DDP mode not implemented for --evolve'
opt.notest, opt.nosave = True, True # only test/save final epoch
# ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices
yaml_file = Path(opt.save_dir) / 'hyp_evolved.yaml' # save best result here
Expand Down
4 changes: 2 additions & 2 deletions utils/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def exif_size(img):


def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=None, augment=False, cache=False, pad=0.0,
rect=False, rank=-1, world_size=1, workers=8, image_weights=False, quad=False, prefix=''):
rect=False, rank=-1, workers=8, image_weights=False, quad=False, prefix=''):
# Make sure only the first process in DDP process the dataset first, and the following others can use the cache
with torch_distributed_zero_first(rank):
dataset = LoadImagesAndLabels(path, imgsz, batch_size,
Expand All @@ -79,7 +79,7 @@ def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=Non
prefix=prefix)

batch_size = min(batch_size, len(dataset))
nw = min([os.cpu_count() // world_size, batch_size if batch_size > 1 else 0, workers]) # number of workers
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, workers]) # number of workers
sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None
loader = torch.utils.data.DataLoader if image_weights else InfiniteDataLoader
# Use torch.utils.data.DataLoader() if dataset.properties will update during training else InfiniteDataLoader()
Expand Down
5 changes: 3 additions & 2 deletions utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import torch
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
import torchvision
Expand All @@ -30,10 +31,10 @@ def torch_distributed_zero_first(local_rank: int):
Decorator to make all processes in distributed training wait for each local_master to do something.
"""
if local_rank not in [-1, 0]:
torch.distributed.barrier()
dist.barrier()
yield
if local_rank == 0:
torch.distributed.barrier()
dist.barrier()


def init_torch_seeds(seed=0):
Expand Down
6 changes: 4 additions & 2 deletions utils/wandb_logging/wandb_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Utilities and tools for tracking runs with Weights & Biases."""
import logging
import os
import sys
from contextlib import contextmanager
from pathlib import Path
Expand All @@ -18,6 +19,7 @@
except ImportError:
wandb = None

RANK = int(os.getenv('RANK', -1))
WANDB_ARTIFACT_PREFIX = 'wandb-artifact://'


Expand All @@ -42,10 +44,10 @@ def get_run_info(run_path):


def check_wandb_resume(opt):
process_wandb_config_ddp_mode(opt) if opt.global_rank not in [-1, 0] else None
process_wandb_config_ddp_mode(opt) if RANK not in [-1, 0] else None
if isinstance(opt.resume, str):
if opt.resume.startswith(WANDB_ARTIFACT_PREFIX):
if opt.global_rank not in [-1, 0]: # For resuming DDP runs
if RANK not in [-1, 0]: # For resuming DDP runs
entity, project, run_id, model_artifact_name = get_run_info(opt.resume)
api = wandb.Api()
artifact = api.artifact(entity + '/' + project + '/' + model_artifact_name + ':latest')
Expand Down