diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 0f76c072291f4..35027cf3c1ac6 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -851,16 +851,18 @@ def fit( if self.is_function_implemented('on_fit_start'): model.on_fit_start() - self.setup('fit') - if self.is_function_implemented('setup'): - model.setup('fit') - # on multi-gpu jobs we only want to manipulate (download, etc) on node_rank=0, local_rank=0 # or in the case where each node needs to do its own manipulation in which case just local_rank=0 if self.can_prepare_data(): model.prepare_data() self._is_data_prepared = True + self.barrier('fit_prepare_data') + + self.setup('fit') + if self.is_function_implemented('setup'): + model.setup('fit') + # Run auto batch size scaling if self.auto_scale_batch_size: if isinstance(self.auto_scale_batch_size, bool): @@ -1150,6 +1152,8 @@ def test( model_ref = self.model if model is None else model model_ref.setup('test') + self.barrier('test_setup') + if model is None and ckpt_path == 'best' and self.checkpoint_callback.save_top_k <= 0: raise MisconfigurationException( 'ckpt_path is "best", but ModelCheckpoint is not configured to save the best model.') @@ -1251,6 +1255,14 @@ def check_model_configuration(self, model: LightningModule): raise MisconfigurationException('You have defined `test_step()` but did not' ' implement `test_dataloader` nor passed in `.test(test_dataloader)`.') + def barrier(self, name): + if self.use_ddp or self.use_ddp2: + torch_distrib.barrier() + + if self.on_tpu and XLA_AVAILABLE: + # wait for all processes to catch up + torch_xla.core.xla_model.rendezvous(f'pl.Trainer.{name}') + class _PatchDataLoader(object): r"""