Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Standalone Lite: DDP Spawn Strategy Family #14675

Merged
merged 70 commits into from
Sep 15, 2022
Merged
Show file tree
Hide file tree
Changes from 57 commits
Commits
Show all changes
70 commits
Select commit Hold shift + click to select a range
fe59302
add accelerator implementations to lite
awaelchli Sep 7, 2022
7271f94
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 7, 2022
b6de11f
fix imports
awaelchli Sep 7, 2022
2ef04e6
rename registry argument
awaelchli Sep 7, 2022
9bbaf4f
fix test
awaelchli Sep 7, 2022
48bc1e8
fix tests
awaelchli Sep 7, 2022
0cf9651
Merge branch 'master' into lite/accelerators3
awaelchli Sep 7, 2022
dc09055
remove duplicated test
awaelchli Sep 7, 2022
6a14975
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 7, 2022
e6d619c
fix tests
awaelchli Sep 7, 2022
9055717
deprecation
awaelchli Sep 7, 2022
f016626
deprecations
awaelchli Sep 7, 2022
084bc6f
flake8
awaelchli Sep 7, 2022
9c19b48
fixes
awaelchli Sep 8, 2022
3d09dac
add mps to runif
awaelchli Sep 8, 2022
7a5a740
fix tests
awaelchli Sep 8, 2022
de78087
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 8, 2022
48ef646
Apply suggestions from code review
awaelchli Sep 9, 2022
6d60b96
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 9, 2022
4e018c4
remove more
awaelchli Sep 9, 2022
983a6d7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 9, 2022
2220350
local import
awaelchli Sep 9, 2022
cfce27e
Merge remote-tracking branch 'origin/lite/accelerators' into lite/acc…
awaelchli Sep 9, 2022
4ba5809
undo device stats :(
awaelchli Sep 9, 2022
231d8c3
fix import
awaelchli Sep 9, 2022
6e1f03a
stupid typehints
awaelchli Sep 9, 2022
1505eb4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 9, 2022
334e3cf
Merge branch 'master' into lite/accelerators
Borda Sep 9, 2022
e832e67
more refactors :(
awaelchli Sep 9, 2022
a90ef22
Merge remote-tracking branch 'origin/lite/accelerators' into lite/acc…
awaelchli Sep 9, 2022
8bf889b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 9, 2022
1195cec
fix
awaelchli Sep 11, 2022
f4dd9a5
Merge branch 'master' into lite/accelerators3
awaelchli Sep 12, 2022
c1f029e
rename init_device to setup_device
awaelchli Sep 12, 2022
4cc08fe
remove unused import
awaelchli Sep 12, 2022
9b8572d
make uppercase to differentiate from class
awaelchli Sep 12, 2022
06bf069
trick test after moving import locally
awaelchli Sep 12, 2022
6dbe465
Merge branch 'lite/accelerators3' into lite/strategy-base
awaelchli Sep 12, 2022
be60f9a
add base classes and registry
awaelchli Sep 12, 2022
f325117
reg
awaelchli Sep 12, 2022
6a8812d
registry
awaelchli Sep 12, 2022
90466e0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 12, 2022
b829e90
Merge branch 'master' into lite/strategy-base
awaelchli Sep 12, 2022
81d1b1a
tests
awaelchli Sep 12, 2022
b8de59f
update to other branches
awaelchli Sep 12, 2022
2dcabd6
resolve todo(lite)
awaelchli Sep 12, 2022
60a7479
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 12, 2022
9479bad
add very basic unit tests
awaelchli Sep 12, 2022
ab62924
fix name assignment
awaelchli Sep 12, 2022
a2b00bd
Merge branch 'lite/strategy-base' into lite/strategies-spawn
awaelchli Sep 12, 2022
f917cb4
add spawn family
awaelchli Sep 12, 2022
e12f142
merge
awaelchli Sep 12, 2022
20808b8
tests
awaelchli Sep 12, 2022
ae55eff
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 12, 2022
690a87f
fixes
awaelchli Sep 13, 2022
dbf0730
Merge branch 'master' into lite/strategies-spawn
awaelchli Sep 14, 2022
cbd26d5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 14, 2022
fb2522b
update
awaelchli Sep 14, 2022
07a2956
remove deprecated pg backend logic
awaelchli Sep 15, 2022
3685ba1
wip
awaelchli Sep 15, 2022
d4f3a54
integrate changes from #11073
awaelchli Sep 15, 2022
7d4f5b6
rename TPUSpawnStrategy to XLAStrategy
awaelchli Sep 15, 2022
2fd9d73
add back missing method
awaelchli Sep 15, 2022
5236eef
Merge branch 'master' into lite/strategies-spawn
awaelchli Sep 15, 2022
b5dd25d
isort
awaelchli Sep 15, 2022
0b473a6
Merge branch 'master' into lite/strategies-spawn
awaelchli Sep 15, 2022
78c96b1
import
awaelchli Sep 15, 2022
d7e5db9
Apply suggestions from code review
awaelchli Sep 15, 2022
04f3f78
Update src/lightning_lite/strategies/ddp_spawn.py
awaelchli Sep 15, 2022
185dd64
made sharded implementations identical
awaelchli Sep 15, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/lightning_lite/strategies/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from lightning_lite.strategies.ddp_spawn import DDPSpawnStrategy # noqa: F401
from lightning_lite.strategies.fairscale import DDPSpawnShardedStrategy # noqa: F401
from lightning_lite.strategies.parallel import ParallelStrategy # noqa: F401
from lightning_lite.strategies.registry import _call_register_strategies, _StrategyRegistry
from lightning_lite.strategies.single_device import SingleDeviceStrategy # noqa: F401
from lightning_lite.strategies.strategy import Strategy # noqa: F401
from lightning_lite.strategies.tpu_spawn import TPUSpawnStrategy # noqa: F401

STRATEGY_REGISTRY = _StrategyRegistry()
_STRATEGIES_BASE_MODULE = "lightning_lite.strategies"
Expand Down
218 changes: 218 additions & 0 deletions src/lightning_lite/strategies/ddp_spawn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from datetime import timedelta
from typing import Any, Dict, List, Optional, Union

import torch
import torch.distributed
from torch import Tensor
from torch.distributed.constants import default_pg_timeout
from torch.nn import Module
from torch.nn.parallel.distributed import DistributedDataParallel
from typing_extensions import Literal

from lightning_lite.accelerators.accelerator import Accelerator
from lightning_lite.plugins.environments.cluster_environment import ClusterEnvironment
from lightning_lite.plugins.io.checkpoint_plugin import CheckpointIO
from lightning_lite.plugins.precision import Precision
from lightning_lite.strategies.launchers.multiprocessing import _MultiProcessingLauncher
from lightning_lite.strategies.parallel import ParallelStrategy
from lightning_lite.strategies.strategy import TBroadcast
from lightning_lite.utilities.distributed import (
_get_process_group_backend_from_env,
distributed_available,
get_default_process_group_backend_for_device,
)
from lightning_lite.utilities.distributed import group as _group
from lightning_lite.utilities.distributed import init_dist_connection, ReduceOp, sync_ddp_if_available
from lightning_lite.utilities.rank_zero import rank_zero_only

_DDP_FORK_ALIASES = (
carmocca marked this conversation as resolved.
Show resolved Hide resolved
"ddp_fork",
"ddp_fork_find_unused_parameters_false",
"ddp_notebook",
"ddp_notebook_find_unused_parameters_false",
)


class DDPSpawnStrategy(ParallelStrategy):
"""Spawns processes using the :func:`torch.multiprocessing.spawn` method and joins processes after training
finishes."""

def __init__(
self,
accelerator: Optional[Accelerator] = None,
parallel_devices: Optional[List[torch.device]] = None,
cluster_environment: Optional[ClusterEnvironment] = None,
checkpoint_io: Optional[CheckpointIO] = None,
precision_plugin: Optional[Precision] = None,
process_group_backend: Optional[str] = None,
timeout: Optional[timedelta] = default_pg_timeout,
start_method: Literal["spawn", "fork", "forkserver"] = "spawn",
**kwargs: Any,
):
super().__init__(
accelerator=accelerator,
parallel_devices=parallel_devices,
cluster_environment=cluster_environment,
checkpoint_io=checkpoint_io,
precision_plugin=precision_plugin,
)
self._num_nodes = 1
self._process_group_backend: Optional[str] = process_group_backend
self._timeout: Optional[timedelta] = timeout
self._start_method = start_method
self._ddp_kwargs = kwargs
self._local_rank = 0

@property
def root_device(self) -> torch.device:
assert self.parallel_devices is not None
return self.parallel_devices[self.local_rank]

@property
def num_nodes(self) -> int:
return self._num_nodes

@num_nodes.setter
def num_nodes(self, num_nodes: int) -> None:
# note that world ranks is related to num_nodes, when resetting it, need to reset world ranks
self._num_nodes = num_nodes

@property
def num_processes(self) -> int:
return len(self.parallel_devices) if self.parallel_devices is not None else 0

@property
def distributed_sampler_kwargs(self) -> Dict[str, int]:
distributed_sampler_kwargs = dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank)
return distributed_sampler_kwargs

@property
def process_group_backend(self) -> Optional[str]:
return self._process_group_backend

@property
def local_rank(self) -> int:
return self._local_rank

def _configure_launcher(self) -> None:
self._launcher = _MultiProcessingLauncher(self, start_method=self._start_method)

def setup_module(self, module: Module) -> Module:
return DistributedDataParallel(module=module, device_ids=self._determine_ddp_device_ids(), **self._ddp_kwargs)

def module_to_device(self, module: Module) -> None:
if self.root_device.type == "cuda":
# TODO(lite): This should be handled outside module_to_device, by a call to accelertor.setup_device()
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
# set the device on the spawned subprocesses
torch.cuda.set_device(self.root_device)
module.to(self.root_device)

def reduce(
self, tensor: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"
) -> Tensor:
"""Reduces a tensor from several distributed processes to one aggregated tensor.

Args:
tensor: the tensor to sync and reduce
group: the process group to gather results from. Defaults to all processes (world)
reduce_op: the reduction operation. Defaults to 'mean'/'avg'.
Can also be a string 'sum' to calculate the sum during reduction.

Return:
reduced value, except when the input was not a tensor the output remains is unchanged
"""
if isinstance(tensor, Tensor):
tensor = sync_ddp_if_available(tensor, group, reduce_op=reduce_op)
return tensor

def barrier(self, *args: Any, **kwargs: Any) -> None:
if not distributed_available():
return
if torch.distributed.get_backend() == "nccl":
torch.distributed.barrier(device_ids=self._determine_ddp_device_ids())
else:
torch.distributed.barrier()

def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
if not distributed_available():
return obj
obj = [obj]
if self.global_rank != src:
obj = [None] # type: ignore[list-item]
torch.distributed.broadcast_object_list(obj, src, group=_group.WORLD)
return obj[0]

@classmethod
def register_strategies(cls, strategy_registry: Dict) -> None:
entries = (
("ddp_spawn", "spawn"),
("ddp_fork", "fork"),
("ddp_notebook", "fork"),
)
for name, start_method in entries:
strategy_registry.register(
name,
cls,
description=f"DDP strategy with `start_method` '{start_method}'",
start_method=start_method,
)

entries = (
("ddp_spawn_find_unused_parameters_false", "spawn"),
("ddp_fork_find_unused_parameters_false", "fork"),
("ddp_notebook_find_unused_parameters_false", "fork"),
)
for name, start_method in entries:
strategy_registry.register(
name,
cls,
description=f"DDP strategy with `find_unused_parameters` as False and `start_method` '{start_method}'",
find_unused_parameters=False,
start_method=start_method,
)

def set_world_ranks(self, process_idx: int = 0) -> None:
self._local_rank = process_idx
if self.cluster_environment is None:
return
self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank)
self.cluster_environment.set_world_size(self.num_nodes * self.num_processes)
rank_zero_only.rank = self.cluster_environment.global_rank()

