Skip to content

Commit

Permalink
Add auto wrapping for DDPFullyShardedNativeStrategy (#14252)
Browse files Browse the repository at this point in the history
  • Loading branch information
rohitgr7 authored Aug 26, 2022
1 parent 70deac2 commit 6d00f31
Show file tree
Hide file tree
Showing 6 changed files with 166 additions and 33 deletions.
29 changes: 23 additions & 6 deletions docs/source-pytorch/advanced/model_parallel.rst
Original file line number Diff line number Diff line change
Expand Up @@ -212,14 +212,31 @@ PyTorch Fully Sharded Training
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

PyTorch has it's own version of `FSDP <https://pytorch.org/docs/stable/fsdp.html>`_ which is upstreamed from their `fairscale <https://fairscale.readthedocs.io/en/latest/api/nn/fsdp.html>`__ project.
It was introduced in their `v1.11.0 release <https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/>`_. The API is pretty similar to that of FairScale.
It was introduced in their `v1.11.0 release <https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/>`_ but it is recommended to use it with PyTorch v1.12 or more and that's what
Lightning supports. The API is pretty similar to that of FairScale.

.. note::
Currently Fully Sharded Training relies on the user to wrap the model with Fully Sharded within the ``LightningModule``.
This means you must create a single model that is treated as a ``torch.nn.Module`` within the ``LightningModule``.
This is a limitation of Fully Sharded Training that will be resolved in the future.

To activate parameter sharding, you must wrap your model using the``wrap`` function. Internally in Lightning, we enable a context manager around the ``configure_sharded_model`` function to make sure the ``wrap`` parameters are passed correctly.
Auto Wrapping
"""""""""""""
Model layers should be wrapped in FSDP in a nested way to save peak memory and enable communication and computation overlapping. The
simplest way to do it is auto wrapping, which can serve as a drop-in replacement for DDP without changing the rest of the code. You don't
have to ``wrap`` layers manually as in the case of manual wrapping.

.. code-block:: python
model = BoringModel()
trainer = Trainer(accelerator="gpu", devices=4, strategy="fsdp_native", precision=16)
trainer.fit(model)
Read more `here <https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/#auto-wrapping>`__.


Manual Wrapping
"""""""""""""""

Manual wrapping can be useful to explore complex sharding strategies by applying ``wrap`` selectively to some parts of the model. To activate
parameter sharding with manual wrapping, you can wrap your model using the ``wrap`` function. Internally in Lightning, we enable a context manager around the ``configure_sharded_model`` function to make sure the ``wrap`` parameters are passed correctly.

When not using Fully Sharded these wrap functions are a no-op. This means once the changes have been made, there is no need to remove the changes for other strategies.

Expand Down
3 changes: 3 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added prefix to log message in `seed_everything` with rank info ([#13290](https://github.com/Lightning-AI/lightning/issues/13290))


- Added support for auto wrapping for `DDPFullyShardedNativeStrategy` ([#14252](https://github.com/Lightning-AI/lightning/issues/14252))


- Added support for passing extra init-parameters to the `LightningDataModule.from_datasets` ([#14185](https://github.com/Lightning-AI/lightning/issues/14185))


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,13 @@
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
else:
MixedPrecision = None # type: ignore[misc,assignment]
ShardedGradScaler = None # type: ignore[misc,assignment]


class FullyShardedNativeNativeMixedPrecisionPlugin(NativeMixedPrecisionPlugin):
"""Native AMP for Fully Sharded Native Training."""

def __init__(
self, precision: Union[str, int], device: str, scaler: Optional[torch.cuda.amp.GradScaler] = None
) -> None:
def __init__(self, precision: Union[str, int], device: str, scaler: Optional[ShardedGradScaler] = None) -> None:
if not _TORCH_GREATER_EQUAL_1_12:
raise MisconfigurationException(
"`FullyShardedNativeNativeMixedPrecisionPlugin` is supported from PyTorch v1.12.0 onwards."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
# limitations under the License.
from typing import Optional, Union

import torch

from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _FAIRSCALE_AVAILABLE
Expand All @@ -29,9 +27,7 @@
class ShardedNativeMixedPrecisionPlugin(NativeMixedPrecisionPlugin):
"""Native AMP for Sharded Training."""

def __init__(
self, precision: Union[str, int], device: str, scaler: Optional[torch.cuda.amp.GradScaler] = None
) -> None:
def __init__(self, precision: Union[str, int], device: str, scaler: Optional[ShardedGradScaler] = None) -> None:
if not _FAIRSCALE_AVAILABLE:
raise MisconfigurationException(
"You have asked for sharded AMP but you have not installed it."
Expand Down
55 changes: 55 additions & 0 deletions src/pytorch_lightning/strategies/fully_sharded_native.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from torch.distributed.distributed_c10d import _get_default_group, ProcessGroup

import pytorch_lightning as pl
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.precision import PrecisionPlugin
Expand All @@ -38,9 +39,11 @@
from pytorch_lightning.utilities.distributed import init_dist_connection, ReduceOp, sync_ddp_if_available
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.optimizer import optimizers_to_device
from pytorch_lightning.utilities.rank_zero import rank_zero_info
from pytorch_lightning.utilities.seed import reset_seed
from pytorch_lightning.utilities.types import STEP_OUTPUT

if _TORCH_GREATER_EQUAL_1_12:
from torch.distributed.fsdp.fully_sharded_data_parallel import (
Expand All @@ -51,6 +54,7 @@
)
from torch.distributed.fsdp.wrap import enable_wrap
else:
FullyShardedDataParallel = None # type: ignore[misc,assignment]
MixedPrecision = None # type: ignore[misc,assignment]
BackwardPrefetch = None # type: ignore[misc,assignment]
CPUOffload = None # type: ignore[misc,assignment]
Expand Down Expand Up @@ -201,6 +205,28 @@ def _configure_launcher(self) -> None:
self._launcher = _SubprocessScriptLauncher(self.cluster_environment, self.num_processes, self.num_nodes)
self._rank_0_will_call_children_scripts = True

def _setup_model(self, model: torch.nn.Module) -> FullyShardedDataParallel:
"""Wraps the model into a
:class:`~torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel` module."""
# If model is already wrapped, we need to avoid sending the `auto_wrap_policy`
assert self.lightning_module is not None
if (
any(isinstance(mod, FullyShardedDataParallel) for mod in self.lightning_module.modules())
and "auto_wrap_policy" in self.kwargs
):
del self.kwargs["auto_wrap_policy"]

log.detail(f"setting up FSDP model with device id: {self.root_device.index}, kwargs: {self.kwargs}")
return FullyShardedDataParallel(
module=model,
process_group=self.process_group,
cpu_offload=self.cpu_offload,
backward_prefetch=self.backward_prefetch,
mixed_precision=self.mixed_precision_config,
device_id=self.root_device.index,
**self.kwargs,
)

def setup(self, trainer: "pl.Trainer") -> None:
assert self.accelerator is not None
self.accelerator.setup(trainer)
Expand All @@ -215,9 +241,20 @@ def setup(self, trainer: "pl.Trainer") -> None:
assert self.lightning_module is not None
self.lightning_module._device = self.root_device

assert isinstance(self.model, pl.LightningModule)
self.model = _LightningModuleWrapperBase(self.model)
if is_overridden("configure_sharded_model", self.lightning_module):
rank_zero_info(
"You have overridden `LightningModule.configure_sharded_model` hook. It will assume that all the layers"
" are already wrapped for sharding and won't wrap the entire model using `FullyShardedDataParallel`."
)
else:
self.model = self._setup_model(self.model)
self.barrier()

self.setup_optimizers(trainer)
optimizers_to_device(self.optimizers, self.root_device)

self.setup_precision_plugin()

def model_to_device(self) -> None:
Expand Down Expand Up @@ -273,6 +310,24 @@ def reduce(
tensor = sync_ddp_if_available(tensor, group, reduce_op=reduce_op)
return tensor

def training_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
# we don't need precision context since casting is done by FSDP
# read `mixed_precision` docstring here: https://pytorch.org/docs/stable/fsdp.html
assert self.model is not None
return self.model(*args, **kwargs)

def validation_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:
assert self.model is not None
return self.model(*args, **kwargs)

def test_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:
assert self.model is not None
return self.model(*args, **kwargs)

def predict_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
assert self.model is not None
return self.model(*args, **kwargs)

def _determine_device_ids(self) -> List[int]:
return [self.root_device.index]

Expand Down
101 changes: 82 additions & 19 deletions tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,22 @@
from pytorch_lightning.strategies import DDPFullyShardedNativeStrategy
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12
from pytorch_lightning.utilities.types import STEP_OUTPUT
from tests_pytorch.helpers.runif import RunIf

if _TORCH_GREATER_EQUAL_1_12:
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel, MixedPrecision
from torch.distributed.fsdp.wrap import wrap


def custom_auto_wrap_policy(
module,
recurse,
unwrapped_params: int,
min_num_params: int = int(1e8),
) -> bool:
return unwrapped_params >= 2


@RunIf(min_torch="1.12")
def test_invalid_on_cpu(tmpdir):
"""Test to ensure that we raise Misconfiguration for Native FSDP on CPU."""
Expand Down Expand Up @@ -78,38 +86,73 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
def configure_optimizers(self):
return torch.optim.SGD(self.layer.parameters(), lr=0.1)

def on_train_batch_end(self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int) -> None:
def on_train_batch_end(self, outputs, batch, batch_idx) -> None:
self._assert_layer_fsdp_instance()

def on_test_batch_end(
self, outputs: Optional[STEP_OUTPUT], batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
def on_test_batch_end(self, outputs, batch, batch_idx, dataloader_idx) -> None:
self._assert_layer_fsdp_instance()

def on_validation_batch_end(
self, outputs: Optional[STEP_OUTPUT], batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
def on_validation_batch_end(self, outputs, batch, batch_idx, dataloader_idx) -> None:
self._assert_layer_fsdp_instance()

def on_predict_batch_end(self, outputs: Optional[Any], batch: Any, batch_idx: int, dataloader_idx: int) -> None:
def on_predict_batch_end(self, outputs, batch, batch_idx, dataloader_idx) -> None:
self._assert_layer_fsdp_instance()

def _assert_layer_fsdp_instance(self) -> None:
assert isinstance(self.layer, FullyShardedDataParallel)
assert isinstance(self.trainer.strategy.precision_plugin, FullyShardedNativeNativeMixedPrecisionPlugin)
assert isinstance(self.layer.module[0], FullyShardedDataParallel)
assert isinstance(self.layer.module[2], FullyShardedDataParallel)
# root should not be resharding
assert self.layer.reshard_after_forward is False
# Assert that the nested layers are set reshard_after_forward to True
assert self.layer.module[0].reshard_after_forward is True
assert self.layer.module[2].reshard_after_forward is True

precision = torch.float16 if self.precision == 16 else torch.bfloat16
assert self.layer.mixed_precision.param_dtype == precision
assert self.layer.mixed_precision.reduce_dtype == precision
assert self.layer.mixed_precision.buffer_dtype == precision

for layer_num in [0, 2]:
assert isinstance(self.layer.module[layer_num], FullyShardedDataParallel)
# Assert that the nested layers are set reshard_after_forward to True
assert self.layer.module[layer_num].reshard_after_forward is True

assert self.layer[layer_num].mixed_precision.param_dtype == precision
assert self.layer[layer_num].mixed_precision.reduce_dtype == precision
assert self.layer[layer_num].mixed_precision.buffer_dtype == precision


class TestFSDPModelAutoWrapped(BoringModel):
def __init__(self):
super().__init__()
self.layer = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.ReLU(), torch.nn.Linear(32, 2))

def configure_optimizers(self):
return torch.optim.SGD(self.trainer.model.parameters(), lr=0.1)

def on_train_batch_end(self, outputs, batch, batch_idx) -> None:
self._assert_layer_fsdp_instance()

def on_test_batch_end(self, outputs, batch, batch_idx, dataloader_idx) -> None:
self._assert_layer_fsdp_instance()

def on_validation_batch_end(self, outputs, batch, batch_idx, dataloader_idx) -> None:
self._assert_layer_fsdp_instance()

def on_predict_batch_end(self, outputs, batch, batch_idx, dataloader_idx) -> None:
self._assert_layer_fsdp_instance()

def _assert_layer_fsdp_instance(self) -> None:
assert isinstance(self.layer, torch.nn.Sequential)
assert isinstance(self.trainer.strategy.precision_plugin, FullyShardedNativeNativeMixedPrecisionPlugin)

precision = torch.float16 if self.precision == 16 else torch.bfloat16
for layer_num in [0, 2]:
assert isinstance(self.layer[layer_num], FullyShardedDataParallel)
# Assert that the nested layers are set reshard_after_forward to True
assert self.layer[layer_num].reshard_after_forward

assert self.layer[layer_num].mixed_precision.param_dtype == precision
assert self.layer[layer_num].mixed_precision.reduce_dtype == precision
assert self.layer[layer_num].mixed_precision.buffer_dtype == precision


@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="1.12")
def test_fully_sharded_native_strategy_sync_batchnorm(tmpdir):
Expand Down Expand Up @@ -140,18 +183,32 @@ def test_fully_sharded_native_strategy_checkpoint(tmpdir, precision):


@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="1.12")
def test_fully_sharded_native_strategy_checkpoint_multi_gpus(tmpdir):
@pytest.mark.parametrize(
"model, strategy",
[
(TestFSDPModel(), "fsdp_native"),
(TestFSDPModelAutoWrapped(), DDPFullyShardedNativeStrategy),
],
)
def test_fully_sharded_native_strategy_checkpoint_multi_gpus(tmpdir, model, strategy):
"""Test to ensure that checkpoint is saved correctly when using multiple GPUs, and all stages can be run."""

model = TestFSDPModel()
ck = ModelCheckpoint(save_last=True)

if not isinstance(strategy, str):
strategy = strategy(auto_wrap_policy=custom_auto_wrap_policy)

trainer = Trainer(
default_root_dir=tmpdir,
accelerator="gpu",
devices=2,
strategy="fsdp_native",
strategy=strategy,
precision=16,
max_epochs=1,
limit_train_batches=2,
limit_val_batches=2,
limit_test_batches=2,
limit_predict_batches=2,
callbacks=[ck],
)
_run_multiple_stages(trainer, model)
Expand All @@ -164,14 +221,20 @@ def _run_multiple_stages(trainer, model, model_path: Optional[str] = None):

trainer.save_checkpoint(model_path, weights_only=True)

_assert_save_equality(trainer, model_path, cls=TestFSDPModel)
_assert_save_equality(trainer, model_path, cls=model.__class__)

# Test entry point
trainer.test(model) # model is wrapped, will not call `configure_sharded_model`

# provide model path, will create a new unwrapped model and load and then call configure_shared_model to wrap
# provide model path, will create a new unwrapped model and load and then call `configure_shared_model` to wrap
trainer.test(ckpt_path=model_path)

# Predict entry point
trainer.predict(model) # model is wrapped, will not call `configure_sharded_model`

# provide model path, will create a new unwrapped model and load and then call `configure_shared_model` to wrap
trainer.predict(ckpt_path=model_path)


def _assert_save_equality(trainer, ckpt_path, cls=TestFSDPModel):
# Use FullySharded to get the state dict for the sake of comparison
Expand Down

0 comments on commit 6d00f31

Please sign in to comment.