Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Standalone Lite: DDP Spawn Strategy Family #14675

Merged
merged 70 commits into from
Sep 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
70 commits
Select commit Hold shift + click to select a range
fe59302
add accelerator implementations to lite
awaelchli Sep 7, 2022
7271f94
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 7, 2022
b6de11f
fix imports
awaelchli Sep 7, 2022
2ef04e6
rename registry argument
awaelchli Sep 7, 2022
9bbaf4f
fix test
awaelchli Sep 7, 2022
48bc1e8
fix tests
awaelchli Sep 7, 2022
0cf9651
Merge branch 'master' into lite/accelerators3
awaelchli Sep 7, 2022
dc09055
remove duplicated test
awaelchli Sep 7, 2022
6a14975
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 7, 2022
e6d619c
fix tests
awaelchli Sep 7, 2022
9055717
deprecation
awaelchli Sep 7, 2022
f016626
deprecations
awaelchli Sep 7, 2022
084bc6f
flake8
awaelchli Sep 7, 2022
9c19b48
fixes
awaelchli Sep 8, 2022
3d09dac
add mps to runif
awaelchli Sep 8, 2022
7a5a740
fix tests
awaelchli Sep 8, 2022
de78087
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 8, 2022
48ef646
Apply suggestions from code review
awaelchli Sep 9, 2022
6d60b96
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 9, 2022
4e018c4
remove more
awaelchli Sep 9, 2022
983a6d7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 9, 2022
2220350
local import
awaelchli Sep 9, 2022
cfce27e
Merge remote-tracking branch 'origin/lite/accelerators' into lite/acc…
awaelchli Sep 9, 2022
4ba5809
undo device stats :(
awaelchli Sep 9, 2022
231d8c3
fix import
awaelchli Sep 9, 2022
6e1f03a
stupid typehints
awaelchli Sep 9, 2022
1505eb4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 9, 2022
334e3cf
Merge branch 'master' into lite/accelerators
Borda Sep 9, 2022
e832e67
more refactors :(
awaelchli Sep 9, 2022
a90ef22
Merge remote-tracking branch 'origin/lite/accelerators' into lite/acc…
awaelchli Sep 9, 2022
8bf889b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 9, 2022
1195cec
fix
awaelchli Sep 11, 2022
f4dd9a5
Merge branch 'master' into lite/accelerators3
awaelchli Sep 12, 2022
c1f029e
rename init_device to setup_device
awaelchli Sep 12, 2022
4cc08fe
remove unused import
awaelchli Sep 12, 2022
9b8572d
make uppercase to differentiate from class
awaelchli Sep 12, 2022
06bf069
trick test after moving import locally
awaelchli Sep 12, 2022
6dbe465
Merge branch 'lite/accelerators3' into lite/strategy-base
awaelchli Sep 12, 2022
be60f9a
add base classes and registry
awaelchli Sep 12, 2022
f325117
reg
awaelchli Sep 12, 2022
6a8812d
registry
awaelchli Sep 12, 2022
90466e0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 12, 2022
b829e90
Merge branch 'master' into lite/strategy-base
awaelchli Sep 12, 2022
81d1b1a
tests
awaelchli Sep 12, 2022
b8de59f
update to other branches
awaelchli Sep 12, 2022
2dcabd6
resolve todo(lite)
awaelchli Sep 12, 2022
60a7479
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 12, 2022
9479bad
add very basic unit tests
awaelchli Sep 12, 2022
ab62924
fix name assignment
awaelchli Sep 12, 2022
a2b00bd
Merge branch 'lite/strategy-base' into lite/strategies-spawn
awaelchli Sep 12, 2022
f917cb4
add spawn family
awaelchli Sep 12, 2022
e12f142
merge
awaelchli Sep 12, 2022
20808b8
tests
awaelchli Sep 12, 2022
ae55eff
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 12, 2022
690a87f
fixes
awaelchli Sep 13, 2022
dbf0730
Merge branch 'master' into lite/strategies-spawn
awaelchli Sep 14, 2022
cbd26d5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 14, 2022
fb2522b
update
awaelchli Sep 14, 2022
07a2956
remove deprecated pg backend logic
awaelchli Sep 15, 2022
3685ba1
wip
awaelchli Sep 15, 2022
d4f3a54
integrate changes from #11073
awaelchli Sep 15, 2022
7d4f5b6
rename TPUSpawnStrategy to XLAStrategy
awaelchli Sep 15, 2022
2fd9d73
add back missing method
awaelchli Sep 15, 2022
5236eef
Merge branch 'master' into lite/strategies-spawn
awaelchli Sep 15, 2022
b5dd25d
isort
awaelchli Sep 15, 2022
0b473a6
Merge branch 'master' into lite/strategies-spawn
awaelchli Sep 15, 2022
78c96b1
import
awaelchli Sep 15, 2022
d7e5db9
Apply suggestions from code review
awaelchli Sep 15, 2022
04f3f78
Update src/lightning_lite/strategies/ddp_spawn.py
awaelchli Sep 15, 2022
185dd64
made sharded implementations identical
awaelchli Sep 15, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/lightning_lite/strategies/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from lightning_lite.strategies.ddp import DDPStrategy # noqa: F401
from lightning_lite.strategies.ddp_spawn import DDPSpawnStrategy # noqa: F401
from lightning_lite.strategies.deepspeed import DeepSpeedStrategy # noqa: F401
from lightning_lite.strategies.dp import DataParallelStrategy # noqa: F401
from lightning_lite.strategies.fairscale import DDPShardedStrategy # noqa: F401
from lightning_lite.strategies.fairscale import DDPSpawnShardedStrategy # noqa: F401
from lightning_lite.strategies.parallel import ParallelStrategy # noqa: F401
from lightning_lite.strategies.registry import _call_register_strategies, _StrategyRegistry
from lightning_lite.strategies.single_device import SingleDeviceStrategy # noqa: F401
from lightning_lite.strategies.single_tpu import SingleTPUStrategy # noqa: F401
from lightning_lite.strategies.strategy import Strategy # noqa: F401
from lightning_lite.strategies.xla import XLAStrategy # noqa: F401

STRATEGY_REGISTRY = _StrategyRegistry()
_STRATEGIES_BASE_MODULE = "lightning_lite.strategies"
Expand Down
214 changes: 214 additions & 0 deletions src/lightning_lite/strategies/ddp_spawn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from datetime import timedelta
from typing import Any, Dict, List, Optional, Union

import torch
import torch.distributed
from torch import Tensor
from torch.distributed.constants import default_pg_timeout
from torch.nn import Module
from torch.nn.parallel.distributed import DistributedDataParallel
from typing_extensions import Literal

from lightning_lite.accelerators.accelerator import Accelerator
from lightning_lite.plugins.environments.cluster_environment import ClusterEnvironment
from lightning_lite.plugins.io.checkpoint_plugin import CheckpointIO
from lightning_lite.plugins.precision import Precision
from lightning_lite.strategies.launchers.multiprocessing import _MultiProcessingLauncher
from lightning_lite.strategies.parallel import ParallelStrategy
from lightning_lite.strategies.strategy import TBroadcast
from lightning_lite.utilities.distributed import distributed_available, get_default_process_group_backend_for_device
from lightning_lite.utilities.distributed import group as _group
from lightning_lite.utilities.distributed import init_dist_connection, ReduceOp, sync_ddp_if_available
from lightning_lite.utilities.rank_zero import rank_zero_only

_DDP_FORK_ALIASES = (
carmocca marked this conversation as resolved.
Show resolved Hide resolved
"ddp_fork",
"ddp_fork_find_unused_parameters_false",
"ddp_notebook",
"ddp_notebook_find_unused_parameters_false",
)


class DDPSpawnStrategy(ParallelStrategy):
"""Spawns processes using the :func:`torch.multiprocessing.spawn` method and joins processes after training
finishes."""

def __init__(
self,
accelerator: Optional[Accelerator] = None,
parallel_devices: Optional[List[torch.device]] = None,
cluster_environment: Optional[ClusterEnvironment] = None,
checkpoint_io: Optional[CheckpointIO] = None,
precision_plugin: Optional[Precision] = None,
process_group_backend: Optional[str] = None,
timeout: Optional[timedelta] = default_pg_timeout,
start_method: Literal["spawn", "fork", "forkserver"] = "spawn",
**kwargs: Any,
):
super().__init__(
accelerator=accelerator,
parallel_devices=parallel_devices,
cluster_environment=cluster_environment,
checkpoint_io=checkpoint_io,
precision_plugin=precision_plugin,
)
self._num_nodes = 1
self._process_group_backend: Optional[str] = process_group_backend
self._timeout: Optional[timedelta] = timeout
self._start_method = start_method
self._ddp_kwargs = kwargs
self._local_rank = 0

@property
def root_device(self) -> torch.device:
assert self.parallel_devices is not None
return self.parallel_devices[self.local_rank]

@property
def num_nodes(self) -> int:
return self._num_nodes

@num_nodes.setter
def num_nodes(self, num_nodes: int) -> None:
# note that world ranks is related to num_nodes, when resetting it, need to reset world ranks
self._num_nodes = num_nodes

@property
def num_processes(self) -> int:
return len(self.parallel_devices) if self.parallel_devices is not None else 0

@property
def distributed_sampler_kwargs(self) -> Dict[str, int]:
distributed_sampler_kwargs = dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank)
return distributed_sampler_kwargs

@property
def process_group_backend(self) -> Optional[str]:
return self._process_group_backend

@property
def local_rank(self) -> int:
return self._local_rank

def _configure_launcher(self) -> None:
self._launcher = _MultiProcessingLauncher(self, start_method=self._start_method)

def setup_environment(self) -> None:
self._setup_distributed()
super().setup_environment()

def setup_module(self, module: Module) -> Module:
return DistributedDataParallel(module=module, device_ids=self._determine_ddp_device_ids(), **self._ddp_kwargs)

def module_to_device(self, module: Module) -> None:
if self.root_device.type == "cuda":
# TODO(lite): This should be handled outside module_to_device, by a call to accelerator.setup_device()
# set the device on the spawned subprocesses
torch.cuda.set_device(self.root_device)
module.to(self.root_device)

def reduce(
self, tensor: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"
) -> Tensor:
"""Reduces a tensor from several distributed processes to one aggregated tensor.

Args:
tensor: the tensor to sync and reduce
group: the process group to gather results from. Defaults to all processes (world)
reduce_op: the reduction operation. Defaults to 'mean'/'avg'.
Can also be a string 'sum' to calculate the sum during reduction.

Return:
reduced value, except when the input was not a tensor the output remains is unchanged
"""
if isinstance(tensor, Tensor):
tensor = sync_ddp_if_available(tensor, group, reduce_op=reduce_op)
return tensor

def barrier(self, *args: Any, **kwargs: Any) -> None:
if not distributed_available():
return
if torch.distributed.get_backend() == "nccl":
torch.distributed.barrier(device_ids=self._determine_ddp_device_ids())
else:
torch.distributed.barrier()

def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
if not distributed_available():
return obj
obj = [obj]
if self.global_rank != src:
obj = [None] # type: ignore[list-item]
torch.distributed.broadcast_object_list(obj, src, group=_group.WORLD)
return obj[0]

@classmethod
def register_strategies(cls, strategy_registry: Dict) -> None:
entries = (
("ddp_spawn", "spawn"),
("ddp_fork", "fork"),
("ddp_notebook", "fork"),
)
for name, start_method in entries:
strategy_registry.register(
name,
cls,
description=f"DDP strategy with `start_method` '{start_method}'",
start_method=start_method,
)

entries = (
("ddp_spawn_find_unused_parameters_false", "spawn"),
("ddp_fork_find_unused_parameters_false", "fork"),
("ddp_notebook_find_unused_parameters_false", "fork"),
)
for name, start_method in entries:
strategy_registry.register(
name,
cls,
description=f"DDP strategy with `find_unused_parameters` as False and `start_method` '{start_method}'",
find_unused_parameters=False,
start_method=start_method,
)

def _setup_distributed(self) -> None:
self._set_world_ranks()
rank_zero_only.rank = self.global_rank
self._process_group_backend = self._get_process_group_backend()
assert self.cluster_environment is not None
init_dist_connection(
self.cluster_environment,
self._process_group_backend,
self.global_rank,
self.world_size,
timeout=self._timeout,
)

def _get_process_group_backend(self) -> str:
return self._process_group_backend or get_default_process_group_backend_for_device(self.root_device)

def _set_world_ranks(self, process_idx: int = 0) -> None:
self._local_rank = process_idx
if self.cluster_environment is None:
return
self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank)
self.cluster_environment.set_world_size(self.num_nodes * self.num_processes)
rank_zero_only.rank = self.cluster_environment.global_rank()

