From 50a9828679d075772a0875a5b2488fb9febb1082 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Mon, 30 Aug 2021 18:35:07 +0200 Subject: [PATCH] DDP `torch.jit.trace()` `--sync-bn` fix (#4615) * Remove assert * debug0 * trace=not opt.sync * sync to sync_bn fix * Cleanup --- train.py | 3 +-- utils/loggers/__init__.py | 9 +++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/train.py b/train.py index 2fe38ef043d..36492edb8f0 100644 --- a/train.py +++ b/train.py @@ -333,7 +333,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB) pbar.set_description(('%10s' * 2 + '%10.4g' * 5) % ( f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1])) - callbacks.on_train_batch_end(ni, model, imgs, targets, paths, plots) + callbacks.on_train_batch_end(ni, model, imgs, targets, paths, plots, opt.sync_bn) # end batch ------------------------------------------------------------------------------------------------ # Scheduler @@ -499,7 +499,6 @@ def main(opt): assert opt.batch_size % WORLD_SIZE == 0, '--batch-size must be multiple of CUDA device count' assert not opt.image_weights, '--image-weights argument is not compatible with DDP training' assert not opt.evolve, '--evolve argument is not compatible with DDP training' - assert not opt.sync_bn, '--sync-bn known training issue, see https://github.com/ultralytics/yolov5/issues/3998' torch.cuda.set_device(LOCAL_RANK) device = torch.device('cuda', LOCAL_RANK) dist.init_process_group(backend="nccl" if dist.is_nccl_available() else "gloo") diff --git a/utils/loggers/__init__.py b/utils/loggers/__init__.py index 775803abf06..0750be6c882 100644 --- a/utils/loggers/__init__.py +++ b/utils/loggers/__init__.py @@ -69,13 +69,14 @@ def on_pretrain_routine_end(self): if self.wandb: self.wandb.log({"Labels": [wandb.Image(str(x), caption=x.name) for x in paths]}) - def on_train_batch_end(self, ni, model, imgs, targets, paths, plots): + def on_train_batch_end(self, ni, model, imgs, targets, paths, plots, sync_bn): # Callback runs on train batch end if plots: if ni == 0: - with warnings.catch_warnings(): - warnings.simplefilter('ignore') # suppress jit trace warning - self.tb.add_graph(torch.jit.trace(de_parallel(model), imgs[0:1], strict=False), []) + if not sync_bn: # tb.add_graph() --sync known issue https://github.com/ultralytics/yolov5/issues/3754 + with warnings.catch_warnings(): + warnings.simplefilter('ignore') # suppress jit trace warning + self.tb.add_graph(torch.jit.trace(de_parallel(model), imgs[0:1], strict=False), []) if ni < 3: f = self.save_dir / f'train_batch{ni}.jpg' # filename Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start()