Skip to content

Commit

Permalink
Avoid FSDP deprecations during save/load with newer torch versions (#…
Browse files Browse the repository at this point in the history
…19463)

* Avoid FSDP deprecations during save/load with newer torch versions

* Refactor

* Tests
  • Loading branch information
carmocca authored Feb 14, 2024
1 parent 59e45d6 commit 6745994
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 59 deletions.
78 changes: 53 additions & 25 deletions src/lightning/fabric/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
_TORCH_GREATER_EQUAL_2_0,
_TORCH_GREATER_EQUAL_2_1,
_TORCH_GREATER_EQUAL_2_2,
_TORCH_GREATER_EQUAL_2_3,
)
from lightning.fabric.utilities.init import _EmptyInit
from lightning.fabric.utilities.load import _METADATA_FILENAME, _lazy_load, _materialize_tensors, _move_state_into
Expand Down Expand Up @@ -448,7 +449,6 @@ def save_checkpoint(
if path.is_dir() and self._state_dict_type == "full" and not _is_sharded_checkpoint(path):
raise IsADirectoryError(f"The checkpoint path exists and is a directory: {path}")

from torch.distributed.checkpoint import FileSystemWriter, save_state_dict
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

modules = [module for module in state.values() if _has_fsdp_modules(module)]
Expand Down Expand Up @@ -491,9 +491,7 @@ def save_checkpoint(
target_dict = metadata
_apply_filter(key, filter or {}, converted, target_dict)

# FSDP's FileSystemWriter streams the tensors to disk to minimize memory peaks
writer = FileSystemWriter(path=path, single_file_per_rank=True)
save_state_dict(converted_state, writer)
_distributed_checkpoint_save(converted_state, path)

if self.global_rank == 0:
torch.save(metadata, path / _METADATA_FILENAME)
Expand Down Expand Up @@ -555,16 +553,10 @@ def load_checkpoint(
"Loading a single optimizer object from a checkpoint is not supported yet with the FSDP strategy."
)

from torch.distributed.checkpoint import FileSystemReader
from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import OptimStateKeyType

if _TORCH_GREATER_EQUAL_2_2:
from torch.distributed.checkpoint import load
else:
from torch.distributed.checkpoint import load_state_dict as load # deprecated

modules = {key: module for key, module in state.items() if _has_fsdp_modules(module)}
if len(modules) == 0:
raise ValueError(
Expand All @@ -583,26 +575,30 @@ def load_checkpoint(

if _is_sharded_checkpoint(path):
state_dict_ctx = _get_sharded_state_dict_context(module)
reader = FileSystemReader(path=path)

with state_dict_ctx:
module_state = {module_key: module.state_dict()}
load(module_state, reader)
_distributed_checkpoint_load(module_state, path)
module.load_state_dict(module_state[module_key], strict=strict)

# the optimizer states must be loaded separately
for optim_key, optim in optimizers.items():
optim_state = load_sharded_optimizer_state_dict(
model_state_dict=module_state[module_key],
optimizer_key=optim_key,
storage_reader=reader,
)
flattened_osd = FSDP.optim_state_dict_to_load(
optim_state_dict=optim_state[optim_key],
model=module,
optim=optim,
)
optim.load_state_dict(flattened_osd)
if optimizers:
from torch.distributed.checkpoint import FileSystemReader
# TODO: replace with newer APIs
# https://github.com/pytorch/pytorch/issues/119800#issuecomment-1942156271
reader = FileSystemReader(path=path)
# the optimizer states must be loaded separately
for optim_key, optim in optimizers.items():
optim_state = load_sharded_optimizer_state_dict(
model_state_dict=module_state[module_key],
optimizer_key=optim_key,
storage_reader=reader,
)
flattened_osd = FSDP.optim_state_dict_to_load(
optim_state_dict=optim_state[optim_key],
model=module,
optim=optim,
)
optim.load_state_dict(flattened_osd)

# Load metadata (anything not a module or optimizer)
metadata = torch.load(path / _METADATA_FILENAME)
Expand Down Expand Up @@ -920,3 +916,35 @@ def _move_torchmetrics_to_device(module: torch.nn.Module, device: torch.device)

for metric in (m for m in module.modules() if isinstance(m, Metric)):
metric.to(device) # `.to()` is in-place


def _distributed_checkpoint_save(converted_state: Dict[str, Any], path: Path) -> None:
if _TORCH_GREATER_EQUAL_2_3:
from torch.distributed.checkpoint import save
# let torch automatically infer the writer to use. This might also support fsspec paths in the future
# https://github.com/pytorch/pytorch/issues/118036
save(converted_state, checkpoint_id=path) # type: ignore[call-arg]
else: # deprecated
from torch.distributed.checkpoint import FileSystemWriter
if _TORCH_GREATER_EQUAL_2_2:
from torch.distributed.checkpoint import save
else:
from torch.distributed.checkpoint import save_state_dict as save
# FSDP's FileSystemWriter streams the tensors to disk to minimize memory peaks
writer = FileSystemWriter(path=path, single_file_per_rank=True)
save(converted_state, writer)

def _distributed_checkpoint_load(module_state: Dict[str, Any], path: Path) -> None:
if _TORCH_GREATER_EQUAL_2_3:
from torch.distributed.checkpoint import load
# let torch automatically infer the reader to use. This might also support fsspec paths in the future
# https://github.com/pytorch/pytorch/issues/118036
load(module_state, checkpoint_id=path) # type: ignore[call-arg]
else: # deprecated
from torch.distributed.checkpoint import FileSystemReader
if _TORCH_GREATER_EQUAL_2_2:
from torch.distributed.checkpoint import load
else:
from torch.distributed.checkpoint import load_state_dict as load
reader = FileSystemReader(path=path)
load(module_state, reader)
5 changes: 3 additions & 2 deletions src/lightning/fabric/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@
_IS_INTERACTIVE = hasattr(sys, "ps1") or bool(sys.flags.interactive)

_TORCH_GREATER_EQUAL_2_0 = compare_version("torch", operator.ge, "2.0.0")
_TORCH_GREATER_EQUAL_2_1 = compare_version("torch", operator.ge, "2.1.0", use_base_version=True)
_TORCH_GREATER_EQUAL_2_2 = compare_version("torch", operator.ge, "2.2.0", use_base_version=True)
_TORCH_GREATER_EQUAL_2_1 = compare_version("torch", operator.ge, "2.1.0")
_TORCH_GREATER_EQUAL_2_2 = compare_version("torch", operator.ge, "2.2.0")
_TORCH_GREATER_EQUAL_2_3 = compare_version("torch", operator.ge, "2.3.0", use_base_version=True)
_TORCH_EQUAL_2_0 = _TORCH_GREATER_EQUAL_2_0 and not _TORCH_GREATER_EQUAL_2_1

_PYTHON_GREATER_EQUAL_3_8_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 8)
Expand Down
24 changes: 9 additions & 15 deletions src/lightning/pytorch/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
_METADATA_FILENAME,
_activation_checkpointing_kwargs,
_auto_wrap_policy_kwargs,
_distributed_checkpoint_load,
_distributed_checkpoint_save,
_get_full_state_dict_context,
_get_sharded_state_dict_context,
_has_meta_device_parameters,
Expand All @@ -55,7 +57,6 @@
from lightning.fabric.utilities.imports import (
_TORCH_GREATER_EQUAL_2_0,
_TORCH_GREATER_EQUAL_2_1,
_TORCH_GREATER_EQUAL_2_2,
)
from lightning.fabric.utilities.init import _EmptyInit
from lightning.fabric.utilities.load import _lazy_load, _materialize_tensors
Expand Down Expand Up @@ -561,8 +562,6 @@ def save_checkpoint(
raise IsADirectoryError(f"The checkpoint path exists and is a directory: {path}")

if self._state_dict_type == "sharded":
from torch.distributed.checkpoint import FileSystemWriter, save_state_dict

if path.is_file():
path.unlink()
path.mkdir(parents=True, exist_ok=True)
Expand All @@ -572,9 +571,7 @@ def save_checkpoint(
{f"optimizer_{idx}": optim_state for idx, optim_state in enumerate(checkpoint.pop("optimizer_states"))}
)

# FSDP's FileSystemWriter streams the tensors to disk to minimize memory peaks
writer = FileSystemWriter(path=path, single_file_per_rank=True)
save_state_dict(converted_state, writer)
_distributed_checkpoint_save(converted_state, path)

if self.global_rank == 0:
torch.save(checkpoint, path / _METADATA_FILENAME)
Expand All @@ -596,23 +593,20 @@ def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]:
assert self.lightning_module is not None

if _is_sharded_checkpoint(path):
from torch.distributed.checkpoint import FileSystemReader
from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict

if _TORCH_GREATER_EQUAL_2_2:
from torch.distributed.checkpoint import load
else:
from torch.distributed.checkpoint import load_state_dict as load # deprecated

state_dict_ctx = _get_sharded_state_dict_context(self.model)
reader = FileSystemReader(path=path)

with state_dict_ctx:
module_state = {"model": self.model.state_dict()}
load(module_state, reader)
_distributed_checkpoint_load(module_state, path)
self.model.load_state_dict(module_state["model"], strict=self.lightning_module.strict_loading)

if self.lightning_module.trainer.state.fn == TrainerFn.FITTING:
if self.lightning_module.trainer.state.fn == TrainerFn.FITTING and self.optimizers:
from torch.distributed.checkpoint import FileSystemReader
# TODO: replace with newer APIs
# https://github.com/pytorch/pytorch/issues/119800#issuecomment-1942156271
reader = FileSystemReader(path=path)
# the optimizer states must be loaded separately
for idx, optim in enumerate(self.optimizers):
optim_key = f"optimizer_{idx}"
Expand Down
22 changes: 13 additions & 9 deletions tests/tests_fabric/strategies/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
_has_meta_device_parameters,
_is_sharded_checkpoint,
)
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_1
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_1, _TORCH_GREATER_EQUAL_2_2
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, FullyShardedDataParallel, MixedPrecision
from torch.optim import Adam

Expand Down Expand Up @@ -241,13 +241,12 @@ def test_fsdp_save_checkpoint_storage_options(tmp_path):


@RunIf(min_torch="2.0.0")
@mock.patch("torch.distributed.checkpoint.save_state_dict", return_value=MagicMock())
@mock.patch("lightning.fabric.strategies.fsdp.FSDPStrategy.broadcast", lambda _, x: x)
@mock.patch("lightning.fabric.strategies.fsdp._get_full_state_dict_context", return_value=MagicMock())
@mock.patch("lightning.fabric.strategies.fsdp._get_sharded_state_dict_context", return_value=MagicMock())
@mock.patch("lightning.fabric.strategies.fsdp.torch.save", return_value=Mock())
@mock.patch("lightning.fabric.strategies.fsdp.shutil", return_value=MagicMock())
def test_fsdp_save_checkpoint_path_exists(shutil_mock, torch_save_mock, __, ___, ____, tmp_path):
@mock.patch("lightning.fabric.strategies.fsdp._get_full_state_dict_context")
@mock.patch("lightning.fabric.strategies.fsdp._get_sharded_state_dict_context")
@mock.patch("lightning.fabric.strategies.fsdp.torch.save")
@mock.patch("lightning.fabric.strategies.fsdp.shutil")
def test_fsdp_save_checkpoint_path_exists(shutil_mock, torch_save_mock, __, ___, tmp_path):
strategy = FSDPStrategy(state_dict_type="full")

# state_dict_type='full', path exists, path is not a sharded checkpoint: error
Expand Down Expand Up @@ -278,22 +277,27 @@ def test_fsdp_save_checkpoint_path_exists(shutil_mock, torch_save_mock, __, ___,
torch_save_mock.assert_called_once()

strategy = FSDPStrategy(state_dict_type="sharded")
save_mock = mock.patch(
"torch.distributed.checkpoint.save"
if _TORCH_GREATER_EQUAL_2_2 else "torch.distributed.checkpoint.save_state_dict")

# state_dict_type='sharded', path exists, path is a folder: no error (overwrite)
path = tmp_path / "not-empty-2"
path.mkdir()
(path / "file").touch()
model = Mock(spec=FullyShardedDataParallel)
model.modules.return_value = [model]
strategy.save_checkpoint(path=path, state={"model": model})
with save_mock:
strategy.save_checkpoint(path=path, state={"model": model})
assert (path / "file").exists()

# state_dict_type='sharded', path exists, path is a file: no error (overwrite)
path = tmp_path / "file-2.pt"
path.touch()
model = Mock(spec=FullyShardedDataParallel)
model.modules.return_value = [model]
strategy.save_checkpoint(path=path, state={"model": model})
with save_mock:
strategy.save_checkpoint(path=path, state={"model": model})
assert path.is_dir()


Expand Down
22 changes: 14 additions & 8 deletions tests/tests_pytorch/strategies/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from lightning.fabric.utilities.imports import (
_TORCH_GREATER_EQUAL_2_0,
_TORCH_GREATER_EQUAL_2_1,
_TORCH_GREATER_EQUAL_2_2,
)
from lightning.fabric.utilities.load import _load_distributed_checkpoint
from lightning.pytorch import Trainer
Expand Down Expand Up @@ -801,13 +802,12 @@ def test_save_checkpoint_storage_options(tmp_path):


@RunIf(min_torch="2.0.0")
@mock.patch("torch.distributed.checkpoint.save_state_dict", return_value=MagicMock())
@mock.patch("lightning.pytorch.strategies.fsdp.FSDPStrategy.broadcast", lambda _, x: x)
@mock.patch("lightning.pytorch.strategies.fsdp._get_full_state_dict_context", return_value=MagicMock())
@mock.patch("lightning.pytorch.strategies.fsdp._get_sharded_state_dict_context", return_value=MagicMock())
@mock.patch("lightning.fabric.plugins.io.torch_io._atomic_save", return_value=Mock())
@mock.patch("lightning.pytorch.strategies.fsdp.shutil", return_value=MagicMock())
def test_fsdp_save_checkpoint_path_exists(shutil_mock, torch_save_mock, __, ___, ____, tmp_path):
@mock.patch("lightning.pytorch.strategies.fsdp._get_full_state_dict_context")
@mock.patch("lightning.pytorch.strategies.fsdp._get_sharded_state_dict_context")
@mock.patch("lightning.fabric.plugins.io.torch_io._atomic_save")
@mock.patch("lightning.pytorch.strategies.fsdp.shutil")
def test_fsdp_save_checkpoint_path_exists(shutil_mock, torch_save_mock, __, ___, tmp_path):
strategy = FSDPStrategy(state_dict_type="full")

# state_dict_type='full', path exists, path is not a sharded checkpoint: error
Expand Down Expand Up @@ -839,21 +839,27 @@ def test_fsdp_save_checkpoint_path_exists(shutil_mock, torch_save_mock, __, ___,

strategy = FSDPStrategy(state_dict_type="sharded")

save_mock = mock.patch(
"torch.distributed.checkpoint.save"
if _TORCH_GREATER_EQUAL_2_2 else "torch.distributed.checkpoint.save_state_dict")

# state_dict_type='sharded', path exists, path is a folder: no error (overwrite)
path = tmp_path / "not-empty-2"
path.mkdir()
(path / "file").touch()
model = Mock(spec=FullyShardedDataParallel)
model.modules.return_value = [model]
strategy.save_checkpoint({"state_dict": {}, "optimizer_states": {"": {}}}, filepath=path)
with save_mock:
strategy.save_checkpoint({"state_dict": {}, "optimizer_states": {"": {}}}, filepath=path)
assert (path / "file").exists()

# state_dict_type='sharded', path exists, path is a file: no error (overwrite)
path = tmp_path / "file-2.pt"
path.touch()
model = Mock(spec=FullyShardedDataParallel)
model.modules.return_value = [model]
strategy.save_checkpoint({"state_dict": {}, "optimizer_states": {"": {}}}, filepath=path)
with save_mock:
strategy.save_checkpoint({"state_dict": {}, "optimizer_states": {"": {}}}, filepath=path)
assert path.is_dir()


Expand Down

0 comments on commit 6745994

Please sign in to comment.