def _worker_setup(self, process_idx: int) -> None:
self.set_world_ranks(process_idx)
rank_zero_only.rank = self.global_rank
self._process_group_backend = self._get_process_group_backend()
assert self.cluster_environment is not None
init_dist_connection(
self.cluster_environment,
self._process_group_backend,
self.global_rank,
self.world_size,
timeout=self._timeout,
)

def _get_process_group_backend(self) -> str:
return (
self._process_group_backend
or _get_process_group_backend_from_env()
carmocca marked this conversation as resolved.
Show resolved Hide resolved
or get_default_process_group_backend_for_device(self.root_device)
)

def _determine_ddp_device_ids(self) -> Optional[List[int]]:
if self.root_device.type == "cpu":
return None
return [self.root_device.index]
84 changes: 84 additions & 0 deletions src/lightning_lite/strategies/fairscale.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from typing import Dict, Generator, List, Tuple

from lightning_utilities.core.imports import module_available
from torch.nn import Module
from torch.optim import Optimizer

from lightning_lite.strategies.ddp_spawn import DDPSpawnStrategy
from lightning_lite.utilities.imports import _IS_WINDOWS

_FAIRSCALE_AVAILABLE = not _IS_WINDOWS and module_available("fairscale.nn")

if _FAIRSCALE_AVAILABLE:
from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel
from fairscale.optim import OSS

