Skip to content

Commit

Permalink
Added a check to validate that wrapped FSDP models are used while ini…
Browse files Browse the repository at this point in the history
…tializing optimizers (#15301)



Co-authored-by: awaelchli <[email protected]>
  • Loading branch information
rohitgr7 and awaelchli authored Nov 8, 2022
1 parent 18f7f2d commit 0886e63
Show file tree
Hide file tree
Showing 8 changed files with 248 additions and 141 deletions.
6 changes: 6 additions & 0 deletions docs/source-pytorch/advanced/model_parallel.rst
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,12 @@ Model layers should be wrapped in FSDP in a nested way to save peak memory and e
simplest way to do it is auto wrapping, which can serve as a drop-in replacement for DDP without changing the rest of the code. You don't
have to ``wrap`` layers manually as in the case of manual wrapping.

.. note::
While initializing the optimizers inside ``configure_optimizers`` hook, make sure to use ``self.trainer.model.parameters()``, else
PyTorch will raise an error. This is required because when you use auto-wrap, the model layers are sharded and your
``lightning_module.parameters()`` will return a generator with no params. This inconvenience will be addressed in the future.


.. code-block:: python
model = BoringModel()
Expand Down
6 changes: 6 additions & 0 deletions src/lightning_lite/strategies/fairscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,3 +196,9 @@ def no_backward_sync(self, module: Module) -> Generator:
)
with module.no_sync():
yield None


def _optimizer_has_flat_params(optimizer: Optimizer) -> bool:
from fairscale.nn.misc.flatten_params_wrapper import FlatParameter

return any(isinstance(param, FlatParameter) for param in optimizer.param_groups[0]["params"])
20 changes: 20 additions & 0 deletions src/lightning_lite/strategies/fsdp_native.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from torch.optim import Optimizer


def _optimizer_has_flat_params(optimizer: Optimizer) -> bool:
from torch.distributed.fsdp import FlatParameter

return any(isinstance(param, FlatParameter) for param in optimizer.param_groups[0]["params"])
3 changes: 3 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
-


