Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FSDP (native) support for LightningLite #14967

Merged
merged 108 commits into from
Nov 21, 2022
Merged
Show file tree
Hide file tree
Changes from 102 commits
Commits
Show all changes
108 commits
Select commit Hold shift + click to select a range
57774ed
wip
awaelchli Oct 1, 2022
043783e
wip precision
awaelchli Oct 1, 2022
0caf973
fsdp
awaelchli Oct 1, 2022
a2130b9
fsdp support in lite
awaelchli Oct 1, 2022
80d24fe
typing fixes
awaelchli Oct 1, 2022
cc65718
imports
awaelchli Oct 1, 2022
9d91f86
import fixes
awaelchli Oct 1, 2022
b535621
fix test
awaelchli Oct 1, 2022
8e85f69
more tests
awaelchli Oct 2, 2022
5385a19
integration tests
awaelchli Oct 2, 2022
de24f12
debug
awaelchli Oct 2, 2022
1d39715
Merge branch 'lite/fsdp-debug' into lite/fsdp
awaelchli Oct 2, 2022
9051a13
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 2, 2022
782bd0a
fix autowrap policy
awaelchli Oct 2, 2022
5f9d1a1
Merge remote-tracking branch 'origin/lite/fsdp' into lite/fsdp
awaelchli Oct 2, 2022
34251a2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 2, 2022
c8eb2b1
debug
awaelchli Oct 2, 2022
27e0151
Merge remote-tracking branch 'origin/lite/fsdp' into lite/fsdp
awaelchli Oct 2, 2022
c9dd26e
debug
awaelchli Oct 2, 2022
4832c9e
simplify
awaelchli Oct 3, 2022
01e9b56
Merge branch 'master' into lite/fsdp
Borda Oct 5, 2022
70437da
support individual setup of model and optimizer
awaelchli Oct 19, 2022
50981a3
error messaging
awaelchli Oct 19, 2022
8e51e33
Merge branch 'master' into lite/fsdp-debug3
awaelchli Oct 19, 2022
c61a2b7
Merge branch 'lite/individual-setup' into lite/fsdp-debug3
awaelchli Oct 19, 2022
4eadd24
test
awaelchli Oct 19, 2022
dec4f9c
debug
awaelchli Oct 19, 2022
d5f1c9e
debug
awaelchli Oct 19, 2022
559187b
debug
awaelchli Oct 19, 2022
e286dd9
debug
awaelchli Oct 19, 2022
230dc03
Merge branch 'master' into lite/individual-setup
awaelchli Oct 20, 2022
b4613ec
wip
awaelchli Oct 20, 2022
281e26c
wip
awaelchli Oct 20, 2022
9d6971b
update structure
awaelchli Oct 20, 2022
20bc93b
tests
awaelchli Oct 21, 2022
7dc0aa5
error messages
awaelchli Oct 22, 2022
0190b50
test errors
awaelchli Oct 22, 2022
1e3e7b1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 22, 2022
472b605
mypy
awaelchli Oct 22, 2022
b1fe35b
Merge branch 'lite/individual-setup' of github.com:Lightning-AI/light…
awaelchli Oct 22, 2022
3630d05
add changelog
awaelchli Oct 22, 2022
3509abe
Merge branch 'lite/individual-setup' into lite/fsdp-debug3
awaelchli Oct 23, 2022
8d035fa
messaging
awaelchli Oct 23, 2022
fadb2b6
debug
awaelchli Oct 23, 2022
39e7c09
debug
awaelchli Oct 23, 2022
7bc9421
fix
awaelchli Oct 23, 2022
9d086f2
udpate test
awaelchli Oct 23, 2022
ac29054
remove done todo
awaelchli Oct 23, 2022
5e1f433
missing err message
awaelchli Oct 23, 2022
a900207
tests
awaelchli Oct 23, 2022
a41eef9
flake
awaelchli Oct 23, 2022
38fbd24
Merge branch 'master' into lite/individual-setup
awaelchli Oct 23, 2022
9cd50b4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 23, 2022
564ab17
docstrings
awaelchli Oct 23, 2022
11e7feb
Merge branch 'lite/individual-setup' of github.com:Lightning-AI/light…
awaelchli Oct 23, 2022
6afc0ae
doc fix
awaelchli Oct 23, 2022
8b999b1
support python < 3.10
awaelchli Oct 23, 2022
3036840
validation
awaelchli Oct 26, 2022
d02d71c
Merge branch 'lite/fsdp-debug3' into lite/fsdp
awaelchli Oct 26, 2022
f25674e
Merge branch 'lite/fsdp' into lite/fsdp-debug3
awaelchli Oct 26, 2022
4bb5d56
debug
awaelchli Oct 26, 2022
7f944cc
debug
awaelchli Oct 26, 2022
070b828
update
awaelchli Oct 26, 2022
a63b3c9
debug
awaelchli Oct 26, 2022
3e20de2
validate
awaelchli Oct 26, 2022
cf9b92f
revert
awaelchli Oct 26, 2022
2f54259
Merge branch 'lite/fsdp-debug3' into lite/fsdp
awaelchli Oct 26, 2022
47dae76
Merge branch 'lite/individual-setup' into lite/fsdp
awaelchli Oct 26, 2022
8601ad5
Merge branch 'master' into lite/individual-setup
awaelchli Oct 26, 2022
103096a
Merge branch 'lite/individual-setup' into lite/fsdp
awaelchli Oct 26, 2022
fc34be3
x
awaelchli Oct 26, 2022
3088273
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 26, 2022
d49ac0a
debug
awaelchli Oct 26, 2022
7bb7bb8
Merge remote-tracking branch 'origin/lite/fsdp' into lite/fsdp
awaelchli Oct 26, 2022
469cc2d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 26, 2022
c138716
debug
awaelchli Oct 26, 2022
23aa1ba
Merge remote-tracking branch 'origin/lite/fsdp' into lite/fsdp
awaelchli Oct 26, 2022
c8f5a67
debug
awaelchli Oct 26, 2022
304bac0
simplify
awaelchli Oct 26, 2022
28e64b7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 26, 2022
ef7fb0e
typo
awaelchli Oct 26, 2022
7126b1e
Merge remote-tracking branch 'origin/lite/fsdp' into lite/fsdp
awaelchli Oct 26, 2022
f44c9b0
Merge branch 'master' into lite/fsdp
awaelchli Nov 11, 2022
9bf3531
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 11, 2022
b96496f
Merge branch 'master' into lite/fsdp
awaelchli Nov 13, 2022
0750bfe
changelog
awaelchli Nov 13, 2022
b96e140
fix setup_module call
awaelchli Nov 13, 2022
2b3f618
fix
awaelchli Nov 13, 2022
a1eed5a
fix test
awaelchli Nov 13, 2022
612dc3d
update
awaelchli Nov 13, 2022
6b4fc35
Merge branch 'master' into lite/fsdp
awaelchli Nov 20, 2022
190c5d2
update
awaelchli Nov 20, 2022
3c67190
fix
awaelchli Nov 20, 2022
239a674
fix duplicate import
awaelchli Nov 20, 2022
ad770e6
add no_backward_sync for FSDP
awaelchli Nov 20, 2022
1c19b96
fix
awaelchli Nov 20, 2022
1d2fa56
fix literal import
awaelchli Nov 20, 2022
fac33ac
fix
awaelchli Nov 20, 2022
56f2109
manual wrap
awaelchli Nov 20, 2022
17bd714
avoid double wrap
awaelchli Nov 20, 2022
a97c56e
fix mypy
awaelchli Nov 20, 2022
fac22b4
revert original test
awaelchli Nov 20, 2022
ffaba94
skip import on torch <1.12
awaelchli Nov 20, 2022
133bb9c
torch compatibility
awaelchli Nov 20, 2022
07d4998
fix
awaelchli Nov 20, 2022
0ece17d
revert comments in pytorch tests
awaelchli Nov 21, 2022
6bb5ee6
Merge branch 'master' into lite/fsdp
awaelchli Nov 21, 2022
e47d3c8
fix copy-paste error in docstring
awaelchli Nov 21, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/lightning_lite/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `LightningLite.setup_module()` and `LightningLite.setup_optimizers()` to support strategies that need to set up the model before an optimizer can be created ([#15185](https://github.com/Lightning-AI/lightning/pull/15185))


- Added support for Fully Sharded Data Parallel (FSDP) training in Lightning Lite ([#14967](https://github.com/Lightning-AI/lightning/issues/14967))


### Changed

- The `LightningLite.run()` method is no longer abstract ([#14992](https://github.com/Lightning-AI/lightning/issues/14992))
Expand Down
15 changes: 13 additions & 2 deletions src/lightning_lite/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
TorchElasticEnvironment,
)
from lightning_lite.plugins.precision.double import DoublePrecision
from lightning_lite.plugins.precision.fsdp import FSDPPrecision
from lightning_lite.strategies import (
DDPShardedStrategy,
DDPSpawnShardedStrategy,
Expand All @@ -53,6 +54,7 @@
XLAStrategy,
)
from lightning_lite.strategies.ddp_spawn import _DDP_FORK_ALIASES
from lightning_lite.strategies.fsdp import _FSDP_ALIASES, FSDPStrategy
from lightning_lite.utilities import _StrategyType, rank_zero_info, rank_zero_warn
from lightning_lite.utilities.device_parser import _determine_root_gpu_device
from lightning_lite.utilities.imports import _IS_INTERACTIVE
Expand Down Expand Up @@ -417,6 +419,13 @@ def _check_strategy_and_fallback(self) -> None:
f"You selected `Lite(strategy='{strategy_flag}')` but process forking is not supported on this"
f" platform. We recommed `Lite(strategy='ddp_spawn')` instead."
)
if (
strategy_flag in _FSDP_ALIASES or isinstance(self._strategy_flag, FSDPStrategy)
) and self._accelerator_flag not in ("cuda", "gpu"):
raise ValueError(
"You selected the FSDP strategy but FSDP is only available on GPU. Set `Lite(accelerator='gpu', ...)`"
" to continue or select a different strategy."
)
if strategy_flag:
self._strategy_flag = strategy_flag

Expand Down Expand Up @@ -465,9 +474,11 @@ def _check_and_init_precision(self) -> Precision:
if self._precision_input == 16
else "Using bfloat16 Automatic Mixed Precision (AMP)"
)

device = "cpu" if self._accelerator_flag == "cpu" else "cuda"
return NativeMixedPrecision(self._precision_input, device)

if isinstance(self.strategy, FSDPStrategy):
return FSDPPrecision(precision=self._precision_input, device=device)
return NativeMixedPrecision(precision=self._precision_input, device=device)

raise RuntimeError("No precision set")

Expand Down
11 changes: 9 additions & 2 deletions src/lightning_lite/lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
DDPShardedStrategy,
DDPSpawnShardedStrategy,
DeepSpeedStrategy,
FSDPStrategy,
SingleDeviceStrategy,
Strategy,
XLAStrategy,
Expand Down Expand Up @@ -593,14 +594,20 @@ def _prepare_run_method(self) -> None:
# wrap the run method, so we can inject setup logic or spawn processes for the user
setattr(self, "run", partial(self._run_impl, self.run))

@staticmethod
def _validate_setup(module: nn.Module, optimizers: Sequence[Optimizer]) -> None:
def _validate_setup(self, module: nn.Module, optimizers: Sequence[Optimizer]) -> None:
if isinstance(module, _LiteModule):
raise ValueError("A model should be passed only once to the `setup` method.")

if any(isinstance(opt, _LiteOptimizer) for opt in optimizers):
raise ValueError("An optimizer should be passed only once to the `setup` method.")

if isinstance(self._strategy, FSDPStrategy):
raise RuntimeError(
f"The `{type(self).__name__}` requires the model and optimizer(s) to be set up separately."
" Create and set up the model first through `model = self.setup_model(model)`. Then create the"
" optimizer and set it up: `optimizer = self.setup_optimizer(optimizer)`."
)

def _validate_setup_module(self, module: nn.Module) -> None:
if isinstance(module, _LiteModule):
raise ValueError("A model should be passed only once to the `setup_module` method.")
Expand Down
2 changes: 2 additions & 0 deletions src/lightning_lite/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from lightning_lite.plugins.io.xla import XLACheckpointIO
from lightning_lite.plugins.precision.deepspeed import DeepSpeedPrecision
from lightning_lite.plugins.precision.double import DoublePrecision
from lightning_lite.plugins.precision.fsdp import FSDPPrecision
from lightning_lite.plugins.precision.native_amp import NativeMixedPrecision
from lightning_lite.plugins.precision.precision import Precision
from lightning_lite.plugins.precision.tpu import TPUPrecision
Expand All @@ -33,4 +34,5 @@
"NativeMixedPrecision",
"TPUPrecision",
"TPUBf16Precision",
"FSDPPrecision",
]
2 changes: 2 additions & 0 deletions src/lightning_lite/plugins/precision/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
from lightning_lite.plugins.precision.deepspeed import DeepSpeedPrecision
from lightning_lite.plugins.precision.double import DoublePrecision
from lightning_lite.plugins.precision.fsdp import FSDPPrecision
from lightning_lite.plugins.precision.native_amp import NativeMixedPrecision
from lightning_lite.plugins.precision.precision import Precision
from lightning_lite.plugins.precision.tpu import TPUPrecision
Expand All @@ -25,4 +26,5 @@
"Precision",
"TPUPrecision",
"TPUBf16Precision",
"FSDPPrecision",
]
59 changes: 59 additions & 0 deletions src/lightning_lite/plugins/precision/fsdp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, TYPE_CHECKING

import torch
from typing_extensions import Literal

from lightning_lite.plugins.precision.native_amp import NativeMixedPrecision
from lightning_lite.utilities.enums import PrecisionType
from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_12

if TYPE_CHECKING:
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler


class FSDPPrecision(NativeMixedPrecision):
"""AMP for Fully Sharded Data Parallel training."""

def __init__(
self, precision: Literal[16, "bf16"], device: str, scaler: Optional["ShardedGradScaler"] = None
) -> None:
if not _TORCH_GREATER_EQUAL_1_12:
raise NotImplementedError("`FSDPPrecision` is supported from PyTorch v1.12.0 onwards.")

from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler

super().__init__(
precision=precision,
device=device,
scaler=(ShardedGradScaler() if scaler is None and precision == 16 else None),
)

@property
def mixed_precision_config(self) -> "MixedPrecision":
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision

if self.precision == PrecisionType.HALF:
dtype = torch.float16
elif self.precision == PrecisionType.BFLOAT:
dtype = torch.bfloat16
else:
raise ValueError(f"Was unable to infer precision type, received {self.precision!r}.")
return MixedPrecision(
param_dtype=dtype,
reduce_dtype=dtype,
buffer_dtype=dtype,
)
1 change: 1 addition & 0 deletions src/lightning_lite/strategies/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from lightning_lite.strategies.dp import DataParallelStrategy # noqa: F401
from lightning_lite.strategies.fairscale import DDPShardedStrategy # noqa: F401
from lightning_lite.strategies.fairscale import DDPSpawnShardedStrategy # noqa: F401
from lightning_lite.strategies.fsdp import FSDPStrategy # noqa: F401
from lightning_lite.strategies.parallel import ParallelStrategy # noqa: F401
from lightning_lite.strategies.registry import _call_register_strategies, _StrategyRegistry
from lightning_lite.strategies.single_device import SingleDeviceStrategy # noqa: F401
Expand Down
3 changes: 1 addition & 2 deletions src/lightning_lite/strategies/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,7 @@ def num_processes(self) -> int:

@property
def distributed_sampler_kwargs(self) -> Dict[str, Any]:
distributed_sampler_kwargs = dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank)
return distributed_sampler_kwargs
return dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank)

@property
def process_group_backend(self) -> Optional[str]:
Expand Down
3 changes: 1 addition & 2 deletions src/lightning_lite/strategies/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,7 @@ def num_processes(self) -> int:

@property
def distributed_sampler_kwargs(self) -> Dict[str, int]:
distributed_sampler_kwargs = dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank)
return distributed_sampler_kwargs
return dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank)

@property
def process_group_backend(self) -> Optional[str]:
Expand Down
3 changes: 1 addition & 2 deletions src/lightning_lite/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,8 +297,7 @@ def zero_stage_3(self) -> bool:

@property
def distributed_sampler_kwargs(self) -> Dict[str, int]:
distributed_sampler_kwargs = dict(num_replicas=self.world_size, rank=self.global_rank)
return distributed_sampler_kwargs
return dict(num_replicas=self.world_size, rank=self.global_rank)

@property
def model(self) -> "deepspeed.DeepSpeedEngine":
Expand Down
Loading