Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added a check to validate that wrapped FSDP models are used while initializing optimizers #15301

Merged
merged 20 commits into from
Nov 8, 2022
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/source-pytorch/advanced/model_parallel.rst
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,12 @@ Model layers should be wrapped in FSDP in a nested way to save peak memory and e
simplest way to do it is auto wrapping, which can serve as a drop-in replacement for DDP without changing the rest of the code. You don't
have to ``wrap`` layers manually as in the case of manual wrapping.

.. note::
While initializing the optimizers inside ``configure_optimizers`` hook, make sure to use ``self.trainer.model.parameters()``, else
PyTorch will raise an error. This is required because when you use auto-wrap, the model layers are sharded and your
``lightning_module.parameters()`` will return a generator with no params. This inconvenience will be addressed in the future.


.. code-block:: python

model = BoringModel()
Expand Down
6 changes: 6 additions & 0 deletions src/lightning_lite/strategies/fairscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,3 +196,9 @@ def no_backward_sync(self, module: Module) -> Generator:
)
with module.no_sync():
yield None


def _optimizer_has_flat_params(optimizer: Optimizer) -> bool:
from fairscale.nn.misc.flatten_params_wrapper import FlatParameter

return any(isinstance(param, FlatParameter) for param in optimizer.param_groups[0]["params"])
20 changes: 20 additions & 0 deletions src/lightning_lite/strategies/fsdp_native.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from torch.optim import Optimizer


def _optimizer_has_flat_params(optimizer: Optimizer) -> bool:
from torch.distributed.fsdp import FlatParameter

