Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Align ddp and ddp-spawn strategies in setting up the environment #11073

Merged
merged 32 commits into from
Sep 29, 2022
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
9e53c90
align ddp and ddp_spawn in setting up the environment
awaelchli Dec 15, 2021
7a48d5c
Merge branch 'master' into refactor/spawn/setup-environment
awaelchli Jan 4, 2022
454f5c7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 4, 2022
07bd00e
align
awaelchli Jan 4, 2022
a9c137d
remove
awaelchli Jan 4, 2022
66f2e8c
move
awaelchli Jan 4, 2022
7dd4056
Merge remote-tracking branch 'origin/refactor/spawn/setup-environment…
awaelchli Jan 4, 2022
85eea33
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 4, 2022
cf46f1c
Merge branch 'master' into refactor/spawn/setup-environment
awaelchli Jun 30, 2022
3c93c0a
resolve merge errors
awaelchli Jun 30, 2022
02a4465
notebook
awaelchli Jun 30, 2022
4f08082
notebook
awaelchli Jun 30, 2022
601c2cd
update tests
awaelchli Jun 30, 2022
c9a1093
Merge branch 'master' into refactor/spawn/setup-environment
awaelchli Jul 24, 2022
ada9a5b
fix
awaelchli Jul 25, 2022
a585729
changelog
awaelchli Jul 25, 2022
c037ad5
types
awaelchli Jul 25, 2022
a7cff0e
Merge branch 'master' into refactor/spawn/setup-environment
awaelchli Aug 25, 2022
5ec39ad
Merge branch 'master' into refactor/spawn/setup-environment
awaelchli Aug 28, 2022
39f5f72
reset
awaelchli Aug 28, 2022
151755c
imports
awaelchli Aug 28, 2022
cd9bf08
fix xla launcher
awaelchli Aug 28, 2022
048339d
Merge branch 'master' into refactor/spawn/setup-environment
awaelchli Sep 19, 2022
a28affc
update
awaelchli Sep 19, 2022
e9bc6ee
update
awaelchli Sep 19, 2022
9cbfa0a
Update src/pytorch_lightning/CHANGELOG.md
awaelchli Sep 19, 2022
058c897
Merge branch 'master' into refactor/spawn/setup-environment
carmocca Sep 29, 2022
eb69b9e
merge master for the 100th time
awaelchli Sep 29, 2022
9bfd40b
move it as early as possible
awaelchli Sep 29, 2022
dd41040
Merge branch 'master' into refactor/spawn/setup-environment
carmocca Sep 29, 2022
13680c9
Revert "move it as early as possible"
awaelchli Sep 29, 2022
243d8ac
Use correct hook
carmocca Sep 29, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/lightning_lite/lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,16 +380,17 @@ def seed_everything(seed: Optional[int] = None, workers: Optional[bool] = None)
return seed_everything(seed=seed, workers=workers)

