Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add auto wrapping for DDPFullyShardedNativeStrategy #14252

Merged
merged 20 commits into from
Aug 26, 2022
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 20 additions & 6 deletions docs/source-pytorch/advanced/model_parallel.rst
Original file line number Diff line number Diff line change
Expand Up @@ -212,14 +212,28 @@ PyTorch Fully Sharded Training
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

PyTorch has it's own version of `FSDP <https://pytorch.org/docs/stable/fsdp.html>`_ which is upstreamed from their `fairscale <https://fairscale.readthedocs.io/en/latest/api/nn/fsdp.html>`__ project.
It was introduced in their `v1.11.0 release <https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/>`_. The API is pretty similar to that of FairScale.
It was introduced in their `v1.11.0 release <https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/>`_ but it is recommended to use it with PyTorch v1.12 or more and that's what
Lightning supports. The API is pretty similar to that of FairScale.

.. note::
Currently Fully Sharded Training relies on the user to wrap the model with Fully Sharded within the ``LightningModule``.
This means you must create a single model that is treated as a ``torch.nn.Module`` within the ``LightningModule``.
This is a limitation of Fully Sharded Training that will be resolved in the future.

To activate parameter sharding, you must wrap your model using the``wrap`` function. Internally in Lightning, we enable a context manager around the ``configure_sharded_model`` function to make sure the ``wrap`` parameters are passed correctly.
Auto Wrapping
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
"""""""""""""
Model layers should be wrapped in FSDP in a nested way to save peak memory and enable communication and computation overlapping. The
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
simplest way to do it is auto wrapping, which can serve as a drop-in replacement for DDP without changing the rest of the code. You don't
have to ``wrap`` layers manually as in the case of manual wrapping.

.. code-block:: python

model = BoringModel()
trainer = Trainer(accelerator="gpu", devices=4, strategy="fsdp_native", precision=16)
trainer.fit(model)


