Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add auto wrapping support for DDPFullyShardedStrategy #14383

Merged
merged 15 commits into from
Sep 5, 2022
105 changes: 71 additions & 34 deletions docs/source-pytorch/advanced/model_parallel.rst

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions docs/source-pytorch/extensions/strategy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,10 @@ The below table lists all relevant strategies available in Lightning with their
- Strategy for Fully Sharded Data Parallel provided by FairScale. :ref:`Learn more. <advanced/model_parallel:FairScale Fully Sharded Training>`
* - ddp_sharded
- :class:`~pytorch_lightning.strategies.DDPShardedStrategy`
- Optimizer and gradient sharded training provided by FairScale. :ref:`Learn more. <advanced/model_parallel:Sharded Training>`
- Optimizer and gradient sharded training provided by FairScale. :ref:`Learn more. <advanced/model_parallel:FairScale Sharded Training>`
* - ddp_sharded_spawn
- :class:`~pytorch_lightning.strategies.DDPSpawnShardedStrategy`
- Optimizer sharded training provided by FairScale. :ref:`Learn more. <advanced/model_parallel:Sharded Training>`
- Optimizer sharded training provided by FairScale. :ref:`Learn more. <advanced/model_parallel:FairScale Sharded Training>`
* - ddp_spawn
- :class:`~pytorch_lightning.strategies.DDPSpawnStrategy`
- Spawns processes using the :func:`torch.multiprocessing.spawn` method and joins processes after training finishes. :ref:`Learn more. <accelerators/gpu_intermediate:Distributed Data Parallel Spawn>`
Expand Down
3 changes: 3 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for saving sharded optimizer state dict outside of `DDPShardedStrategy` ([#14208](https://github.com/PyTorchLightning/pytorch-lightning/pull/14208))


- Added support for auto wrapping for `DDPFullyShardedStrategy` ([#14383](https://github.com/Lightning-AI/lightning/issues/14383))



### Changed

Expand Down
7 changes: 4 additions & 3 deletions src/pytorch_lightning/callbacks/stochastic_weight_avg.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import pytorch_lightning as pl
from pytorch_lightning.callbacks.callback import Callback
from pytorch_lightning.strategies import DDPFullyShardedStrategy, DeepSpeedStrategy
from pytorch_lightning.strategies.fully_sharded_native import DDPFullyShardedNativeStrategy
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.types import _LRScheduler, LRSchedulerConfig
Expand Down Expand Up @@ -144,6 +145,9 @@ def pl_module_contains_batch_norm(pl_module: "pl.LightningModule") -> bool:
return any(isinstance(module, nn.modules.batchnorm._BatchNorm) for module in pl_module.modules())

def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None:
if isinstance(trainer.strategy, (DDPFullyShardedStrategy, DDPFullyShardedNativeStrategy, DeepSpeedStrategy)):
raise MisconfigurationException("SWA does not currently support sharded models.")

# copy the model before moving it to accelerator device.
with pl_module._prevent_trainer_and_dataloaders_deepcopy():
self._average_model = deepcopy(pl_module)
Expand All @@ -155,9 +159,6 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
if len(trainer.lr_scheduler_configs) > 1:
raise MisconfigurationException("SWA currently not supported for more than 1 `lr_scheduler`.")

if isinstance(trainer.strategy, (DDPFullyShardedStrategy, DeepSpeedStrategy)):
raise MisconfigurationException("SWA does not currently support sharded models.")

if isinstance(self._swa_epoch_start, float):
self._swa_epoch_start = int(trainer.max_epochs * self._swa_epoch_start)

Expand Down
94 changes: 73 additions & 21 deletions src/pytorch_lightning/strategies/fully_sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import torch

import pytorch_lightning as pl
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
Expand All @@ -26,16 +27,25 @@
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities.enums import PrecisionType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.optimizer import optimizers_to_device
from pytorch_lightning.utilities.types import PredictStep, STEP_OUTPUT, TestStep, TrainingStep, ValidationStep
from pytorch_lightning.utilities.rank_zero import rank_zero_info
from pytorch_lightning.utilities.types import STEP_OUTPUT

if _FAIRSCALE_AVAILABLE:
from fairscale.nn import default_auto_wrap_policy, enable_wrap
from fairscale.nn.data_parallel import FullyShardedDataParallel
else:
FullyShardedDataParallel = None

log = logging.getLogger(__name__)


class _DDPFullyShardedStrategyModuleWrapper(_LightningModuleWrapperBase):
def state_dict(self, *args: Any, **kwargs: Any) -> Dict[str, Any]: # type: ignore[override]
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
return self._forward_module.state_dict(*args, **kwargs)


class DDPFullyShardedStrategy(DDPStrategy):

strategy_name = "ddp_fully_sharded"
Expand Down Expand Up @@ -132,6 +142,25 @@ def process_group(self) -> Any:
self._process_group = torch.distributed.new_group()
return self._process_group

def lightning_module_state_dict(self) -> Dict[str, Any]:
"""Returns model state."""
assert self.model is not None
return self.model.state_dict()

def connect(self, model: "pl.LightningModule") -> None:
"""Called by the accelerator to connect the accelerator and the model with this plugin."""
# TODO: Wait for this issue to resolve and remove this blocker
# https://github.com/facebookresearch/fairscale/issues/648
# Also make sure to update the tests
if not is_overridden("configure_sharded_model", self.lightning_module) and len(list(model.parameters())) == 0:
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
assert self.lightning_module is not None
raise MisconfigurationException(
f"Using the same instance of model with `trainer.{self.lightning_module.trainer.state.fn}()` is not"
" supported with Fairscale FSDP auto-wrap. Please reinitialize your `LightningModule` and pass that."
)

super().connect(model)

def setup_distributed(self) -> None:
if not self.root_device.type == "cuda":
raise MisconfigurationException(
Expand All @@ -144,17 +173,46 @@ def setup(self, trainer: "pl.Trainer") -> None:
self.accelerator.setup(trainer)

if trainer.state.fn == TrainerFn.FITTING:
self.setup_optimizers(trainer)
optimizers_to_device(self.optimizers, self.root_device)

if self._layer_sync:
assert self.model
self.model = self._layer_sync.apply(self.model)

self.setup_precision_plugin()
self.configure_ddp()
assert isinstance(self.model, pl.LightningModule)
self.model = _DDPFullyShardedStrategyModuleWrapper(self.model)
assert self.lightning_module is not None
if not is_overridden("configure_sharded_model", self.lightning_module):
self.model = self._setup_model(self.model)
self.setup_optimizers(self.lightning_module.trainer)
optimizers_to_device(self.optimizers, self.root_device)
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
self.barrier()

self.setup_precision_plugin()

def _setup_model(self, model: torch.nn.Module) -> FullyShardedDataParallel:
"""Wraps the model into a
:class:`~fairscale.nn.data_parallel.fully_sharded_data_parallel.FullyShardedDataParallel` module."""
log.detail(f"setting up `Fairscale FSDP` model with device id: {self.root_device.index}.")

rank_zero_info(
"When using FairScale FSDP auto-wrap, make sure to initalize your model using trainer else"
" you will get an error.\ntorch.optim.Optimizer(self.trainer.model.parameters(), ...)"
)

return FullyShardedDataParallel(
module=model,
process_group=self.process_group,
cpu_offload=self.cpu_offload,
move_grads_to_cpu=self.move_grads_to_cpu,
flatten_parameters=self.flatten_parameters,
mixed_precision=(self.precision_plugin.precision in (PrecisionType.MIXED, PrecisionType.HALF)),
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
reshard_after_forward=self.reshard_after_forward,
fp32_reduce_scatter=self.fp32_reduce_scatter,
compute_dtype=self.compute_dtype,
bucket_cap_mb=self.bucket_cap_mb,
state_dict_device=self.state_dict_device,
)

@contextlib.contextmanager
def model_sharded_context(self) -> Generator:
log.detail(f"{self.__class__.__name__}: entered model_sharded_context.")
Expand Down Expand Up @@ -190,35 +248,29 @@ def configure_ddp(self) -> None:
# (TODO: need to figure out solution)
self.model_to_device()
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved

# setup optimizers after fully sharded has wrapped the lightning module
assert self.lightning_module
self.setup_optimizers(self.lightning_module.trainer)

def model_to_device(self) -> None:
log.detail(f"{self.__class__.__name__}: moving model to device [{self.root_device}]...")
# ensure we update the device type in the lightning module
assert self.lightning_module
self.lightning_module.to(self.root_device)

def training_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
with self.precision_plugin.train_step_context():
assert isinstance(self.model, TrainingStep)
return self.model.training_step(*args, **kwargs)
# we don't need precision context since casting is done by FSDP
# read `mixed_precision` docstring here: https://pytorch.org/docs/stable/fsdp.html
assert self.model is not None
return self.model(*args, **kwargs)

def validation_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:
with self.precision_plugin.val_step_context():
assert isinstance(self.model, ValidationStep)
return self.model.validation_step(*args, **kwargs)
assert self.model is not None
return self.model(*args, **kwargs)

def test_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:
with self.precision_plugin.test_step_context():
assert isinstance(self.model, TestStep)
return self.model.test_step(*args, **kwargs)
assert self.model is not None
return self.model(*args, **kwargs)

def predict_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
with self.precision_plugin.predict_step_context():
assert isinstance(self.model, PredictStep)
return self.model.predict_step(*args, **kwargs)
assert self.model is not None
return self.model(*args, **kwargs)

def post_training_step(self) -> None:
pass
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/strategies/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ def handles_gradient_accumulation(self) -> bool:
"""Whether the plugin handles gradient accumulation internally."""
return False

def lightning_module_state_dict(self) -> Dict[str, Union[Any, Tensor]]:
def lightning_module_state_dict(self) -> Dict[str, Any]:
"""Returns model state."""
assert self.lightning_module is not None
return self.lightning_module.state_dict()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,7 @@ def test_swa_resume_training_from_checkpoint_ddp(tmpdir):
[
pytest.param("fsdp", marks=RunIf(fairscale=True, min_cuda_gpus=1)),
pytest.param("deepspeed", marks=RunIf(deepspeed=True, min_cuda_gpus=1)),
pytest.param("fsdp_native", marks=RunIf(min_cuda_gpus=1, skip_windows=True, min_torch="1.12")),
],
)
def test_misconfiguration_error_with_sharded_model(tmpdir, strategy: str):
Expand Down
Loading