From 7668a6bf598d27c85c8859bbcc5e39bccf1277ca Mon Sep 17 00:00:00 2001 From: Liyang90 Date: Wed, 5 Jun 2024 17:15:03 -0700 Subject: [PATCH] Flexible and easy to use HSDP setting (#19504) Co-authored-by: awaelchli --- src/lightning/fabric/CHANGELOG.md | 2 + src/lightning/fabric/strategies/fsdp.py | 20 +++++++++- src/lightning/pytorch/CHANGELOG.md | 2 + src/lightning/pytorch/strategies/fsdp.py | 41 +++++++++++++++++++-- tests/tests_fabric/strategies/test_fsdp.py | 7 +++- tests/tests_pytorch/strategies/test_fsdp.py | 7 +++- 6 files changed, 73 insertions(+), 6 deletions(-) diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index ea595b2635138..2ee0243a0d5a5 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -19,6 +19,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a call to `torch.distributed.destroy_process_group` in atexit handler if process group needs destruction ([#19931](https://github.com/Lightning-AI/pytorch-lightning/pull/19931)) +- Added support for configuring hybrid-sharding by passing a tuple for the `FSDPStrategy(device_mesh=...)` argument ([#19504](https://github.com/Lightning-AI/pytorch-lightning/pull/19504)) + ### Changed diff --git a/src/lightning/fabric/strategies/fsdp.py b/src/lightning/fabric/strategies/fsdp.py index eb125d191df94..9a711b8449c3e 100644 --- a/src/lightning/fabric/strategies/fsdp.py +++ b/src/lightning/fabric/strategies/fsdp.py @@ -74,12 +74,14 @@ from lightning.fabric.utilities.types import _PATH, _Stateful if TYPE_CHECKING: + from torch.distributed.device_mesh import DeviceMesh from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision, ShardingStrategy from torch.distributed.fsdp.wrap import ModuleWrapPolicy _POLICY = Union[Set[Type[Module]], Callable[[Module, bool, int], bool], ModuleWrapPolicy] _SHARDING_STRATEGY = Union[ShardingStrategy, Literal["FULL_SHARD", "SHARD_GRAD_OP", "NO_SHARD", "HYBRID_SHARD"]] + _FSDP_ALIASES = ("fsdp", "fsdp_cpu_offload") @@ -117,10 +119,14 @@ class FSDPStrategy(ParallelStrategy, _Sharded): - ``"SHARD_GRAD_OP"``: Shards gradients and optimizer states only. Model parameters get replicated. - ``"NO_SHARD"``: No sharding (identical to regular DDP). - ``"HYBRID_SHARD"``: Shards model parameters, gradients, and optimizer states within a single machine, but - replicates across machines. + replicates across machines. See also the `device_mesh` parameter below. Also accepts a :class:`torch.distributed.fsdp.ShardingStrategy` enum value. + device_mesh: A tuple `(replication size, sharding size)` that defines over how many devices to shard and + replicate the model. The product of the two numbers must equal the world size. Only valid in combination + with the `HYBRID_SHARD` sharding strategy. + state_dict_type: The format in which the state of the model and optimizers gets saved into the checkpoint. - ``"full"``: The full weights and optimizer states get assembled on rank 0 and saved to a single file. @@ -146,6 +152,7 @@ def __init__( activation_checkpointing_policy: Optional["_POLICY"] = None, sharding_strategy: "_SHARDING_STRATEGY" = "FULL_SHARD", state_dict_type: Literal["full", "sharded"] = "sharded", + device_mesh: Optional[Union[Tuple[int], "DeviceMesh"]] = None, **kwargs: Any, ) -> None: super().__init__( @@ -163,6 +170,11 @@ def __init__( # Enables joint setup of model and optimizer, multiple optimizer param groups, and `torch.compile()` self._fsdp_kwargs.setdefault("use_orig_params", True) + if device_mesh is not None: + if not _TORCH_GREATER_EQUAL_2_2: + raise ValueError("The `device_mesh` argument is only supported in torch >= 2.2.") + self._fsdp_kwargs["device_mesh"] = device_mesh + self._activation_checkpointing_kwargs = _activation_checkpointing_kwargs( activation_checkpointing, activation_checkpointing_policy ) @@ -244,6 +256,12 @@ def setup_environment(self) -> None: super().setup_environment() self._setup_distributed() + # if 'device_mesh' in the `_fsdp_kwargs` is provided as a tuple, update it into the `DeviceMesh` object here + if isinstance(self._fsdp_kwargs.get("device_mesh"), tuple): + from torch.distributed.device_mesh import init_device_mesh + + self._fsdp_kwargs["device_mesh"] = init_device_mesh("cuda", self._fsdp_kwargs["device_mesh"]) + @override def setup_module_and_optimizers( self, module: Module, optimizers: List[Optimizer] diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 1e3ae02dd2c1b..b47c01592882d 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -22,6 +22,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a call to `torch.distributed.destroy_process_group` in atexit handler if process group needs destruction ([#19931](https://github.com/Lightning-AI/pytorch-lightning/pull/19931)) +- Added support for configuring hybrid-sharding by passing a tuple for the `FSDPStrategy(device_mesh=...)` argument ([#19504](https://github.com/Lightning-AI/pytorch-lightning/pull/19504)) + ### Changed diff --git a/src/lightning/pytorch/strategies/fsdp.py b/src/lightning/pytorch/strategies/fsdp.py index 3c352e8174ddc..90f6c1febdccb 100644 --- a/src/lightning/pytorch/strategies/fsdp.py +++ b/src/lightning/pytorch/strategies/fsdp.py @@ -16,7 +16,21 @@ from contextlib import contextmanager, nullcontext from datetime import timedelta from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Literal, Mapping, Optional, Set, Type, Union +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Generator, + List, + Literal, + Mapping, + Optional, + Set, + Tuple, + Type, + Union, +) import torch from lightning_utilities.core.rank_zero import rank_zero_only as utils_rank_zero_only @@ -53,7 +67,10 @@ _sync_ddp_if_available, ) from lightning.fabric.utilities.distributed import group as _group -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_1 +from lightning.fabric.utilities.imports import ( + _TORCH_GREATER_EQUAL_2_1, + _TORCH_GREATER_EQUAL_2_2, +) from lightning.fabric.utilities.init import _EmptyInit, _has_meta_device_parameters_or_buffers from lightning.fabric.utilities.load import _lazy_load, _materialize_tensors from lightning.fabric.utilities.optimizer import _optimizers_to_device @@ -70,6 +87,7 @@ from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_only, rank_zero_warn if TYPE_CHECKING: + from torch.distributed.device_mesh import DeviceMesh from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision, ShardingStrategy from torch.distributed.fsdp.wrap import ModuleWrapPolicy @@ -114,10 +132,14 @@ class FSDPStrategy(ParallelStrategy): - ``"SHARD_GRAD_OP"``: Shards gradients and optimizer states only. Model parameters get replicated. - ``"NO_SHARD"``: No sharding (identical to regular DDP). - ``"HYBRID_SHARD"``: Shards model parameters, gradients, and optimizer states within a single machine, but - replicates across machines. + replicates across machines. See also the `device_mesh` parameter below. Also accepts a :class:`torch.distributed.fsdp.ShardingStrategy` enum value. + device_mesh: A tuple `(replication size, sharding size)` that defines over how many devices to shard and + replicate the model. The product of the two numbers must equal the world size. Only valid in combination + with the `HYBRID_SHARD` sharding strategy. + state_dict_type: The format in which the state of the model and optimizers gets saved into the checkpoint. - ``"full"``: The full weights and optimizer states get assembled on rank 0 and saved to a single file. @@ -147,6 +169,7 @@ def __init__( activation_checkpointing_policy: Optional["_POLICY"] = None, sharding_strategy: "_SHARDING_STRATEGY" = "FULL_SHARD", state_dict_type: Literal["full", "sharded"] = "full", + device_mesh: Optional[Union[Tuple[int], "DeviceMesh"]] = None, **kwargs: Any, ) -> None: super().__init__( @@ -162,6 +185,12 @@ def __init__( self.cpu_offload = _init_cpu_offload(cpu_offload) self.mixed_precision = mixed_precision self.kwargs = _auto_wrap_policy_kwargs(auto_wrap_policy, kwargs) + + if device_mesh is not None: + if not _TORCH_GREATER_EQUAL_2_2: + raise ValueError("The `device_mesh` argument is only supported in torch >= 2.2.") + self.kwargs["device_mesh"] = device_mesh + self.sharding_strategy = _init_sharding_strategy(sharding_strategy, self.kwargs) # Avoids the need for user to reference params in `configure_optimizers` via @@ -242,6 +271,12 @@ def setup_environment(self) -> None: assert self.cluster_environment is not None _init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout) + # if 'device_mesh' in the `kwargs` is provided as a tuple, update it into the `DeviceMesh` object here + if isinstance(self.kwargs.get("device_mesh"), tuple): + from torch.distributed.device_mesh import init_device_mesh + + self.kwargs["device_mesh"] = init_device_mesh("cuda", self.kwargs["device_mesh"]) + def _get_process_group_backend(self) -> str: return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device) diff --git a/tests/tests_fabric/strategies/test_fsdp.py b/tests/tests_fabric/strategies/test_fsdp.py index ed0dda85ffaef..1cf2a4d2f1f63 100644 --- a/tests/tests_fabric/strategies/test_fsdp.py +++ b/tests/tests_fabric/strategies/test_fsdp.py @@ -72,7 +72,7 @@ def test_sharding_strategy(): @pytest.mark.parametrize("sharding_strategy", ["HYBRID_SHARD", "_HYBRID_SHARD_ZERO2"]) -def test_hybrid_shard_configuration(sharding_strategy): +def test_hybrid_shard_configuration(sharding_strategy, monkeypatch): """Test that the hybrid sharding strategies can only be used with automatic wrapping or a manually specified pg.""" with pytest.raises(RuntimeError, match="The hybrid sharding strategy requires you to pass at least one of"): FSDPStrategy(sharding_strategy=sharding_strategy) @@ -85,6 +85,11 @@ def test_hybrid_shard_configuration(sharding_strategy): assert strategy.sharding_strategy.name == sharding_strategy assert strategy._fsdp_kwargs["process_group"] is process_group + monkeypatch.setattr("lightning.fabric.strategies.fsdp._TORCH_GREATER_EQUAL_2_2", False) + with pytest.raises(ValueError, match="`device_mesh` argument is only supported in torch >= 2.2."): + FSDPStrategy(device_mesh=Mock()) + + monkeypatch.setattr("lightning.fabric.strategies.fsdp._TORCH_GREATER_EQUAL_2_2", True) device_mesh = Mock() strategy = FSDPStrategy(sharding_strategy=sharding_strategy, device_mesh=device_mesh) assert strategy.sharding_strategy.name == sharding_strategy diff --git a/tests/tests_pytorch/strategies/test_fsdp.py b/tests/tests_pytorch/strategies/test_fsdp.py index 5557c07df9960..04eeabbbd7c49 100644 --- a/tests/tests_pytorch/strategies/test_fsdp.py +++ b/tests/tests_pytorch/strategies/test_fsdp.py @@ -501,7 +501,7 @@ def test_sharding_strategy(): @pytest.mark.parametrize("sharding_strategy", ["HYBRID_SHARD", "_HYBRID_SHARD_ZERO2"]) -def test_hybrid_sharding_strategy(sharding_strategy): +def test_hybrid_shard_configuration(sharding_strategy, monkeypatch): """Test that the hybrid sharding strategies can only be used with automatic wrapping or a manually specified pg.""" with pytest.raises(RuntimeError, match="The hybrid sharding strategy requires you to pass at least one of"): FSDPStrategy(sharding_strategy=sharding_strategy) @@ -514,6 +514,11 @@ def test_hybrid_sharding_strategy(sharding_strategy): assert strategy.sharding_strategy.name == sharding_strategy assert strategy.kwargs["process_group"] is process_group + monkeypatch.setattr("lightning.pytorch.strategies.fsdp._TORCH_GREATER_EQUAL_2_2", False) + with pytest.raises(ValueError, match="`device_mesh` argument is only supported in torch >= 2.2."): + FSDPStrategy(device_mesh=Mock()) + + monkeypatch.setattr("lightning.pytorch.strategies.fsdp._TORCH_GREATER_EQUAL_2_2", True) device_mesh = Mock() strategy = FSDPStrategy(sharding_strategy=sharding_strategy, device_mesh=device_mesh) assert strategy.sharding_strategy.name == sharding_strategy