Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Trainer.validate(…) method to run one validation epoch #4948

Merged
merged 45 commits into from
Mar 11, 2021
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
edb3e83
Refactor Trainer in advance of implementing Trainer.validate
EliaCereda Dec 2, 2020
03d7994
Merge remote-tracking branch 'upstream/master' into feature/trainer-v…
EliaCereda Dec 2, 2020
5a54485
Add Trainer.validate(...) method to perform one evaluation epoch over…
EliaCereda Dec 2, 2020
e06775c
Rename methods in Trainer and Accelerator to reflect that they are us…
EliaCereda Dec 2, 2020
b4e409c
Update docs to mention the new Trainer.validate method and associated…
EliaCereda Dec 2, 2020
96e42ba
Add tests for Trainer.validate(…)
EliaCereda Dec 2, 2020
85b3c9f
Update CHANGELOG.md
EliaCereda Dec 2, 2020
39113dc
Merge branch 'master' into feature/trainer-validate-2
tchaton Dec 3, 2020
a6be0d8
Replace usages of Trainer.testing with Trainer.evaluating, should be …
EliaCereda Dec 4, 2020
a922d57
Merge remote-tracking branch 'upstream/master' into feature/trainer-v…
EliaCereda Dec 4, 2020
595f4e8
Clean up calls to LightningDataModule.setup()
EliaCereda Dec 8, 2020
0b09248
Update test_trainer_validate_loop.py to use BoringModel instead of Ev…
EliaCereda Dec 8, 2020
d691d79
Merge remote-tracking branch 'upstream/master' into feature/trainer-v…
EliaCereda Dec 8, 2020
52eaa70
Merge branch 'feature/trainer-validate-1' into feature/trainer-valida…
EliaCereda Dec 8, 2020
06b4419
Fix ShardedPlugin when evaluating
EliaCereda Dec 8, 2020
e6a8be9
Merge remote-tracking branch 'origin/feature/trainer-validate-1' into…
EliaCereda Dec 8, 2020
389940e
Add tests for Trainer.validate with ShardedPlugin
EliaCereda Dec 8, 2020
6d0a95a
Remove superfluous calls to LoggerConnector.set_stage in validate() a…
EliaCereda Dec 10, 2020
704b121
Update more docstrings to mention Trainer.validate
EliaCereda Dec 10, 2020
f6e0759
Merge branch 'release/1.2-dev' into feature/trainer-validate-1
tchaton Jan 11, 2021
90e59c7
Merge remote-tracking branch 'upstream/release/1.2-dev' into feature/…
EliaCereda Jan 26, 2021
45d7e0a
Merge branch 'feature/trainer-validate-1' into feature/trainer-valida…
EliaCereda Jan 26, 2021
12a85b3
Pass {fit,validate,test,predict} to setup()
carmocca Mar 7, 2021
d49ccd1
Fix doctest
carmocca Mar 7, 2021
23db135
stage: Optional[str] = None
carmocca Mar 7, 2021
84f5fdb
Trailing whitespace
carmocca Mar 7, 2021
188b9fe
Update docs and CHANGELOG
carmocca Mar 7, 2021
37473f0
Mention teardown
carmocca Mar 7, 2021
0a30abf
Self-review
carmocca Mar 7, 2021
0e9d69c
Address Borda's comments
carmocca Mar 7, 2021
04343ce
Merge branch 'deleteme-carmocca' into feature/trainer-validate-2
carmocca Mar 7, 2021
9758c7b
Fixing conflicts
carmocca Mar 7, 2021
18280df
Implement Trainer.validate
carmocca Mar 7, 2021
e582d58
Refactor
carmocca Mar 7, 2021
1a5b620
Merge branch 'master' into feature/trainer-validate-2
carmocca Mar 8, 2021
5b99ec0
flake8
carmocca Mar 8, 2021
9f4dce2
Refactor
carmocca Mar 8, 2021
088d4bc
Missing import
carmocca Mar 8, 2021
58fcca4
Fix test
carmocca Mar 8, 2021
babb73d
Same threshold
carmocca Mar 8, 2021
235dc27
Address tchaton's comments
carmocca Mar 8, 2021
73dd265
Merge branch 'master' into feature/trainer-validate-2
carmocca Mar 8, 2021
e423b98
Missing import
carmocca Mar 10, 2021
cdec83b
Merge branch 'master' into feature/trainer-validate-2
carmocca Mar 10, 2021
8fab50f
Apply suggestions from code review
carmocca Mar 10, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `Pytorch Geometric` integration example with Lightning ([#4568](https://github.com/PyTorchLightning/pytorch-lightning/pull/4568))


- Added `Trainer.validate()` method to perform one evaluation epoch over the validation set (
[#4707](https://github.com/PyTorchLightning/pytorch-lightning/pull/4707))


### Changed

- Tuner algorithms will be skipped if `fast_dev_run=True` ([#3903](https://github.com/PyTorchLightning/pytorch-lightning/pull/3903))
Expand Down
15 changes: 14 additions & 1 deletion docs/source/trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -148,14 +148,27 @@ So you can run it like so:

------------

Validation
----------
You can perform an evaluation epoch over the validation set, outside of the training loop,
using :meth:`pytorch_lightning.trainer.trainer.Trainer.validate`. This might be
useful if you want to collect new metrics from a model right at its initialization
or that has already been trained.

.. code-block:: python

trainer.validate(val_dataloaders=val_dataloaders)

------------

Testing
-------
Once you're done training, feel free to run the test set!
(Only right before publishing your paper or pushing to production)

.. code-block:: python

trainer.test(test_dataloader=test_dataloader)
trainer.test(test_dataloaders=test_dataloaders)

------------

Expand Down
8 changes: 4 additions & 4 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@ def barrier(self, name: Optional[str] = None):
def broadcast(self, obj, src=0):
return obj

def train_or_test(self):
if self.trainer.testing:
results = self.trainer.run_test()
def train_or_evaluate(self):
if self.trainer.evaluating:
results = self.trainer.run_test_or_validate()
else:
results = self.trainer.train()
return results
Expand Down Expand Up @@ -160,7 +160,7 @@ def early_stopping_should_stop(self, pl_module):
return self.trainer.should_stop

def setup_optimizers(self, model):
if self.trainer.testing is True:
if self.trainer.evaluating:
return

optimizers, lr_schedulers, optimizer_frequencies = self.trainer.init_optimizers(model)
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/accelerators/cpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ def train(self):
# set up training routine
self.trainer.train_loop.setup_training(model)

# train or test
results = self.train_or_test()
# train or evaluate
results = self.train_or_evaluate()
return results

def training_step(self, args):
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/accelerators/ddp2_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,8 @@ def ddp_train(self, process_idx, mp_queue, model):
# set up training routine
self.trainer.train_loop.setup_training(model)

# train or test
results = self.train_or_test()
# train or evaluate
results = self.train_or_evaluate()

# clean up memory
torch.cuda.empty_cache()
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/accelerators/ddp_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,8 +275,8 @@ def ddp_train(self, process_idx, model):
self.barrier('ddp_setup')
self.trainer.train_loop.setup_training(model)

# train or test
results = self.train_or_test()
# train or evaluate
results = self.train_or_evaluate()

# clean up memory
torch.cuda.empty_cache()
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,8 @@ def ddp_train(self, process_idx, mp_queue, model):
# set up training routine
self.trainer.train_loop.setup_training(model)

# train or test
results = self.train_or_test()
# train or evaluate
results = self.train_or_evaluate()

# get original model
model = self.trainer.get_model()
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/accelerators/ddp_hpc_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,8 @@ def ddp_train(self, process_idx, model):
# set up training routine
self.trainer.train_loop.setup_training(model)

# train or test
results = self.train_or_test()
# train or evaluate
results = self.train_or_evaluate()

# clean up memory
torch.cuda.empty_cache()
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/accelerators/ddp_spawn_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,8 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0
# set up training routine
self.trainer.train_loop.setup_training(model)

# train or test
results = self.train_or_test()
# train or evaluate
results = self.train_or_evaluate()

# get original model
model = self.trainer.get_model()
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/accelerators/dp_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ def train(self):
# set up training routine
self.trainer.train_loop.setup_training(model)

# train or test
results = self.train_or_test()
# train or evaluate
results = self.train_or_evaluate()

return results

Expand Down
5 changes: 3 additions & 2 deletions pytorch_lightning/accelerators/gpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,9 @@ def train(self):
# set up training routine
self.trainer.train_loop.setup_training(model)

# train or test
results = self.train_or_test()
# train or evaluate
results = self.train_or_evaluate()

return results

def training_step(self, args):
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/accelerators/horovod_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,8 @@ def train(self):
# set up training routine
self.trainer.train_loop.setup_training(self.trainer.model)

# train or test
results = self.train_or_test()
# train or evaluate
results = self.train_or_evaluate()

# Make sure all workers have finished training before returning to the user
hvd.join()
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/accelerators/tpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,8 @@ def tpu_train_in_process(self, tpu_core_idx: int, model: LightningModule, traine
# set up training routine
self.trainer.train_loop.setup_training(model)

# train or test
results = self.train_or_test()
# train or evaluate
results = self.train_or_evaluate()

# save weights at the end of training
self.__save_end_of_training_weights(model, trainer)
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/callbacks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ class Callback(abc.ABC):
"""

def setup(self, trainer, pl_module, stage: str):
"""Called when fit or test begins"""
"""Called when fit, validate, or test begins"""
pass

def teardown(self, trainer, pl_module, stage: str):
"""Called when fit or test ends"""
"""Called when fit, validate, or test ends"""
pass

def on_init_start(self, trainer):
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,13 +134,13 @@ def on_load_checkpoint(self, checkpointed_state):
self.patience = checkpointed_state['patience']

def on_validation_end(self, trainer, pl_module):
if trainer.running_sanity_check:
if trainer.running_sanity_check or trainer.evaluating:
return

self._run_early_stopping_check(trainer, pl_module)

def on_validation_epoch_end(self, trainer, pl_module):
if trainer.running_sanity_check:
if trainer.running_sanity_check or trainer.evaluating:
return

if self._validate_condition_metric(trainer.logger_connector.callback_metrics):
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ def save_checkpoint(self, trainer, pl_module):
or self.period < 1 # no models are saved
or (epoch + 1) % self.period # skip epoch
or trainer.running_sanity_check # don't save anything during sanity check
or trainer.evaluating # don't save anything during evaluation: might delete the checkpoint being evaluated
or self.last_global_step_saved == global_step # already saved at the last step
):
return
Expand Down
22 changes: 18 additions & 4 deletions pytorch_lightning/callbacks/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,9 +282,13 @@ def init_train_tqdm(self) -> tqdm:

def init_validation_tqdm(self) -> tqdm:
""" Override this to customize the tqdm bar for validation. """

# The main progress bar doesn't exist in trainer.validate(...)
has_main_bar = int(self.main_progress_bar is not None)

bar = tqdm(
desc='Validating',
position=(2 * self.process_position + 1),
position=(2 * self.process_position + has_main_bar),
disable=self.is_disabled,
leave=False,
dynamic_ncols=True,
Expand Down Expand Up @@ -341,19 +345,29 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, data
def on_validation_start(self, trainer, pl_module):
super().on_validation_start(trainer, pl_module)
if not trainer.running_sanity_check:
self._update_bar(self.main_progress_bar) # fill up remaining
# The main progress bar doesn't exist in trainer.validate(...)
if self.main_progress_bar is not None:
self._update_bar(self.main_progress_bar) # fill up remaining

self.val_progress_bar = self.init_validation_tqdm()
self.val_progress_bar.total = convert_inf(self.total_val_batches)

def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
if self._should_update(self.val_batch_idx, self.total_val_batches):
self._update_bar(self.val_progress_bar)
self._update_bar(self.main_progress_bar)

# The main progress bar doesn't exist in trainer.validate(...)
if self.main_progress_bar is not None:
self._update_bar(self.main_progress_bar)

def on_validation_end(self, trainer, pl_module):
super().on_validation_end(trainer, pl_module)
self.main_progress_bar.set_postfix(trainer.progress_bar_dict)

# The main progress bar doesn't exist in trainer.validate(...)
if self.main_progress_bar is not None:
self.main_progress_bar.set_postfix(trainer.progress_bar_dict)

self.val_progress_bar.close()

def on_train_end(self, trainer, pl_module):
Expand Down
15 changes: 14 additions & 1 deletion pytorch_lightning/core/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,16 @@ def wrapped_fn(*args, **kwargs):
if fn.__name__ == "setup":

# Get stage either by grabbing from args or checking kwargs.
# If not provided, set call status of 'fit' and 'test' to True.
# If not provided, set call status of 'fit', 'validation', and 'test' to True.
# We do this so __attach_datamodule in trainer.py doesn't mistakenly call setup('test') on trainer.test()
stage = args[1] if len(args) > 1 else kwargs.get("stage", None)

if stage == "fit" or stage is None:
obj._has_setup_fit = True

if stage == "validation" or stage is None:
obj._has_setup_validation = True

if stage == "test" or stage is None:
obj._has_setup_test = True

Expand Down Expand Up @@ -155,6 +158,7 @@ def __init__(
# Private attrs to keep track of whether or not data hooks have been called yet
self._has_prepared_data = False
self._has_setup_fit = False
self._has_setup_validation = False
self._has_setup_test = False

@property
Expand Down Expand Up @@ -230,6 +234,15 @@ def has_setup_fit(self):
"""
return self._has_setup_fit

@property
def has_setup_validation(self):
"""Return bool letting you know if datamodule.setup('validation') has been called or not.

Returns:
bool: True if datamodule.setup('validation') has been called. False by default.
"""
return self._has_setup_validation

@property
def has_setup_test(self):
"""Return bool letting you know if datamodule.setup('test') has been called or not.
Expand Down
8 changes: 4 additions & 4 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@ class ModelHooks:
"""Hooks to be used in LightningModule."""
def setup(self, stage: str):
"""
Called at the beginning of fit and test.
Called at the beginning of fit (training + validation), validation, and test.
This is a good hook when you need to build models dynamically or adjust something about them.
This hook is called on every process when using DDP.

Args:
stage: either 'fit' or 'test'
stage: either 'fit', 'validation', or 'test'

Example::

Expand All @@ -54,10 +54,10 @@ def setup(stage):

def teardown(self, stage: str):
"""
Called at the end of fit and test.
Called at the end of fit (training + validation), validation, and test.

Args:
stage: either 'fit' or 'test'
stage: either 'fit', 'validation', or 'test'
"""

def on_fit_start(self):
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/trainer/configuration_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@ def verify_loop_configurations(self, model: LightningModule):
model: The model to check the configuration.

"""
if not self.trainer.testing:
if not self.trainer.evaluating:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
self.__verify_train_loop_configuration(model)
self.__verify_eval_loop_configuration(model, 'validation')
else:
# check test loop configuration
self.__verify_eval_loop_configuration(model, 'test')
# check evaluation loop configurations
self.__verify_eval_loop_configuration(model, self.trainer.evaluating)

def __verify_train_loop_configuration(self, model):
# -----------------------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def prepare_eval_loop_results(self):
for dl_idx in range(self.trainer.evaluation_loop.num_dataloaders):
self.add_to_eval_loop_results(dl_idx, has_been_initialized)

def get_evaluate_epoch_results(self, test_mode):
def get_evaluate_epoch_results(self):
if not self.trainer.running_sanity_check:
# log all the metrics as a single dict
metrics_to_log = self.cached_results.get_epoch_log_metrics()
Expand All @@ -274,11 +274,11 @@ def get_evaluate_epoch_results(self, test_mode):

self.prepare_eval_loop_results()

# log results of test
if test_mode and self.trainer.is_global_zero and self.trainer.verbose_test:
# log results of evaluation
if self.trainer.evaluating and self.trainer.is_global_zero and self.trainer.verbose_evaluate:
print('-' * 80)
for result_idx, results in enumerate(self.eval_loop_results):
print(f'DATALOADER:{result_idx} TEST RESULTS')
print(f'DATALOADER:{result_idx} {self.trainer.evaluating.upper()} RESULTS')
pprint(results)
print('-' * 80)

Expand Down
6 changes: 5 additions & 1 deletion pytorch_lightning/trainer/connectors/model_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,11 @@ def copy_trainer_model_properties(self, model):
m.use_ddp2 = self.trainer.use_ddp2
m.use_ddp = self.trainer.use_ddp
m.use_amp = self.trainer.amp_backend is not None
m.testing = self.trainer.testing
# Currently, the only users of m.testing appear to be DP and DDP,
# which use it to determine whether the model is currently inside
# the validation or test loop. For this reason it must check if
# trainer.evaluating is equal to "test" specifically.
m.testing = self.trainer.evaluating == 'test'
m.use_single_gpu = self.trainer.use_single_gpu
m.use_tpu = self.trainer.use_tpu
m.tpu_local_core_rank = self.trainer.tpu_local_core_rank
Expand Down
Loading