def _determine_ddp_device_ids(self) -> Optional[List[int]]:
if self.root_device.type == "cpu":
return None
return [self.root_device.index]
97 changes: 91 additions & 6 deletions src/lightning_lite/strategies/fairscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from lightning_lite.accelerators import Accelerator
from lightning_lite.plugins import CheckpointIO, ClusterEnvironment, Precision
from lightning_lite.strategies import DDPSpawnStrategy
from lightning_lite.strategies.ddp import DDPStrategy
from lightning_lite.utilities.enums import PrecisionType
from lightning_lite.utilities.imports import _IS_WINDOWS
Expand Down Expand Up @@ -76,7 +77,7 @@ def setup_module_and_optimizers(
The model wrapped into a :class:`~fairscale.nn.data_parallel.ShardedDataParallel` module
and a list of optimizer wrapped in :class:~`fairscale.optim.OSS`.
"""
optimizers = self._reinit_optimizers_with_oss(optimizers)
optimizers = _reinit_optimizers_with_oss(optimizers, self.precision_plugin, self.num_nodes)
model = ShardedDataParallel(module, sharded_optimizer=optimizers, **self._ddp_kwargs)
return model, optimizers

Expand Down Expand Up @@ -107,16 +108,100 @@ def register_strategies(cls, strategy_registry: Dict) -> None:
description=cls.__class__.__name__,
)


class DDPSpawnShardedStrategy(DDPSpawnStrategy):
"""Optimizer and gradient sharded training provided by FairScale with Spawn."""

_REDUCE_BUFFER_SIZE_DEFAULT: int = 2**23 # 8M

def __init__(
self,
accelerator: Optional[Accelerator] = None,
parallel_devices: Optional[List[torch.device]] = None,
cluster_environment: Optional[ClusterEnvironment] = None,
checkpoint_io: Optional[CheckpointIO] = None,
precision_plugin: Optional[Precision] = None,
process_group_backend: Optional[str] = None,
timeout: Optional[timedelta] = default_pg_timeout,
**kwargs: Any,
) -> None:
super().__init__(
accelerator=accelerator,
parallel_devices=parallel_devices,
cluster_environment=cluster_environment,
checkpoint_io=checkpoint_io,
precision_plugin=precision_plugin,
process_group_backen=process_group_backend,
timeout=timeout,
**kwargs,
)
super().__init__()
if "reduce_buffer_size" not in self._ddp_kwargs:
# For multi-node training, enabling bucketing will improve performance.
self._ddp_kwargs["reduce_buffer_size"] = self._REDUCE_BUFFER_SIZE_DEFAULT if self.num_nodes > 1 else 0

def setup_module_and_optimizers(
self, module: Module, optimizers: List[Optimizer]
) -> Tuple[Module, List[Optimizer]]:
"""Wraps the model and optimizers with fairscale components.

Return:
The model wrapped into a :class:`~fairscale.nn.data_parallel.ShardedDataParallel` module
and a list of optimizer wrapped in :class:~`fairscale.optim.OSS`.
"""
optimizers = _reinit_optimizers_with_oss(optimizers, self.precision_plugin, self.num_nodes)
model = ShardedDataParallel(module, sharded_optimizer=optimizers, **self._ddp_kwargs)
return model, optimizers

@contextmanager
def block_backward_sync(self, module: Module) -> Generator:
"""Blocks syncing gradients behaviour on backwards pass.

This is useful for skipping sync when accumulating gradients, reducing communication overhead
Returns: context manager with sync behaviour off
"""
if isinstance(module, ShardedDataParallel):
with module.no_sync():
yield None
else:
yield None

@classmethod
def register_strategies(cls, strategy_registry: Dict) -> None:
strategy_registry.register(
"ddp_sharded_spawn_find_unused_parameters_false",
cls,
description="DDP Spawn Sharded Strategy with `find_unused_parameters` as False",
find_unused_parameters=False,
)
strategy_registry.register(
"ddp_sharded_spawn",
cls,
description=cls.__class__.__name__,
)

def _reinit_optimizers_with_oss(self, optimizers: List[Optimizer]) -> List["OSS"]:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
for x, optimizer in enumerate(optimizers):
if not isinstance(optimizer, OSS):
optim_class = type(optimizer)
zero_optimizer = OSS(params=optimizer.param_groups, optim=optim_class, **optimizer.defaults)
is_fp16 = self.precision_plugin.precision in (PrecisionType.MIXED, PrecisionType.HALF)
# For multi-node training, compressing the model shards in fp16 before broadcasting
# improves performance. When using PyTorch AMP, it will not degrade
# the model performance.
zero_optimizer.broadcast_fp16 = is_fp16 and self.num_nodes > 1
optimizers[x] = zero_optimizer
del optimizer
return optimizers


def _reinit_optimizers_with_oss(
optimizers: List[Optimizer], precision_plugin: Precision, num_nodes: int
) -> List["OSS"]:
for x, optimizer in enumerate(optimizers):
if not isinstance(optimizer, OSS):
optim_class = type(optimizer)
zero_optimizer = OSS(params=optimizer.param_groups, optim=optim_class, **optimizer.defaults)
is_fp16 = precision_plugin.precision in (PrecisionType.MIXED, PrecisionType.HALF)
# For multi-node training, compressing the model shards in fp16 before broadcasting
# improves performance. When using PyTorch AMP, it will not degrade
# the model performance.
zero_optimizer.broadcast_fp16 = is_fp16 and num_nodes > 1
optimizers[x] = zero_optimizer
del optimizer
return optimizers
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,7 @@ def _wrapping_function(
) -> None:
if global_states:
global_states.restore()
# TODO(lite): Update worker setup once DDPSpawn strategy is in Lite
self._strategy._worker_setup(process_idx)
self._strategy._local_rank = process_idx
results = function(*args, **kwargs)

if self._strategy.local_rank == 0:
Expand Down
3 changes: 1 addition & 2 deletions src/lightning_lite/strategies/launchers/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,7 @@ def _wrapping_function(
return_queue: SimpleQueue,
global_states: Optional[_GlobalStateSnapshot] = None,
) -> None:
# TODO(lite): Update worker setup once TPUSpawn strategy is in Lite
self._strategy._worker_setup(process_idx)
self._strategy._local_rank = process_idx
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@awaelchli I believe this is not correct, as this no longer sets XLAStrategy._launched=True and then #14926 fails with

  File "/home/runner/work/lightning/tests/tests_lite/strategies/test_xla.py", line 17, in broadcast_on_tpu_fn
    result = strategy.broadcast(obj)
  File "/home/runner/work/lightning/src/lightning_lite/strategies/xla.py", line 146, in broadcast
    data_tensor = torch.tensor(data, device=self.root_device, dtype=torch.float)
  File "/home/runner/work/lightning/src/lightning_lite/strategies/xla.py", line 72, in root_device
    raise RuntimeError("Accessing the XLA device before processes have spawned is not allowed.")

I'm a bit confused about what's the best way to do this.

Do strategy.setup_environment()?

If yes:

  • This does not set the local rank. Should it? or do we still manually set the local rank?
  • Should this also be done for the other Lite launchers?

Also, why did Lite remove _worker_setup? The logic being different between PL and Lite strategies is confusing.

This blocks #14926

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was done in reaction to your comment: #11073 (comment)
I think it is correct, otherwise many many tests for ddp spawn would fail, and tpu spawn is not fundamentally different regarding this local rank business.

I commented on #14926 that maybe all that is missing in the test is a strategy.setup_environment().

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My motivation with #11073 was always to simplify these things so that these questions wouldn't come up in the first place. But nobody wants to merge it lol, already posted 3x times in waiting pr over the last 5 months or so.

This does not set the local rank. Should it? or do we still manually set the local rank?

For the multiprocessing launcher, the information of local rank can only come from the launcher directly. So the answer here is no.

Should this also be done for the other Lite launchers?

If #11073 lands both codes would be identical in this regard.

Also, why did Lite remove _worker_setup? The logic being different between PL and Lite strategies is confusing.

If #11073 lands both codes would be identical in this regard.

results = function(*args, **kwargs)

if self._strategy.local_rank == 0:
Expand Down
Loading