diff --git a/src/lightning_lite/CHANGELOG.md b/src/lightning_lite/CHANGELOG.md index 74db65f889ce1..1ead9c14ecde0 100644 --- a/src/lightning_lite/CHANGELOG.md +++ b/src/lightning_lite/CHANGELOG.md @@ -28,7 +28,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - The `LightningLite.run()` method is no longer abstract ([#14992](https://github.com/Lightning-AI/lightning/issues/14992)) -- +- The `XLAStrategy` now inherits from `ParallelStrategy` instead of `DDPSpawnStrategy` ([#15838](https://github.com/Lightning-AI/lightning/issues/15838)) + ### Deprecated diff --git a/src/lightning_lite/strategies/launchers/xla.py b/src/lightning_lite/strategies/launchers/xla.py index bcb770d942791..0f47b235160ce 100644 --- a/src/lightning_lite/strategies/launchers/xla.py +++ b/src/lightning_lite/strategies/launchers/xla.py @@ -18,14 +18,15 @@ from torch.multiprocessing import get_context from lightning_lite.accelerators.tpu import _XLA_AVAILABLE -from lightning_lite.strategies.launchers.multiprocessing import _GlobalStateSnapshot, _MultiProcessingLauncher +from lightning_lite.strategies.launchers.base import _Launcher +from lightning_lite.strategies.launchers.multiprocessing import _GlobalStateSnapshot from lightning_lite.utilities.apply_func import move_data_to_device if TYPE_CHECKING: from lightning_lite.strategies import XLAStrategy -class _XLALauncher(_MultiProcessingLauncher): +class _XLALauncher(_Launcher): r"""Launches processes that run a given function in parallel on XLA supported hardware, and joins them all at the end. @@ -44,7 +45,8 @@ class _XLALauncher(_MultiProcessingLauncher): def __init__(self, strategy: "XLAStrategy") -> None: if not _XLA_AVAILABLE: raise ModuleNotFoundError(str(_XLA_AVAILABLE)) - super().__init__(strategy=strategy, start_method="fork") + self._strategy = strategy + self._start_method = "fork" @property def is_interactive_compatible(self) -> bool: diff --git a/src/lightning_lite/strategies/xla.py b/src/lightning_lite/strategies/xla.py index ecd751e4d26d5..6d67fe2227ed6 100644 --- a/src/lightning_lite/strategies/xla.py +++ b/src/lightning_lite/strategies/xla.py @@ -26,7 +26,7 @@ from lightning_lite.plugins.io.checkpoint_io import CheckpointIO from lightning_lite.plugins.io.xla import XLACheckpointIO from lightning_lite.plugins.precision import Precision -from lightning_lite.strategies.ddp_spawn import DDPSpawnStrategy +from lightning_lite.strategies import ParallelStrategy from lightning_lite.strategies.launchers.xla import _XLALauncher from lightning_lite.strategies.strategy import TBroadcast from lightning_lite.utilities.apply_func import apply_to_collection @@ -38,7 +38,7 @@ from torch_xla.distributed.parallel_loader import MpDeviceLoader -class XLAStrategy(DDPSpawnStrategy): +class XLAStrategy(ParallelStrategy): """Strategy for training multiple TPU devices using the :func:`torch_xla.distributed.xla_multiprocessing.spawn` method.""" @@ -55,11 +55,11 @@ def __init__( cluster_environment=XLAEnvironment(), checkpoint_io=checkpoint_io, precision=precision, - start_method="fork", ) self._checkpoint_io: Optional[CheckpointIO] self._backward_sync_control = None # XLA synchronizes gradients in the optimizer.step() call self._launched = False + self._local_rank = 0 @property def root_device(self) -> torch.device: @@ -69,6 +69,14 @@ def root_device(self) -> torch.device: return xm.xla_device() + @property + def num_processes(self) -> int: + return len(self.parallel_devices) if self.parallel_devices is not None else 0 + + @property + def local_rank(self) -> int: + return self._local_rank + @property def checkpoint_io(self) -> CheckpointIO: if self._checkpoint_io is None: @@ -93,10 +101,11 @@ def is_distributed(self) -> bool: def _configure_launcher(self) -> None: self._launcher = _XLALauncher(self) - def _setup_distributed(self) -> None: + def setup_environment(self) -> None: self._launched = True self._set_world_ranks() rank_zero_only.rank = self.global_rank + super().setup_environment() def setup_module(self, module: Module) -> Module: return module @@ -201,6 +210,13 @@ def register_strategies(cls, strategy_registry: Dict) -> None: strategy_registry.register("tpu_spawn", cls, description=cls.__class__.__name__) strategy_registry.register("xla", cls, description=cls.__class__.__name__) + def _set_world_ranks(self) -> None: + if self.cluster_environment is None: + return + self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank) + self.cluster_environment.set_world_size(self.num_processes) + rank_zero_only.rank = self.cluster_environment.global_rank() + @staticmethod def _validate_dataloader(dataloaders: DataLoader) -> None: def check_has_len(dataloader: DataLoader) -> None: