Skip to content

Commit

Permalink
Lite: Flatten XLAStrategy (#15838)
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli authored Dec 9, 2022
1 parent 5595166 commit 3c9b7cb
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 8 deletions.
3 changes: 2 additions & 1 deletion src/lightning_lite/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- The `LightningLite.run()` method is no longer abstract ([#14992](https://github.com/Lightning-AI/lightning/issues/14992))


-
- The `XLAStrategy` now inherits from `ParallelStrategy` instead of `DDPSpawnStrategy` ([#15838](https://github.com/Lightning-AI/lightning/issues/15838))



### Deprecated
Expand Down
8 changes: 5 additions & 3 deletions src/lightning_lite/strategies/launchers/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@
from torch.multiprocessing import get_context

from lightning_lite.accelerators.tpu import _XLA_AVAILABLE
from lightning_lite.strategies.launchers.multiprocessing import _GlobalStateSnapshot, _MultiProcessingLauncher
from lightning_lite.strategies.launchers.base import _Launcher
from lightning_lite.strategies.launchers.multiprocessing import _GlobalStateSnapshot
from lightning_lite.utilities.apply_func import move_data_to_device

if TYPE_CHECKING:
from lightning_lite.strategies import XLAStrategy


class _XLALauncher(_MultiProcessingLauncher):
class _XLALauncher(_Launcher):
r"""Launches processes that run a given function in parallel on XLA supported hardware, and joins them all at the
end.
Expand All @@ -44,7 +45,8 @@ class _XLALauncher(_MultiProcessingLauncher):
def __init__(self, strategy: "XLAStrategy") -> None:
if not _XLA_AVAILABLE:
raise ModuleNotFoundError(str(_XLA_AVAILABLE))
super().__init__(strategy=strategy, start_method="fork")
self._strategy = strategy
self._start_method = "fork"

@property
def is_interactive_compatible(self) -> bool:
Expand Down
24 changes: 20 additions & 4 deletions src/lightning_lite/strategies/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from lightning_lite.plugins.io.checkpoint_io import CheckpointIO
from lightning_lite.plugins.io.xla import XLACheckpointIO
from lightning_lite.plugins.precision import Precision
from lightning_lite.strategies.ddp_spawn import DDPSpawnStrategy
from lightning_lite.strategies import ParallelStrategy
from lightning_lite.strategies.launchers.xla import _XLALauncher
from lightning_lite.strategies.strategy import TBroadcast
from lightning_lite.utilities.apply_func import apply_to_collection
Expand All @@ -38,7 +38,7 @@
from torch_xla.distributed.parallel_loader import MpDeviceLoader


class XLAStrategy(DDPSpawnStrategy):
class XLAStrategy(ParallelStrategy):
"""Strategy for training multiple TPU devices using the :func:`torch_xla.distributed.xla_multiprocessing.spawn`
method."""

Expand All @@ -55,11 +55,11 @@ def __init__(
cluster_environment=XLAEnvironment(),
checkpoint_io=checkpoint_io,
precision=precision,
start_method="fork",
)
self._checkpoint_io: Optional[CheckpointIO]
self._backward_sync_control = None # XLA synchronizes gradients in the optimizer.step() call
self._launched = False
self._local_rank = 0

@property
def root_device(self) -> torch.device:
Expand All @@ -69,6 +69,14 @@ def root_device(self) -> torch.device:

return xm.xla_device()

@property
def num_processes(self) -> int:
return len(self.parallel_devices) if self.parallel_devices is not None else 0

@property
def local_rank(self) -> int:
return self._local_rank

@property
def checkpoint_io(self) -> CheckpointIO:
if self._checkpoint_io is None:
Expand All @@ -93,10 +101,11 @@ def is_distributed(self) -> bool:
def _configure_launcher(self) -> None:
self._launcher = _XLALauncher(self)

def _setup_distributed(self) -> None:
def setup_environment(self) -> None:
self._launched = True
self._set_world_ranks()
rank_zero_only.rank = self.global_rank
super().setup_environment()

def setup_module(self, module: Module) -> Module:
return module
Expand Down Expand Up @@ -201,6 +210,13 @@ def register_strategies(cls, strategy_registry: Dict) -> None:
strategy_registry.register("tpu_spawn", cls, description=cls.__class__.__name__)
strategy_registry.register("xla", cls, description=cls.__class__.__name__)

def _set_world_ranks(self) -> None:
if self.cluster_environment is None:
return
self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank)
self.cluster_environment.set_world_size(self.num_processes)
rank_zero_only.rank = self.cluster_environment.global_rank()

@staticmethod
def _validate_dataloader(dataloaders: DataLoader) -> None:
def check_has_len(dataloader: DataLoader) -> None:
Expand Down

0 comments on commit 3c9b7cb

Please sign in to comment.