Manual Wrapping
"""""""""""""""

Manual wrapping can be useful to explore complex sharding strategies by applying ``wrap`` selectively to some parts of the model. To activate
parameter sharding with manual wrapping, you can wrap your model using the ``wrap`` function. Internally in Lightning, we enable a context manager around the ``configure_sharded_model`` function to make sure the ``wrap`` parameters are passed correctly.

When not using Fully Sharded these wrap functions are a no-op. This means once the changes have been made, there is no need to remove the changes for other strategies.

Expand Down
3 changes: 3 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added prefix to log message in `seed_everything` with rank info ([#13290](https://github.com/Lightning-AI/lightning/issues/13290))


- Added support for auto wrapping for `DDPFullyShardedNativeStrategy` ([#14252](https://github.com/Lightning-AI/lightning/issues/14252))


- Added support for passing extra init-parameters to the `LightningDataModule.from_datasets` ([#14185](https://github.com/Lightning-AI/lightning/issues/14185))


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,13 @@
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
else:
MixedPrecision = None # type: ignore[misc,assignment]
ShardedGradScaler = None # type: ignore[misc,assignment]


class FullyShardedNativeNativeMixedPrecisionPlugin(NativeMixedPrecisionPlugin):
"""Native AMP for Fully Sharded Native Training."""

def __init__(
self, precision: Union[str, int], device: str, scaler: Optional[torch.cuda.amp.GradScaler] = None
) -> None:
def __init__(self, precision: Union[str, int], device: str, scaler: Optional[ShardedGradScaler] = None) -> None:
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
if not _TORCH_GREATER_EQUAL_1_12:
raise MisconfigurationException(
"`FullyShardedNativeNativeMixedPrecisionPlugin` is supported from PyTorch v1.12.0 onwards."
Expand Down
58 changes: 56 additions & 2 deletions src/pytorch_lightning/strategies/fully_sharded_native.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from torch.distributed.distributed_c10d import _get_default_group, ProcessGroup

import pytorch_lightning as pl
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.precision import PrecisionPlugin
Expand All @@ -38,9 +39,11 @@
from pytorch_lightning.utilities.distributed import init_dist_connection, ReduceOp, sync_ddp_if_available
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.optimizer import optimizers_to_device
from pytorch_lightning.utilities.rank_zero import rank_zero_info
from pytorch_lightning.utilities.seed import reset_seed
from pytorch_lightning.utilities.types import STEP_OUTPUT

if _TORCH_GREATER_EQUAL_1_12:
from torch.distributed.fsdp.fully_sharded_data_parallel import (
Expand All @@ -51,6 +54,7 @@
)
from torch.distributed.fsdp.wrap import enable_wrap
else:
FullyShardedDataParallel = None # type: ignore[misc,assignment]
MixedPrecision = None # type: ignore[misc,assignment]
BackwardPrefetch = None # type: ignore[misc,assignment]
CPUOffload = None # type: ignore[misc,assignment]
Expand Down Expand Up @@ -201,6 +205,26 @@ def _configure_launcher(self) -> None:
self._launcher = _SubprocessScriptLauncher(self.cluster_environment, self.num_processes, self.num_nodes)
self._rank_0_will_call_children_scripts = True

def _setup_model(self, model: torch.nn.Module) -> FullyShardedDataParallel:
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
"""Wraps the model into a
:class:`~torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel` module."""
# If model is already wrapped, we need to avoid sending the `auto_wrap_policy`
assert self.lightning_module is not None
if any(isinstance(mod, FullyShardedDataParallel) for _, mod in self.lightning_module.named_modules()):
if "auto_wrap_policy" in self.kwargs:
self.kwargs.pop("auto_wrap_policy")
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved

log.detail(f"setting up FSDP model with device id: {self.root_device.index}, kwargs: {self.kwargs}")
return FullyShardedDataParallel(
module=model,
process_group=self.process_group,
cpu_offload=self.cpu_offload,
backward_prefetch=self.backward_prefetch,
mixed_precision=self.mixed_precision_config,
device_id=self.root_device.index,
**self.kwargs,
)

def setup(self, trainer: "pl.Trainer") -> None:
assert self.accelerator is not None
self.accelerator.setup(trainer)
Expand All @@ -215,9 +239,21 @@ def setup(self, trainer: "pl.Trainer") -> None:
assert self.lightning_module is not None
self.lightning_module._device = self.root_device

assert isinstance(self.model, pl.LightningModule)
self.model = _LightningModuleWrapperBase(self.model)
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
if is_overridden("configure_sharded_model", self.lightning_module):
carmocca marked this conversation as resolved.
Show resolved Hide resolved
rank_zero_info(
"You have overridden `LightningModule.configure_sharded_model` hook. It will assume that all the layers"
" are already wrapped for sharding and won't wrap the entire model using `FullyShardedDataParallel`."
)
else:
self.model = self._setup_model(self.model)
justusschock marked this conversation as resolved.
Show resolved Hide resolved
self.barrier()
self.setup_optimizers(trainer)
optimizers_to_device(self.optimizers, self.root_device)

if trainer.state.fn == TrainerFn.FITTING:
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
self.setup_optimizers(trainer)
optimizers_to_device(self.optimizers, self.root_device)

self.setup_precision_plugin()

def model_to_device(self) -> None:
Expand Down Expand Up @@ -273,6 +309,24 @@ def reduce(
tensor = sync_ddp_if_available(tensor, group, reduce_op=reduce_op)
return tensor

def training_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
# we don't need precision context since casting is done by FSDP
# read `mixed_precision``` docstring here: https://pytorch.org/docs/stable/fsdp.html
carmocca marked this conversation as resolved.
Show resolved Hide resolved
assert self.model is not None
return self.model(*args, **kwargs)

def validation_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:
assert self.model is not None
return self.model(*args, **kwargs)

def test_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:
assert self.model is not None
return self.model(*args, **kwargs)

def predict_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
assert self.model is not None
return self.model(*args, **kwargs)

def _determine_device_ids(self) -> List[int]:
return [self.root_device.index]

Expand Down
99 changes: 81 additions & 18 deletions tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,22 @@
from pytorch_lightning.strategies import DDPFullyShardedNativeStrategy
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12
from pytorch_lightning.utilities.types import STEP_OUTPUT
from tests_pytorch.helpers.runif import RunIf

if _TORCH_GREATER_EQUAL_1_12:
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel, MixedPrecision
from torch.distributed.fsdp.wrap import wrap


def custom_auto_wrap_policy(
module,
recurse,
unwrapped_params: int,
min_num_params: int = int(1e8),
) -> bool:
return unwrapped_params >= 2


@RunIf(min_torch="1.12")
def test_invalid_on_cpu(tmpdir):
"""Test to ensure that we raise Misconfiguration for Native FSDP on CPU."""
Expand Down Expand Up @@ -78,38 +86,73 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
def configure_optimizers(self):
return torch.optim.SGD(self.layer.parameters(), lr=0.1)

def on_train_batch_end(self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int) -> None:
def on_train_batch_end(self, outputs, batch, batch_idx) -> None:
self._assert_layer_fsdp_instance()

def on_test_batch_end(
self, outputs: Optional[STEP_OUTPUT], batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
def on_test_batch_end(self, outputs, batch, batch_idx, dataloader_idx) -> None:
self._assert_layer_fsdp_instance()

def on_validation_batch_end(
self, outputs: Optional[STEP_OUTPUT], batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
def on_validation_batch_end(self, outputs, batch, batch_idx, dataloader_idx) -> None:
self._assert_layer_fsdp_instance()

def on_predict_batch_end(self, outputs: Optional[Any], batch: Any, batch_idx: int, dataloader_idx: int) -> None:
def on_predict_batch_end(self, outputs, batch, batch_idx, dataloader_idx) -> None:
self._assert_layer_fsdp_instance()

def _assert_layer_fsdp_instance(self) -> None:
assert isinstance(self.layer, FullyShardedDataParallel)
assert isinstance(self.trainer.strategy.precision_plugin, FullyShardedNativeNativeMixedPrecisionPlugin)
assert isinstance(self.layer.module[0], FullyShardedDataParallel)
assert isinstance(self.layer.module[2], FullyShardedDataParallel)
# root should not be resharding
assert self.layer.reshard_after_forward is False
# Assert that the nested layers are set reshard_after_forward to True
assert self.layer.module[0].reshard_after_forward is True
assert self.layer.module[2].reshard_after_forward is True

precision = torch.float16 if self.precision == 16 else torch.bfloat16
assert self.layer.mixed_precision.param_dtype == precision
assert self.layer.mixed_precision.reduce_dtype == precision
assert self.layer.mixed_precision.buffer_dtype == precision

for layer_num in [0, 2]:
assert isinstance(self.layer.module[layer_num], FullyShardedDataParallel)
# Assert that the nested layers are set reshard_after_forward to True
assert self.layer.module[layer_num].reshard_after_forward is True

assert self.layer[layer_num].mixed_precision.param_dtype == precision
assert self.layer[layer_num].mixed_precision.reduce_dtype == precision
assert self.layer[layer_num].mixed_precision.buffer_dtype == precision


class TestFSDPModelAutoWrapped(BoringModel):
def __init__(self):
super().__init__()
self.layer = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.ReLU(), torch.nn.Linear(32, 2))

def configure_optimizers(self):
return torch.optim.SGD(self.trainer.model.parameters(), lr=0.1)

def on_train_batch_end(self, outputs, batch, batch_idx) -> None:
self._assert_layer_fsdp_instance()

def on_test_batch_end(self, outputs, batch, batch_idx, dataloader_idx) -> None:
self._assert_layer_fsdp_instance()

def on_validation_batch_end(self, outputs, batch, batch_idx, dataloader_idx) -> None:
self._assert_layer_fsdp_instance()

def on_predict_batch_end(self, outputs, batch, batch_idx, dataloader_idx) -> None:
self._assert_layer_fsdp_instance()

def _assert_layer_fsdp_instance(self) -> None:
assert isinstance(self.layer, torch.nn.Sequential)
assert isinstance(self.trainer.strategy.precision_plugin, FullyShardedNativeNativeMixedPrecisionPlugin)

precision = torch.float16 if self.precision == 16 else torch.bfloat16
for layer_num in [0, 2]:
assert isinstance(self.layer[layer_num], FullyShardedDataParallel)
# Assert that the nested layers are set reshard_after_forward to True
assert self.layer[layer_num].reshard_after_forward

assert self.layer[layer_num].mixed_precision.param_dtype == precision
assert self.layer[layer_num].mixed_precision.reduce_dtype == precision
assert self.layer[layer_num].mixed_precision.buffer_dtype == precision


@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="1.12")
def test_fully_sharded_native_strategy_sync_batchnorm(tmpdir):
Expand Down Expand Up @@ -140,18 +183,32 @@ def test_fully_sharded_native_strategy_checkpoint(tmpdir, precision):


@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="1.12")
def test_fully_sharded_native_strategy_checkpoint_multi_gpus(tmpdir):
@pytest.mark.parametrize(
"model, strategy",
[
(TestFSDPModel(), "fsdp_native"),
(TestFSDPModelAutoWrapped(), DDPFullyShardedNativeStrategy),
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
],
)
def test_fully_sharded_native_strategy_checkpoint_multi_gpus(tmpdir, model, strategy):
"""Test to ensure that checkpoint is saved correctly when using multiple GPUs, and all stages can be run."""

model = TestFSDPModel()
ck = ModelCheckpoint(save_last=True)

if not isinstance(strategy, str):
strategy = strategy(auto_wrap_policy=custom_auto_wrap_policy)

trainer = Trainer(
default_root_dir=tmpdir,
accelerator="gpu",
devices=2,
strategy="fsdp_native",
devices=1,
strategy=strategy,
precision=16,
max_epochs=1,
limit_train_batches=2,
limit_val_batches=2,
limit_test_batches=2,
limit_predict_batches=2,
callbacks=[ck],
)
_run_multiple_stages(trainer, model)
Expand All @@ -172,6 +229,12 @@ def _run_multiple_stages(trainer, model, model_path: Optional[str] = None):
# provide model path, will create a new unwrapped model and load and then call configure_shared_model to wrap
trainer.test(ckpt_path=model_path)

# Predict entry point
trainer.predict(model) # model is wrapped, will not call `configure_sharded_model`

# provide model path, will create a new unwrapped model and load and then call configure_shared_model to wrap
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
trainer.predict(ckpt_path=model_path)


def _assert_save_equality(trainer, ckpt_path, cls=TestFSDPModel):
# Use FullySharded to get the state dict for the sake of comparison
Expand Down