diff --git a/src/lightning_lite/lite.py b/src/lightning_lite/lite.py index f148632d630ee..414c63c841813 100644 --- a/src/lightning_lite/lite.py +++ b/src/lightning_lite/lite.py @@ -380,16 +380,17 @@ def seed_everything(seed: Optional[int] = None, workers: Optional[bool] = None) return seed_everything(seed=seed, workers=workers) def _run_impl(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any: - # apply sharded context to prevent OOM - run_method = partial(self._run_with_strategy_setup, run_method) + # wrap the real run method with setup logic + run_method = partial(self._run_with_setup, run_method) if self._strategy.launcher is not None: return self._strategy.launcher.launch(run_method, *args, **kwargs) else: return run_method(*args, **kwargs) - def _run_with_strategy_setup(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any: + def _run_with_setup(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any: self._strategy.setup_environment() + # apply sharded context to prevent OOM with self._strategy.module_sharded_context(), _replace_dunder_methods( DataLoader, "dataset" ), _replace_dunder_methods(BatchSampler): diff --git a/src/lightning_lite/strategies/xla.py b/src/lightning_lite/strategies/xla.py index f485c30a5f433..48e2338f637c6 100644 --- a/src/lightning_lite/strategies/xla.py +++ b/src/lightning_lite/strategies/xla.py @@ -94,7 +94,7 @@ def is_distributed(self) -> bool: def _configure_launcher(self) -> None: self._launcher = _XLALauncher(self) - def setup_environment(self) -> None: + def _setup_distributed(self) -> None: self._launched = True self._set_world_ranks() rank_zero_only.rank = self.global_rank diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index c6868311cd1b3..5d816afcfe238 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -93,6 +93,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed fall-back to `LightningEnvironment` when number of SLURM tasks does not correspond to number of processes in Trainer ([#14300](https://github.com/Lightning-AI/lightning/pull/14300)) +- Aligned DDP and DDPSpawn strategies in setting up the environment ([#11073](https://github.com/Lightning-AI/lightning/pull/11073)) + + - Integrated the Lite Precision plugins into the PL Precision plugins - the base class in PL now extends the `lightning_lite.precision.Precision` base class ([#14798](https://github.com/Lightning-AI/lightning/pull/14798)) * The `PrecisionPlugin.backward` signature changed: The `closure_loss` argument was renamed to `tensor` * The `PrecisionPlugin.{pre_,post_}backward` signature changed: The `closure_loss` argument was renamed to `tensor` and moved as the first argument diff --git a/src/pytorch_lightning/strategies/ddp.py b/src/pytorch_lightning/strategies/ddp.py index 15dba4c98877b..43dd98129ae9b 100644 --- a/src/pytorch_lightning/strategies/ddp.py +++ b/src/pytorch_lightning/strategies/ddp.py @@ -196,13 +196,8 @@ def _setup_model(self, model: Module) -> DistributedDataParallel: def setup_distributed(self) -> None: log.detail(f"{self.__class__.__name__}: setting up distributed...") reset_seed() - - # determine which process we are and world size self.set_world_ranks() - - # set warning rank rank_zero_only.rank = self.global_rank - self._process_group_backend = self._get_process_group_backend() assert self.cluster_environment is not None init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout) diff --git a/src/pytorch_lightning/strategies/ddp_spawn.py b/src/pytorch_lightning/strategies/ddp_spawn.py index 81c6fccb1dc67..793c8155d01cb 100644 --- a/src/pytorch_lightning/strategies/ddp_spawn.py +++ b/src/pytorch_lightning/strategies/ddp_spawn.py @@ -133,6 +133,10 @@ def process_group_backend(self) -> Optional[str]: def _configure_launcher(self) -> None: self._launcher = _MultiProcessingLauncher(self, start_method=self._start_method) + def setup_environment(self) -> None: + self.setup_distributed() + super().setup_environment() + def setup(self, trainer: "pl.Trainer") -> None: assert self.cluster_environment is not None @@ -160,16 +164,9 @@ def _setup_model(self, model: Module) -> DistributedDataParallel: """Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module.""" return DistributedDataParallel(module=model, device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs) - def set_world_ranks(self, process_idx: int = 0) -> None: - self._local_rank = process_idx - if self.cluster_environment is None: - return - self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank) - self.cluster_environment.set_world_size(self.num_nodes * self.num_processes) - rank_zero_only.rank = self.cluster_environment.global_rank() - - def _worker_setup(self, process_idx: int) -> None: - self.set_world_ranks(process_idx) + def setup_distributed(self) -> None: + log.detail(f"{self.__class__.__name__}: setting up distributed...") + self.set_world_ranks() rank_zero_only.rank = self.global_rank self._process_group_backend = self._get_process_group_backend() assert self.cluster_environment is not None @@ -181,6 +178,13 @@ def _worker_setup(self, process_idx: int) -> None: timeout=self._timeout, ) + def set_world_ranks(self) -> None: + if self.cluster_environment is None: + return + self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank) + self.cluster_environment.set_world_size(self.num_nodes * self.num_processes) + rank_zero_only.rank = self.cluster_environment.global_rank() + def _get_process_group_backend(self) -> str: return self._process_group_backend or get_default_process_group_backend_for_device(self.root_device) diff --git a/src/pytorch_lightning/strategies/deepspeed.py b/src/pytorch_lightning/strategies/deepspeed.py index 6e91cc3d54a4f..a3a6a7998e546 100644 --- a/src/pytorch_lightning/strategies/deepspeed.py +++ b/src/pytorch_lightning/strategies/deepspeed.py @@ -45,7 +45,7 @@ from pytorch_lightning.utilities import GradClipAlgorithmType from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden -from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info, rank_zero_warn +from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info, rank_zero_only, rank_zero_warn from pytorch_lightning.utilities.types import LRSchedulerConfig, STEP_OUTPUT log = logging.getLogger(__name__) @@ -348,12 +348,9 @@ def _load_config(self, config: Optional[Union[_PATH, Dict[str, Any]]]) -> Option def setup_distributed(self) -> None: reset_seed() - - # determine which process we are and world size self.set_world_ranks() - + rank_zero_only.rank = self.global_rank self._init_deepspeed_distributed() - if not self._config_initialized: self._format_config() self._config_initialized = True diff --git a/src/pytorch_lightning/strategies/launchers/multiprocessing.py b/src/pytorch_lightning/strategies/launchers/multiprocessing.py index de41b8ff2b455..3ae5e5e75b6be 100644 --- a/src/pytorch_lightning/strategies/launchers/multiprocessing.py +++ b/src/pytorch_lightning/strategies/launchers/multiprocessing.py @@ -132,7 +132,7 @@ def _wrapping_function( ) -> None: if global_states: global_states.restore() - self._strategy._worker_setup(process_idx) + self._strategy._local_rank = process_idx results = function(*args, **kwargs) if trainer is not None: diff --git a/src/pytorch_lightning/strategies/launchers/xla.py b/src/pytorch_lightning/strategies/launchers/xla.py index 80e48371277e5..5a2996636285c 100644 --- a/src/pytorch_lightning/strategies/launchers/xla.py +++ b/src/pytorch_lightning/strategies/launchers/xla.py @@ -103,7 +103,7 @@ def _wrapping_function( return_queue: SimpleQueue, global_states: Optional[_GlobalStateSnapshot] = None, ) -> None: - self._strategy._worker_setup(process_idx) + self._strategy._local_rank = process_idx results = function(*args, **kwargs) if trainer is not None: diff --git a/src/pytorch_lightning/strategies/tpu_spawn.py b/src/pytorch_lightning/strategies/tpu_spawn.py index 7e6ccfd0d82d8..d220b719ccbfb 100644 --- a/src/pytorch_lightning/strategies/tpu_spawn.py +++ b/src/pytorch_lightning/strategies/tpu_spawn.py @@ -212,9 +212,9 @@ def reduce( return output - def _worker_setup(self, process_idx: int) -> None: + def setup_distributed(self) -> None: self._launched = True - self.set_world_ranks(process_idx) + self.set_world_ranks() rank_zero_only.rank = self.global_rank def validation_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]: diff --git a/tests/tests_pytorch/strategies/test_ddp_spawn_strategy.py b/tests/tests_pytorch/strategies/test_ddp_spawn_strategy.py index 7fb22206c45c6..7c1d347970b43 100644 --- a/tests/tests_pytorch/strategies/test_ddp_spawn_strategy.py +++ b/tests/tests_pytorch/strategies/test_ddp_spawn_strategy.py @@ -154,10 +154,9 @@ def test_ddp_spawn_transfer_weights(tmpdir, trainer_fn): assert not temp_file.exists() -@RunIf(min_cuda_gpus=1) @mock.patch("torch.distributed.init_process_group") def test_ddp_spawn_strategy_set_timeout(mock_init_process_group): - """Tests with ddp strategy.""" + """Test that the timeout gets passed to the ``torch.distributed.init_process_group`` function.""" test_timedelta = timedelta(seconds=30) model = BoringModel() ddp_spawn_strategy = DDPSpawnStrategy(timeout=test_timedelta) @@ -170,7 +169,6 @@ def test_ddp_spawn_strategy_set_timeout(mock_init_process_group): trainer.strategy.connect(model) trainer.lightning_module.trainer = trainer trainer.strategy.setup_environment() - trainer.strategy._worker_setup(0) process_group_backend = trainer.strategy._get_process_group_backend() global_rank = trainer.strategy.cluster_environment.global_rank() diff --git a/tests/tests_pytorch/strategies/test_ddp_strategy.py b/tests/tests_pytorch/strategies/test_ddp_strategy.py index 2665eb7c3e370..cf16e7716ee43 100644 --- a/tests/tests_pytorch/strategies/test_ddp_strategy.py +++ b/tests/tests_pytorch/strategies/test_ddp_strategy.py @@ -236,10 +236,9 @@ def node_rank(self): assert ddp_strategy.launcher is None -@RunIf(min_cuda_gpus=1) @mock.patch("torch.distributed.init_process_group") def test_ddp_strategy_set_timeout(mock_init_process_group): - """Tests with ddp strategy.""" + """Test that the timeout gets passed to the ``torch.distributed.init_process_group`` function.""" test_timedelta = timedelta(seconds=30) model = BoringModel() ddp_strategy = DDPStrategy(timeout=test_timedelta)