diff --git a/CHANGELOG.md b/CHANGELOG.md index 046e07cc55736..73c16dbd5e2aa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `TrainerState.{FITTING,VALIDATING,TESTING,PREDICTING,TUNING}` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945)) +- Added `Trainer.validate()` method to perform one evaluation epoch over the validation set ([#4948](https://github.com/PyTorchLightning/pytorch-lightning/pull/4948)) + + - Added `LightningEnvironment` for Lightning-specific DDP ([#5915](https://github.com/PyTorchLightning/pytorch-lightning/pull/5915)) diff --git a/docs/source/common/trainer.rst b/docs/source/common/trainer.rst index 17cfc7eccbc20..6edf896ada01c 100644 --- a/docs/source/common/trainer.rst +++ b/docs/source/common/trainer.rst @@ -151,6 +151,19 @@ So you can run it like so: ------------ +Validation +---------- +You can perform an evaluation epoch over the validation set, outside of the training loop, +using :meth:`pytorch_lightning.trainer.trainer.Trainer.validate`. This might be +useful if you want to collect new metrics from a model right at its initialization +or after it has already been trained. + +.. code-block:: python + + trainer.validate(val_dataloaders=val_dataloaders) + +------------ + Testing ------- Once you're done training, feel free to run the test set! @@ -158,7 +171,7 @@ Once you're done training, feel free to run the test set! .. code-block:: python - trainer.test(test_dataloaders=test_dataloader) + trainer.test(test_dataloaders=test_dataloaders) ------------ diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index c382e67b21a64..74e57e2b5642e 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -355,9 +355,11 @@ def init_predict_tqdm(self) -> tqdm: def init_validation_tqdm(self) -> tqdm: """ Override this to customize the tqdm bar for validation. """ + # The main progress bar doesn't exist in `trainer.validate()` + has_main_bar = self.main_progress_bar is not None bar = tqdm( desc='Validating', - position=(2 * self.process_position + 1), + position=(2 * self.process_position + has_main_bar), disable=self.is_disabled, leave=False, dynamic_ncols=True, @@ -426,7 +428,8 @@ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, def on_validation_end(self, trainer, pl_module): super().on_validation_end(trainer, pl_module) - self.main_progress_bar.set_postfix(trainer.progress_bar_dict) + if self.main_progress_bar is not None: + self.main_progress_bar.set_postfix(trainer.progress_bar_dict) self.val_progress_bar.close() def on_train_end(self, trainer, pl_module): @@ -479,8 +482,10 @@ def print( def _should_update(self, current, total): return self.is_enabled and (current % self.refresh_rate == 0 or current == total) - def _update_bar(self, bar): + def _update_bar(self, bar: Optional[tqdm]) -> None: """ Updates the bar by the refresh rate without overshooting. """ + if bar is None: + return if bar.total is not None: delta = min(self.refresh_rate, bar.total - bar.n) else: diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index 1bf38048ee159..8c539b5ff478d 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden @@ -22,7 +23,7 @@ class ConfigValidator(object): def __init__(self, trainer): self.trainer = trainer - def verify_loop_configurations(self, model: LightningModule): + def verify_loop_configurations(self, model: LightningModule) -> None: r""" Checks that the model is configured correctly before the run is started. @@ -30,10 +31,16 @@ def verify_loop_configurations(self, model: LightningModule): model: The model to check the configuration. """ - if self.trainer.training: + if self.trainer.state == TrainerState.FITTING: self.__verify_train_loop_configuration(model) - elif self.trainer.evaluating: - self.__verify_eval_loop_configuration(model) + self.__verify_eval_loop_configuration(model, 'val') + elif self.trainer.state == TrainerState.TUNING: + self.__verify_train_loop_configuration(model) + elif self.trainer.state == TrainerState.VALIDATING: + self.__verify_eval_loop_configuration(model, 'val') + elif self.trainer.state == TrainerState.TESTING: + self.__verify_eval_loop_configuration(model, 'test') + # TODO: add predict def __verify_train_loop_configuration(self, model): # ----------------------------------- @@ -81,11 +88,9 @@ def __verify_train_loop_configuration(self, model): ' It ensures optimizer_step or optimizer_zero_grad are called on every batch.' ) - def __verify_eval_loop_configuration(self, model): - stage = "val" if self.trainer.validating else "test" - + def __verify_eval_loop_configuration(self, model: LightningModule, stage: str) -> None: loader_name = f'{stage}_dataloader' - step_name = f'{stage}_step' + step_name = 'validation_step' if stage == 'val' else 'test_step' has_loader = is_overridden(loader_name, model) has_step = is_overridden(step_name, model) diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index 9e08cf031175f..d787f796f3d88 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -93,10 +93,10 @@ def __enforce_datamodule_dataloader_override(self, train_dataloader, val_dataloa def attach_dataloaders( self, model, - train_dataloader=None, - val_dataloaders=None, - test_dataloaders=None, - predict_dataloaders=None, + train_dataloader: Optional[Union[DataLoader, List[DataLoader]]] = None, + val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, + test_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, + predict_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, ): # when dataloader is passed via fit, patch the train_dataloader # functions to overwrite with these implementations @@ -112,7 +112,7 @@ def attach_dataloaders( if predict_dataloaders is not None: model.predict_dataloader = _PatchDataLoader(predict_dataloaders) - def attach_datamodule(self, model, datamodule: Optional[LightningDataModule]) -> None: + def attach_datamodule(self, model, datamodule: Optional[LightningDataModule] = None) -> None: # We use datamodule if it's been provided, otherwise we check model for it datamodule = datamodule or getattr(model, 'datamodule', None) diff --git a/pytorch_lightning/trainer/states.py b/pytorch_lightning/trainer/states.py index 33a2326c518d5..b1f188ab047fe 100644 --- a/pytorch_lightning/trainer/states.py +++ b/pytorch_lightning/trainer/states.py @@ -58,7 +58,7 @@ class RunningStage(LightningEnum): """ TRAINING = 'train' SANITY_CHECKING = 'sanity_check' - VALIDATING = 'validation' + VALIDATING = 'validate' TESTING = 'test' PREDICTING = 'predict' TUNING = 'tune' diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index ff8be336ee57a..c3039d24aadc0 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -820,6 +820,69 @@ def run_sanity_check(self, ref_model): self._running_stage = stage + def validate( + self, + model: Optional[LightningModule] = None, + val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, + ckpt_path: Optional[str] = 'best', + verbose: bool = True, + datamodule: Optional[LightningDataModule] = None, + ): + r""" + Perform one evaluation epoch over the validation set. + + Args: + model: The model to validate. + + val_dataloaders: Either a single PyTorch DataLoader or a list of them, + specifying validation samples. + + ckpt_path: Either ``best`` or path to the checkpoint you wish to validate. + If ``None``, use the current weights of the model. + When the model is given as argument, this parameter will not apply. + + verbose: If True, prints the validation results. + + datamodule: A instance of :class:`LightningDataModule`. + + Returns: + The dictionary with final validation results returned by validation_epoch_end. + If validation_epoch_end is not defined, the output is a list of the dictionaries + returned by validation_step. + """ + # -------------------- + # SETUP HOOK + # -------------------- + self.verbose_evaluate = verbose + + self.state = TrainerState.VALIDATING + self.validating = True + + # If you supply a datamodule you can't supply val_dataloaders + if val_dataloaders and datamodule: + raise MisconfigurationException( + 'You cannot pass both `trainer.validate(val_dataloaders=..., datamodule=...)`' + ) + + model_provided = model is not None + model = model or self.lightning_module + + # Attach datamodule to get setup/prepare_data added to model before the call to it below + self.data_connector.attach_datamodule(model, datamodule) + # Attach dataloaders (if given) + self.data_connector.attach_dataloaders(model, val_dataloaders=val_dataloaders) + + if not model_provided: + self.validated_ckpt_path = self.__load_ckpt_weights(model, ckpt_path=ckpt_path) + + # run validate + results = self.fit(model) + + assert self.state.stopped + self.validating = False + + return results + def test( self, model: Optional[LightningModule] = None, @@ -833,17 +896,19 @@ def test( fit to make sure you never run on your test set until you want to. Args: - ckpt_path: Either ``best`` or path to the checkpoint you wish to test. - If ``None``, use the current weights of the model. Default to ``best``. - datamodule: A instance of :class:`LightningDataModule`. - model: The model to test. test_dataloaders: Either a single PyTorch DataLoader or a list of them, specifying test samples. + ckpt_path: Either ``best`` or path to the checkpoint you wish to test. + If ``None``, use the current weights of the model. + When the model is given as argument, this parameter will not apply. + verbose: If True, prints the test results. + datamodule: A instance of :class:`LightningDataModule`. + Returns: Returns a list of dictionaries, one for each test dataloader containing their respective metrics. """ @@ -858,7 +923,7 @@ def test( # If you supply a datamodule you can't supply test_dataloaders if test_dataloaders and datamodule: raise MisconfigurationException( - 'You cannot pass test_dataloaders to trainer.test if you supply a datamodule' + 'You cannot pass both `trainer.test(test_dataloaders=..., datamodule=...)`' ) model_provided = model is not None @@ -866,22 +931,25 @@ def test( # Attach datamodule to get setup/prepare_data added to model before the call to it below self.data_connector.attach_datamodule(model, datamodule) - results = ( - self.__evaluate_given_model(model, dataloaders=test_dataloaders) if model_provided else - self.__evaluate_using_weights(model, ckpt_path=ckpt_path, dataloaders=test_dataloaders) - ) + # Attach dataloaders (if given) + self.data_connector.attach_dataloaders(model, test_dataloaders=test_dataloaders) + + if not model_provided: + self.tested_ckpt_path = self.__load_ckpt_weights(model, ckpt_path=ckpt_path) + + # run test + results = self.fit(model) assert self.state.stopped self.testing = False return results - def __evaluate_using_weights( + def __load_ckpt_weights( self, model, ckpt_path: Optional[str] = None, - dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None - ): + ) -> Optional[str]: # if user requests the best checkpoint but we don't have it, error if ckpt_path == 'best' and not self.checkpoint_callback.best_model_path: raise MisconfigurationException( @@ -894,42 +962,18 @@ def __evaluate_using_weights( if ckpt_path == 'best': ckpt_path = self.checkpoint_callback.best_model_path - if len(ckpt_path) == 0: - rank_zero_warn( - f'`.test()` found no path for the best weights, {ckpt_path}. Please' - ' specify a path for a checkpoint `.test(ckpt_path=PATH)`' + if not ckpt_path: + fn = self.state.value + raise MisconfigurationException( + f'`.{fn}()` found no path for the best weights: "{ckpt_path}". Please' + ' specify a path for a checkpoint `.{fn}(ckpt_path=PATH)`' ) - return {} self.training_type_plugin.barrier() ckpt = pl_load(ckpt_path, map_location=lambda storage, loc: storage) model.load_state_dict(ckpt['state_dict']) - - # attach dataloaders - if dataloaders is not None: - self.data_connector.attach_dataloaders(model, test_dataloaders=dataloaders) - - if self.validating: - self.validated_ckpt_path = ckpt_path - else: - self.tested_ckpt_path = ckpt_path - - # run test - results = self.fit(model) - - return results - - def __evaluate_given_model(self, model, dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None): - # attach data - if dataloaders is not None: - self.data_connector.attach_dataloaders(model, test_dataloaders=dataloaders) - - # run test - # sets up testing so we short circuit to eval - results = self.fit(model) - - return results + return ckpt_path def predict( self, @@ -970,15 +1014,11 @@ def predict( 'You cannot pass dataloaders to trainer.predict if you supply a datamodule.' ) - if datamodule is not None: - # Attach datamodule to get setup/prepare_data added to model before the call to it below - self.data_connector.attach_datamodule(model, datamodule) - - # attach data - if dataloaders is not None: - self.data_connector.attach_dataloaders(model, predict_dataloaders=dataloaders) + # Attach datamodule to get setup/prepare_data added to model before the call to it below + self.data_connector.attach_datamodule(model, datamodule) + # Attach dataloaders (if given) + self.data_connector.attach_dataloaders(model, predict_dataloaders=dataloaders) - self.model = model results = self.fit(model) assert self.state.stopped diff --git a/tests/accelerators/test_common.py b/tests/accelerators/test_common.py new file mode 100644 index 0000000000000..6962af7249d1b --- /dev/null +++ b/tests/accelerators/test_common.py @@ -0,0 +1,44 @@ +import pytest +import torch + +import tests.helpers.utils as tutils +from pytorch_lightning import Trainer +from tests.accelerators.test_dp import CustomClassificationModelDP +from tests.helpers.datamodules import ClassifDataModule +from tests.helpers.runif import RunIf + + +@pytest.mark.parametrize("trainer_kwargs", ( + pytest.param({"gpus": 1}, marks=RunIf(min_gpus=1)), + pytest.param({"accelerator": "dp", "gpus": 2}, marks=RunIf(min_gpus=2)), + pytest.param({"accelerator": "ddp_spawn", "gpus": 2}, marks=RunIf(min_gpus=2)), +)) +def test_evaluate(tmpdir, trainer_kwargs): + tutils.set_random_master_port() + + dm = ClassifDataModule() + model = CustomClassificationModelDP() + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=2, + limit_train_batches=10, + limit_val_batches=10, + deterministic=True, + **trainer_kwargs + ) + + result = trainer.fit(model, datamodule=dm) + assert result + assert 'ckpt' in trainer.checkpoint_callback.best_model_path + + old_weights = model.layer_0.weight.clone().detach().cpu() + + result = trainer.validate(datamodule=dm) + assert result[0]['val_acc'] > 0.55 + + result = trainer.test(datamodule=dm) + assert result[0]['test_acc'] > 0.55 + + # make sure weights didn't change + new_weights = model.layer_0.weight.clone().detach().cpu() + torch.testing.assert_allclose(old_weights, new_weights) diff --git a/tests/accelerators/test_dp.py b/tests/accelerators/test_dp.py index 52f585409e865..6b84e1a70ae58 100644 --- a/tests/accelerators/test_dp.py +++ b/tests/accelerators/test_dp.py @@ -25,8 +25,6 @@ from tests.helpers.runif import RunIf from tests.helpers.simple_models import ClassificationModel -PRETEND_N_OF_GPUS = 16 - class CustomClassificationModelDP(ClassificationModel): @@ -96,36 +94,6 @@ def test_multi_gpu_model_dp(tmpdir): memory.get_memory_profile('min_max') -@RunIf(min_gpus=2) -def test_dp_test(tmpdir): - tutils.set_random_master_port() - - dm = ClassifDataModule() - model = CustomClassificationModelDP() - trainer = pl.Trainer( - default_root_dir=tmpdir, - max_epochs=2, - limit_train_batches=10, - limit_val_batches=10, - gpus=[0, 1], - accelerator='dp', - ) - trainer.fit(model, datamodule=dm) - assert 'ckpt' in trainer.checkpoint_callback.best_model_path - results = trainer.test(datamodule=dm) - assert 'test_acc' in results[0] - - old_weights = model.layer_0.weight.clone().detach().cpu() - - results = trainer.test(model, datamodule=dm) - assert 'test_acc' in results[0] - - # make sure weights didn't change - new_weights = model.layer_0.weight.clone().detach().cpu() - - assert torch.all(torch.eq(old_weights, new_weights)) - - class ReductionTestModel(BoringModel): def train_dataloader(self): diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index 2426348f770bf..626eb59dffb9c 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -19,8 +19,8 @@ @mock.patch("torch.save") # need to mock torch.save or we get pickle error -def test_trainer_callback_system_fit(_, tmpdir): - """Test the callback system for fit.""" +def test_trainer_callback_hook_system_fit(_, tmpdir): + """Test the callback hook system for fit.""" model = BoringModel() callback_mock = MagicMock() @@ -97,8 +97,8 @@ def test_trainer_callback_system_fit(_, tmpdir): ] -def test_trainer_callback_system_test(tmpdir): - """Test the callback system for test.""" +def test_trainer_callback_hook_system_test(tmpdir): + """Test the callback hook system for test.""" model = BoringModel() callback_mock = MagicMock() @@ -130,6 +130,42 @@ def test_trainer_callback_system_test(tmpdir): ] +def test_trainer_callback_hook_system_validate(tmpdir): + """Test the callback hook system for validate.""" + + model = BoringModel() + callback_mock = MagicMock() + trainer = Trainer( + default_root_dir=tmpdir, + callbacks=[callback_mock], + max_epochs=1, + limit_val_batches=2, + progress_bar_refresh_rate=0, + ) + + trainer.validate(model) + + assert callback_mock.method_calls == [ + call.on_init_start(trainer), + call.on_init_end(trainer), + call.setup(trainer, model, 'validate'), + call.on_before_accelerator_backend_setup(trainer, model), + call.on_validation_start(trainer, model), + call.on_validation_epoch_start(trainer, model), + call.on_validation_batch_start(trainer, model, ANY, 0, 0), + call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0), + call.on_validation_batch_start(trainer, model, ANY, 1, 0), + call.on_validation_batch_end(trainer, model, ANY, ANY, 1, 0), + call.on_validation_epoch_end(trainer, model), + call.on_epoch_end(trainer, model), + call.on_validation_end(trainer, model), + call.teardown(trainer, model, 'validate'), + ] + + +# TODO: add callback tests for predict and tune + + def test_callbacks_configured_in_model(tmpdir): """ Test the callback system with callbacks added through the model hook. """ @@ -166,22 +202,29 @@ def assert_expected_calls(_trainer, model_callback, trainer_callback): # .fit() trainer_options.update(callbacks=[trainer_callback_mock]) trainer = Trainer(**trainer_options) + assert trainer_callback_mock in trainer.callbacks assert model_callback_mock not in trainer.callbacks trainer.fit(model) + assert model_callback_mock in trainer.callbacks assert trainer.callbacks[-1] == model_callback_mock assert_expected_calls(trainer, model_callback_mock, trainer_callback_mock) # .test() - model_callback_mock.reset_mock() - trainer_callback_mock.reset_mock() - trainer_options.update(callbacks=[trainer_callback_mock]) - trainer = Trainer(**trainer_options) - trainer.test(model) - assert model_callback_mock in trainer.callbacks - assert trainer.callbacks[-1] == model_callback_mock - assert_expected_calls(trainer, model_callback_mock, trainer_callback_mock) + for fn in ("test", "validate"): + model_callback_mock.reset_mock() + trainer_callback_mock.reset_mock() + + trainer_options.update(callbacks=[trainer_callback_mock]) + trainer = Trainer(**trainer_options) + + trainer_fn = getattr(trainer, fn) + trainer_fn(model) + + assert model_callback_mock in trainer.callbacks + assert trainer.callbacks[-1] == model_callback_mock + assert_expected_calls(trainer, model_callback_mock, trainer_callback_mock) def test_configure_callbacks_hook_multiple_calls(tmpdir): @@ -208,10 +251,13 @@ def configure_callbacks(self): callbacks_after_fit = trainer.callbacks.copy() assert callbacks_after_fit == callbacks_before_fit + [model_callback_mock] - trainer.test(model) - callbacks_after_test = trainer.callbacks.copy() - assert callbacks_after_test == callbacks_after_fit + for fn in ("test", "validate"): + trainer_fn = getattr(trainer, fn) + trainer_fn(model) + + callbacks_after = trainer.callbacks.copy() + assert callbacks_after == callbacks_after_fit - trainer.test(ckpt_path=None) - callbacks_after_test = trainer.callbacks.copy() - assert callbacks_after_test == callbacks_after_fit + trainer_fn(ckpt_path=None) + callbacks_after = trainer.callbacks.copy() + assert callbacks_after == callbacks_after_fit diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index 67ea5a00cfda3..76f1e4cb0570f 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -90,7 +90,6 @@ def test_progress_bar_totals(tmpdir): trainer = Trainer( default_root_dir=tmpdir, progress_bar_refresh_rate=1, - limit_val_batches=1.0, max_epochs=1, ) bar = trainer.progress_bar_callback @@ -122,6 +121,12 @@ def test_progress_bar_totals(tmpdir): assert 0 == bar.total_test_batches assert bar.test_progress_bar is None + trainer.validate(model) + + assert bar.val_progress_bar.total == m + assert bar.val_progress_bar.n == m + assert bar.val_batch_idx == m + trainer.test(model) # check test progress bar total @@ -157,6 +162,13 @@ def test_progress_bar_fast_dev_run(tmpdir): assert 2 == progress_bar.main_progress_bar.total assert 2 == progress_bar.main_progress_bar.n + trainer.validate(model) + + # the validation progress bar should display 1 batch + assert 1 == progress_bar.val_batch_idx + assert 1 == progress_bar.val_progress_bar.total + assert 1 == progress_bar.val_progress_bar.n + trainer.test(model) # the test progress bar should display 1 batch @@ -214,8 +226,16 @@ def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, datal trainer.fit(model) assert progress_bar.train_batches_seen == 3 * progress_bar.total_train_batches assert progress_bar.val_batches_seen == 3 * progress_bar.total_val_batches + trainer.num_sanity_val_steps + assert progress_bar.test_batches_seen == 0 + + trainer.validate(model) + assert progress_bar.train_batches_seen == 3 * progress_bar.total_train_batches + assert progress_bar.val_batches_seen == 4 * progress_bar.total_val_batches + trainer.num_sanity_val_steps + assert progress_bar.test_batches_seen == 0 trainer.test(model) + assert progress_bar.train_batches_seen == 3 * progress_bar.total_train_batches + assert progress_bar.val_batches_seen == 4 * progress_bar.total_val_batches + trainer.num_sanity_val_steps assert progress_bar.test_batches_seen == progress_bar.total_test_batches diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 845b05aed9b38..d96fe3dcab33d 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -865,6 +865,9 @@ def assert_checkpoint_log_dir(idx): assert_checkpoint_log_dir(0) assert_checkpoint_content(ckpt_dir) + trainer.validate(model) + assert trainer.current_epoch == epochs - 1 + trainer.test(model) assert trainer.current_epoch == epochs - 1 @@ -878,17 +881,24 @@ def assert_checkpoint_log_dir(idx): assert_trainer_init(trainer) model = ExtendedBoringModel() + trainer.test(model) assert not trainer.checkpoint_connector.has_trained # resume_from_checkpoint is resumed when calling `.fit` assert trainer.global_step == 0 assert trainer.current_epoch == 0 + trainer.fit(model) assert not trainer.checkpoint_connector.has_trained assert trainer.global_step == epochs * limit_train_batches assert trainer.current_epoch == epochs assert_checkpoint_log_dir(idx) + trainer.validate(model) + assert not trainer.checkpoint_connector.has_trained + assert trainer.global_step == epochs * limit_train_batches + assert trainer.current_epoch == epochs + def test_configure_model_checkpoint(tmpdir): """ Test all valid and invalid ways a checkpoint callback can be passed to the Trainer. """ diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index ab51a87329e2f..2118fec6c207b 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -19,7 +19,6 @@ import pytest import torch -import torch.nn.functional as F from pytorch_lightning import LightningDataModule, Trainer from pytorch_lightning.callbacks import ModelCheckpoint @@ -29,7 +28,7 @@ from tests.helpers.datamodules import ClassifDataModule from tests.helpers.runif import RunIf from tests.helpers.simple_models import ClassificationModel -from tests.helpers.utils import reset_seed, set_random_master_port +from tests.helpers.utils import reset_seed @mock.patch("pytorch_lightning.trainer.trainer.Trainer.node_rank", new_callable=PropertyMock) @@ -297,20 +296,6 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: assert checkpoint[dm.__class__.__name__] == dm.__class__.__name__ -def test_test_loop_only(tmpdir): - reset_seed() - - dm = BoringDataModule() - model = BoringModel() - - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=1, - weights_summary=None, - ) - trainer.test(model, datamodule=dm) - - def test_full_loop(tmpdir): reset_seed() @@ -327,109 +312,17 @@ def test_full_loop(tmpdir): # fit model result = trainer.fit(model, dm) assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + assert dm.trainer is not None assert result - # test - result = trainer.test(datamodule=dm) - assert result[0]['test_acc'] > 0.6 - - -def test_trainer_attached_to_dm(tmpdir): - reset_seed() - - dm = BoringDataModule() - model = BoringModel() - - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=1, - limit_train_batches=2, - limit_val_batches=2, - limit_test_batches=2, - weights_summary=None, - deterministic=True, - ) - - # fit model - trainer.fit(model, dm) - assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + # validate + result = trainer.validate(datamodule=dm) assert dm.trainer is not None + assert result[0]['val_acc'] > 0.7 # test result = trainer.test(datamodule=dm) - result = result[0] assert dm.trainer is not None - - -@RunIf(min_gpus=1) -def test_full_loop_single_gpu(tmpdir): - reset_seed() - - dm = ClassifDataModule() - model = ClassificationModel() - - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=1, - weights_summary=None, - gpus=1, - deterministic=True, - ) - - # fit model - result = trainer.fit(model, dm) - assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" - assert result - - # test - result = trainer.test(datamodule=dm) - assert result[0]['test_acc'] > 0.6 - - -@RunIf(min_gpus=2) -def test_full_loop_dp(tmpdir): - set_random_master_port() - - class CustomClassificationModelDP(ClassificationModel): - - def _step(self, batch, batch_idx): - x, y = batch - logits = self(x) - return {'logits': logits, 'y': y} - - def training_step(self, batch, batch_idx): - out = self._step(batch, batch_idx) - loss = F.cross_entropy(out['logits'], out['y']) - return loss - - def validation_step(self, batch, batch_idx): - return self._step(batch, batch_idx) - - def test_step(self, batch, batch_idx): - return self._step(batch, batch_idx) - - def test_step_end(self, outputs): - self.log('test_acc', self.test_acc(outputs['logits'], outputs['y'])) - - dm = ClassifDataModule() - model = CustomClassificationModelDP() - - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=1, - weights_summary=None, - accelerator='dp', - gpus=2, - deterministic=True, - ) - - # fit model - result = trainer.fit(model, datamodule=dm) - assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" - assert result - - # test - result = trainer.test(datamodule=dm) assert result[0]['test_acc'] > 0.6 diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 7c53925bd7cc4..0d1c7cf40a2bf 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -466,7 +466,23 @@ def teardown(self, stage=None): 'on_fit_end', 'teardown', ] + assert model.called == expected + + model = HookedModel() + trainer.validate(model, verbose=False) + expected = [ + 'on_validation_model_eval', + 'on_validation_start', + 'on_validation_epoch_start', + 'on_validation_batch_start', + 'on_validation_batch_end', + 'on_validation_epoch_end', + 'on_epoch_end', + 'on_validation_end', + 'on_validation_model_train', + 'teardown', + ] assert model.called == expected model = HookedModel() diff --git a/tests/plugins/test_sharded_plugin.py b/tests/plugins/test_sharded_plugin.py index b59563f70e4aa..a48f048160ee5 100644 --- a/tests/plugins/test_sharded_plugin.py +++ b/tests/plugins/test_sharded_plugin.py @@ -259,30 +259,20 @@ def test_ddp_sharded_plugin_resume_from_checkpoint_gpu_to_cpu(tmpdir): @RunIf(skip_windows=True, special=True, fairscale=True) -def test_ddp_sharded_plugin_test(tmpdir): +@pytest.mark.parametrize("trainer_kwargs", ( + {'num_processes': 2}, + pytest.param({'gpus': 2}, marks=RunIf(min_gpus=2)) +)) +def test_ddp_sharded_plugin_test_multigpu(tmpdir, trainer_kwargs): """ - Test to ensure we can use test without fit + Test to ensure we can use validate and test without fit """ model = BoringModel() trainer = Trainer( accelerator='ddp_sharded_spawn', - num_processes=2, - fast_dev_run=True, - ) - - trainer.test(model) - - -@RunIf(min_gpus=2, skip_windows=True, fairscale=True) -def test_ddp_sharded_plugin_test_multigpu(tmpdir): - """ - Test to ensure we can use test without fit - """ - model = BoringModel() - trainer = Trainer( - accelerator='ddp_sharded_spawn', - gpus=2, fast_dev_run=True, + **trainer_kwargs, ) + trainer.validate(model) trainer.test(model) diff --git a/tests/trainer/optimization/test_optimizers.py b/tests/trainer/optimization/test_optimizers.py index 01c23ed18fe65..34845c46b45eb 100644 --- a/tests/trainer/optimization/test_optimizers.py +++ b/tests/trainer/optimization/test_optimizers.py @@ -287,15 +287,22 @@ def test_configure_optimizers_with_frequency(tmpdir): assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" -def test_init_optimizers_during_testing(tmpdir): +@pytest.mark.parametrize("fn", ("validate", "test")) +def test_init_optimizers_during_evaluation(tmpdir, fn): """ - Test that optimizers is an empty list during testing. + Test that optimizers is an empty list during evaluation """ - model = EvalModelTemplate() - model.configure_optimizers = model.configure_optimizers__multiple_schedulers - - trainer = Trainer(default_root_dir=tmpdir, limit_test_batches=10) - trainer.test(model, ckpt_path=None) + class TestModel(BoringModel): + def configure_optimizers(self): + optimizer1 = torch.optim.Adam(self.parameters(), lr=0.1) + optimizer2 = torch.optim.Adam(self.parameters(), lr=0.1) + lr_scheduler1 = torch.optim.lr_scheduler.StepLR(optimizer1, step_size=1) + lr_scheduler2 = torch.optim.lr_scheduler.StepLR(optimizer2, step_size=1) + return [optimizer1, optimizer2], [lr_scheduler1, lr_scheduler2] + + trainer = Trainer(default_root_dir=tmpdir, limit_val_batches=10, limit_test_batches=10) + validate_or_test = getattr(trainer, fn) + validate_or_test(TestModel(), ckpt_path=None) assert len(trainer.lr_schedulers) == 0 assert len(trainer.optimizers) == 0 diff --git a/tests/trainer/test_config_validator.py b/tests/trainer/test_config_validator.py index 00ad020aa1b57..59e10480a485e 100644 --- a/tests/trainer/test_config_validator.py +++ b/tests/trainer/test_config_validator.py @@ -13,12 +13,9 @@ # limitations under the License. import pytest -import tests.helpers.utils as tutils from pytorch_lightning import Trainer from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.base import EvalModelTemplate - -# TODO: add matching messages +from tests.helpers import BoringModel def test_wrong_train_setting(tmpdir): @@ -26,49 +23,44 @@ def test_wrong_train_setting(tmpdir): * Test that an error is thrown when no `train_dataloader()` is defined * Test that an error is thrown when no `training_step()` is defined """ - tutils.reset_seed() - hparams = EvalModelTemplate.get_default_hparams() trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) - with pytest.raises(MisconfigurationException): - model = EvalModelTemplate(**hparams) + with pytest.raises(MisconfigurationException, match=r'No `train_dataloader\(\)` method defined.'): + model = BoringModel() model.train_dataloader = None trainer.fit(model) - with pytest.raises(MisconfigurationException): - model = EvalModelTemplate(**hparams) + with pytest.raises(MisconfigurationException, match=r'No `training_step\(\)` method defined.'): + model = BoringModel() model.training_step = None trainer.fit(model) def test_wrong_configure_optimizers(tmpdir): """ Test that an error is thrown when no `configure_optimizers()` is defined """ - tutils.reset_seed() trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) - with pytest.raises(MisconfigurationException): - model = EvalModelTemplate() + with pytest.raises(MisconfigurationException, match=r'No `configure_optimizers\(\)` method defined.'): + model = BoringModel() model.configure_optimizers = None trainer.fit(model) -def test_val_loop_config(tmpdir): +def test_fit_val_loop_config(tmpdir): """" When either val loop or val data are missing raise warning """ - tutils.reset_seed() - hparams = EvalModelTemplate.get_default_hparams() trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) # no val data has val loop - with pytest.warns(UserWarning): - model = EvalModelTemplate(**hparams) + with pytest.warns(UserWarning, match=r'you passed in a val_dataloader but have no validation_step'): + model = BoringModel() model.validation_step = None trainer.fit(model) # has val loop but no val data - with pytest.warns(UserWarning): - model = EvalModelTemplate(**hparams) + with pytest.warns(UserWarning, match=r'you defined a validation_step but have no val_dataloader'): + model = BoringModel() model.val_dataloader = None trainer.fit(model) @@ -77,17 +69,35 @@ def test_test_loop_config(tmpdir): """" When either test loop or test data are missing """ - hparams = EvalModelTemplate.get_default_hparams() trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) # has test loop but no test data - with pytest.warns(UserWarning): - model = EvalModelTemplate(**hparams) + with pytest.warns(UserWarning, match=r'you defined a test_step but have no test_dataloader'): + model = BoringModel() model.test_dataloader = None trainer.test(model) # has test data but no test loop - with pytest.warns(UserWarning): - model = EvalModelTemplate(**hparams) + with pytest.warns(UserWarning, match=r'you passed in a test_dataloader but have no test_step'): + model = BoringModel() model.test_step = None - trainer.test(model, test_dataloaders=model.dataloader(train=False)) + trainer.test(model) + + +def test_val_loop_config(tmpdir): + """" + When either validation loop or validation data are missing + """ + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) + + # has val loop but no val data + with pytest.warns(UserWarning, match=r'you defined a validation_step but have no val_dataloader'): + model = BoringModel() + model.val_dataloader = None + trainer.validate(model) + + # has val data but no val loop + with pytest.warns(UserWarning, match=r'you passed in a val_dataloader but have no validation_step'): + model = BoringModel() + model.validation_step = None + trainer.validate(model) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 5530779b4f77d..e4aea38fb7f37 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -125,8 +125,7 @@ def test_multiple_val_dataloader(tmpdir): assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" # verify there are 2 val loaders - assert len(trainer.val_dataloaders) == 2, \ - 'Multiple val_dataloaders not initiated properly' + assert len(trainer.val_dataloaders) == 2, 'Multiple val_dataloaders not initiated properly' # make sure predictions are good for each val set for dataloader in trainer.val_dataloaders: @@ -134,18 +133,22 @@ def test_multiple_val_dataloader(tmpdir): @pytest.mark.parametrize('ckpt_path', [None, 'best', 'specific']) -def test_multiple_test_dataloader(tmpdir, ckpt_path): - """Verify multiple test_dataloader.""" - - model_template = EvalModelTemplate() +def test_multiple_eval_dataloader(tmpdir, ckpt_path): + """Verify multiple evaluation dataloaders.""" class MultipleTestDataloaderModel(EvalModelTemplate): - def test_dataloader(self): return [self.dataloader(train=False), self.dataloader(train=False)] - def test_step(self, batch, batch_idx, *args, **kwargs): - return model_template.test_step__multiple_dataloaders(batch, batch_idx, *args, **kwargs) + def test_step(self, *args, **kwargs): + return super().test_step__multiple_dataloaders(*args, **kwargs) + + def val_dataloader(self): + return self.test_dataloader() + + def validation_step(self, *args, **kwargs): + output = self.test_step(*args, **kwargs) + return {k.replace("test_", "val_"): v for k, v in output.items()} model = MultipleTestDataloaderModel() @@ -159,18 +162,19 @@ def test_step(self, batch, batch_idx, *args, **kwargs): trainer.fit(model) if ckpt_path == 'specific': ckpt_path = trainer.checkpoint_callback.best_model_path - trainer.test(ckpt_path=ckpt_path) - # verify there are 2 test loaders - assert len(trainer.test_dataloaders) == 2, 'Multiple test_dataloaders not initiated properly' + trainer.validate(ckpt_path=ckpt_path, verbose=False) + # verify there are 2 loaders + assert len(trainer.val_dataloaders) == 2 + # make sure predictions are good for each dl + for dataloader in trainer.val_dataloaders: + tpipes.run_prediction_eval_model_template(trainer.model, dataloader) - # make sure predictions are good for each test set + trainer.test(ckpt_path=ckpt_path, verbose=False) + assert len(trainer.test_dataloaders) == 2 for dataloader in trainer.test_dataloaders: tpipes.run_prediction_eval_model_template(trainer.model, dataloader) - # run the test method - trainer.test(ckpt_path=ckpt_path) - def test_train_dataloader_passed_to_fit(tmpdir): """Verify that train dataloader can be passed to fit """ @@ -189,90 +193,45 @@ def test_train_dataloader_passed_to_fit(tmpdir): assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" -def test_train_val_dataloaders_passed_to_fit(tmpdir): - """ Verify that train & val dataloader can be passed to fit """ - - # train, val passed to fit - model = EvalModelTemplate() - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=1, - limit_val_batches=0.1, - limit_train_batches=0.2, - ) - fit_options = dict(train_dataloader=model.dataloader(train=True), val_dataloaders=model.dataloader(train=False)) - - trainer.fit(model, **fit_options) - assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" - assert len(trainer.val_dataloaders) == 1, \ - f'`val_dataloaders` not initiated properly, got {trainer.val_dataloaders}' - - @pytest.mark.parametrize('ckpt_path', [None, 'best', 'specific']) -def test_all_dataloaders_passed_to_fit(tmpdir, ckpt_path): - """Verify train, val & test dataloader(s) can be passed to fit and test method""" +@pytest.mark.parametrize("n", (1, 2)) +def test_dataloaders_passed_to_fn(tmpdir, ckpt_path, n): + """Verify that dataloaders can be passed.""" model = EvalModelTemplate() + if n == 1: + dataloaders = model.dataloader(train=False) + else: + dataloaders = [model.dataloader(train=False)] * 2 + model.validation_step = model.validation_step__multiple_dataloaders + model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders + model.test_step = model.test_step__multiple_dataloaders - # train, val and test passed to fit + # train, multiple val and multiple test passed to fit trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, limit_val_batches=0.1, limit_train_batches=0.2, ) - fit_options = dict(train_dataloader=model.dataloader(train=True), val_dataloaders=model.dataloader(train=False)) - trainer.fit(model, **fit_options) - - if ckpt_path == 'specific': - ckpt_path = trainer.checkpoint_callback.best_model_path - test_options = dict(test_dataloaders=model.dataloader(train=False), ckpt_path=ckpt_path) - trainer.test(**test_options) + trainer.fit(model, train_dataloader=model.dataloader(train=True), val_dataloaders=dataloaders) assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" - assert len(trainer.val_dataloaders) == 1, \ - f'val_dataloaders` not initiated properly, got {trainer.val_dataloaders}' - assert len(trainer.test_dataloaders) == 1, \ - f'test_dataloaders` not initiated properly, got {trainer.test_dataloaders}' - + assert len(trainer.val_dataloaders) == n -@pytest.mark.parametrize('ckpt_path', [None, 'best', 'specific']) -def test_multiple_dataloaders_passed_to_fit(tmpdir, ckpt_path): - """Verify that multiple val & test dataloaders can be passed to fit.""" - - model = EvalModelTemplate() - model.validation_step = model.validation_step__multiple_dataloaders - model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders - model.test_step = model.test_step__multiple_dataloaders - - # train, multiple val and multiple test passed to fit - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=1, - limit_val_batches=0.1, - limit_train_batches=0.2, - ) - fit_options = dict( - train_dataloader=model.dataloader(train=True), - val_dataloaders=[model.dataloader(train=False), model.dataloader(train=False)] - ) - trainer.fit(model, **fit_options) if ckpt_path == 'specific': ckpt_path = trainer.checkpoint_callback.best_model_path - test_options = dict( - test_dataloaders=[model.dataloader(train=False), model.dataloader(train=False)], ckpt_path=ckpt_path - ) - trainer.test(**test_options) - assert len(trainer.val_dataloaders) == 2, \ - f'Multiple `val_dataloaders` not initiated properly, got {trainer.val_dataloaders}' - assert len(trainer.test_dataloaders) == 2, \ - f'Multiple `test_dataloaders` not initiated properly, got {trainer.test_dataloaders}' + trainer.test(test_dataloaders=dataloaders, ckpt_path=ckpt_path) + trainer.validate(val_dataloaders=dataloaders, ckpt_path=ckpt_path) + + assert len(trainer.val_dataloaders) == n + assert len(trainer.test_dataloaders) == n @pytest.mark.parametrize(['limit_train_batches', 'limit_val_batches', 'limit_test_batches'], [ - pytest.param(0.0, 0.0, 0.0), - pytest.param(1.0, 1.0, 1.0), + (0.0, 0.0, 0.0), + (1.0, 1.0, 1.0), ]) def test_inf_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches): """Verify inf train, val & test dataloaders (e.g. IterableDataset) passed with batch limit in percent""" @@ -299,8 +258,8 @@ def test_inf_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches, @pytest.mark.parametrize(['limit_train_batches', 'limit_val_batches', 'limit_test_batches'], [ - pytest.param(0, 0, 0), - pytest.param(10, 10, 10), + (0, 0, 0), + (10, 10, 10), ]) def test_inf_dataloaders_with_limit_num_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches): """Verify inf train, val & test dataloaders (e.g. IterableDataset) passed with batch limit as number""" @@ -327,10 +286,10 @@ def test_inf_dataloaders_with_limit_num_batches(tmpdir, limit_train_batches, lim @pytest.mark.parametrize(['limit_train_batches', 'limit_val_batches', 'limit_test_batches'], [ - pytest.param(0.0, 0.0, 0.0), - pytest.param(0, 0, 0.5), - pytest.param(1.0, 1.0, 1.0), - pytest.param(0.2, 0.4, 0.4), + (0.0, 0.0, 0.0), + (0, 0, 0.5), + (1.0, 1.0, 1.0), + (0.2, 0.4, 0.4), ]) def test_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches): """Verify num_batches for train, val & test dataloaders passed with batch limit in percent""" @@ -362,9 +321,9 @@ def test_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches, lim @pytest.mark.parametrize(['limit_train_batches', 'limit_val_batches', 'limit_test_batches'], [ - pytest.param(0, 0, 0), - pytest.param(1, 2, 3), - pytest.param(1, 2, 1e50), + (0, 0, 0), + (1, 2, 3), + (1, 2, 1e50), ]) @mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) def test_dataloaders_with_limit_num_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches): @@ -445,10 +404,10 @@ def test_dataloaders_with_fast_dev_run(tmpdir, fast_dev_run): if fast_dev_run == 'temp': with pytest.raises(MisconfigurationException, match='either a bool or an int'): - trainer = Trainer(**trainer_options) + Trainer(**trainer_options) elif fast_dev_run == -1: with pytest.raises(MisconfigurationException, match='should be >= 0'): - trainer = Trainer(**trainer_options) + Trainer(**trainer_options) else: trainer = Trainer(**trainer_options) @@ -1191,12 +1150,6 @@ def test_replace_sampler_with_multiprocessing_context(tmpdir): train = RandomDataset(32, 64) context = 'spawn' train = DataLoader(train, batch_size=32, num_workers=2, multiprocessing_context=context, shuffle=True) - - class ExtendedBoringModel(BoringModel): - - def train_dataloader(self): - return train - trainer = Trainer( max_epochs=1, progress_bar_refresh_rate=20, diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 385d8c1c6b462..0aab2c25eaf4c 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -599,44 +599,57 @@ def test_benchmark_option(tmpdir): assert torch.backends.cudnn.benchmark -@pytest.mark.parametrize("ckpt_path", [None, "best", "specific"]) -@pytest.mark.parametrize("save_top_k", [-1, 0, 1, 2]) -def test_tested_checkpoint_path(tmpdir, ckpt_path, save_top_k): - hparams = EvalModelTemplate.get_default_hparams() +@pytest.mark.parametrize("ckpt_path", (None, "best", "specific")) +@pytest.mark.parametrize("save_top_k", (-1, 0, 1, 2)) +@pytest.mark.parametrize("fn", ("validate", "test")) +def test_tested_checkpoint_path(tmpdir, ckpt_path, save_top_k, fn): + class TestModel(BoringModel): + def validation_step(self, batch, batch_idx): + self.log("foo", -batch_idx) + return super().validation_step(batch, batch_idx) - model = EvalModelTemplate(**hparams) + model = TestModel() trainer = Trainer( max_epochs=2, progress_bar_refresh_rate=0, default_root_dir=tmpdir, - callbacks=[ModelCheckpoint(monitor="early_stop_on", save_top_k=save_top_k)], + callbacks=[ModelCheckpoint(monitor="foo", save_top_k=save_top_k)], ) trainer.fit(model) + + test_or_validate = getattr(trainer, fn) if ckpt_path == "best": # ckpt_path is 'best', meaning we load the best weights if save_top_k == 0: with pytest.raises(MisconfigurationException, match=".*is not configured to save the best.*"): - trainer.test(ckpt_path=ckpt_path) + test_or_validate(ckpt_path=ckpt_path) else: - trainer.test(ckpt_path=ckpt_path) - assert trainer.tested_ckpt_path == trainer.checkpoint_callback.best_model_path + test_or_validate(ckpt_path=ckpt_path) + if fn == "test": + assert trainer.tested_ckpt_path == trainer.checkpoint_callback.best_model_path + else: + assert trainer.validated_ckpt_path == trainer.checkpoint_callback.best_model_path elif ckpt_path is None: # ckpt_path is None, meaning we don't load any checkpoints and # use the weights from the end of training - trainer.test(ckpt_path=ckpt_path) + test_or_validate(ckpt_path=ckpt_path) assert trainer.tested_ckpt_path is None + assert trainer.validated_ckpt_path is None else: # specific checkpoint, pick one from saved ones if save_top_k == 0: with pytest.raises(FileNotFoundError): - trainer.test(ckpt_path="random.ckpt") + test_or_validate(ckpt_path="random.ckpt") else: ckpt_path = str( list((Path(tmpdir) / f"lightning_logs/version_{trainer.logger.version}/checkpoints").iterdir() )[0].absolute() ) - trainer.test(ckpt_path=ckpt_path) - assert trainer.tested_ckpt_path == ckpt_path + test_or_validate(ckpt_path=ckpt_path) + if fn == "test": + assert trainer.tested_ckpt_path == ckpt_path + else: + assert trainer.validated_ckpt_path == ckpt_path def test_disabled_training(tmpdir): @@ -1292,10 +1305,11 @@ def test_trainer_pickle(tmpdir): cloudpickle.dumps(trainer) -def test_trainer_setup_call(tmpdir): - """Test setup call with fit and test call.""" +@pytest.mark.parametrize("stage", ("fit", "validate", "test")) +def test_trainer_setup_call(tmpdir, stage): + """Test setup call gets the correct stage""" - class CurrentModel(EvalModelTemplate): + class CurrentModel(BoringModel): def setup(self, stage): self.stage = stage @@ -1311,21 +1325,23 @@ def setup(self, model, stage): # fit model trainer = TrainerSubclass(default_root_dir=tmpdir, max_epochs=1, checkpoint_callback=False) - trainer.fit(model) - assert trainer.stage == "fit" - assert trainer.lightning_module.stage == "fit" + if stage == "fit": + trainer.fit(model) + elif stage == "validate": + trainer.validate(model, ckpt_path=None) + else: + trainer.test(model, ckpt_path=None) - trainer.test(ckpt_path=None) - assert trainer.stage == "test" - assert trainer.lightning_module.stage == "test" + assert trainer.stage == stage + assert trainer.lightning_module.stage == stage @pytest.mark.parametrize( "train_batches, max_steps, log_interval", [ - pytest.param(10, 10, 1), - pytest.param(3, 10, 1), - pytest.param(3, 10, 5), + (10, 10, 1), + (3, 10, 1), + (3, 10, 5), ], ) @patch("pytorch_lightning.loggers.tensorboard.TensorBoardLogger.log_metrics") @@ -1398,7 +1414,7 @@ def predict(tmpdir, accelerator, gpus, num_processes, model=None, plugins=None, dataloaders = [torch.utils.data.DataLoader(RandomDataset(32, 2)), torch.utils.data.DataLoader(RandomDataset(32, 2))] model = model or BoringModel() - datamodule = TestLightningDataModule(dataloaders) + dm = TestLightningDataModule(dataloaders) trainer = Trainer( default_root_dir=tmpdir, @@ -1411,7 +1427,7 @@ def predict(tmpdir, accelerator, gpus, num_processes, model=None, plugins=None, plugins=plugins, ) if datamodule: - results = trainer.predict(model, datamodule=datamodule) + results = trainer.predict(model, datamodule=dm) else: results = trainer.predict(model, dataloaders=dataloaders) diff --git a/tests/trainer/test_trainer_test_loop.py b/tests/trainer/test_trainer_test_loop.py deleted file mode 100644 index 7e2a9299fc8a0..0000000000000 --- a/tests/trainer/test_trainer_test_loop.py +++ /dev/null @@ -1,76 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import torch - -import pytorch_lightning as pl -import tests.helpers.utils as tutils -from tests.base import EvalModelTemplate -from tests.helpers.runif import RunIf - - -@RunIf(min_gpus=2) -def test_single_gpu_test(tmpdir): - tutils.set_random_master_port() - - model = EvalModelTemplate() - trainer = pl.Trainer( - default_root_dir=tmpdir, - max_epochs=2, - limit_train_batches=10, - limit_val_batches=10, - gpus=[0], - ) - trainer.fit(model) - assert 'ckpt' in trainer.checkpoint_callback.best_model_path - results = trainer.test() - assert 'test_acc' in results[0] - - old_weights = model.c_d1.weight.clone().detach().cpu() - - results = trainer.test(model) - assert 'test_acc' in results[0] - - # make sure weights didn't change - new_weights = model.c_d1.weight.clone().detach().cpu() - - assert torch.all(torch.eq(old_weights, new_weights)) - - -@RunIf(min_gpus=2) -def test_ddp_spawn_test(tmpdir): - tutils.set_random_master_port() - - model = EvalModelTemplate() - trainer = pl.Trainer( - default_root_dir=tmpdir, - max_epochs=2, - limit_train_batches=10, - limit_val_batches=10, - gpus=[0, 1], - accelerator='ddp_spawn', - ) - trainer.fit(model) - assert 'ckpt' in trainer.checkpoint_callback.best_model_path - results = trainer.test() - assert 'test_acc' in results[0] - - old_weights = model.c_d1.weight.clone().detach().cpu() - - results = trainer.test(model) - assert 'test_acc' in results[0] - - # make sure weights didn't change - new_weights = model.c_d1.weight.clone().detach().cpu() - - assert torch.all(torch.eq(old_weights, new_weights))