Skip to content

Commit

Permalink
Adding world_size
Browse files Browse the repository at this point in the history
Reduce calls to torch.distributed. For use in create_dataloader.
  • Loading branch information
NanoCode012 authored Jul 14, 2020
1 parent e742dd9 commit d738487
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ def train(hyp, tb_writer, opt, device):
loss, loss_items = compute_loss(pred, targets.to(device), model)
# loss is scaled with batch size in func compute_loss. But in DDP mode, gradient is averaged between devices.
if local_rank != -1:
loss *= dist.get_world_size()
loss *= opt.world_size
if not torch.isfinite(loss):
print('WARNING: non-finite loss, ending training ', loss_items)
return results
Expand Down Expand Up @@ -451,15 +451,17 @@ def train(hyp, tb_writer, opt, device):
opt.total_batch_size = opt.batch_size
if device.type == 'cpu':
mixed_precision = False
opt.world_size = 1
elif opt.local_rank != -1:
# DDP mode
assert torch.cuda.device_count() > opt.local_rank
torch.cuda.set_device(opt.local_rank)
device = torch.device("cuda", opt.local_rank)
dist.init_process_group(backend='nccl', init_method='env://') # distributed backend

assert opt.batch_size % dist.get_world_size() == 0
opt.batch_size = opt.total_batch_size // dist.get_world_size()

opt.world_size = dist.get_world_size()
assert opt.batch_size % opt.world_size == 0
opt.batch_size = opt.total_batch_size // opt.world_size
print(opt)

# Train
Expand Down

0 comments on commit d738487

Please sign in to comment.