From 15184c694c41f20812ec25e0aefd2c9e98013211 Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Thu, 8 Dec 2022 14:37:29 +0100 Subject: [PATCH] Fix restarting attribute for lr finder (#15620) --- src/pytorch_lightning/CHANGELOG.md | 4 ++ src/pytorch_lightning/callbacks/lr_finder.py | 2 +- src/pytorch_lightning/tuner/lr_finder.py | 15 ++++--- tests/tests_pytorch/tuner/test_lr_finder.py | 47 ++++++++++++++++++++ 4 files changed, 61 insertions(+), 7 deletions(-) diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 67197271d07ca..06a0418231c9d 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -69,6 +69,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed deprecated `pytorch_lightning.utilities.memory.get_gpu_memory_map` in favor of `pytorch_lightning.accelerators.cuda.get_nvidia_gpu_stats` ([#15617](https://github.com/Lightning-AI/lightning/pull/15617)) + - Temporarily removed support for Hydra multi-run ([#15737](https://github.com/Lightning-AI/lightning/pull/15737)) @@ -87,6 +88,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed issue with unsupported torch.inference_mode() on hpu backends ([#15918](https://github.com/Lightning-AI/lightning/pull/15918)) +- Fixed `fit_loop.restarting` to be `False` for lr finder ([#15620](https://github.com/Lightning-AI/lightning/pull/15620)) + ## [1.8.3] - 2022-11-22 @@ -104,6 +107,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed the automatic fallback from `Trainer(strategy="ddp_spawn", ...)` to `Trainer(strategy="ddp", ...)` when on an LSF cluster ([#15103](https://github.com/PyTorchLightning/pytorch-lightning/issues/15103)) + ## [1.8.1] - 2022-11-10 ### Added diff --git a/src/pytorch_lightning/callbacks/lr_finder.py b/src/pytorch_lightning/callbacks/lr_finder.py index 4d235751ca791..1c950e64086b9 100644 --- a/src/pytorch_lightning/callbacks/lr_finder.py +++ b/src/pytorch_lightning/callbacks/lr_finder.py @@ -85,7 +85,7 @@ def __init__( max_lr: float = 1, num_training_steps: int = 100, mode: str = "exponential", - early_stop_threshold: float = 4.0, + early_stop_threshold: Optional[float] = 4.0, update_attr: bool = False, ) -> None: mode = mode.lower() diff --git a/src/pytorch_lightning/tuner/lr_finder.py b/src/pytorch_lightning/tuner/lr_finder.py index 29a5d47776a9e..2652267c93ae6 100644 --- a/src/pytorch_lightning/tuner/lr_finder.py +++ b/src/pytorch_lightning/tuner/lr_finder.py @@ -208,7 +208,7 @@ def lr_find( max_lr: float = 1, num_training: int = 100, mode: str = "exponential", - early_stop_threshold: float = 4.0, + early_stop_threshold: Optional[float] = 4.0, update_attr: bool = False, ) -> Optional[_LRFinder]: """See :meth:`~pytorch_lightning.tuner.tuning.Tuner.lr_find`""" @@ -225,6 +225,8 @@ def lr_find( ckpt_path = trainer.strategy.broadcast(ckpt_path) trainer.save_checkpoint(ckpt_path) + start_steps = trainer.global_step + # Arguments we adjust during the lr finder, save for restoring params = __lr_finder_dump_params(trainer) @@ -245,7 +247,7 @@ def lr_find( _try_loop_run(trainer, params) # Prompt if we stopped early - if trainer.global_step != num_training: + if trainer.global_step != num_training + start_steps: log.info(f"LR finder stopped early after {trainer.global_step} steps due to diverging loss.") # Transfer results from callback to lr finder object @@ -270,6 +272,7 @@ def lr_find( # Restore initial state of model trainer._checkpoint_connector.restore(ckpt_path) trainer.strategy.remove_checkpoint(ckpt_path) + trainer.fit_loop.restarting = False # reset restarting flag as checkpoint restoring sets it to True return lr_finder @@ -289,7 +292,7 @@ def __lr_finder_dump_params(trainer: "pl.Trainer") -> Dict[str, Any]: } -def __lr_finder_reset_params(trainer: "pl.Trainer", num_training: int, early_stop_threshold: float) -> None: +def __lr_finder_reset_params(trainer: "pl.Trainer", num_training: int, early_stop_threshold: Optional[float]) -> None: from pytorch_lightning.loggers.logger import DummyLogger trainer.strategy.lr_scheduler_configs = [] @@ -300,8 +303,8 @@ def __lr_finder_reset_params(trainer: "pl.Trainer", num_training: int, early_sto trainer.callbacks = [_LRCallback(num_training, early_stop_threshold, progress_bar_refresh_rate=1)] # No logging trainer.logger = DummyLogger() if trainer.logger is not None else None - # Max step set to number of iterations - trainer.fit_loop.max_steps = num_training + # Max step set to number of iterations starting at current number of iterations + trainer.fit_loop.max_steps = num_training + trainer.global_step trainer.limit_val_batches = num_training @@ -340,7 +343,7 @@ class _LRCallback(Callback): def __init__( self, num_training: int, - early_stop_threshold: float = 4.0, + early_stop_threshold: Optional[float] = 4.0, progress_bar_refresh_rate: int = 0, beta: float = 0.98, ): diff --git a/tests/tests_pytorch/tuner/test_lr_finder.py b/tests/tests_pytorch/tuner/test_lr_finder.py index ed4d9d33430f0..25fdcd35f31f7 100644 --- a/tests/tests_pytorch/tuner/test_lr_finder.py +++ b/tests/tests_pytorch/tuner/test_lr_finder.py @@ -441,6 +441,53 @@ def test_if_lr_finder_callback_already_configured(): trainer.tune(model) +def test_lr_finder_callback_restarting(tmpdir): + """Test that `LearningRateFinder` does not set restarting=True when loading checkpoint.""" + + num_lr_steps = 100 + + class MyBoringModel(BoringModel): + def __init__(self): + super().__init__() + self.learning_rate = 0.123 + + def on_train_batch_start(self, batch, batch_idx): + if getattr(self, "_expected_max_steps", None) is not None: + assert self.trainer.fit_loop.max_steps == self._expected_max_steps + + def configure_optimizers(self): + return torch.optim.SGD(self.parameters(), lr=self.learning_rate) + + class CustomLearningRateFinder(LearningRateFinder): + milestones = (1,) + + def lr_find(self, trainer, pl_module) -> None: + pl_module._expected_max_steps = trainer.global_step + self._num_training_steps + super().lr_find(trainer, pl_module) + pl_module._expected_max_steps = None + assert not trainer.fit_loop.restarting + + def on_train_epoch_start(self, trainer, pl_module): + if trainer.current_epoch in self.milestones or trainer.current_epoch == 0: + self.lr_find(trainer, pl_module) + + model = MyBoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=3, + callbacks=[ + CustomLearningRateFinder(early_stop_threshold=None, update_attr=True, num_training_steps=num_lr_steps) + ], + limit_train_batches=10, + limit_val_batches=0, + limit_test_batches=0, + num_sanity_val_steps=0, + enable_model_summary=False, + ) + + trainer.fit(model) + + @mock.patch.dict(os.environ, os.environ.copy(), clear=True) @RunIf(standalone=True) def test_lr_finder_with_ddp(tmpdir):