def _run_impl(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any:
# apply sharded context to prevent OOM
run_method = partial(self._run_with_strategy_setup, run_method)
# wrap the real run method with setup logic
run_method = partial(self._run_with_setup, run_method)

if self._strategy.launcher is not None:
return self._strategy.launcher.launch(run_method, *args, **kwargs)
else:
return run_method(*args, **kwargs)

def _run_with_strategy_setup(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any:
def _run_with_setup(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any:
self._strategy.setup_environment()
# apply sharded context to prevent OOM
with self._strategy.module_sharded_context(), _replace_dunder_methods(
DataLoader, "dataset"
), _replace_dunder_methods(BatchSampler):
Expand Down
4 changes: 3 additions & 1 deletion src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed fall-back to `LightningEnvironment` when number of SLURM tasks does not correspond to number of processes in Trainer ([#14300](https://github.com/Lightning-AI/lightning/pull/14300))


- Aligned DDP and DDPSpawn strategies in setting up the environment ([#11073](https://github.com/Lightning-AI/lightning/pull/11073))


- Integrated the Lite Precision plugins into the PL Precision plugins - the base class in PL now extends the `lightning_lite.precision.Precision` base class ([#14798](https://github.com/Lightning-AI/lightning/pull/14798))
* The `PrecisionPlugin.backward` signature changed: The `closure_loss` argument was renamed to `tensor`
* The `PrecisionPlugin.{pre_,post_}backward` signature changed: The `closure_loss` argument was renamed to `tensor` and moved as the first argument
Expand All @@ -105,7 +108,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- It is no longer needed to call `model.double()` when using `precision=64` in Lightning Lite ([#14827](https://github.com/Lightning-AI/lightning/pull/14827))


### Deprecated

- Deprecated `LightningDeepSpeedModule` ([#14000](https://github.com/Lightning-AI/lightning/pull/14000))
Expand Down
5 changes: 0 additions & 5 deletions src/pytorch_lightning/strategies/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,13 +196,8 @@ def _setup_model(self, model: Module) -> DistributedDataParallel:
def setup_distributed(self) -> None:
log.detail(f"{self.__class__.__name__}: setting up distributed...")
reset_seed()

# determine which process we are and world size
self.set_world_ranks()

# set warning rank
rank_zero_only.rank = self.global_rank

self._process_group_backend = self._get_process_group_backend()
assert self.cluster_environment is not None
init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout)
Expand Down
24 changes: 14 additions & 10 deletions src/pytorch_lightning/strategies/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,10 @@ def process_group_backend(self) -> Optional[str]:
def _configure_launcher(self) -> None:
self._launcher = _MultiProcessingLauncher(self, start_method=self._start_method)

def setup_environment(self) -> None:
self.setup_distributed()
super().setup_environment()

def setup(self, trainer: "pl.Trainer") -> None:

assert self.cluster_environment is not None
Expand Down Expand Up @@ -160,16 +164,9 @@ def _setup_model(self, model: Module) -> DistributedDataParallel:
"""Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module."""
return DistributedDataParallel(module=model, device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs)

def set_world_ranks(self, process_idx: int = 0) -> None:
self._local_rank = process_idx
if self.cluster_environment is None:
return
self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank)
self.cluster_environment.set_world_size(self.num_nodes * self.num_processes)
rank_zero_only.rank = self.cluster_environment.global_rank()

def _worker_setup(self, process_idx: int) -> None:
self.set_world_ranks(process_idx)
def setup_distributed(self) -> None:
log.detail(f"{self.__class__.__name__}: setting up distributed...")
self.set_world_ranks()
rank_zero_only.rank = self.global_rank
self._process_group_backend = self._get_process_group_backend()
assert self.cluster_environment is not None
Expand All @@ -181,6 +178,13 @@ def _worker_setup(self, process_idx: int) -> None:
timeout=self._timeout,
)

def set_world_ranks(self) -> None:
if self.cluster_environment is None:
return
self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank)
self.cluster_environment.set_world_size(self.num_nodes * self.num_processes)
rank_zero_only.rank = self.cluster_environment.global_rank()

def _get_process_group_backend(self) -> str:
return self._process_group_backend or get_default_process_group_backend_for_device(self.root_device)

Expand Down
7 changes: 2 additions & 5 deletions src/pytorch_lightning/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from pytorch_lightning.utilities import GradClipAlgorithmType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info, rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities.types import LRSchedulerConfig, STEP_OUTPUT

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -348,12 +348,9 @@ def _load_config(self, config: Optional[Union[_PATH, Dict[str, Any]]]) -> Option

def setup_distributed(self) -> None:
reset_seed()

# determine which process we are and world size
self.set_world_ranks()

rank_zero_only.rank = self.global_rank
self._init_deepspeed_distributed()

if not self._config_initialized:
self._format_config()
self._config_initialized = True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def _wrapping_function(
) -> None:
if global_states:
global_states.restore()
self._strategy._worker_setup(process_idx)
self._strategy._local_rank = process_idx
results = function(*args, **kwargs)

if trainer is not None:
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/strategies/launchers/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def _wrapping_function(
return_queue: SimpleQueue,
global_states: Optional[_GlobalStateSnapshot] = None,
) -> None:
self._strategy._worker_setup(process_idx)
self._strategy._local_rank = process_idx
results = function(*args, **kwargs)

if trainer is not None:
Expand Down
4 changes: 2 additions & 2 deletions src/pytorch_lightning/strategies/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,9 +212,9 @@ def reduce(

return output

def _worker_setup(self, process_idx: int) -> None:
def setup_distributed(self) -> None:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
self._launched = True
self.set_world_ranks(process_idx)
self.set_world_ranks()
rank_zero_only.rank = self.global_rank

def validation_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:
Expand Down
4 changes: 1 addition & 3 deletions tests/tests_pytorch/strategies/test_ddp_spawn_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,9 @@ def test_ddp_spawn_transfer_weights(tmpdir, trainer_fn):
assert not temp_file.exists()


@RunIf(min_cuda_gpus=1)
@mock.patch("torch.distributed.init_process_group")
def test_ddp_spawn_strategy_set_timeout(mock_init_process_group):
"""Tests with ddp strategy."""
"""Test that the timeout gets passed to the ``torch.distributed.init_process_group`` function."""
test_timedelta = timedelta(seconds=30)
model = BoringModel()
ddp_spawn_strategy = DDPSpawnStrategy(timeout=test_timedelta)
Expand All @@ -170,7 +169,6 @@ def test_ddp_spawn_strategy_set_timeout(mock_init_process_group):
trainer.strategy.connect(model)
trainer.lightning_module.trainer = trainer
trainer.strategy.setup_environment()
trainer.strategy._worker_setup(0)

process_group_backend = trainer.strategy._get_process_group_backend()
global_rank = trainer.strategy.cluster_environment.global_rank()
Expand Down
3 changes: 1 addition & 2 deletions tests/tests_pytorch/strategies/test_ddp_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,10 +236,9 @@ def node_rank(self):
assert ddp_strategy.launcher is None


@RunIf(min_cuda_gpus=1)
@mock.patch("torch.distributed.init_process_group")
def test_ddp_strategy_set_timeout(mock_init_process_group):
"""Tests with ddp strategy."""
"""Test that the timeout gets passed to the ``torch.distributed.init_process_group`` function."""
test_timedelta = timedelta(seconds=30)
model = BoringModel()
ddp_strategy = DDPStrategy(timeout=test_timedelta)
Expand Down