Skip to content

Commit

Permalink
Align ddp and ddp-spawn strategies in setting up the environment (#11073
Browse files Browse the repository at this point in the history
)

Co-authored-by: Kushashwa Ravi Shrimali <[email protected]>
Co-authored-by: Carlos Mocholí <[email protected]>
  • Loading branch information
3 people authored Sep 29, 2022
1 parent 3a70e5d commit 822a7f5
Show file tree
Hide file tree
Showing 11 changed files with 30 additions and 33 deletions.
7 changes: 4 additions & 3 deletions src/lightning_lite/lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,16 +380,17 @@ def seed_everything(seed: Optional[int] = None, workers: Optional[bool] = None)
return seed_everything(seed=seed, workers=workers)

def _run_impl(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any:
# apply sharded context to prevent OOM
run_method = partial(self._run_with_strategy_setup, run_method)
# wrap the real run method with setup logic
run_method = partial(self._run_with_setup, run_method)

if self._strategy.launcher is not None:
return self._strategy.launcher.launch(run_method, *args, **kwargs)
else:
return run_method(*args, **kwargs)

def _run_with_strategy_setup(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any:
def _run_with_setup(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any:
self._strategy.setup_environment()
# apply sharded context to prevent OOM
with self._strategy.module_sharded_context(), _replace_dunder_methods(
DataLoader, "dataset"
), _replace_dunder_methods(BatchSampler):
Expand Down
2 changes: 1 addition & 1 deletion src/lightning_lite/strategies/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def is_distributed(self) -> bool:
def _configure_launcher(self) -> None:
self._launcher = _XLALauncher(self)

def setup_environment(self) -> None:
def _setup_distributed(self) -> None:
self._launched = True
self._set_world_ranks()
rank_zero_only.rank = self.global_rank
Expand Down
3 changes: 3 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed fall-back to `LightningEnvironment` when number of SLURM tasks does not correspond to number of processes in Trainer ([#14300](https://github.com/Lightning-AI/lightning/pull/14300))


- Aligned DDP and DDPSpawn strategies in setting up the environment ([#11073](https://github.com/Lightning-AI/lightning/pull/11073))


- Integrated the Lite Precision plugins into the PL Precision plugins - the base class in PL now extends the `lightning_lite.precision.Precision` base class ([#14798](https://github.com/Lightning-AI/lightning/pull/14798))
* The `PrecisionPlugin.backward` signature changed: The `closure_loss` argument was renamed to `tensor`
* The `PrecisionPlugin.{pre_,post_}backward` signature changed: The `closure_loss` argument was renamed to `tensor` and moved as the first argument
Expand Down
5 changes: 0 additions & 5 deletions src/pytorch_lightning/strategies/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,13 +196,8 @@ def _setup_model(self, model: Module) -> DistributedDataParallel:
def setup_distributed(self) -> None:
log.detail(f"{self.__class__.__name__}: setting up distributed...")
reset_seed()

# determine which process we are and world size
self.set_world_ranks()

# set warning rank
rank_zero_only.rank = self.global_rank

self._process_group_backend = self._get_process_group_backend()
assert self.cluster_environment is not None
init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout)
Expand Down
24 changes: 14 additions & 10 deletions src/pytorch_lightning/strategies/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,10 @@ def process_group_backend(self) -> Optional[str]:
def _configure_launcher(self) -> None:
self._launcher = _MultiProcessingLauncher(self, start_method=self._start_method)

def setup_environment(self) -> None:
self.setup_distributed()
super().setup_environment()

def setup(self, trainer: "pl.Trainer") -> None:

assert self.cluster_environment is not None
Expand Down Expand Up @@ -160,16 +164,9 @@ def _setup_model(self, model: Module) -> DistributedDataParallel:
"""Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module."""
return DistributedDataParallel(module=model, device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs)

def set_world_ranks(self, process_idx: int = 0) -> None:
self._local_rank = process_idx
if self.cluster_environment is None:
return
self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank)
self.cluster_environment.set_world_size(self.num_nodes * self.num_processes)
rank_zero_only.rank = self.cluster_environment.global_rank()

def _worker_setup(self, process_idx: int) -> None:
self.set_world_ranks(process_idx)
def setup_distributed(self) -> None:
log.detail(f"{self.__class__.__name__}: setting up distributed...")
self.set_world_ranks()
rank_zero_only.rank = self.global_rank
self._process_group_backend = self._get_process_group_backend()
assert self.cluster_environment is not None
Expand All @@ -181,6 +178,13 @@ def _worker_setup(self, process_idx: int) -> None:
timeout=self._timeout,
)

def set_world_ranks(self) -> None:
if self.cluster_environment is None:
return
self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank)
self.cluster_environment.set_world_size(self.num_nodes * self.num_processes)
rank_zero_only.rank = self.cluster_environment.global_rank()

def _get_process_group_backend(self) -> str:
return self._process_group_backend or get_default_process_group_backend_for_device(self.root_device)

Expand Down
7 changes: 2 additions & 5 deletions src/pytorch_lightning/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from pytorch_lightning.utilities import GradClipAlgorithmType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info, rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities.types import LRSchedulerConfig, STEP_OUTPUT

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -348,12 +348,9 @@ def _load_config(self, config: Optional[Union[_PATH, Dict[str, Any]]]) -> Option

def setup_distributed(self) -> None:
reset_seed()

# determine which process we are and world size
self.set_world_ranks()

rank_zero_only.rank = self.global_rank
self._init_deepspeed_distributed()

if not self._config_initialized:
self._format_config()
self._config_initialized = True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def _wrapping_function(
) -> None:
if global_states:
global_states.restore()
self._strategy._worker_setup(process_idx)
self._strategy._local_rank = process_idx
results = function(*args, **kwargs)

if trainer is not None:
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/strategies/launchers/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def _wrapping_function(
return_queue: SimpleQueue,
global_states: Optional[_GlobalStateSnapshot] = None,
) -> None:
self._strategy._worker_setup(process_idx)
self._strategy._local_rank = process_idx
results = function(*args, **kwargs)

if trainer is not None:
Expand Down
4 changes: 2 additions & 2 deletions src/pytorch_lightning/strategies/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,9 +212,9 @@ def reduce(

return output

def _worker_setup(self, process_idx: int) -> None:
def setup_distributed(self) -> None:
self._launched = True
self.set_world_ranks(process_idx)
self.set_world_ranks()
rank_zero_only.rank = self.global_rank

def validation_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:
Expand Down
4 changes: 1 addition & 3 deletions tests/tests_pytorch/strategies/test_ddp_spawn_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,9 @@ def test_ddp_spawn_transfer_weights(tmpdir, trainer_fn):
assert not temp_file.exists()


@RunIf(min_cuda_gpus=1)
@mock.patch("torch.distributed.init_process_group")
def test_ddp_spawn_strategy_set_timeout(mock_init_process_group):
"""Tests with ddp strategy."""
"""Test that the timeout gets passed to the ``torch.distributed.init_process_group`` function."""
test_timedelta = timedelta(seconds=30)
model = BoringModel()
ddp_spawn_strategy = DDPSpawnStrategy(timeout=test_timedelta)
Expand All @@ -170,7 +169,6 @@ def test_ddp_spawn_strategy_set_timeout(mock_init_process_group):
trainer.strategy.connect(model)
trainer.lightning_module.trainer = trainer
trainer.strategy.setup_environment()
trainer.strategy._worker_setup(0)

process_group_backend = trainer.strategy._get_process_group_backend()
global_rank = trainer.strategy.cluster_environment.global_rank()
Expand Down
3 changes: 1 addition & 2 deletions tests/tests_pytorch/strategies/test_ddp_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,10 +236,9 @@ def node_rank(self):
assert ddp_strategy.launcher is None


@RunIf(min_cuda_gpus=1)
@mock.patch("torch.distributed.init_process_group")
def test_ddp_strategy_set_timeout(mock_init_process_group):
"""Tests with ddp strategy."""
"""Test that the timeout gets passed to the ``torch.distributed.init_process_group`` function."""
test_timedelta = timedelta(seconds=30)
model = BoringModel()
ddp_strategy = DDPStrategy(timeout=test_timedelta)
Expand Down

0 comments on commit 822a7f5

Please sign in to comment.