diff --git a/pytorch_lightning/core/optimizer.py b/pytorch_lightning/core/optimizer.py index 1f58027de54ef..5c6e0137fe06e 100644 --- a/pytorch_lightning/core/optimizer.py +++ b/pytorch_lightning/core/optimizer.py @@ -54,8 +54,8 @@ def __init__(self, optimizer: Optimizer): self._strategy: Optional[pl.strategies.Strategy] = None self._optimizer_idx = 0 # to inject logic around the optimizer step, particularly useful with manual optimization - self.on_before_step = do_nothing_closure - self.on_after_step = do_nothing_closure + self._on_before_step = do_nothing_closure + self._on_after_step = do_nothing_closure @property def optimizer(self) -> Optimizer: @@ -157,7 +157,7 @@ def closure_dis(): with opt_dis.toggle_model(sync_grad=accumulated_grad_batches): opt_dis.step(closure=closure_dis) """ - self.on_before_step() + self._on_before_step() if closure is None: closure = do_nothing_closure @@ -173,7 +173,7 @@ def closure_dis(): with self._strategy.lightning_module.trainer.profiler.profile(profiler_action): step_output = self._strategy.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs) - self.on_after_step() + self._on_after_step() return step_output diff --git a/pytorch_lightning/loops/optimization/manual_loop.py b/pytorch_lightning/loops/optimization/manual_loop.py index eebfe405e7c5f..dc92fde5a1ac1 100644 --- a/pytorch_lightning/loops/optimization/manual_loop.py +++ b/pytorch_lightning/loops/optimization/manual_loop.py @@ -94,8 +94,8 @@ def reset(self) -> None: def on_run_start(self, *_: Any, **__: Any) -> None: # inject logic around the optimizer step for i, lightning_optimizer in self.trainer.strategy._lightning_optimizers.items(): - lightning_optimizer.on_before_step = self._on_before_step - lightning_optimizer.on_after_step = self._on_after_step + lightning_optimizer._on_before_step = self._on_before_step + lightning_optimizer._on_after_step = self._on_after_step def advance(self, batch: Any, batch_idx: int) -> None: # type: ignore[override] """Performs the training step for manual optimization. @@ -140,8 +140,8 @@ def on_run_end(self) -> _OUTPUTS_TYPE: output, self._output = self._output, {} # free memory # reset logic around the optimizer step for i, lightning_optimizer in self.trainer.strategy._lightning_optimizers.items(): - lightning_optimizer.on_before_step = do_nothing_closure - lightning_optimizer.on_after_step = do_nothing_closure + lightning_optimizer._on_before_step = do_nothing_closure + lightning_optimizer._on_after_step = do_nothing_closure return output def _on_before_step(self) -> None: diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index bb6889d797f5c..0712f5df85f51 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -333,7 +333,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: checkpoint = { # the epoch is saved for compatibility but it's not relevant for restoration "epoch": self.trainer.current_epoch, - "global_step": self.trainer.global_step + 1, + "global_step": self.trainer.global_step + model.automatic_optimization, "pytorch-lightning_version": pl.__version__, "state_dict": self._get_lightning_module_state_dict(), "loops": self._get_loops_state_dict(), diff --git a/tests/core/test_lightning_optimizer.py b/tests/core/test_lightning_optimizer.py index 3b2fe7bfde8cc..0ee82fcc2e4ec 100644 --- a/tests/core/test_lightning_optimizer.py +++ b/tests/core/test_lightning_optimizer.py @@ -152,7 +152,8 @@ def test_state(tmpdir): lightning_dict = { k: v for k, v in lightning_optimizer.__dict__.items() - if k not in {"_optimizer", "_optimizer_idx", "_strategy", "_lightning_module"} + if k + not in {"_optimizer", "_optimizer_idx", "_strategy", "_lightning_module", "_on_before_step", "_on_after_step"} } assert lightning_dict == optimizer.__dict__