From 8c6119fbcedccbc17300df1680f41ac30b4b1c79 Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Tue, 6 Sep 2022 00:37:26 +0530 Subject: [PATCH] Add auto wrapping support for `DDPFullyShardedStrategy` (#14383) --- .../advanced/model_parallel.rst | 105 +++++++++++------ docs/source-pytorch/extensions/strategy.rst | 4 +- src/pytorch_lightning/CHANGELOG.md | 3 + .../callbacks/stochastic_weight_avg.py | 7 +- .../strategies/fully_sharded.py | 97 ++++++++++++---- src/pytorch_lightning/strategies/strategy.py | 2 +- .../callbacks/test_stochastic_weight_avg.py | 1 + ..._ddp_fully_sharded_with_full_state_dict.py | 106 +++++++++++++++--- 8 files changed, 248 insertions(+), 77 deletions(-) diff --git a/docs/source-pytorch/advanced/model_parallel.rst b/docs/source-pytorch/advanced/model_parallel.rst index 50ae2cd2827d0..757b7dffa4580 100644 --- a/docs/source-pytorch/advanced/model_parallel.rst +++ b/docs/source-pytorch/advanced/model_parallel.rst @@ -1,7 +1,8 @@ .. _model-parallel: +################################## Train 1 trillion+ parameter models -================================== +################################## When training large models, fitting larger batch sizes, or trying to increase throughput using multi-GPU compute, Lightning provides advanced optimized distributed training strategies to support these cases and offer substantial improvements in memory usage. @@ -19,8 +20,9 @@ Check out this amazing video explaining model parallelism and how it works behin allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture" allowfullscreen> +********************************************* Choosing an Advanced Distributed GPU Strategy -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +********************************************* If you would like to stick with PyTorch DDP, see :ref:`ddp-optimizations`. @@ -29,7 +31,7 @@ Unlike :class:`~torch.nn.parallel.DistributedDataParallel` (DDP) where the maxim There are many considerations when choosing a strategy as described below. In addition, check out the visualization of various strategy benchmarks using `minGPT `__ `here `__. Pre-training vs Fine-tuning -""""""""""""""""""""""""""" +=========================== When fine-tuning, we often use a magnitude less data compared to pre-training a model. This is important when choosing a distributed strategy as usually for pre-training, **we are compute-bound**. This means we cannot sacrifice throughput as much as if we were fine-tuning, because in fine-tuning the data requirement is smaller. @@ -45,7 +47,7 @@ For example when using 128 GPUs, you can **pre-train** large 10 to 20 Billion pa But for **fine-tuning** a model, you can reach 10 to 20 Billion parameter models using :ref:`deepspeed-zero-stage-3-offload` on a **single GPU**. This does come with a significant throughput hit, which needs to be weighed accordingly. When Shouldn't I use an Optimized Distributed Strategy? -""""""""""""""""""""""""""""""""""""""""""""""""""""""" +======================================================= Sharding techniques help when model sizes are fairly large; roughly 500M+ parameters is where we've seen benefits. However, in the following cases, we recommend sticking to ordinary distributed strategies * When your model is small (ResNet50 of around 80M Parameters), unless you are using unusually large batch sizes or inputs. @@ -55,8 +57,10 @@ Sharding techniques help when model sizes are fairly large; roughly 500M+ parame .. _sharded-training: -Sharded Training -^^^^^^^^^^^^^^^^ +************************** +FairScale Sharded Training +************************** + Lightning integration of optimizer sharded training provided by `FairScale `_. The technique can be found within `DeepSpeed ZeRO `_ and `ZeRO-2 `_, @@ -94,7 +98,7 @@ Internally we re-initialize your optimizers and shard them across your machines .. _fully-sharded-training: FairScale Fully Sharded Training -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +================================ .. warning:: FairScale Fully Sharded Training is in BETA and the API is subject to change. Please create an `issue `_ if you run into any problems. @@ -104,7 +108,7 @@ FairScale Fully Sharded Training Fully Sharded Training alleviates the need to worry about balancing layers onto specific devices using some form of pipe parallelism, and optimizes for distributed communication with minimal effort. Shard Parameters to Reach 10+ Billion Parameters -"""""""""""""""""""""""""""""""""""""""""""""""" +------------------------------------------------ To reach larger parameter sizes and to be memory efficient, we have to shard parameters. There are various ways to enable this. @@ -114,9 +118,37 @@ To reach larger parameter sizes and to be memory efficient, we have to shard par This is a limitation of Fully Sharded Training that will be resolved in the future. Enabling Module Sharding for Maximum Memory Efficiency -"""""""""""""""""""""""""""""""""""""""""""""""""""""" +------------------------------------------------------ + +Auto Wrapping +^^^^^^^^^^^^^ + +Model layers should be wrapped in FSDP in a nested way to save peak memory and enable communication and computation overlapping. The +simplest way to do it is auto wrapping, which can serve as a drop-in replacement for DDP without changing the rest of the code. You don't +have to ``wrap`` layers manually as in the case of manual wrapping. + +.. note:: + While initializing the optimizers inside ``configure_optimizers`` hook, make sure to use ``self.trainer.model.parameters()``, else + PyTorch will raise an error. This is required because when you use auto-wrap, the model layers are sharded and your + ``lightning_module.parameters()`` will return a generator with no params. This inconvenience will be addressed in the future. + +.. code-block:: python + + class MyModel(BoringModel): + def configure_optimizers(self): + return torch.optim.AdamW(self.trainer.model.parameters(), lr=1e-2) -To activate parameter sharding, you must wrap your model using the ``wrap`` or ``auto_wrap`` functions. Internally in Lightning, we enable a context manager around the ``configure_sharded_model`` function to make sure the ``wrap`` and ``auto_wrap`` parameters are passed correctly. + + model = MyModel() + trainer = Trainer(accelerator="gpu", devices=4, strategy="fsdp", precision=16) + trainer.fit(model) + + +Manual Wrapping +^^^^^^^^^^^^^^^ + +Manual wrapping can be useful to explore complex sharding strategies by applying ``wrap`` selectively to some parts of the model. To activate +parameter sharding with manual wrapping, you can wrap your model using the ``wrap`` function. Internally in Lightning, we enable a context manager around the ``configure_sharded_model`` function to make sure the ``wrap`` parameters are passed correctly. When not using Fully Sharded Training these wrap functions are a no-op. That means once the changes have been made, there is no need to remove the changes for other strategies. @@ -164,7 +196,7 @@ Here's an example using both ``wrap`` and ``auto_wrap`` to create your model: self.model = nn.Sequential(linear_layer, nn.ReLU(), block, final_block) def configure_optimizers(self): - return torch.optim.AdamW(self.model.parameters()) + return torch.optim.AdamW(self.model.parameters(), lr=1e-2) model = MyModel() @@ -178,8 +210,8 @@ Here's an example using both ``wrap`` and ``auto_wrap`` to create your model: .. _fairscale-activation-checkpointing: -FairScale Activation Checkpointing -"""""""""""""""""""""""""""""""""" +Activation Checkpointing +------------------------ Activation checkpointing frees activations from memory as soon as they are not needed during the forward pass. They are then re-computed for the backwards pass as needed. Activation checkpointing is very useful when you have intermediate layers that produce large activations. @@ -208,8 +240,9 @@ This saves memory when training larger models, however it requires wrapping modu .. _fully-sharded-native-training: +****************************** PyTorch Fully Sharded Training -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +****************************** PyTorch has it's own version of `FSDP `_ which is upstreamed from their `fairscale `__ project. It was introduced in their `v1.11.0 release `_ but it is recommended to use it with PyTorch v1.12 or more and that's what @@ -217,7 +250,8 @@ Lightning supports. The API is pretty similar to that of FairScale. Auto Wrapping -""""""""""""" +============= + Model layers should be wrapped in FSDP in a nested way to save peak memory and enable communication and computation overlapping. The simplest way to do it is auto wrapping, which can serve as a drop-in replacement for DDP without changing the rest of the code. You don't have to ``wrap`` layers manually as in the case of manual wrapping. @@ -233,7 +267,7 @@ Read more `here `_ if you run into any issues. @@ -343,7 +378,7 @@ If you run into an issue with the install or later in training, ensure that the .. _deepspeed-zero-stage-1: DeepSpeed ZeRO Stage 1 -"""""""""""""""""""""" +====================== `DeepSpeed ZeRO Stage 1 `_ partitions your optimizer states (Stage 1) across your GPUs to reduce memory. @@ -361,7 +396,7 @@ It is recommended to skip Stage 1 and use Stage 2, which comes with larger memor .. _deepspeed-zero-stage-2: DeepSpeed ZeRO Stage 2 -"""""""""""""""""""""" +====================== `DeepSpeed ZeRO Stage 2 `_ partitions your optimizer states (Stage 1) and your gradients (Stage 2) across your GPUs to reduce memory. In most cases, this is more efficient or at parity with DDP, primarily due to the optimized custom communications written by the DeepSpeed team. As a result, benefits can also be seen on a single GPU. Do note that the default bucket sizes allocate around ``3.6GB`` of VRAM to use during distributed communications, which can be tweaked when instantiating the strategy described in a few sections below. @@ -382,7 +417,7 @@ As a result, benefits can also be seen on a single GPU. Do note that the default .. _deepspeed-zero-stage-2-offload: DeepSpeed ZeRO Stage 2 Offload -"""""""""""""""""""""""""""""" +------------------------------ Below we show an example of running `ZeRO-Offload `_. ZeRO-Offload leverages the host CPU to offload optimizer memory/computation, reducing the overall memory consumption. @@ -452,7 +487,7 @@ For even more speed benefit, DeepSpeed offers an optimized CPU version of ADAM c .. _deepspeed-zero-stage-3: DeepSpeed ZeRO Stage 3 -"""""""""""""""""""""" +====================== DeepSpeed ZeRO Stage 3 shards the optimizer states, gradients and the model parameters (also optionally activations). Sharding model parameters and activations comes with an increase in distributed communication, however allows you to scale your models massively from one GPU to multiple GPUs. **The DeepSpeed team report the ability to fine-tune models with over 40B parameters on a single GPU and over 2 Trillion parameters on 512 GPUs.** For more information we suggest checking the `DeepSpeed ZeRO-3 Offload documentation `__. @@ -511,7 +546,7 @@ You can also use the Lightning Trainer to run predict or evaluate with DeepSpeed Shard Model Instantly to Reduce Initialization Time/Memory -"""""""""""""""""""""""""""""""""""""""""""""""""""""""""" +---------------------------------------------------------- When instantiating really large models, it is sometimes necessary to shard the model layers instantly. @@ -550,7 +585,7 @@ This reduces the time taken to initialize very large models, as well as ensure w .. _deepspeed-zero-stage-3-offload: DeepSpeed ZeRO Stage 3 Offload -"""""""""""""""""""""""""""""" +------------------------------ DeepSpeed ZeRO Stage 3 Offloads optimizer state, gradients to the host CPU to reduce memory usage as ZeRO Stage 2 does, however additionally allows you to offload the parameters as well for even more memory saving. @@ -584,7 +619,7 @@ DeepSpeed ZeRO Stage 3 Offloads optimizer state, gradients to the host CPU to re DeepSpeed Infinity (NVMe Offloading) -"""""""""""""""""""""""""""""""""""" +------------------------------------ Additionally, DeepSpeed supports offloading to NVMe drives for even larger models, utilizing the large memory space found in NVMes. DeepSpeed `reports `__ the ability to fine-tune 1 Trillion+ parameters using NVMe Offloading on one 8 GPU machine. Below shows how to enable this, assuming the NVMe drive is mounted in a directory called ``/local_nvme``. @@ -621,7 +656,7 @@ When offloading to NVMe you may notice that the speed is slow. There are paramet .. _deepspeed-activation-checkpointing: DeepSpeed Activation Checkpointing -"""""""""""""""""""""""""""""""""" +---------------------------------- Activation checkpointing frees activations from memory as soon as they are not needed during the forward pass. They are then re-computed for the backwards pass as needed. @@ -697,7 +732,7 @@ This saves memory when training larger models, however requires using a checkpoi .. _deepspeed-zero-stage-3-tips: DeepSpeed ZeRO Stage 3 Tips -""""""""""""""""""""""""""" +--------------------------- Here is some helpful information when setting up DeepSpeed ZeRO Stage 3 with Lightning. @@ -709,7 +744,7 @@ Here is some helpful information when setting up DeepSpeed ZeRO Stage 3 with Lig .. _deepspeed-zero-stage-3-single-file: Collating Single File Checkpoint for DeepSpeed ZeRO Stage 3 -""""""""""""""""""""""""""""""""""""""""""""""""""""""""""" +----------------------------------------------------------- After training using ZeRO Stage 3, you'll notice that your checkpoints are a directory of sharded model and optimizer states. If you'd like to collate a single file from the checkpoint directory please use the below command, which handles all the Lightning states additionally when collating the file. @@ -728,7 +763,7 @@ After training using ZeRO Stage 3, you'll notice that your checkpoints are a dir This single file checkpoint does not include the optimizer/lr-scheduler states. This means we cannot restore training via the ``trainer.fit(ckpt_path=)`` call. Ensure to keep the sharded checkpoint directory if this is required. Custom DeepSpeed Config -""""""""""""""""""""""" +======================= In some cases you may want to define your own DeepSpeed Config, to access all parameters defined. We've exposed most of the important parameters, however, there may be debugging parameters to enable. Also, DeepSpeed allows the use of custom DeepSpeed optimizers and schedulers defined within a config file that is supported. @@ -801,12 +836,13 @@ You can use also use an environment variable via your PyTorch Lightning script: .. _ddp-optimizations: +***************** DDP Optimizations -^^^^^^^^^^^^^^^^^ +***************** When Using DDP Strategies, Set find_unused_parameters=False -""""""""""""""""""""""""""""""""""""""""""""""""""""""""""" +=========================================================== By default, we have set ``find_unused_parameters=True`` for compatibility reasons that have been observed in the past (refer to the `discussion `_ for more details). When enabled, it can result in a performance hit and can be disabled in most cases. Read more about it `here `_. @@ -836,7 +872,7 @@ When enabled, it can result in a performance hit and can be disabled in most cas DDP Static Graph -"""""""""""""""" +================ `DDP static graph `__ assumes that your model employs the same set of used/unused parameters in every iteration, so that it can deterministically know the flow of @@ -854,7 +890,7 @@ training and apply special optimizations during runtime. When Using DDP on a Multi-node Cluster, Set NCCL Parameters -""""""""""""""""""""""""""""""""""""""""""""""""""""""""""" +=========================================================== `NCCL `__ is the NVIDIA Collective Communications Library that is used by PyTorch to handle communication across nodes and GPUs. There are reported benefits in terms of speedups when adjusting NCCL parameters as seen in this `issue `__. In the issue, we see a 30% speed improvement when training the Transformer XLM-RoBERTa and a 15% improvement in training with Detectron2. @@ -875,7 +911,7 @@ NCCL parameters can be adjusted via environment variables. Gradients as Bucket View -"""""""""""""""""""""""" +======================== Enabling ``gradient_as_bucket_view=True`` in the ``DDPStrategy`` will make gradients views point to different offsets of the ``allreduce`` communication buckets. See :class:`~torch.nn.parallel.DistributedDataParallel` for more information. @@ -894,8 +930,9 @@ This can reduce peak memory usage and throughput as saved memory will be equal t trainer = Trainer(accelerator="gpu", devices=4, strategy=DDPStrategy(gradient_as_bucket_view=True)) trainer.fit(model) + DDP Communication Hooks -""""""""""""""""""""""" +======================= DDP Communication hooks is an interface to control how gradients are communicated across workers, overriding the standard allreduce in DistributedDataParallel. This allows you to enable performance improving communication hooks when using multiple nodes. diff --git a/docs/source-pytorch/extensions/strategy.rst b/docs/source-pytorch/extensions/strategy.rst index ed39f68d45e23..21a6e8a8814b2 100644 --- a/docs/source-pytorch/extensions/strategy.rst +++ b/docs/source-pytorch/extensions/strategy.rst @@ -83,10 +83,10 @@ The below table lists all relevant strategies available in Lightning with their - Strategy for Fully Sharded Data Parallel provided by FairScale. :ref:`Learn more. ` * - ddp_sharded - :class:`~pytorch_lightning.strategies.DDPShardedStrategy` - - Optimizer and gradient sharded training provided by FairScale. :ref:`Learn more. ` + - Optimizer and gradient sharded training provided by FairScale. :ref:`Learn more. ` * - ddp_sharded_spawn - :class:`~pytorch_lightning.strategies.DDPSpawnShardedStrategy` - - Optimizer sharded training provided by FairScale. :ref:`Learn more. ` + - Optimizer sharded training provided by FairScale. :ref:`Learn more. ` * - ddp_spawn - :class:`~pytorch_lightning.strategies.DDPSpawnStrategy` - Spawns processes using the :func:`torch.multiprocessing.spawn` method and joins processes after training finishes. :ref:`Learn more. ` diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index f517d56fb0465..0d5d55d3324ec 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -21,6 +21,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for saving sharded optimizer state dict outside of `DDPShardedStrategy` ([#14208](https://github.com/PyTorchLightning/pytorch-lightning/pull/14208)) +- Added support for auto wrapping for `DDPFullyShardedStrategy` ([#14383](https://github.com/Lightning-AI/lightning/issues/14383)) + + ### Changed diff --git a/src/pytorch_lightning/callbacks/stochastic_weight_avg.py b/src/pytorch_lightning/callbacks/stochastic_weight_avg.py index 90e2c62a7962d..51cbceb7f9fb6 100644 --- a/src/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/src/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -25,6 +25,7 @@ import pytorch_lightning as pl from pytorch_lightning.callbacks.callback import Callback from pytorch_lightning.strategies import DDPFullyShardedStrategy, DeepSpeedStrategy +from pytorch_lightning.strategies.fully_sharded_native import DDPFullyShardedNativeStrategy from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_warn from pytorch_lightning.utilities.types import _LRScheduler, LRSchedulerConfig @@ -144,6 +145,9 @@ def pl_module_contains_batch_norm(pl_module: "pl.LightningModule") -> bool: return any(isinstance(module, nn.modules.batchnorm._BatchNorm) for module in pl_module.modules()) def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None: + if isinstance(trainer.strategy, (DDPFullyShardedStrategy, DDPFullyShardedNativeStrategy, DeepSpeedStrategy)): + raise MisconfigurationException("SWA does not currently support sharded models.") + # copy the model before moving it to accelerator device. with pl_module._prevent_trainer_and_dataloaders_deepcopy(): self._average_model = deepcopy(pl_module) @@ -155,9 +159,6 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") - if len(trainer.lr_scheduler_configs) > 1: raise MisconfigurationException("SWA currently not supported for more than 1 `lr_scheduler`.") - if isinstance(trainer.strategy, (DDPFullyShardedStrategy, DeepSpeedStrategy)): - raise MisconfigurationException("SWA does not currently support sharded models.") - if isinstance(self._swa_epoch_start, float): self._swa_epoch_start = int(trainer.max_epochs * self._swa_epoch_start) diff --git a/src/pytorch_lightning/strategies/fully_sharded.py b/src/pytorch_lightning/strategies/fully_sharded.py index 6f7ca3b34b03d..a364d7d19a679 100644 --- a/src/pytorch_lightning/strategies/fully_sharded.py +++ b/src/pytorch_lightning/strategies/fully_sharded.py @@ -18,6 +18,7 @@ import torch import pytorch_lightning as pl +from pytorch_lightning.overrides.base import _LightningModuleWrapperBase from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO @@ -26,16 +27,28 @@ from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities.enums import PrecisionType from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.optimizer import optimizers_to_device -from pytorch_lightning.utilities.types import PredictStep, STEP_OUTPUT, TestStep, TrainingStep, ValidationStep +from pytorch_lightning.utilities.rank_zero import rank_zero_info +from pytorch_lightning.utilities.types import STEP_OUTPUT if _FAIRSCALE_AVAILABLE: from fairscale.nn import default_auto_wrap_policy, enable_wrap from fairscale.nn.data_parallel import FullyShardedDataParallel +else: + FullyShardedDataParallel = None log = logging.getLogger(__name__) +class _DDPFullyShardedStrategyModuleWrapper(_LightningModuleWrapperBase): + def state_dict(self, *args: Any, **kwargs: Any) -> Dict[str, Any]: # type: ignore[override] + # this is required because with FSDP lightning_module is empty because weights are sharded. + # So we need to call self.trainer.model.state_dict (wrapped version) and use this wraper to + # avoid extra keys `_forward_module.layer.weight.` since we want `layer.weight.` in state_dict. + return self._forward_module.state_dict(*args, **kwargs) + + class DDPFullyShardedStrategy(DDPStrategy): strategy_name = "ddp_fully_sharded" @@ -132,6 +145,25 @@ def process_group(self) -> Any: self._process_group = torch.distributed.new_group() return self._process_group + def lightning_module_state_dict(self) -> Dict[str, Any]: + """Returns model state.""" + assert self.model is not None + return self.model.state_dict() + + def connect(self, model: "pl.LightningModule") -> None: + """Called by the accelerator to connect the accelerator and the model with this plugin.""" + # TODO: Wait for this issue to resolve and remove this blocker + # https://github.com/facebookresearch/fairscale/issues/648 + # Also make sure to update the tests + if not is_overridden("configure_sharded_model", self.lightning_module) and len(list(model.parameters())) == 0: + assert self.lightning_module is not None + raise MisconfigurationException( + f"Using the same instance of model with `trainer.{self.lightning_module.trainer.state.fn}()` is not" + " supported with Fairscale FSDP auto-wrap. Please reinitialize your `LightningModule` and pass that." + ) + + super().connect(model) + def setup_distributed(self) -> None: if not self.root_device.type == "cuda": raise MisconfigurationException( @@ -144,17 +176,46 @@ def setup(self, trainer: "pl.Trainer") -> None: self.accelerator.setup(trainer) if trainer.state.fn == TrainerFn.FITTING: - self.setup_optimizers(trainer) - optimizers_to_device(self.optimizers, self.root_device) - if self._layer_sync: assert self.model self.model = self._layer_sync.apply(self.model) - self.setup_precision_plugin() self.configure_ddp() + assert isinstance(self.model, pl.LightningModule) + self.model = _DDPFullyShardedStrategyModuleWrapper(self.model) + assert self.lightning_module is not None + if not is_overridden("configure_sharded_model", self.lightning_module): + self.model = self._setup_model(self.model) + self.setup_optimizers(self.lightning_module.trainer) + optimizers_to_device(self.optimizers, self.root_device) self.barrier() + self.setup_precision_plugin() + + def _setup_model(self, model: torch.nn.Module) -> FullyShardedDataParallel: + """Wraps the model into a + :class:`~fairscale.nn.data_parallel.fully_sharded_data_parallel.FullyShardedDataParallel` module.""" + log.detail(f"setting up `Fairscale FSDP` model with device id: {self.root_device.index}.") + + rank_zero_info( + "When using FairScale FSDP auto-wrap, make sure to initalize your model using trainer else" + " you will get an error.\ntorch.optim.Optimizer(self.trainer.model.parameters(), ...)" + ) + + return FullyShardedDataParallel( + module=model, + process_group=self.process_group, + cpu_offload=self.cpu_offload, + move_grads_to_cpu=self.move_grads_to_cpu, + flatten_parameters=self.flatten_parameters, + mixed_precision=(self.precision_plugin.precision in (PrecisionType.MIXED, PrecisionType.HALF)), + reshard_after_forward=self.reshard_after_forward, + fp32_reduce_scatter=self.fp32_reduce_scatter, + compute_dtype=self.compute_dtype, + bucket_cap_mb=self.bucket_cap_mb, + state_dict_device=self.state_dict_device, + ) + @contextlib.contextmanager def model_sharded_context(self) -> Generator: log.detail(f"{self.__class__.__name__}: entered model_sharded_context.") @@ -190,10 +251,6 @@ def configure_ddp(self) -> None: # (TODO: need to figure out solution) self.model_to_device() - # setup optimizers after fully sharded has wrapped the lightning module - assert self.lightning_module - self.setup_optimizers(self.lightning_module.trainer) - def model_to_device(self) -> None: log.detail(f"{self.__class__.__name__}: moving model to device [{self.root_device}]...") # ensure we update the device type in the lightning module @@ -201,24 +258,22 @@ def model_to_device(self) -> None: self.lightning_module.to(self.root_device) def training_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT: - with self.precision_plugin.train_step_context(): - assert isinstance(self.model, TrainingStep) - return self.model.training_step(*args, **kwargs) + # we don't need precision context since casting is done by FSDP + # read `mixed_precision` docstring here: https://pytorch.org/docs/stable/fsdp.html + assert self.model is not None + return self.model(*args, **kwargs) def validation_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]: - with self.precision_plugin.val_step_context(): - assert isinstance(self.model, ValidationStep) - return self.model.validation_step(*args, **kwargs) + assert self.model is not None + return self.model(*args, **kwargs) def test_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]: - with self.precision_plugin.test_step_context(): - assert isinstance(self.model, TestStep) - return self.model.test_step(*args, **kwargs) + assert self.model is not None + return self.model(*args, **kwargs) def predict_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT: - with self.precision_plugin.predict_step_context(): - assert isinstance(self.model, PredictStep) - return self.model.predict_step(*args, **kwargs) + assert self.model is not None + return self.model(*args, **kwargs) def post_training_step(self) -> None: pass diff --git a/src/pytorch_lightning/strategies/strategy.py b/src/pytorch_lightning/strategies/strategy.py index 0d89529a8d115..0a10722166f8d 100644 --- a/src/pytorch_lightning/strategies/strategy.py +++ b/src/pytorch_lightning/strategies/strategy.py @@ -443,7 +443,7 @@ def handles_gradient_accumulation(self) -> bool: """Whether the plugin handles gradient accumulation internally.""" return False - def lightning_module_state_dict(self) -> Dict[str, Union[Any, Tensor]]: + def lightning_module_state_dict(self) -> Dict[str, Any]: """Returns model state.""" assert self.lightning_module is not None return self.lightning_module.state_dict() diff --git a/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py b/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py index a39a7a2145225..f18fce183f4cd 100644 --- a/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py +++ b/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py @@ -348,6 +348,7 @@ def test_swa_resume_training_from_checkpoint_ddp(tmpdir): [ pytest.param("fsdp", marks=RunIf(fairscale=True, min_cuda_gpus=1)), pytest.param("deepspeed", marks=RunIf(deepspeed=True, min_cuda_gpus=1)), + pytest.param("fsdp_native", marks=RunIf(min_cuda_gpus=1, skip_windows=True, min_torch="1.12")), ], ) def test_misconfiguration_error_with_sharded_model(tmpdir, strategy: str): diff --git a/tests/tests_pytorch/strategies/test_ddp_fully_sharded_with_full_state_dict.py b/tests/tests_pytorch/strategies/test_ddp_fully_sharded_with_full_state_dict.py index fe587877e84fb..88a07a78efecf 100644 --- a/tests/tests_pytorch/strategies/test_ddp_fully_sharded_with_full_state_dict.py +++ b/tests/tests_pytorch/strategies/test_ddp_fully_sharded_with_full_state_dict.py @@ -41,7 +41,7 @@ def test_fsdp_with_sharded_amp(device_count_mock, mock_cuda_available, tmpdir): assert isinstance(trainer.strategy.precision_plugin, FullyShardedNativeMixedPrecisionPlugin) -class TestFSDPModel(BoringModel): +class TestFSDPModelManualWrapped(BoringModel): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.layer: Optional[torch.nn.Module] = None @@ -69,16 +69,16 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: def configure_optimizers(self): return torch.optim.SGD(self.layer.parameters(), lr=0.1) - def on_train_start(self) -> None: + def on_train_batch_end(self, *_, **__) -> None: self._assert_layer_fsdp_instance() - def on_test_start(self) -> None: + def on_test_batch_end(self, *_, **__) -> None: self._assert_layer_fsdp_instance() - def on_validation_start(self) -> None: + def on_validation_batch_end(self, *_, **__) -> None: self._assert_layer_fsdp_instance() - def on_prediction_start(self) -> None: + def on_prediction_batch_end(self, *_, **__) -> None: self._assert_layer_fsdp_instance() def _assert_layer_fsdp_instance(self) -> None: @@ -87,8 +87,8 @@ def _assert_layer_fsdp_instance(self) -> None: assert isinstance(self.layer.module[2], FullyShardedDataParallel) # Assert that the nested layers are set reshard_after_forward to True - assert self.layer.module[0].reshard_after_forward is True - assert self.layer.module[2].reshard_after_forward is True + assert self.layer.module[0].reshard_after_forward + assert self.layer.module[2].reshard_after_forward if isinstance(self.trainer.precision_plugin, FullyShardedNativeMixedPrecisionPlugin): assert self.layer.mixed_precision @@ -96,11 +96,40 @@ def _assert_layer_fsdp_instance(self) -> None: assert self.layer.module[2].mixed_precision +class TestFSDPModelAutoWrapped(BoringModel): + def __init__(self): + super().__init__() + self.layer = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.ReLU(), torch.nn.Linear(32, 2)) + + def configure_optimizers(self): + return torch.optim.SGD(self.trainer.model.parameters(), lr=0.1) + + def on_train_batch_end(self, *_, **__) -> None: + self._assert_layer_fsdp_instance() + + def on_test_batch_end(self, *_, **__) -> None: + self._assert_layer_fsdp_instance() + + def on_validation_batch_end(self, *_, **__) -> None: + self._assert_layer_fsdp_instance() + + def on_prediction_batch_end(self, *_, **__) -> None: + self._assert_layer_fsdp_instance() + + def _assert_layer_fsdp_instance(self) -> None: + assert isinstance(self.trainer.model, FullyShardedDataParallel) + # `disable_reshard_on_root=True` (default) in FSDP which turns-off resharding + assert not self.trainer.model.reshard_after_forward + + if isinstance(self.trainer.precision_plugin, FullyShardedNativeMixedPrecisionPlugin): + assert self.trainer.model.mixed_precision + + @RunIf(min_cuda_gpus=1, standalone=True, fairscale=True) def test_fully_sharded_strategy_checkpoint(tmpdir): """Test to ensure that checkpoint is saved correctly when using a single GPU, and all stages can be run.""" - model = TestFSDPModel() + model = TestFSDPModelManualWrapped() trainer = Trainer( default_root_dir=tmpdir, accelerator="gpu", @@ -115,18 +144,28 @@ def test_fully_sharded_strategy_checkpoint(tmpdir): @RunIf(min_cuda_gpus=2, standalone=True, fairscale=True) -def test_fully_sharded_strategy_checkpoint_multi_gpus(tmpdir): +@pytest.mark.parametrize( + "model, strategy", + [ + (TestFSDPModelManualWrapped(), DDPFullyShardedStrategy(min_num_params=2)), + (TestFSDPModelAutoWrapped(), "fsdp"), + ], +) +def test_fully_sharded_strategy_checkpoint_multi_gpus(tmpdir, model, strategy): """Test to ensure that checkpoint is saved correctly when using multiple GPUs, and all stages can be run.""" - model = TestFSDPModel() ck = ModelCheckpoint(save_last=True) trainer = Trainer( default_root_dir=tmpdir, accelerator="gpu", devices=2, - strategy="fsdp", + strategy=strategy, precision=16, max_epochs=1, + limit_train_batches=2, + limit_val_batches=2, + limit_test_batches=2, + limit_predict_batches=2, callbacks=[ck], enable_progress_bar=False, enable_model_summary=False, @@ -134,7 +173,7 @@ def test_fully_sharded_strategy_checkpoint_multi_gpus(tmpdir): _run_multiple_stages(trainer, model) -def _assert_save_equality(trainer, ckpt_path, cls=TestFSDPModel): +def _assert_save_equality(trainer, ckpt_path, cls=TestFSDPModelManualWrapped): # Use FullySharded to get the state dict for the sake of comparison model_state_dict = trainer.strategy.lightning_module_state_dict() @@ -153,19 +192,36 @@ def _run_multiple_stages(trainer, model, model_path: Optional[str] = None): trainer.save_checkpoint(model_path, weights_only=True) - _assert_save_equality(trainer, model_path, cls=TestFSDPModel) + _assert_save_equality(trainer, model_path, cls=model.__class__) # Test entry point + if model.__class__ is TestFSDPModelAutoWrapped: + model = TestFSDPModelAutoWrapped() trainer.test(model) # model is wrapped, will not call configure_shared_model - # provide model path, will create a new unwrapped model and load and then call configure_shared_model to wrap - trainer.test(ckpt_path=model_path) + # provide model path, will create a new unwrapped model and load and then call `configure_shared_model` to wrap + if model.__class__ is TestFSDPModelAutoWrapped: + model = TestFSDPModelAutoWrapped() + trainer.test(model, ckpt_path=model_path) + + # Predict entry point + if model.__class__ is TestFSDPModelAutoWrapped: + model = TestFSDPModelAutoWrapped() + + if model.__class__ is TestFSDPModelAutoWrapped: + model = TestFSDPModelAutoWrapped() + trainer.predict(model) # model is wrapped, will not call `configure_sharded_model` + + # provide model path, will create a new unwrapped model and load and then call `configure_shared_model` to wrap + if model.__class__ is TestFSDPModelAutoWrapped: + model = TestFSDPModelAutoWrapped() + trainer.predict(model, ckpt_path=model_path) @RunIf(min_cuda_gpus=1, standalone=True, fairscale=True) def test_fsdp_gradient_clipping_raises(tmpdir): """Test to ensure that an exception is raised when clipping gradients by value with FSDP.""" - model = BoringModel() + model = TestFSDPModelManualWrapped() trainer = Trainer( default_root_dir=tmpdir, strategy="fsdp", @@ -182,3 +238,21 @@ def test_fsdp_gradient_clipping_raises(tmpdir): MisconfigurationException, match="gradient_clip_algorithm='norm'` is currently not supported for `FullySharded" ): trainer.fit(model) + + +@RunIf(min_cuda_gpus=1, skip_windows=True, standalone=True, fairscale_fully_sharded=True) +def test_fsdp_rewrap_limitation(tmpdir): + trainer = Trainer( + default_root_dir=tmpdir, + accelerator="gpu", + devices=1, + max_steps=1, + limit_val_batches=0, + limit_test_batches=1, + strategy="fsdp", + ) + model = TestFSDPModelAutoWrapped() + trainer.fit(model) + + with pytest.raises(MisconfigurationException, match="Using the same instance of model .* not supported"): + trainer.test(model)