return any(isinstance(param, FlatParameter) for param in optimizer.param_groups[0]["params"])
3 changes: 3 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added an error message when attempting to launch processes with `python -i` and an interactive-incompatible strategy ([#15293](https://github.com/Lightning-AI/lightning/pull/15293))


- Added a check to validate that wrapped FSDP models are used while initializing optimizers ([#15319](https://github.com/Lightning-AI/lightning/pull/15319))
akihironitta marked this conversation as resolved.
Show resolved Hide resolved


### Changed

- The `NeptuneLogger` now uses `neptune.init_run` instead of the deprecated `neptune.init` to initialize a run ([#15393](https://github.com/Lightning-AI/lightning/pull/15393))
Expand Down
16 changes: 14 additions & 2 deletions src/pytorch_lightning/strategies/fully_sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@
# limitations under the License.
import contextlib
import logging
from typing import Any, Dict, Generator, List, Optional
from typing import Any, Dict, Generator, Iterable, List, Optional

import torch
from torch.optim import Optimizer

import pytorch_lightning as pl
from lightning_lite.plugins import CheckpointIO, ClusterEnvironment
from lightning_lite.strategies.fairscale import _FAIRSCALE_AVAILABLE
from lightning_lite.strategies.fairscale import _FAIRSCALE_AVAILABLE, _optimizer_has_flat_params
from lightning_lite.utilities.enums import PrecisionType
from lightning_lite.utilities.optimizer import _optimizers_to_device
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
Expand Down Expand Up @@ -186,6 +187,7 @@ def setup(self, trainer: "pl.Trainer") -> None:
if not is_overridden("configure_sharded_model", self.lightning_module):
self.model = self._setup_model(self.model)
self.setup_optimizers(self.lightning_module.trainer)
_validate_optimizers(self.optimizers)
_optimizers_to_device(self.optimizers, self.root_device)
self.barrier()

Expand Down Expand Up @@ -288,3 +290,13 @@ def register_strategies(cls, strategy_registry: Dict) -> None:
cls,
description=f"{cls.__class__.__name__}",
)


def _validate_optimizers(optimizers: Iterable[Optimizer]) -> None:
for optimizer in optimizers:
if not _optimizer_has_flat_params(optimizer):
raise ValueError(
"The optimizer does not seem to reference any FSDP parameters. HINT: Make sure to create the"
" optimizer after setting up the model by referencing `self.trainer.model.parameters()` in the"
" `configure_optimizers()` hook."
)
21 changes: 20 additions & 1 deletion src/pytorch_lightning/strategies/fully_sharded_native.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@
# limitations under the License.
import contextlib
import logging
from typing import Any, Dict, Generator, List, Optional, Union
from typing import Any, Dict, Generator, Iterable, List, Optional, Union

import torch
from torch import Tensor
from torch.optim import Optimizer

import pytorch_lightning as pl
from lightning_lite.plugins import CheckpointIO, ClusterEnvironment
from lightning_lite.strategies.fsdp_native import _optimizer_has_flat_params
from lightning_lite.utilities.distributed import (
_get_default_process_group_backend_for_device,
_init_dist_connection,
Expand Down Expand Up @@ -215,6 +217,12 @@ def _setup_model(self, model: torch.nn.Module) -> FullyShardedDataParallel:
del self.kwargs["auto_wrap_policy"]

log.detail(f"setting up FSDP model with device id: {self.root_device.index}, kwargs: {self.kwargs}")

rank_zero_info(
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
"When using PyTorch FSDP auto-wrap, make sure to initalize your model using trainer else"
" you will get an error.\ntorch.optim.Optimizer(self.trainer.model.parameters(), ...)"
)
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved

return FullyShardedDataParallel(
module=model,
process_group=self.process_group,
Expand Down Expand Up @@ -251,6 +259,7 @@ def setup(self, trainer: "pl.Trainer") -> None:
self.barrier()

self.setup_optimizers(trainer)
_validate_optimizers(self.optimizers)
_optimizers_to_device(self.optimizers, self.root_device)

self.setup_precision_plugin()
Expand Down Expand Up @@ -371,3 +380,13 @@ def register_strategies(cls, strategy_registry: Dict) -> None:
cpu_offload=CPUOffload(offload_params=True),
)
cls._registered_strategies.append("fsdp_native_full_shard_offload")


def _validate_optimizers(optimizers: Iterable[Optimizer]) -> None:
for optimizer in optimizers:
if not _optimizer_has_flat_params(optimizer):
raise ValueError(
"The optimizer does not seem to reference any FSDP parameters. HINT: Make sure to create the"
" optimizer after setting up the model by referencing `self.trainer.model.parameters()` in the"
" `configure_optimizers()` hook."
)
2 changes: 2 additions & 0 deletions tests/tests_pytorch/models/test_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,8 @@ def test_callbacks_references_fit_ckpt_path(tmpdir):
@RunIf(min_cuda_gpus=2)
def test_running_test_pretrained_model_distrib_dp(tmpdir):
"""Verify `test()` on pretrained model."""
seed_everything(7)

dm = ClassifDataModule()
model = CustomClassificationModelDP(lr=0.1)

Expand Down
154 changes: 84 additions & 70 deletions tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,46 +18,6 @@
from torch.distributed.fsdp.wrap import wrap


def custom_auto_wrap_policy(
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
module,
recurse,
unwrapped_params: int,
min_num_params: int = int(1e8),
) -> bool:
return unwrapped_params >= 2


@RunIf(min_torch="1.12")
def test_invalid_on_cpu(tmpdir):
"""Test to ensure that we raise Misconfiguration for Native FSDP on CPU."""
with pytest.raises(
MisconfigurationException,
match=f"You selected strategy to be `{DDPFullyShardedNativeStrategy.strategy_name}`, "
"but GPU accelerator is not used.",
):
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, strategy="fsdp_native")
assert isinstance(trainer.strategy, DDPFullyShardedNativeStrategy)
trainer.strategy.setup_environment()


@RunIf(min_torch="1.12", min_cuda_gpus=1)
@pytest.mark.parametrize("precision, expected", [(16, torch.float16), ("bf16", torch.bfloat16)])
def test_precision_plugin_config(precision, expected):
plugin = FullyShardedNativeNativeMixedPrecisionPlugin(precision=precision, device="cuda")
config = plugin.mixed_precision_config
assert config.param_dtype == expected
assert config.buffer_dtype == expected
assert config.reduce_dtype == expected


@RunIf(min_torch="1.12")
def test_fsdp_custom_mixed_precision(tmpdir):
"""Test to ensure that passing a custom mixed precision config works."""
config = MixedPrecision()
strategy = DDPFullyShardedNativeStrategy(mixed_precision=config)
assert strategy.mixed_precision_config == config


class TestFSDPModel(BoringModel):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -154,6 +114,80 @@ def _assert_layer_fsdp_instance(self) -> None:
assert self.layer[layer_num].mixed_precision.buffer_dtype == precision


def _run_multiple_stages(trainer, model, model_path: Optional[str] = None):
trainer.fit(model)
model_path = trainer.strategy.broadcast(model_path)
model_path = model_path if model_path else trainer.checkpoint_callback.last_model_path

trainer.save_checkpoint(model_path, weights_only=True)

_assert_save_equality(trainer, model_path, cls=model.__class__)

# Test entry point
trainer.test(model) # model is wrapped, will not call `configure_sharded_model`

# provide model path, will create a new unwrapped model and load and then call `configure_shared_model` to wrap
trainer.test(ckpt_path=model_path)

# Predict entry point
trainer.predict(model) # model is wrapped, will not call `configure_sharded_model`

# provide model path, will create a new unwrapped model and load and then call `configure_shared_model` to wrap
trainer.predict(ckpt_path=model_path)


def _assert_save_equality(trainer, ckpt_path, cls=TestFSDPModel):
# Use FullySharded to get the state dict for the sake of comparison
model_state_dict = trainer.strategy.lightning_module_state_dict()

if trainer.is_global_zero:
saved_model = cls.load_from_checkpoint(ckpt_path)

# Assert model parameters are identical after loading
for ddp_param, shard_param in zip(model_state_dict.values(), saved_model.state_dict().values()):
assert torch.equal(ddp_param.float().cpu(), shard_param)


def custom_auto_wrap_policy(
module,
recurse,
unwrapped_params: int,
min_num_params: int = int(1e8),
) -> bool:
return unwrapped_params >= 2


@RunIf(min_torch="1.12")
def test_invalid_on_cpu(tmpdir):
"""Test to ensure that we raise Misconfiguration for Native FSDP on CPU."""
with pytest.raises(
MisconfigurationException,
match=f"You selected strategy to be `{DDPFullyShardedNativeStrategy.strategy_name}`, "
"but GPU accelerator is not used.",
):
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, strategy="fsdp_native")
assert isinstance(trainer.strategy, DDPFullyShardedNativeStrategy)
trainer.strategy.setup_environment()


@RunIf(min_torch="1.12", min_cuda_gpus=1)
@pytest.mark.parametrize("precision, expected", [(16, torch.float16), ("bf16", torch.bfloat16)])
def test_precision_plugin_config(precision, expected):
plugin = FullyShardedNativeNativeMixedPrecisionPlugin(precision=precision, device="cuda")
config = plugin.mixed_precision_config
assert config.param_dtype == expected
assert config.buffer_dtype == expected
assert config.reduce_dtype == expected


@RunIf(min_torch="1.12")
def test_fsdp_custom_mixed_precision(tmpdir):
"""Test to ensure that passing a custom mixed precision config works."""
config = MixedPrecision()
strategy = DDPFullyShardedNativeStrategy(mixed_precision=config)
assert strategy.mixed_precision_config == config


@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="1.12")
def test_fully_sharded_native_strategy_sync_batchnorm(tmpdir):
"""Test to ensure that sync_batchnorm works when using fsdp_native and GPU, and all stages can be run."""
Expand Down Expand Up @@ -214,35 +248,15 @@ def test_fully_sharded_native_strategy_checkpoint_multi_gpus(tmpdir, model, stra
_run_multiple_stages(trainer, model)


def _run_multiple_stages(trainer, model, model_path: Optional[str] = None):
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
trainer.fit(model)
model_path = trainer.strategy.broadcast(model_path)
model_path = model_path if model_path else trainer.checkpoint_callback.last_model_path

trainer.save_checkpoint(model_path, weights_only=True)

_assert_save_equality(trainer, model_path, cls=model.__class__)

# Test entry point
trainer.test(model) # model is wrapped, will not call `configure_sharded_model`

# provide model path, will create a new unwrapped model and load and then call `configure_shared_model` to wrap
trainer.test(ckpt_path=model_path)

# Predict entry point
trainer.predict(model) # model is wrapped, will not call `configure_sharded_model`

# provide model path, will create a new unwrapped model and load and then call `configure_shared_model` to wrap
trainer.predict(ckpt_path=model_path)


def _assert_save_equality(trainer, ckpt_path, cls=TestFSDPModel):
# Use FullySharded to get the state dict for the sake of comparison
model_state_dict = trainer.strategy.lightning_module_state_dict()
@RunIf(min_cuda_gpus=1, skip_windows=True, standalone=True, min_torch="1.12")
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
def test_invalid_parameters_in_optimizer(tmpdir):
class CustomModel(BoringModel):
def configure_optimizers(self):
layer = torch.nn.Linear(4, 5)
return torch.optim.Adam(layer.parameters(), lr=1e-2)

if trainer.is_global_zero:
saved_model = cls.load_from_checkpoint(ckpt_path)
trainer = Trainer(strategy="fsdp_native", accelerator="gpu", devices=1)
model = CustomModel()

# Assert model parameters are identical after loading
for ddp_param, shard_param in zip(model_state_dict.values(), saved_model.state_dict().values()):
assert torch.equal(ddp_param.float().cpu(), shard_param)
with pytest.raises(ValueError, match="The optimizer does not seem to reference any FSDP parameters"):
trainer.fit(model)
Loading