Skip to content

Commit

Permalink
Fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca committed Feb 11, 2022
1 parent 72cc318 commit 124e9ce
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 10 deletions.
8 changes: 4 additions & 4 deletions pytorch_lightning/core/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ def __init__(self, optimizer: Optimizer):
self._strategy: Optional[pl.strategies.Strategy] = None
self._optimizer_idx = 0
# to inject logic around the optimizer step, particularly useful with manual optimization
self.on_before_step = do_nothing_closure
self.on_after_step = do_nothing_closure
self._on_before_step = do_nothing_closure
self._on_after_step = do_nothing_closure

@property
def optimizer(self) -> Optimizer:
Expand Down Expand Up @@ -157,7 +157,7 @@ def closure_dis():
with opt_dis.toggle_model(sync_grad=accumulated_grad_batches):
opt_dis.step(closure=closure_dis)
"""
self.on_before_step()
self._on_before_step()

if closure is None:
closure = do_nothing_closure
Expand All @@ -173,7 +173,7 @@ def closure_dis():
with self._strategy.lightning_module.trainer.profiler.profile(profiler_action):
step_output = self._strategy.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs)

self.on_after_step()
self._on_after_step()

return step_output

Expand Down
8 changes: 4 additions & 4 deletions pytorch_lightning/loops/optimization/manual_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ def reset(self) -> None:
def on_run_start(self, *_: Any, **__: Any) -> None:
# inject logic around the optimizer step
for i, lightning_optimizer in self.trainer.strategy._lightning_optimizers.items():
lightning_optimizer.on_before_step = self._on_before_step
lightning_optimizer.on_after_step = self._on_after_step
lightning_optimizer._on_before_step = self._on_before_step
lightning_optimizer._on_after_step = self._on_after_step

def advance(self, batch: Any, batch_idx: int) -> None: # type: ignore[override]
"""Performs the training step for manual optimization.
Expand Down Expand Up @@ -140,8 +140,8 @@ def on_run_end(self) -> _OUTPUTS_TYPE:
output, self._output = self._output, {} # free memory
# reset logic around the optimizer step
for i, lightning_optimizer in self.trainer.strategy._lightning_optimizers.items():
lightning_optimizer.on_before_step = do_nothing_closure
lightning_optimizer.on_after_step = do_nothing_closure
lightning_optimizer._on_before_step = do_nothing_closure
lightning_optimizer._on_after_step = do_nothing_closure
return output

def _on_before_step(self) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
checkpoint = {
# the epoch is saved for compatibility but it's not relevant for restoration
"epoch": self.trainer.current_epoch,
"global_step": self.trainer.global_step + 1,
"global_step": self.trainer.global_step + model.automatic_optimization,
"pytorch-lightning_version": pl.__version__,
"state_dict": self._get_lightning_module_state_dict(),
"loops": self._get_loops_state_dict(),
Expand Down
3 changes: 2 additions & 1 deletion tests/core/test_lightning_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,8 @@ def test_state(tmpdir):
lightning_dict = {
k: v
for k, v in lightning_optimizer.__dict__.items()
if k not in {"_optimizer", "_optimizer_idx", "_strategy", "_lightning_module"}
if k
not in {"_optimizer", "_optimizer_idx", "_strategy", "_lightning_module", "_on_before_step", "_on_after_step"}
}

assert lightning_dict == optimizer.__dict__
Expand Down

0 comments on commit 124e9ce

Please sign in to comment.