else:
OSS = ShardedDataParallel = object


class DDPSpawnShardedStrategy(DDPSpawnStrategy):
"""Optimizer sharded training provided by FairScale."""
awaelchli marked this conversation as resolved.
Show resolved Hide resolved

def setup_module_and_optimizers(
self, module: Module, optimizers: List[Optimizer]
) -> Tuple[Module, List[Optimizer]]:
"""Wraps the model and optimizers with fairscale components.

Return:
The model wrapped into a :class:`~fairscale.nn.data_parallel.ShardedDataParallel` module
and a list of optimizer wrapped in :class:~`fairscale.optim.OSS`.
"""
optimizers = self._reinit_optimizers_with_oss(optimizers)
model = ShardedDataParallel(module, sharded_optimizer=optimizers, **self._ddp_kwargs)
return model, optimizers

@contextmanager
def block_backward_sync(self, module: Module) -> Generator:
"""Blocks syncing gradients behaviour on backwards pass.

This is useful for skipping sync when accumulating gradients, reducing communication overhead
Returns: context manager with sync behaviour off
"""
if isinstance(module, ShardedDataParallel):
with module.no_sync():
yield None
else:
yield None

@classmethod
def register_strategies(cls, strategy_registry: Dict) -> None:
strategy_registry.register(
"ddp_sharded_spawn_find_unused_parameters_false",
cls,
description="DDP Spawn Sharded Strategy with `find_unused_parameters` as False",
find_unused_parameters=False,
)
strategy_registry.register(
"ddp_sharded_spawn",
cls,
description=f"{cls.__class__.__name__}",
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
)

def _reinit_optimizers_with_oss(self, optimizers: List[Optimizer]) -> List["OSS"]:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
for x, optimizer in enumerate(optimizers):
if not isinstance(optimizer, OSS):
optim_class = type(optimizer)
zero_optimizer = OSS(params=optimizer.param_groups, optim=optim_class, **optimizer.defaults)
optimizers[x] = zero_optimizer
del optimizer
return optimizers
Loading