Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added barrier #2245

Merged
merged 6 commits into from
Jun 19, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -851,16 +851,18 @@ def fit(
if self.is_function_implemented('on_fit_start'):
model.on_fit_start()

self.setup('fit')
if self.is_function_implemented('setup'):
model.setup('fit')

# on multi-gpu jobs we only want to manipulate (download, etc) on node_rank=0, local_rank=0
# or in the case where each node needs to do its own manipulation in which case just local_rank=0
if self.can_prepare_data():
model.prepare_data()
self._is_data_prepared = True

self.barrier('fit_prepare_data')

self.setup('fit')
if self.is_function_implemented('setup'):
model.setup('fit')

Borda marked this conversation as resolved.
Show resolved Hide resolved
# Run auto batch size scaling
if self.auto_scale_batch_size:
if isinstance(self.auto_scale_batch_size, bool):
Expand Down Expand Up @@ -1150,6 +1152,8 @@ def test(
model_ref = self.model if model is None else model
model_ref.setup('test')

self.barrier('test_setup')

if model is None and ckpt_path == 'best' and self.checkpoint_callback.save_top_k <= 0:
raise MisconfigurationException(
'ckpt_path is "best", but ModelCheckpoint is not configured to save the best model.')
Expand Down Expand Up @@ -1251,6 +1255,14 @@ def check_model_configuration(self, model: LightningModule):
raise MisconfigurationException('You have defined `test_step()` but did not'
' implement `test_dataloader` nor passed in `.test(test_dataloader)`.')

def barrier(self, name):
if self.use_ddp or self.use_ddp2:
torch_distrib.barrier()

if self.on_tpu and XLA_AVAILABLE:
# wait for all processes to catch up
torch_xla.core.xla_model.rendezvous(f'pl.Trainer.{name}')


class _PatchDataLoader(object):
r"""
Expand Down