diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index 4770719e7ce5d..954fc2381cbc5 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -56,6 +56,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for saving checkpoints with either full state-dict or sharded state dict via `FSDPStrategy(state_dict_type="full"|"sharded")` ([#17526](https://github.com/Lightning-AI/lightning/pull/17526)) +- Added support for loading a full-state checkpoint file into a sharded model ([#17623](https://github.com/Lightning-AI/lightning/pull/17623)) + + ### Changed - Allow using iterable-style datasets with TPUs ([#17331](https://github.com/Lightning-AI/lightning/pull/17331)) diff --git a/src/lightning/fabric/strategies/fsdp.py b/src/lightning/fabric/strategies/fsdp.py index 48cd5c68a3155..4139dc819be54 100644 --- a/src/lightning/fabric/strategies/fsdp.py +++ b/src/lightning/fabric/strategies/fsdp.py @@ -51,6 +51,7 @@ from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, FullyShardedDataParallel, MixedPrecision _FSDP_ALIASES = ("fsdp", "fsdp_cpu_offload") +_METADATA_FILENAME = "meta.pt" class FSDPStrategy(ParallelStrategy, _Sharded): @@ -385,7 +386,7 @@ def save_checkpoint( save_state_dict(converted_state, writer) if self.global_rank == 0: - torch.save(metadata, path / "meta.pt") + torch.save(metadata, path / _METADATA_FILENAME) elif self._state_dict_type == "full": state_dict_ctx = _get_full_state_dict_context(module) @@ -425,11 +426,6 @@ def load_checkpoint( ) # broadcast the path from rank 0 to ensure all the states are loaded from a common path path = Path(self.broadcast(path)) - if path.is_file(): - raise NotImplementedError( - f"The path `{path}` is a file, but the `FSDPStrategy` currently only supports loading from a checkpoint" - f" with sharded states in a directory." - ) from torch.distributed.checkpoint import FileSystemReader, load_state_dict from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict @@ -451,39 +447,69 @@ def load_checkpoint( ) module_key, module = list(modules.items())[0] - state_dict_ctx = _get_sharded_state_dict_context(module) - reader = FileSystemReader(path=path) - - with state_dict_ctx: - module_state = {module_key: module.state_dict()} - load_state_dict(module_state, reader) - module.load_state_dict(module_state[module_key]) + if _is_sharded_checkpoint(path): + state_dict_ctx = _get_sharded_state_dict_context(module) + reader = FileSystemReader(path=path) - # the optimizer states must be loaded separately - for optim_key, optim in optimizers.items(): - optim_state = load_sharded_optimizer_state_dict( - model_state_dict=module_state[module_key], - optimizer_key=optim_key, - storage_reader=reader, - ) - flattened_osd = FSDP.optim_state_dict_to_load( - optim_state_dict=optim_state[optim_key], - model=module, - optim=optim, + with state_dict_ctx: + module_state = {module_key: module.state_dict()} + load_state_dict(module_state, reader) + module.load_state_dict(module_state[module_key]) + + # the optimizer states must be loaded separately + for optim_key, optim in optimizers.items(): + optim_state = load_sharded_optimizer_state_dict( + model_state_dict=module_state[module_key], + optimizer_key=optim_key, + storage_reader=reader, + ) + flattened_osd = FSDP.optim_state_dict_to_load( + optim_state_dict=optim_state[optim_key], + model=module, + optim=optim, + ) + optim.load_state_dict(flattened_osd) + + # Load metadata (anything not a module or optimizer) + metadata = torch.load(path / _METADATA_FILENAME) + for key, obj in state.items(): + if isinstance(obj, (FSDP, Optimizer)): + continue + if key not in metadata: + raise KeyError(f"'{key}' not found in the checkpoint.") + state[key] = metadata.pop(key) + + # return the remaining metadata that wasn't requested as part of `state` + return metadata + + if _is_full_checkpoint(path): + if optimizers: + rank_zero_warn( + "Loading a full-state checkpoint into FSDP currently only supports loading the model weights." + " The optimizer state won't be reloaded." ) - optim.load_state_dict(flattened_osd) - # Load metadata (anything not a module or optimizer) - metadata = torch.load(path / "meta.pt") - for key, obj in state.items(): - if isinstance(obj, (FSDP, Optimizer)): - continue - if key not in metadata: - raise KeyError(f"'{key}' not found in the checkpoint.") - state[key] = metadata.pop(key) + # This is inefficient, as multiple copies of the checkpoint are held in CPU memory at once. + # There is currently no other way because `summon_full_params` does not support write-back from rank 0 only. + checkpoint = torch.load(path, map_location="cpu") + with FSDP.summon_full_params(module, writeback=True, rank0_only=False): + module.load_state_dict(checkpoint.pop(module_key)) + + # Load metadata (anything not a module or optimizer) + for key, obj in state.items(): + if isinstance(obj, (FSDP, Optimizer)): + continue + if key not in checkpoint: + raise KeyError(f"'{key}' not found in the checkpoint.") + state[key] = checkpoint.pop(key) - # return the remaining metadata that wasn't requested as part of `state` - return metadata + # return the remaining metadata that wasn't requested as part of `state` + return checkpoint + + raise ValueError( + f"The path {str(path)!r} does not point to a valid checkpoint. Make sure the path points to either a" + " directory with FSDP checkpoint shards, or a single file with a full checkpoint." + ) @classmethod def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None: @@ -597,3 +623,12 @@ def _get_full_state_dict_context(module: "FullyShardedDataParallel") -> _Generat optim_state_dict_config=optim_state_dict_config, ) return state_dict_type_context + + +def _is_sharded_checkpoint(path: Path) -> bool: + """A heuristic check to determine whether the path points to a directory with checkpoint shards.""" + return path.is_dir() and (path / _METADATA_FILENAME).is_file() + + +def _is_full_checkpoint(path: Path) -> bool: + return path.is_file() diff --git a/tests/tests_fabric/strategies/test_fsdp.py b/tests/tests_fabric/strategies/test_fsdp.py index ef52cc844dee8..d03404a3a294f 100644 --- a/tests/tests_fabric/strategies/test_fsdp.py +++ b/tests/tests_fabric/strategies/test_fsdp.py @@ -272,6 +272,17 @@ def test_fsdp_save_checkpoint_unknown_state_dict_type(tmp_path): strategy.save_checkpoint(path=tmp_path, state={"model": model}) +@RunIf(min_torch="2.0.0") +def test_fsdp_load_unkown_checkpoint_type(tmp_path): + """Test that the strategy validates the contents at the checkpoint path.""" + strategy = FSDPStrategy() + model = Mock(spec=FullyShardedDataParallel) + path = tmp_path / "empty_dir" # neither a single file nor a directory with meta file + path.mkdir() + with pytest.raises(ValueError, match="does not point to a valid checkpoint"): + strategy.load_checkpoint(path=path, state={"model": model}) + + @RunIf(min_torch="1.12") @mock.patch("torch.distributed.init_process_group") def test_set_timeout(init_process_group_mock): diff --git a/tests/tests_fabric/strategies/test_fsdp_integration.py b/tests/tests_fabric/strategies/test_fsdp_integration.py index 0a0b67df94bcb..d208512c71ddf 100644 --- a/tests/tests_fabric/strategies/test_fsdp_integration.py +++ b/tests/tests_fabric/strategies/test_fsdp_integration.py @@ -12,12 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +from contextlib import nullcontext from copy import deepcopy from pathlib import Path from unittest import mock +from unittest.mock import ANY import pytest import torch +from lightning_utilities.test.warning import no_warning_call from torch.nn import Parameter from lightning.fabric import Fabric @@ -121,7 +124,7 @@ def test_fsdp_train_save_load(tmp_path, manual_wrapping, precision): @RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.0.0") -def test_fsdp_save_load_full_state_dict(tmp_path): +def test_fsdp_save_full_state_dict(tmp_path): """Test that FSDP saves the full state into a single file with `state_dict_type="full"`.""" fabric = BoringFabric( accelerator="cuda", @@ -154,6 +157,38 @@ def test_fsdp_save_load_full_state_dict(tmp_path): assert all(torch.equal(p0, p1) for p0, p1 in zip(params_before, params_after)) +@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.0.0") +def test_fsdp_load_full_state_dict_into_sharded_model(tmp_path): + """Test that the strategy can load a full-state checkpoint into a FSDP sharded model.""" + fabric = BoringFabric(accelerator="cuda", devices=1) + fabric.run() + + # Save a full-state-dict checkpoint + checkpoint_path = Path(fabric.broadcast(str(tmp_path / "full-checkpoint.pt"))) + state = {"model": fabric.model, "optimizer": fabric.optimizer, "steps": 1} + fabric.save(checkpoint_path, state) + + # Create a FSDP sharded model + fabric = BoringFabric( + accelerator="cuda", + strategy=FSDPStrategy(auto_wrap_policy=always_wrap_policy), + devices=2, + ) + fabric.run() + + warning_msg = "currently only supports loading the model weights" + warns = pytest.warns(UserWarning, match=warning_msg) if fabric.global_rank == 0 else nullcontext() + state = {"model": fabric.model, "optimizer": fabric.optimizer, "steps": 44} + with warns: + fabric.load(checkpoint_path, state) + assert state["steps"] == 1 + + state = {"model": fabric.model} + with no_warning_call(UserWarning, match=warning_msg): + remainder = fabric.load(checkpoint_path, state) + assert remainder == {"steps": 1, "optimizer": ANY} + + @RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="1.12") @pytest.mark.parametrize("move_to_device", [True, False]) @mock.patch("lightning.fabric.wrappers._FabricModule")