Skip to content

Commit

Permalink
Enable loading full state dict checkpoints with FSDP (#17623)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholí <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored May 31, 2023
1 parent e0ce34e commit fd296e0
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 36 deletions.
3 changes: 3 additions & 0 deletions src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for saving checkpoints with either full state-dict or sharded state dict via `FSDPStrategy(state_dict_type="full"|"sharded")` ([#17526](https://github.com/Lightning-AI/lightning/pull/17526))


- Added support for loading a full-state checkpoint file into a sharded model ([#17623](https://github.com/Lightning-AI/lightning/pull/17623))


### Changed

- Allow using iterable-style datasets with TPUs ([#17331](https://github.com/Lightning-AI/lightning/pull/17331))
Expand Down
105 changes: 70 additions & 35 deletions src/lightning/fabric/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, FullyShardedDataParallel, MixedPrecision

_FSDP_ALIASES = ("fsdp", "fsdp_cpu_offload")
_METADATA_FILENAME = "meta.pt"


class FSDPStrategy(ParallelStrategy, _Sharded):
Expand Down Expand Up @@ -385,7 +386,7 @@ def save_checkpoint(
save_state_dict(converted_state, writer)

if self.global_rank == 0:
torch.save(metadata, path / "meta.pt")
torch.save(metadata, path / _METADATA_FILENAME)

elif self._state_dict_type == "full":
state_dict_ctx = _get_full_state_dict_context(module)
Expand Down Expand Up @@ -425,11 +426,6 @@ def load_checkpoint(
)
# broadcast the path from rank 0 to ensure all the states are loaded from a common path
path = Path(self.broadcast(path))
if path.is_file():
raise NotImplementedError(
f"The path `{path}` is a file, but the `FSDPStrategy` currently only supports loading from a checkpoint"
f" with sharded states in a directory."
)

from torch.distributed.checkpoint import FileSystemReader, load_state_dict
from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict
Expand All @@ -451,39 +447,69 @@ def load_checkpoint(
)
module_key, module = list(modules.items())[0]

state_dict_ctx = _get_sharded_state_dict_context(module)
reader = FileSystemReader(path=path)

with state_dict_ctx:
module_state = {module_key: module.state_dict()}
load_state_dict(module_state, reader)
module.load_state_dict(module_state[module_key])
if _is_sharded_checkpoint(path):
state_dict_ctx = _get_sharded_state_dict_context(module)
reader = FileSystemReader(path=path)

# the optimizer states must be loaded separately
for optim_key, optim in optimizers.items():
optim_state = load_sharded_optimizer_state_dict(
model_state_dict=module_state[module_key],
optimizer_key=optim_key,
storage_reader=reader,
)
flattened_osd = FSDP.optim_state_dict_to_load(
optim_state_dict=optim_state[optim_key],
model=module,
optim=optim,
with state_dict_ctx:
module_state = {module_key: module.state_dict()}
load_state_dict(module_state, reader)
module.load_state_dict(module_state[module_key])

# the optimizer states must be loaded separately
for optim_key, optim in optimizers.items():
optim_state = load_sharded_optimizer_state_dict(
model_state_dict=module_state[module_key],
optimizer_key=optim_key,
storage_reader=reader,
)
flattened_osd = FSDP.optim_state_dict_to_load(
optim_state_dict=optim_state[optim_key],
model=module,
optim=optim,
)
optim.load_state_dict(flattened_osd)

# Load metadata (anything not a module or optimizer)
metadata = torch.load(path / _METADATA_FILENAME)
for key, obj in state.items():
if isinstance(obj, (FSDP, Optimizer)):
continue
if key not in metadata:
raise KeyError(f"'{key}' not found in the checkpoint.")
state[key] = metadata.pop(key)

# return the remaining metadata that wasn't requested as part of `state`
return metadata

if _is_full_checkpoint(path):
if optimizers:
rank_zero_warn(
"Loading a full-state checkpoint into FSDP currently only supports loading the model weights."
" The optimizer state won't be reloaded."
)
optim.load_state_dict(flattened_osd)

# Load metadata (anything not a module or optimizer)
metadata = torch.load(path / "meta.pt")
for key, obj in state.items():
if isinstance(obj, (FSDP, Optimizer)):
continue
if key not in metadata:
raise KeyError(f"'{key}' not found in the checkpoint.")
state[key] = metadata.pop(key)
# This is inefficient, as multiple copies of the checkpoint are held in CPU memory at once.
# There is currently no other way because `summon_full_params` does not support write-back from rank 0 only.
checkpoint = torch.load(path, map_location="cpu")
with FSDP.summon_full_params(module, writeback=True, rank0_only=False):
module.load_state_dict(checkpoint.pop(module_key))

# Load metadata (anything not a module or optimizer)
for key, obj in state.items():
if isinstance(obj, (FSDP, Optimizer)):
continue
if key not in checkpoint:
raise KeyError(f"'{key}' not found in the checkpoint.")
state[key] = checkpoint.pop(key)

# return the remaining metadata that wasn't requested as part of `state`
return metadata
# return the remaining metadata that wasn't requested as part of `state`
return checkpoint

raise ValueError(
f"The path {str(path)!r} does not point to a valid checkpoint. Make sure the path points to either a"
" directory with FSDP checkpoint shards, or a single file with a full checkpoint."
)

@classmethod
def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None:
Expand Down Expand Up @@ -597,3 +623,12 @@ def _get_full_state_dict_context(module: "FullyShardedDataParallel") -> _Generat
optim_state_dict_config=optim_state_dict_config,
)
return state_dict_type_context


def _is_sharded_checkpoint(path: Path) -> bool:
"""A heuristic check to determine whether the path points to a directory with checkpoint shards."""
return path.is_dir() and (path / _METADATA_FILENAME).is_file()


def _is_full_checkpoint(path: Path) -> bool:
return path.is_file()
11 changes: 11 additions & 0 deletions tests/tests_fabric/strategies/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,17 @@ def test_fsdp_save_checkpoint_unknown_state_dict_type(tmp_path):
strategy.save_checkpoint(path=tmp_path, state={"model": model})


@RunIf(min_torch="2.0.0")
def test_fsdp_load_unkown_checkpoint_type(tmp_path):
"""Test that the strategy validates the contents at the checkpoint path."""
strategy = FSDPStrategy()
model = Mock(spec=FullyShardedDataParallel)
path = tmp_path / "empty_dir" # neither a single file nor a directory with meta file
path.mkdir()
with pytest.raises(ValueError, match="does not point to a valid checkpoint"):
strategy.load_checkpoint(path=path, state={"model": model})


@RunIf(min_torch="1.12")
@mock.patch("torch.distributed.init_process_group")
def test_set_timeout(init_process_group_mock):
Expand Down
37 changes: 36 additions & 1 deletion tests/tests_fabric/strategies/test_fsdp_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from contextlib import nullcontext
from copy import deepcopy
from pathlib import Path
from unittest import mock
from unittest.mock import ANY

import pytest
import torch
from lightning_utilities.test.warning import no_warning_call
from torch.nn import Parameter

from lightning.fabric import Fabric
Expand Down Expand Up @@ -121,7 +124,7 @@ def test_fsdp_train_save_load(tmp_path, manual_wrapping, precision):


@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.0.0")
def test_fsdp_save_load_full_state_dict(tmp_path):
def test_fsdp_save_full_state_dict(tmp_path):
"""Test that FSDP saves the full state into a single file with `state_dict_type="full"`."""
fabric = BoringFabric(
accelerator="cuda",
Expand Down Expand Up @@ -154,6 +157,38 @@ def test_fsdp_save_load_full_state_dict(tmp_path):
assert all(torch.equal(p0, p1) for p0, p1 in zip(params_before, params_after))


@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.0.0")
def test_fsdp_load_full_state_dict_into_sharded_model(tmp_path):
"""Test that the strategy can load a full-state checkpoint into a FSDP sharded model."""
fabric = BoringFabric(accelerator="cuda", devices=1)
fabric.run()

# Save a full-state-dict checkpoint
checkpoint_path = Path(fabric.broadcast(str(tmp_path / "full-checkpoint.pt")))
state = {"model": fabric.model, "optimizer": fabric.optimizer, "steps": 1}
fabric.save(checkpoint_path, state)

# Create a FSDP sharded model
fabric = BoringFabric(
accelerator="cuda",
strategy=FSDPStrategy(auto_wrap_policy=always_wrap_policy),
devices=2,
)
fabric.run()

warning_msg = "currently only supports loading the model weights"
warns = pytest.warns(UserWarning, match=warning_msg) if fabric.global_rank == 0 else nullcontext()
state = {"model": fabric.model, "optimizer": fabric.optimizer, "steps": 44}
with warns:
fabric.load(checkpoint_path, state)
assert state["steps"] == 1

state = {"model": fabric.model}
with no_warning_call(UserWarning, match=warning_msg):
remainder = fabric.load(checkpoint_path, state)
assert remainder == {"steps": 1, "optimizer": ANY}


@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="1.12")
@pytest.mark.parametrize("move_to_device", [True, False])
@mock.patch("lightning.fabric.wrappers._FabricModule")
Expand Down

0 comments on commit fd296e0

Please sign in to comment.