diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index e33a7be45ef35..a04481b5eddf9 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -16,7 +16,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - - ## [2.1.1] - 2023-11-06 ### Fixed diff --git a/src/lightning/pytorch/strategies/fsdp.py b/src/lightning/pytorch/strategies/fsdp.py index 1407b61d081d8..ec4c165f349e4 100644 --- a/src/lightning/pytorch/strategies/fsdp.py +++ b/src/lightning/pytorch/strategies/fsdp.py @@ -329,6 +329,11 @@ def setup(self, trainer: "pl.Trainer") -> None: self.setup_precision_plugin() def setup_optimizers(self, trainer: "pl.Trainer") -> None: + # If we're setting up for evaluation after fitting, we need to discard the optimizers + # since we're rewrapping the model, otherwise optimizer param references are no longer valid + # and subsequent checkpoint saving can fail + self._reset_optimizers_and_schedulers() + if self.kwargs.get("use_orig_params"): return super().setup_optimizers(trainer) diff --git a/src/lightning/pytorch/strategies/strategy.py b/src/lightning/pytorch/strategies/strategy.py index 5ea8d19310f0c..d50ed7e11d7e8 100644 --- a/src/lightning/pytorch/strategies/strategy.py +++ b/src/lightning/pytorch/strategies/strategy.py @@ -575,6 +575,11 @@ def on_exception(self, exception: BaseException) -> None: """Called when the trainer execution is interrupted by an exception.""" pass + def _reset_optimizers_and_schedulers(self) -> None: + self._optimizers = [] + self._lightning_optimizers = [] + self.lr_scheduler_configs = [] + def __getstate__(self) -> Dict: # `LightningOptimizer` overrides `self.__class__` so they cannot be pickled state = dict(vars(self)) # copy diff --git a/tests/tests_pytorch/strategies/test_fsdp.py b/tests/tests_pytorch/strategies/test_fsdp.py index bcb338e05e7cf..c21c11d499492 100644 --- a/tests/tests_pytorch/strategies/test_fsdp.py +++ b/tests/tests_pytorch/strategies/test_fsdp.py @@ -173,9 +173,13 @@ def _assert_layer_fsdp_instance(self) -> None: def _run_multiple_stages(trainer, model, model_path: Optional[str] = None): trainer.fit(model) + trainer.test(model) + model_path = trainer.strategy.broadcast(model_path) - model_path = model_path if model_path else trainer.checkpoint_callback.last_model_path + model_path = Path(model_path if model_path else trainer.checkpoint_callback.last_model_path) + # Save another checkpoint after testing, without optimizer states + trainer.save_checkpoint(model_path.with_name("after-test")) trainer.save_checkpoint(model_path, weights_only=True) _assert_save_equality(trainer, model_path, cls=model.__class__) @@ -270,13 +274,13 @@ def training_step(self, batch, batch_idx): trainer.fit(model) -@RunIf(min_cuda_gpus=1, skip_windows=True, standalone=True) +@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True) @pytest.mark.parametrize("precision", ["16-mixed", pytest.param("bf16-mixed", marks=RunIf(bf16_cuda=True))]) def test_fsdp_strategy_checkpoint(tmpdir, precision): """Test to ensure that checkpoint is saved correctly when using a single GPU, and all stages can be run.""" model = TestFSDPModel() trainer = Trainer( - default_root_dir=tmpdir, accelerator="gpu", devices=1, strategy="fsdp", precision=precision, max_epochs=1 + default_root_dir=tmpdir, accelerator="gpu", devices=2, strategy="fsdp", precision=precision, max_epochs=1 ) _run_multiple_stages(trainer, model, os.path.join(tmpdir, "last.ckpt"))