Skip to content

Commit

Permalink
Add auto wrapping support for DDPFullyShardedStrategy (#14383)
Browse files Browse the repository at this point in the history
  • Loading branch information
rohitgr7 authored Sep 5, 2022
1 parent 7f148b2 commit 8c6119f
Show file tree
Hide file tree
Showing 8 changed files with 248 additions and 77 deletions.
105 changes: 71 additions & 34 deletions docs/source-pytorch/advanced/model_parallel.rst

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions docs/source-pytorch/extensions/strategy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,10 @@ The below table lists all relevant strategies available in Lightning with their
- Strategy for Fully Sharded Data Parallel provided by FairScale. :ref:`Learn more. <advanced/model_parallel:FairScale Fully Sharded Training>`
* - ddp_sharded
- :class:`~pytorch_lightning.strategies.DDPShardedStrategy`
- Optimizer and gradient sharded training provided by FairScale. :ref:`Learn more. <advanced/model_parallel:Sharded Training>`
- Optimizer and gradient sharded training provided by FairScale. :ref:`Learn more. <advanced/model_parallel:FairScale Sharded Training>`
* - ddp_sharded_spawn
- :class:`~pytorch_lightning.strategies.DDPSpawnShardedStrategy`
- Optimizer sharded training provided by FairScale. :ref:`Learn more. <advanced/model_parallel:Sharded Training>`
- Optimizer sharded training provided by FairScale. :ref:`Learn more. <advanced/model_parallel:FairScale Sharded Training>`
* - ddp_spawn
- :class:`~pytorch_lightning.strategies.DDPSpawnStrategy`
- Spawns processes using the :func:`torch.multiprocessing.spawn` method and joins processes after training finishes. :ref:`Learn more. <accelerators/gpu_intermediate:Distributed Data Parallel Spawn>`
Expand Down
3 changes: 3 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for saving sharded optimizer state dict outside of `DDPShardedStrategy` ([#14208](https://github.com/PyTorchLightning/pytorch-lightning/pull/14208))


- Added support for auto wrapping for `DDPFullyShardedStrategy` ([#14383](https://github.com/Lightning-AI/lightning/issues/14383))



### Changed

Expand Down
7 changes: 4 additions & 3 deletions src/pytorch_lightning/callbacks/stochastic_weight_avg.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import pytorch_lightning as pl
from pytorch_lightning.callbacks.callback import Callback
from pytorch_lightning.strategies import DDPFullyShardedStrategy, DeepSpeedStrategy
from pytorch_lightning.strategies.fully_sharded_native import DDPFullyShardedNativeStrategy
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.types import _LRScheduler, LRSchedulerConfig
Expand Down Expand Up @@ -144,6 +145,9 @@ def pl_module_contains_batch_norm(pl_module: "pl.LightningModule") -> bool:
return any(isinstance(module, nn.modules.batchnorm._BatchNorm) for module in pl_module.modules())

def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None:
if isinstance(trainer.strategy, (DDPFullyShardedStrategy, DDPFullyShardedNativeStrategy, DeepSpeedStrategy)):
raise MisconfigurationException("SWA does not currently support sharded models.")

# copy the model before moving it to accelerator device.
with pl_module._prevent_trainer_and_dataloaders_deepcopy():
self._average_model = deepcopy(pl_module)
Expand All @@ -155,9 +159,6 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
if len(trainer.lr_scheduler_configs) > 1:
raise MisconfigurationException("SWA currently not supported for more than 1 `lr_scheduler`.")

if isinstance(trainer.strategy, (DDPFullyShardedStrategy, DeepSpeedStrategy)):
raise MisconfigurationException("SWA does not currently support sharded models.")

if isinstance(self._swa_epoch_start, float):
self._swa_epoch_start = int(trainer.max_epochs * self._swa_epoch_start)

Expand Down
97 changes: 76 additions & 21 deletions src/pytorch_lightning/strategies/fully_sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import torch

import pytorch_lightning as pl
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
Expand All @@ -26,16 +27,28 @@
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities.enums import PrecisionType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.optimizer import optimizers_to_device
from pytorch_lightning.utilities.types import PredictStep, STEP_OUTPUT, TestStep, TrainingStep, ValidationStep
from pytorch_lightning.utilities.rank_zero import rank_zero_info
from pytorch_lightning.utilities.types import STEP_OUTPUT

if _FAIRSCALE_AVAILABLE:
from fairscale.nn import default_auto_wrap_policy, enable_wrap
from fairscale.nn.data_parallel import FullyShardedDataParallel
else:
FullyShardedDataParallel = None

log = logging.getLogger(__name__)


class _DDPFullyShardedStrategyModuleWrapper(_LightningModuleWrapperBase):
def state_dict(self, *args: Any, **kwargs: Any) -> Dict[str, Any]: # type: ignore[override]
# this is required because with FSDP lightning_module is empty because weights are sharded.
# So we need to call self.trainer.model.state_dict (wrapped version) and use this wraper to
# avoid extra keys `_forward_module.layer.weight.` since we want `layer.weight.` in state_dict.
return self._forward_module.state_dict(*args, **kwargs)


class DDPFullyShardedStrategy(DDPStrategy):

strategy_name = "ddp_fully_sharded"
Expand Down Expand Up @@ -132,6 +145,25 @@ def process_group(self) -> Any:
self._process_group = torch.distributed.new_group()
return self._process_group

def lightning_module_state_dict(self) -> Dict[str, Any]:
"""Returns model state."""
assert self.model is not None
return self.model.state_dict()

def connect(self, model: "pl.LightningModule") -> None:
"""Called by the accelerator to connect the accelerator and the model with this plugin."""
# TODO: Wait for this issue to resolve and remove this blocker
# https://github.com/facebookresearch/fairscale/issues/648
# Also make sure to update the tests
if not is_overridden("configure_sharded_model", self.lightning_module) and len(list(model.parameters())) == 0:
assert self.lightning_module is not None
raise MisconfigurationException(
f"Using the same instance of model with `trainer.{self.lightning_module.trainer.state.fn}()` is not"
" supported with Fairscale FSDP auto-wrap. Please reinitialize your `LightningModule` and pass that."
)

super().connect(model)

def setup_distributed(self) -> None:
if not self.root_device.type == "cuda":
raise MisconfigurationException(
Expand All @@ -144,17 +176,46 @@ def setup(self, trainer: "pl.Trainer") -> None:
self.accelerator.setup(trainer)

if trainer.state.fn == TrainerFn.FITTING:
self.setup_optimizers(trainer)
optimizers_to_device(self.optimizers, self.root_device)

if self._layer_sync:
assert self.model
self.model = self._layer_sync.apply(self.model)

self.setup_precision_plugin()
self.configure_ddp()
assert isinstance(self.model, pl.LightningModule)
self.model = _DDPFullyShardedStrategyModuleWrapper(self.model)
assert self.lightning_module is not None
if not is_overridden("configure_sharded_model", self.lightning_module):
self.model = self._setup_model(self.model)
self.setup_optimizers(self.lightning_module.trainer)
optimizers_to_device(self.optimizers, self.root_device)
self.barrier()

self.setup_precision_plugin()

def _setup_model(self, model: torch.nn.Module) -> FullyShardedDataParallel:
"""Wraps the model into a
:class:`~fairscale.nn.data_parallel.fully_sharded_data_parallel.FullyShardedDataParallel` module."""
log.detail(f"setting up `Fairscale FSDP` model with device id: {self.root_device.index}.")

rank_zero_info(
"When using FairScale FSDP auto-wrap, make sure to initalize your model using trainer else"
" you will get an error.\ntorch.optim.Optimizer(self.trainer.model.parameters(), ...)"
)

return FullyShardedDataParallel(
module=model,
process_group=self.process_group,
cpu_offload=self.cpu_offload,
move_grads_to_cpu=self.move_grads_to_cpu,
flatten_parameters=self.flatten_parameters,
mixed_precision=(self.precision_plugin.precision in (PrecisionType.MIXED, PrecisionType.HALF)),
reshard_after_forward=self.reshard_after_forward,
fp32_reduce_scatter=self.fp32_reduce_scatter,
compute_dtype=self.compute_dtype,
bucket_cap_mb=self.bucket_cap_mb,
state_dict_device=self.state_dict_device,
)

@contextlib.contextmanager
def model_sharded_context(self) -> Generator:
log.detail(f"{self.__class__.__name__}: entered model_sharded_context.")
Expand Down Expand Up @@ -190,35 +251,29 @@ def configure_ddp(self) -> None:
# (TODO: need to figure out solution)
self.model_to_device()

# setup optimizers after fully sharded has wrapped the lightning module
assert self.lightning_module
self.setup_optimizers(self.lightning_module.trainer)

def model_to_device(self) -> None:
log.detail(f"{self.__class__.__name__}: moving model to device [{self.root_device}]...")
# ensure we update the device type in the lightning module
assert self.lightning_module
self.lightning_module.to(self.root_device)

def training_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
with self.precision_plugin.train_step_context():
assert isinstance(self.model, TrainingStep)
return self.model.training_step(*args, **kwargs)
# we don't need precision context since casting is done by FSDP
# read `mixed_precision` docstring here: https://pytorch.org/docs/stable/fsdp.html
assert self.model is not None
return self.model(*args, **kwargs)

def validation_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:
with self.precision_plugin.val_step_context():
assert isinstance(self.model, ValidationStep)
return self.model.validation_step(*args, **kwargs)
assert self.model is not None
return self.model(*args, **kwargs)

def test_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:
with self.precision_plugin.test_step_context():
assert isinstance(self.model, TestStep)
return self.model.test_step(*args, **kwargs)
assert self.model is not None
return self.model(*args, **kwargs)

def predict_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
with self.precision_plugin.predict_step_context():
assert isinstance(self.model, PredictStep)
return self.model.predict_step(*args, **kwargs)
assert self.model is not None
return self.model(*args, **kwargs)

def post_training_step(self) -> None:
pass
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/strategies/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ def handles_gradient_accumulation(self) -> bool:
"""Whether the plugin handles gradient accumulation internally."""
return False

def lightning_module_state_dict(self) -> Dict[str, Union[Any, Tensor]]:
def lightning_module_state_dict(self) -> Dict[str, Any]:
"""Returns model state."""
assert self.lightning_module is not None
return self.lightning_module.state_dict()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,7 @@ def test_swa_resume_training_from_checkpoint_ddp(tmpdir):
[
pytest.param("fsdp", marks=RunIf(fairscale=True, min_cuda_gpus=1)),
pytest.param("deepspeed", marks=RunIf(deepspeed=True, min_cuda_gpus=1)),
pytest.param("fsdp_native", marks=RunIf(min_cuda_gpus=1, skip_windows=True, min_torch="1.12")),
],
)
def test_misconfiguration_error_with_sharded_model(tmpdir, strategy: str):
Expand Down
Loading

0 comments on commit 8c6119f

Please sign in to comment.