Skip to content

Commit

Permalink
Fix trainer.save_checkpoint after trainer.test with FSDP (#18992)
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli authored Nov 13, 2023
1 parent 532c723 commit 3acea8d
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 3 deletions.
3 changes: 3 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed an issue where Metric instances from `torchmetrics` wouldn't get moved to the device when using FSDP ([#18954](https://github.com/Lightning-AI/lightning/issues/18954))


- Fixed an issue preventing the user to `Trainer.save_checkpoint()` an FSDP model when `Trainer.test/validate/predict()` ran after `Trainer.fit()` ([#18992](https://github.com/Lightning-AI/lightning/issues/18992))


## [2.1.0] - 2023-10-11

### Added
Expand Down
5 changes: 5 additions & 0 deletions src/lightning/pytorch/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,11 @@ def setup(self, trainer: "pl.Trainer") -> None:

@override
def setup_optimizers(self, trainer: "pl.Trainer") -> None:
# If we're setting up for evaluation after fitting, we need to discard the optimizers
# since we're rewrapping the model, otherwise optimizer param references are no longer valid
# and subsequent checkpoint saving can fail
self._reset_optimizers_and_schedulers()

if self.kwargs.get("use_orig_params"):
return super().setup_optimizers(trainer)

Expand Down
5 changes: 5 additions & 0 deletions src/lightning/pytorch/strategies/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,11 @@ def on_exception(self, exception: BaseException) -> None:
"""Called when the trainer execution is interrupted by an exception."""
pass

def _reset_optimizers_and_schedulers(self) -> None:
self._optimizers = []
self._lightning_optimizers = []
self.lr_scheduler_configs = []

def __getstate__(self) -> Dict:
# `LightningOptimizer` overrides `self.__class__` so they cannot be pickled
state = dict(vars(self)) # copy
Expand Down
10 changes: 7 additions & 3 deletions tests/tests_pytorch/strategies/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,13 @@ def _assert_layer_fsdp_instance(self) -> None:

def _run_multiple_stages(trainer, model, model_path: Optional[str] = None):
trainer.fit(model)
trainer.test(model)

model_path = trainer.strategy.broadcast(model_path)
model_path = model_path if model_path else trainer.checkpoint_callback.last_model_path
model_path = Path(model_path if model_path else trainer.checkpoint_callback.last_model_path)

# Save another checkpoint after testing, without optimizer states
trainer.save_checkpoint(model_path.with_name("after-test"))
trainer.save_checkpoint(model_path, weights_only=True)

_assert_save_equality(trainer, model_path, cls=model.__class__)
Expand Down Expand Up @@ -270,13 +274,13 @@ def training_step(self, batch, batch_idx):
trainer.fit(model)


@RunIf(min_cuda_gpus=1, skip_windows=True, standalone=True)
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True)
@pytest.mark.parametrize("precision", ["16-mixed", pytest.param("bf16-mixed", marks=RunIf(bf16_cuda=True))])
def test_fsdp_strategy_checkpoint(tmpdir, precision):
"""Test to ensure that checkpoint is saved correctly when using a single GPU, and all stages can be run."""
model = TestFSDPModel()
trainer = Trainer(
default_root_dir=tmpdir, accelerator="gpu", devices=1, strategy="fsdp", precision=precision, max_epochs=1
default_root_dir=tmpdir, accelerator="gpu", devices=2, strategy="fsdp", precision=precision, max_epochs=1
)
_run_multiple_stages(trainer, model, os.path.join(tmpdir, "last.ckpt"))

Expand Down

0 comments on commit 3acea8d

Please sign in to comment.