- Added a check to validate that wrapped FSDP models are used while initializing optimizers ([#15301](https://github.com/Lightning-AI/lightning/pull/15301))


### Changed

- From now on, Lightning Trainer and `LightningModule.load_from_checkpoint` automatically upgrade the loaded checkpoint if it was produced in an old version of Lightning ([#15237](https://github.com/Lightning-AI/lightning/pull/15237))
Expand Down
24 changes: 17 additions & 7 deletions src/pytorch_lightning/strategies/fully_sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import pytorch_lightning as pl
from lightning_lite.plugins import CheckpointIO, ClusterEnvironment
from lightning_lite.strategies.fairscale import _FAIRSCALE_AVAILABLE
from lightning_lite.strategies.fairscale import _FAIRSCALE_AVAILABLE, _optimizer_has_flat_params
from lightning_lite.utilities.enums import PrecisionType
from lightning_lite.utilities.optimizer import _optimizers_to_device
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
Expand All @@ -28,7 +28,6 @@
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.rank_zero import rank_zero_info
from pytorch_lightning.utilities.types import STEP_OUTPUT

if _FAIRSCALE_AVAILABLE:
Expand Down Expand Up @@ -191,16 +190,27 @@ def setup(self, trainer: "pl.Trainer") -> None:

self.setup_precision_plugin()

def setup_optimizers(self, trainer: "pl.Trainer") -> None:
invalid_params_error = False
try:
super().setup_optimizers(trainer)
except ValueError as e:
if "optimizer got an empty parameter list" not in str(e):
raise
invalid_params_error = True

if invalid_params_error or any(not _optimizer_has_flat_params(optimizer) for optimizer in self.optimizers):
raise ValueError(
"The optimizer does not seem to reference any FSDP parameters. HINT: Make sure to create the"
" optimizer after setting up the model by referencing `self.trainer.model.parameters()` in the"
" `configure_optimizers()` hook."
)

def _setup_model(self, model: torch.nn.Module) -> FullyShardedDataParallel:
"""Wraps the model into a
:class:`~fairscale.nn.data_parallel.fully_sharded_data_parallel.FullyShardedDataParallel` module."""
log.detail(f"setting up `Fairscale FSDP` model with device id: {self.root_device.index}.")

rank_zero_info(
"When using FairScale FSDP auto-wrap, make sure to initialize your model using trainer: "
"`torch.optim.Optimizer(self.trainer.model.parameters(), ...)`"
)

return FullyShardedDataParallel(
module=model,
process_group=self.process_group,
Expand Down
18 changes: 18 additions & 0 deletions src/pytorch_lightning/strategies/fully_sharded_native.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import pytorch_lightning as pl
from lightning_lite.plugins import CheckpointIO, ClusterEnvironment
from lightning_lite.strategies.fsdp_native import _optimizer_has_flat_params
from lightning_lite.utilities.distributed import (
_get_default_process_group_backend_for_device,
_init_dist_connection,
Expand Down Expand Up @@ -215,6 +216,7 @@ def _setup_model(self, model: torch.nn.Module) -> FullyShardedDataParallel:
del self.kwargs["auto_wrap_policy"]

log.detail(f"setting up FSDP model with device id: {self.root_device.index}, kwargs: {self.kwargs}")

return FullyShardedDataParallel(
module=model,
process_group=self.process_group,
Expand Down Expand Up @@ -255,6 +257,22 @@ def setup(self, trainer: "pl.Trainer") -> None:

self.setup_precision_plugin()

def setup_optimizers(self, trainer: "pl.Trainer") -> None:
invalid_params_error = False
try:
super().setup_optimizers(trainer)
except ValueError as e:
if "optimizer got an empty parameter list" not in str(e):
raise
invalid_params_error = True

if invalid_params_error or any(not _optimizer_has_flat_params(optimizer) for optimizer in self.optimizers):
raise ValueError(
"The optimizer does not seem to reference any FSDP parameters. HINT: Make sure to create the"
" optimizer after setting up the model by referencing `self.trainer.model.parameters()` in the"
" `configure_optimizers()` hook."
)

def model_to_device(self) -> None:
pass

Expand Down
158 changes: 90 additions & 68 deletions tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,46 +18,6 @@
from torch.distributed.fsdp.wrap import wrap


def custom_auto_wrap_policy(
module,
recurse,
unwrapped_params: int,
min_num_params: int = int(1e8),
) -> bool:
return unwrapped_params >= 2


@RunIf(min_torch="1.12")
def test_invalid_on_cpu(tmpdir):
"""Test to ensure that we raise Misconfiguration for Native FSDP on CPU."""
with pytest.raises(
MisconfigurationException,
match=f"You selected strategy to be `{DDPFullyShardedNativeStrategy.strategy_name}`, "
"but GPU accelerator is not used.",
):
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, strategy="fsdp_native")
assert isinstance(trainer.strategy, DDPFullyShardedNativeStrategy)
trainer.strategy.setup_environment()


@RunIf(min_torch="1.12", min_cuda_gpus=1)
@pytest.mark.parametrize("precision, expected", [(16, torch.float16), ("bf16", torch.bfloat16)])
def test_precision_plugin_config(precision, expected):
plugin = FullyShardedNativeNativeMixedPrecisionPlugin(precision=precision, device="cuda")
config = plugin.mixed_precision_config
assert config.param_dtype == expected
assert config.buffer_dtype == expected
assert config.reduce_dtype == expected


@RunIf(min_torch="1.12")
def test_fsdp_custom_mixed_precision(tmpdir):
"""Test to ensure that passing a custom mixed precision config works."""
config = MixedPrecision()
strategy = DDPFullyShardedNativeStrategy(mixed_precision=config)
assert strategy.mixed_precision_config == config


class TestFSDPModel(BoringModel):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -154,6 +114,80 @@ def _assert_layer_fsdp_instance(self) -> None:
assert self.layer[layer_num].mixed_precision.buffer_dtype == precision


def _run_multiple_stages(trainer, model, model_path: Optional[str] = None):
trainer.fit(model)
model_path = trainer.strategy.broadcast(model_path)
model_path = model_path if model_path else trainer.checkpoint_callback.last_model_path

trainer.save_checkpoint(model_path, weights_only=True)

_assert_save_equality(trainer, model_path, cls=model.__class__)

# Test entry point
trainer.test(model) # model is wrapped, will not call `configure_sharded_model`

# provide model path, will create a new unwrapped model and load and then call `configure_shared_model` to wrap
trainer.test(ckpt_path=model_path)

# Predict entry point
trainer.predict(model) # model is wrapped, will not call `configure_sharded_model`

# provide model path, will create a new unwrapped model and load and then call `configure_shared_model` to wrap
trainer.predict(ckpt_path=model_path)


def _assert_save_equality(trainer, ckpt_path, cls=TestFSDPModel):
# Use FullySharded to get the state dict for the sake of comparison
model_state_dict = trainer.strategy.lightning_module_state_dict()

if trainer.is_global_zero:
saved_model = cls.load_from_checkpoint(ckpt_path)

# Assert model parameters are identical after loading
for ddp_param, shard_param in zip(model_state_dict.values(), saved_model.state_dict().values()):
assert torch.equal(ddp_param.float().cpu(), shard_param)


def custom_auto_wrap_policy(
module,
recurse,
unwrapped_params: int,
min_num_params: int = int(1e8),
) -> bool:
return unwrapped_params >= 2


@RunIf(min_torch="1.12")
def test_invalid_on_cpu(tmpdir):
"""Test to ensure that we raise Misconfiguration for Native FSDP on CPU."""
with pytest.raises(
MisconfigurationException,
match=f"You selected strategy to be `{DDPFullyShardedNativeStrategy.strategy_name}`, "
"but GPU accelerator is not used.",
):
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, strategy="fsdp_native")
assert isinstance(trainer.strategy, DDPFullyShardedNativeStrategy)
trainer.strategy.setup_environment()


@RunIf(min_torch="1.12", min_cuda_gpus=1)
@pytest.mark.parametrize("precision, expected", [(16, torch.float16), ("bf16", torch.bfloat16)])
def test_precision_plugin_config(precision, expected):
plugin = FullyShardedNativeNativeMixedPrecisionPlugin(precision=precision, device="cuda")
config = plugin.mixed_precision_config
assert config.param_dtype == expected
assert config.buffer_dtype == expected
assert config.reduce_dtype == expected


@RunIf(min_torch="1.12")
def test_fsdp_custom_mixed_precision(tmpdir):
"""Test to ensure that passing a custom mixed precision config works."""
config = MixedPrecision()
strategy = DDPFullyShardedNativeStrategy(mixed_precision=config)
assert strategy.mixed_precision_config == config


@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="1.12")
def test_fully_sharded_native_strategy_sync_batchnorm(tmpdir):
"""Test to ensure that sync_batchnorm works when using fsdp_native and GPU, and all stages can be run."""
Expand Down Expand Up @@ -214,35 +248,23 @@ def test_fully_sharded_native_strategy_checkpoint_multi_gpus(tmpdir, model, stra
_run_multiple_stages(trainer, model)


def _run_multiple_stages(trainer, model, model_path: Optional[str] = None):
trainer.fit(model)
model_path = trainer.strategy.broadcast(model_path)
model_path = model_path if model_path else trainer.checkpoint_callback.last_model_path

trainer.save_checkpoint(model_path, weights_only=True)

_assert_save_equality(trainer, model_path, cls=model.__class__)

# Test entry point
trainer.test(model) # model is wrapped, will not call `configure_sharded_model`

# provide model path, will create a new unwrapped model and load and then call `configure_shared_model` to wrap
trainer.test(ckpt_path=model_path)
@RunIf(min_cuda_gpus=1, skip_windows=True, standalone=True, min_torch="1.12")
def test_invalid_parameters_in_optimizer(tmpdir):
trainer = Trainer(strategy="fsdp_native", accelerator="cuda", devices=1)

# Predict entry point
trainer.predict(model) # model is wrapped, will not call `configure_sharded_model`
class EmptyParametersModel(BoringModel):
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-2)

# provide model path, will create a new unwrapped model and load and then call `configure_shared_model` to wrap
trainer.predict(ckpt_path=model_path)
model = EmptyParametersModel()
with pytest.raises(ValueError, match="The optimizer does not seem to reference any FSDP parameters"):
trainer.fit(model)

class NoFlatParametersModel(BoringModel):
def configure_optimizers(self):
layer = torch.nn.Linear(4, 5)
return torch.optim.Adam(layer.parameters(), lr=1e-2)

def _assert_save_equality(trainer, ckpt_path, cls=TestFSDPModel):
# Use FullySharded to get the state dict for the sake of comparison
model_state_dict = trainer.strategy.lightning_module_state_dict()

if trainer.is_global_zero:
saved_model = cls.load_from_checkpoint(ckpt_path)

# Assert model parameters are identical after loading
for ddp_param, shard_param in zip(model_state_dict.values(), saved_model.state_dict().values()):
assert torch.equal(ddp_param.float().cpu(), shard_param)
model = NoFlatParametersModel()
with pytest.raises(ValueError, match="The optimizer does not seem to reference any FSDP parameters"):
trainer.fit(model)
Loading

0 comments on commit 0886e63

Please sign in to comment.