Skip to content

Commit

Permalink
Deprecate the FairScale integration (#16353)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored Jan 17, 2023
1 parent fce54a4 commit 0f4f809
Show file tree
Hide file tree
Showing 20 changed files with 251 additions and 387 deletions.
172 changes: 8 additions & 164 deletions docs/source-pytorch/advanced/model_parallel.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ Train 1 trillion+ parameter models

When training large models, fitting larger batch sizes, or trying to increase throughput using multi-GPU compute, Lightning provides advanced optimized distributed training strategies to support these cases and offer substantial improvements in memory usage.

In many cases these strategies are some flavour of model parallelism however we only introduce concepts at a high level to get you started. Refer to the `FairScale documentation <https://fairscale.readthedocs.io/en/latest/deep_dive/oss_sdp_fsdp.html>`_ for more information about model parallelism.

Note that some of the extreme memory saving configurations will affect the speed of training. This Speed/Memory trade-off in most cases can be adjusted.

Some of these memory-efficient strategies rely on offloading onto other forms of memory, such as CPU RAM or NVMe. This means you can even see memory benefits on a **single GPU**, using a strategy such as :ref:`deepspeed-zero-stage-3-offload`.
Expand Down Expand Up @@ -40,7 +38,7 @@ Overall:

* When **fine-tuning** a model, use advanced memory efficient strategies such as :ref:`deepspeed-zero-stage-3` or :ref:`deepspeed-zero-stage-3-offload`, allowing you to fine-tune larger models if you are limited on compute
* When **pre-training** a model, use simpler optimizations such :ref:`sharded-training`, :ref:`deepspeed-zero-stage-2` or :ref:`fully-sharded-training`, scaling the number of GPUs to reach larger parameter sizes
* For both fine-tuning and pre-training, use :ref:`deepspeed-activation-checkpointing` or :ref:`fairscale-activation-checkpointing` as the throughput degradation is not significant
* For both fine-tuning and pre-training, use :ref:`deepspeed-activation-checkpointing` as the throughput degradation is not significant

For example when using 128 GPUs, you can **pre-train** large 10 to 20 Billion parameter models using :ref:`deepspeed-zero-stage-2` without having to take a performance hit with more advanced optimized multi-gpu strategy.

Expand Down Expand Up @@ -153,11 +151,10 @@ Here's an example of changing the placement policy to "cpu".
.. _sharded-training:

**************************
FairScale Sharded Training
**************************
****************
Sharded Training
****************

Lightning integration of optimizer sharded training provided by `FairScale <https://github.com/facebookresearch/fairscale>`_.
The technique can be found within `DeepSpeed ZeRO <https://arxiv.org/abs/1910.02054>`_ and
`ZeRO-2 <https://www.microsoft.com/en-us/research/blog/zero-2-deepspeed-shattering-barriers-of-deep-learning-speed-scale/>`_,
however the implementation is built from the ground up to be PyTorch compatible and standalone.
Expand All @@ -171,178 +168,25 @@ these benefits in multi-GPU setups are almost free and throughput scales well wi

It is highly recommended to use Sharded Training in multi-GPU environments where memory is limited, or where training larger models are beneficial (500M+ parameter models).
A technical note: as batch size scales, storing activations for the backwards pass becomes the bottleneck in training. As a result, sharding optimizer state and gradients becomes less impactful.
Use :ref:`fairscale-activation-checkpointing` to see even more benefit at the cost of some throughput.

To use Sharded Training, you need to first install FairScale using the command below.

.. code-block:: bash
pip install fairscale

.. code-block:: python
# train using Sharded DDP
trainer = Trainer(strategy="ddp_sharded")
Sharded Training can work across all DDP variants by adding the additional ``--strategy ddp_sharded`` flag via command line using a PyTorch Lightning script.

Internally we re-initialize your optimizers and shard them across your machines and processes. We handle all communication using PyTorch distributed, so no code changes are required.

----

.. _fully-sharded-training:

FairScale Fully Sharded Training
================================

.. warning::
FairScale Fully Sharded Training is in BETA and the API is subject to change. Please create an `issue <https://github.com/Lightning-AI/lightning/issues>`_ if you run into any problems.

`Fully Sharded <https://fairscale.readthedocs.io/en/latest/api/nn/fsdp.html>`_ shards optimizer state, gradients, and parameters across data parallel workers. This allows you to fit much larger models onto multiple GPUs into memory.

Fully Sharded Training alleviates the need to worry about balancing layers onto specific devices using some form of pipe parallelism, and optimizes for distributed communication with minimal effort.

Shard Parameters to Reach 10+ Billion Parameters
------------------------------------------------

To reach larger parameter sizes and to be memory efficient, we have to shard parameters. There are various ways to enable this.

.. note::
Currently Fully Sharded Training relies on the user to wrap the model with Fully Sharded within the ``LightningModule``.
This means you must create a single model that is treated as a ``torch.nn.Module`` within the ``LightningModule``.
This is a limitation of Fully Sharded Training that will be resolved in the future.

Enabling Module Sharding for Maximum Memory Efficiency
------------------------------------------------------

Auto Wrapping
^^^^^^^^^^^^^

Model layers should be wrapped in FSDP in a nested way to save peak memory and enable communication and computation overlapping. The
simplest way to do it is auto wrapping, which can serve as a drop-in replacement for DDP without changing the rest of the code. You don't
have to ``wrap`` layers manually as in the case of manual wrapping.

.. note::
While initializing the optimizers inside ``configure_optimizers`` hook, make sure to use ``self.trainer.model.parameters()``, else
PyTorch will raise an error. This is required because when you use auto-wrap, the model layers are sharded and your
``lightning_module.parameters()`` will return a generator with no params. This inconvenience will be addressed in the future.

.. code-block:: python
class MyModel(BoringModel):
def configure_optimizers(self):
return torch.optim.AdamW(self.trainer.model.parameters(), lr=1e-2)
model = MyModel()
trainer = Trainer(accelerator="gpu", devices=4, strategy="fsdp", precision=16)
trainer.fit(model)
Manual Wrapping
^^^^^^^^^^^^^^^

Manual wrapping can be useful to explore complex sharding strategies by applying ``wrap`` selectively to some parts of the model. To activate
parameter sharding with manual wrapping, you can wrap your model using the ``wrap`` function. Internally in Lightning, we enable a context manager around the ``configure_sharded_model`` function to make sure the ``wrap`` parameters are passed correctly.

When not using Fully Sharded Training these wrap functions are a no-op. That means once the changes have been made, there is no need to remove the changes for other strategies.

``auto_wrap`` recursively wraps :class:`~torch.nn.Module` within the ``LightningModule`` with nested Fully Sharded Wrappers,
signalling that we'd like to partition these modules across data parallel devices, discarding the full weights when not required (information :class:`here <fairscale.nn.fsdp>`).

``auto_wrap`` can have varying levels of success based on the complexity of your model. **Auto Wrap does not support models with shared parameters**.

``wrap`` simply wraps the module with a Fully Sharded Parallel class with the correct parameters from the Lightning context manager.

Here's an example using both ``wrap`` and ``auto_wrap`` to create your model:

.. code-block:: python
import torch
import torch.nn as nn
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from fairscale.nn import checkpoint_wrapper, auto_wrap, wrap
class MyModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.linear_layer = nn.Linear(32, 32)
self.block = nn.Sequential(nn.Linear(32, 32), nn.ReLU())
self.final_block = nn.Sequential(nn.Linear(32, 32), nn.ReLU())
def configure_sharded_model(self):
# modules are sharded across processes
# as soon as they are wrapped with `wrap` or `auto_wrap`.
# During the forward/backward passes, weights get synced across processes
# and de-allocated once computation is complete, saving memory.
# Wraps the layer in a Fully Sharded Wrapper automatically
linear_layer = wrap(self.linear_layer)
# Wraps the module recursively
# based on a minimum number of parameters (default 100M parameters)
block = auto_wrap(self.block)
# For best memory efficiency,
# add FairScale activation checkpointing
final_block = auto_wrap(checkpoint_wrapper(self.final_block))
self.model = nn.Sequential(linear_layer, nn.ReLU(), block, final_block)
def configure_optimizers(self):
return torch.optim.AdamW(self.model.parameters(), lr=1e-2)
model = MyModel()
trainer = Trainer(accelerator="gpu", devices=4, strategy="fsdp", precision=16)
trainer.fit(model)
trainer.test()
trainer.predict()
----

.. _fairscale-activation-checkpointing:

Activation Checkpointing
------------------------

Activation checkpointing frees activations from memory as soon as they are not needed during the forward pass. They are then re-computed for the backwards pass as needed. Activation checkpointing is very useful when you have intermediate layers that produce large activations.

FairScale's checkpointing wrapper also handles batch norm layers correctly, unlike the PyTorch implementation, ensuring stats are tracked correctly due to the multiple forward passes.

This saves memory when training larger models, however it requires wrapping modules you'd like to use activation checkpointing on. See :class:`here <fairscale.nn.checkpoint.checkpoint_wrapper>` for more information.

.. warning::

Do not wrap the entire model with activation checkpointing. This is not the intended use of activation checkpointing, and will lead to failures as seen in `this discussion <https://github.com/Lightning-AI/lightning/discussions/9144>`_.

.. code-block:: python
from pytorch_lightning import Trainer
from fairscale.nn import checkpoint_wrapper
class MyModel(pl.LightningModule):
def __init__(self):
super().__init__()
# Wrap layers using checkpoint_wrapper
self.block_1 = checkpoint_wrapper(nn.Sequential(nn.Linear(32, 32), nn.ReLU()))
self.block_2 = nn.Linear(32, 2)
----

.. _fully-sharded-native-training:

******************************
PyTorch Fully Sharded Training
******************************
**********************
Fully Sharded Training
**********************

PyTorch has it's own version of `FSDP <https://pytorch.org/docs/stable/fsdp.html>`_ which is upstreamed from their `fairscale <https://fairscale.readthedocs.io/en/latest/api/nn/fsdp.html>`__ project.
It was introduced in their `v1.11.0 release <https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/>`_ but it is recommended to use it with PyTorch v1.12 or more and that's what
Lightning supports. The API is pretty similar to that of FairScale.
Lightning supports.


Auto Wrapping
Expand Down
11 changes: 1 addition & 10 deletions docs/source-pytorch/extensions/strategy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -80,16 +80,7 @@ The below table lists all relevant strategies available in Lightning with their
- Colossal-AI provides a collection of parallel components for you. It aims to support you to write your distributed deep learning models just like how you write your model on your laptop. `Learn more. <https://www.colossalai.org/>`__
* - fsdp_native
- :class:`~pytorch_lightning.strategies.DDPFullyShardedNativeStrategy`
- Strategy for Fully Sharded Data Parallel provided by PyTorch. :ref:`Learn more. <advanced/model_parallel:PyTorch Fully Sharded Training>`
* - fsdp
- :class:`~pytorch_lightning.strategies.DDPFullyShardedStrategy`
- Strategy for Fully Sharded Data Parallel provided by FairScale. :ref:`Learn more. <advanced/model_parallel:FairScale Fully Sharded Training>`
* - ddp_sharded
- :class:`~pytorch_lightning.strategies.DDPShardedStrategy`
- Optimizer and gradient sharded training provided by FairScale. :ref:`Learn more. <advanced/model_parallel:FairScale Sharded Training>`
* - ddp_sharded_spawn
- :class:`~pytorch_lightning.strategies.DDPSpawnShardedStrategy`
- Optimizer sharded training provided by FairScale. :ref:`Learn more. <advanced/model_parallel:FairScale Sharded Training>`
- Strategy for Fully Sharded Data Parallel. :ref:`Learn more. <advanced/model_parallel:Fully Sharded Training>`
* - ddp_spawn
- :class:`~pytorch_lightning.strategies.DDPSpawnStrategy`
- Spawns processes using the :func:`torch.multiprocessing.spawn` method and joins processes after training finishes. :ref:`Learn more. <accelerators/gpu_intermediate:Distributed Data Parallel Spawn>`
Expand Down
9 changes: 8 additions & 1 deletion src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,19 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Deprecated the `Trainer.amp_backend` property
* Deprecated the `Trainer(amp_level=...)` argument
* Deprecated the `pytorch_lightning.plugins.ApexMixedPrecisionPlugin` class
* Deprecates the `pytorch_lightning.utilities.enum.sAMPType` enum
* Deprecates the `pytorch_lightning.utilities.enums.AMPType` enum
* Deprecates the `DeepSpeedPrecisionPlugin(amp_type=..., amp_level=...)` arguments
- `horovod` deprecation ([#16141](https://github.com/PyTorchLightning/pytorch-lightning/pull/16141))
* Deprecated `Trainer(strategy="horovod")`
* Deprecated the `HorovodStrategy` class
- Deprecated `pytorch_lightning.lite.LightningLite` in favor of `lightning.fabric.Fabric` ([#16314](https://github.com/Lightning-AI/lightning/pull/16314))
- `FairScale` deprecation (in favor of PyTorch's FSDP implementation) ([#16353](https://github.com/PyTorchLightning/pytorch-lightning/pull/16353))
* Deprecated the `pytorch_lightning.overrides.fairscale.LightningShardedDataParallel` class
* Deprecated the `pytorch_lightning.plugins.precision.fully_sharded_native_amp.FullyShardedNativeMixedPrecisionPlugin` class
* Deprecated the `pytorch_lightning.plugins.precision.sharded_native_amp.ShardedNativeMixedPrecisionPlugin` class
* Deprecated the `pytorch_lightning.strategies.fully_sharded.DDPFullyShardedStrategy` class
* Deprecated the `pytorch_lightning.strategies.sharded.DDPShardedStrategy` class
* Deprecated the `pytorch_lightning.strategies.sharded_spawn.DDPSpawnShardedStrategy` class


### Removed
Expand Down
7 changes: 7 additions & 0 deletions src/pytorch_lightning/overrides/fairscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,13 @@ def __init__(
forward_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None,
pl_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None,
) -> None:
rank_zero_deprecation(
"PyTorch Lightning's sharded implementation using FairScale has been deprecated in v1.9.0 and will be"
" removed in v2.0.0. You can try using the `Trainer(strategy='fsdp_native')` instead."
" The difference is that native FSDP uses PyTorch's implementation and the current strategy uses"
" FairScale's implementation (which was upstreamed to PyTorch). After removal, `strategy='fsdp'` will use"
" the native version by default."
)
self._validate_init_arguments(pl_module, forward_module)
super().__init__(forward_module=(pl_module or forward_module))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,22 @@

from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation


class FullyShardedNativeMixedPrecisionPlugin(ShardedNativeMixedPrecisionPlugin):
"""Native AMP for Fully Sharded Training."""

def __init__(self, *args: Any, **kwargs: Any) -> None:
rank_zero_deprecation(
"PyTorch Lightning's sharded implementation using FairScale has been deprecated in v1.9.0 and will be"
" removed in v2.0.0. You can try using the `Trainer(strategy='fsdp_native')` instead."
" The difference is that native FSDP uses PyTorch's implementation and the current strategy uses"
" FairScale's implementation (which was upstreamed to PyTorch). After removal, `strategy='fsdp'` will use"
" the native version by default."
)
super().__init__(*args, **kwargs)

def clip_grad_by_norm(self, *_: Any, **__: Any) -> None:
# see https://fairscale.readthedocs.io/en/latest/api/nn/fsdp.html
# section `Gradient Clipping`, using `torch.nn.utils.clip_grad_norm_` is incorrect
Expand Down
8 changes: 8 additions & 0 deletions src/pytorch_lightning/plugins/precision/sharded_native_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE
from pytorch_lightning.plugins.precision.native_amp import MixedPrecisionPlugin
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation

if _FAIRSCALE_AVAILABLE:
from fairscale.optim import OSS
Expand All @@ -32,6 +33,13 @@ class ShardedNativeMixedPrecisionPlugin(MixedPrecisionPlugin):
def __init__(
self, precision: Literal["16", 16, "bf16"], device: str, scaler: Optional[ShardedGradScaler] = None
) -> None:
rank_zero_deprecation(
"PyTorch Lightning's sharded implementation using FairScale has been deprecated in v1.9.0 and will be"
" removed in v2.0.0. You can try using the `Trainer(strategy='fsdp_native')` instead."
" The difference is that native FSDP uses PyTorch's implementation and the current strategy uses"
" FairScale's implementation (which was upstreamed to PyTorch). After removal, `strategy='fsdp'` will use"
" the native version by default."
)
if not _FAIRSCALE_AVAILABLE:
raise MisconfigurationException(
"You have asked for sharded AMP but you have not installed it."
Expand Down
9 changes: 8 additions & 1 deletion src/pytorch_lightning/strategies/fully_sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation
from pytorch_lightning.utilities.types import STEP_OUTPUT

if _FAIRSCALE_AVAILABLE:
Expand Down Expand Up @@ -117,7 +118,13 @@ def __init__(
If ``False``, this will default to ``compute_device``.
(Default: True).
"""

rank_zero_deprecation(
"PyTorch Lightning's sharded implementation using FairScale has been deprecated in v1.9.0 and will be"
" removed in v2.0.0. You can try using the `Trainer(strategy='fsdp_native')` instead."
" The difference is that native FSDP uses PyTorch's implementation and the current strategy uses"
" FairScale's implementation (which was upstreamed to PyTorch). After removal, `strategy='fsdp'` will use"
" the native version by default."
)
super().__init__(
accelerator=accelerator,
parallel_devices=parallel_devices,
Expand Down
Loading

0 comments on commit 0f4f809

Please sign in to comment.