diff --git a/src/lightning_lite/strategies/__init__.py b/src/lightning_lite/strategies/__init__.py index 8ced098e3a8dd..f9cf74e30e4c0 100644 --- a/src/lightning_lite/strategies/__init__.py +++ b/src/lightning_lite/strategies/__init__.py @@ -12,14 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. from lightning_lite.strategies.ddp import DDPStrategy # noqa: F401 +from lightning_lite.strategies.ddp_spawn import DDPSpawnStrategy # noqa: F401 from lightning_lite.strategies.deepspeed import DeepSpeedStrategy # noqa: F401 from lightning_lite.strategies.dp import DataParallelStrategy # noqa: F401 from lightning_lite.strategies.fairscale import DDPShardedStrategy # noqa: F401 +from lightning_lite.strategies.fairscale import DDPSpawnShardedStrategy # noqa: F401 from lightning_lite.strategies.parallel import ParallelStrategy # noqa: F401 from lightning_lite.strategies.registry import _call_register_strategies, _StrategyRegistry from lightning_lite.strategies.single_device import SingleDeviceStrategy # noqa: F401 from lightning_lite.strategies.single_tpu import SingleTPUStrategy # noqa: F401 from lightning_lite.strategies.strategy import Strategy # noqa: F401 +from lightning_lite.strategies.xla import XLAStrategy # noqa: F401 STRATEGY_REGISTRY = _StrategyRegistry() _STRATEGIES_BASE_MODULE = "lightning_lite.strategies" diff --git a/src/lightning_lite/strategies/ddp_spawn.py b/src/lightning_lite/strategies/ddp_spawn.py new file mode 100644 index 0000000000000..3e8b48b2a6b43 --- /dev/null +++ b/src/lightning_lite/strategies/ddp_spawn.py @@ -0,0 +1,214 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from datetime import timedelta +from typing import Any, Dict, List, Optional, Union + +import torch +import torch.distributed +from torch import Tensor +from torch.distributed.constants import default_pg_timeout +from torch.nn import Module +from torch.nn.parallel.distributed import DistributedDataParallel +from typing_extensions import Literal + +from lightning_lite.accelerators.accelerator import Accelerator +from lightning_lite.plugins.environments.cluster_environment import ClusterEnvironment +from lightning_lite.plugins.io.checkpoint_plugin import CheckpointIO +from lightning_lite.plugins.precision import Precision +from lightning_lite.strategies.launchers.multiprocessing import _MultiProcessingLauncher +from lightning_lite.strategies.parallel import ParallelStrategy +from lightning_lite.strategies.strategy import TBroadcast +from lightning_lite.utilities.distributed import distributed_available, get_default_process_group_backend_for_device +from lightning_lite.utilities.distributed import group as _group +from lightning_lite.utilities.distributed import init_dist_connection, ReduceOp, sync_ddp_if_available +from lightning_lite.utilities.rank_zero import rank_zero_only + +_DDP_FORK_ALIASES = ( + "ddp_fork", + "ddp_fork_find_unused_parameters_false", + "ddp_notebook", + "ddp_notebook_find_unused_parameters_false", +) + + +class DDPSpawnStrategy(ParallelStrategy): + """Spawns processes using the :func:`torch.multiprocessing.spawn` method and joins processes after training + finishes.""" + + def __init__( + self, + accelerator: Optional[Accelerator] = None, + parallel_devices: Optional[List[torch.device]] = None, + cluster_environment: Optional[ClusterEnvironment] = None, + checkpoint_io: Optional[CheckpointIO] = None, + precision_plugin: Optional[Precision] = None, + process_group_backend: Optional[str] = None, + timeout: Optional[timedelta] = default_pg_timeout, + start_method: Literal["spawn", "fork", "forkserver"] = "spawn", + **kwargs: Any, + ): + super().__init__( + accelerator=accelerator, + parallel_devices=parallel_devices, + cluster_environment=cluster_environment, + checkpoint_io=checkpoint_io, + precision_plugin=precision_plugin, + ) + self._num_nodes = 1 + self._process_group_backend: Optional[str] = process_group_backend + self._timeout: Optional[timedelta] = timeout + self._start_method = start_method + self._ddp_kwargs = kwargs + self._local_rank = 0 + + @property + def root_device(self) -> torch.device: + assert self.parallel_devices is not None + return self.parallel_devices[self.local_rank] + + @property + def num_nodes(self) -> int: + return self._num_nodes + + @num_nodes.setter + def num_nodes(self, num_nodes: int) -> None: + # note that world ranks is related to num_nodes, when resetting it, need to reset world ranks + self._num_nodes = num_nodes + + @property + def num_processes(self) -> int: + return len(self.parallel_devices) if self.parallel_devices is not None else 0 + + @property + def distributed_sampler_kwargs(self) -> Dict[str, int]: + distributed_sampler_kwargs = dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank) + return distributed_sampler_kwargs + + @property + def process_group_backend(self) -> Optional[str]: + return self._process_group_backend + + @property + def local_rank(self) -> int: + return self._local_rank + + def _configure_launcher(self) -> None: + self._launcher = _MultiProcessingLauncher(self, start_method=self._start_method) + + def setup_environment(self) -> None: + self._setup_distributed() + super().setup_environment() + + def setup_module(self, module: Module) -> Module: + return DistributedDataParallel(module=module, device_ids=self._determine_ddp_device_ids(), **self._ddp_kwargs) + + def module_to_device(self, module: Module) -> None: + if self.root_device.type == "cuda": + # TODO(lite): This should be handled outside module_to_device, by a call to accelerator.setup_device() + # set the device on the spawned subprocesses + torch.cuda.set_device(self.root_device) + module.to(self.root_device) + + def reduce( + self, tensor: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean" + ) -> Tensor: + """Reduces a tensor from several distributed processes to one aggregated tensor. + + Args: + tensor: the tensor to sync and reduce + group: the process group to gather results from. Defaults to all processes (world) + reduce_op: the reduction operation. Defaults to 'mean'/'avg'. + Can also be a string 'sum' to calculate the sum during reduction. + + Return: + reduced value, except when the input was not a tensor the output remains is unchanged + """ + if isinstance(tensor, Tensor): + tensor = sync_ddp_if_available(tensor, group, reduce_op=reduce_op) + return tensor + + def barrier(self, *args: Any, **kwargs: Any) -> None: + if not distributed_available(): + return + if torch.distributed.get_backend() == "nccl": + torch.distributed.barrier(device_ids=self._determine_ddp_device_ids()) + else: + torch.distributed.barrier() + + def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: + if not distributed_available(): + return obj + obj = [obj] + if self.global_rank != src: + obj = [None] # type: ignore[list-item] + torch.distributed.broadcast_object_list(obj, src, group=_group.WORLD) + return obj[0] + + @classmethod + def register_strategies(cls, strategy_registry: Dict) -> None: + entries = ( + ("ddp_spawn", "spawn"), + ("ddp_fork", "fork"), + ("ddp_notebook", "fork"), + ) + for name, start_method in entries: + strategy_registry.register( + name, + cls, + description=f"DDP strategy with `start_method` '{start_method}'", + start_method=start_method, + ) + + entries = ( + ("ddp_spawn_find_unused_parameters_false", "spawn"), + ("ddp_fork_find_unused_parameters_false", "fork"), + ("ddp_notebook_find_unused_parameters_false", "fork"), + ) + for name, start_method in entries: + strategy_registry.register( + name, + cls, + description=f"DDP strategy with `find_unused_parameters` as False and `start_method` '{start_method}'", + find_unused_parameters=False, + start_method=start_method, + ) + + def _setup_distributed(self) -> None: + self._set_world_ranks() + rank_zero_only.rank = self.global_rank + self._process_group_backend = self._get_process_group_backend() + assert self.cluster_environment is not None + init_dist_connection( + self.cluster_environment, + self._process_group_backend, + self.global_rank, + self.world_size, + timeout=self._timeout, + ) + + def _get_process_group_backend(self) -> str: + return self._process_group_backend or get_default_process_group_backend_for_device(self.root_device) + + def _set_world_ranks(self, process_idx: int = 0) -> None: + self._local_rank = process_idx + if self.cluster_environment is None: + return + self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank) + self.cluster_environment.set_world_size(self.num_nodes * self.num_processes) + rank_zero_only.rank = self.cluster_environment.global_rank() + + def _determine_ddp_device_ids(self) -> Optional[List[int]]: + if self.root_device.type == "cpu": + return None + return [self.root_device.index] diff --git a/src/lightning_lite/strategies/fairscale.py b/src/lightning_lite/strategies/fairscale.py index b2c630a4dbd44..7c39f94e66969 100644 --- a/src/lightning_lite/strategies/fairscale.py +++ b/src/lightning_lite/strategies/fairscale.py @@ -23,6 +23,7 @@ from lightning_lite.accelerators import Accelerator from lightning_lite.plugins import CheckpointIO, ClusterEnvironment, Precision +from lightning_lite.strategies import DDPSpawnStrategy from lightning_lite.strategies.ddp import DDPStrategy from lightning_lite.utilities.enums import PrecisionType from lightning_lite.utilities.imports import _IS_WINDOWS @@ -76,7 +77,7 @@ def setup_module_and_optimizers( The model wrapped into a :class:`~fairscale.nn.data_parallel.ShardedDataParallel` module and a list of optimizer wrapped in :class:~`fairscale.optim.OSS`. """ - optimizers = self._reinit_optimizers_with_oss(optimizers) + optimizers = _reinit_optimizers_with_oss(optimizers, self.precision_plugin, self.num_nodes) model = ShardedDataParallel(module, sharded_optimizer=optimizers, **self._ddp_kwargs) return model, optimizers @@ -107,16 +108,100 @@ def register_strategies(cls, strategy_registry: Dict) -> None: description=cls.__class__.__name__, ) + +class DDPSpawnShardedStrategy(DDPSpawnStrategy): + """Optimizer and gradient sharded training provided by FairScale with Spawn.""" + + _REDUCE_BUFFER_SIZE_DEFAULT: int = 2**23 # 8M + + def __init__( + self, + accelerator: Optional[Accelerator] = None, + parallel_devices: Optional[List[torch.device]] = None, + cluster_environment: Optional[ClusterEnvironment] = None, + checkpoint_io: Optional[CheckpointIO] = None, + precision_plugin: Optional[Precision] = None, + process_group_backend: Optional[str] = None, + timeout: Optional[timedelta] = default_pg_timeout, + **kwargs: Any, + ) -> None: + super().__init__( + accelerator=accelerator, + parallel_devices=parallel_devices, + cluster_environment=cluster_environment, + checkpoint_io=checkpoint_io, + precision_plugin=precision_plugin, + process_group_backen=process_group_backend, + timeout=timeout, + **kwargs, + ) + super().__init__() + if "reduce_buffer_size" not in self._ddp_kwargs: + # For multi-node training, enabling bucketing will improve performance. + self._ddp_kwargs["reduce_buffer_size"] = self._REDUCE_BUFFER_SIZE_DEFAULT if self.num_nodes > 1 else 0 + + def setup_module_and_optimizers( + self, module: Module, optimizers: List[Optimizer] + ) -> Tuple[Module, List[Optimizer]]: + """Wraps the model and optimizers with fairscale components. + + Return: + The model wrapped into a :class:`~fairscale.nn.data_parallel.ShardedDataParallel` module + and a list of optimizer wrapped in :class:~`fairscale.optim.OSS`. + """ + optimizers = _reinit_optimizers_with_oss(optimizers, self.precision_plugin, self.num_nodes) + model = ShardedDataParallel(module, sharded_optimizer=optimizers, **self._ddp_kwargs) + return model, optimizers + + @contextmanager + def block_backward_sync(self, module: Module) -> Generator: + """Blocks syncing gradients behaviour on backwards pass. + + This is useful for skipping sync when accumulating gradients, reducing communication overhead + Returns: context manager with sync behaviour off + """ + if isinstance(module, ShardedDataParallel): + with module.no_sync(): + yield None + else: + yield None + + @classmethod + def register_strategies(cls, strategy_registry: Dict) -> None: + strategy_registry.register( + "ddp_sharded_spawn_find_unused_parameters_false", + cls, + description="DDP Spawn Sharded Strategy with `find_unused_parameters` as False", + find_unused_parameters=False, + ) + strategy_registry.register( + "ddp_sharded_spawn", + cls, + description=cls.__class__.__name__, + ) + def _reinit_optimizers_with_oss(self, optimizers: List[Optimizer]) -> List["OSS"]: for x, optimizer in enumerate(optimizers): if not isinstance(optimizer, OSS): optim_class = type(optimizer) zero_optimizer = OSS(params=optimizer.param_groups, optim=optim_class, **optimizer.defaults) - is_fp16 = self.precision_plugin.precision in (PrecisionType.MIXED, PrecisionType.HALF) - # For multi-node training, compressing the model shards in fp16 before broadcasting - # improves performance. When using PyTorch AMP, it will not degrade - # the model performance. - zero_optimizer.broadcast_fp16 = is_fp16 and self.num_nodes > 1 optimizers[x] = zero_optimizer del optimizer return optimizers + + +def _reinit_optimizers_with_oss( + optimizers: List[Optimizer], precision_plugin: Precision, num_nodes: int +) -> List["OSS"]: + for x, optimizer in enumerate(optimizers): + if not isinstance(optimizer, OSS): + optim_class = type(optimizer) + zero_optimizer = OSS(params=optimizer.param_groups, optim=optim_class, **optimizer.defaults) + is_fp16 = precision_plugin.precision in (PrecisionType.MIXED, PrecisionType.HALF) + # For multi-node training, compressing the model shards in fp16 before broadcasting + # improves performance. When using PyTorch AMP, it will not degrade + # the model performance. + zero_optimizer.broadcast_fp16 = is_fp16 and num_nodes > 1 + optimizers[x] = zero_optimizer + del optimizer + return optimizers diff --git a/src/lightning_lite/strategies/launchers/multiprocessing.py b/src/lightning_lite/strategies/launchers/multiprocessing.py index ca47efe030302..d416efee56185 100644 --- a/src/lightning_lite/strategies/launchers/multiprocessing.py +++ b/src/lightning_lite/strategies/launchers/multiprocessing.py @@ -118,8 +118,7 @@ def _wrapping_function( ) -> None: if global_states: global_states.restore() - # TODO(lite): Update worker setup once DDPSpawn strategy is in Lite - self._strategy._worker_setup(process_idx) + self._strategy._local_rank = process_idx results = function(*args, **kwargs) if self._strategy.local_rank == 0: diff --git a/src/lightning_lite/strategies/launchers/xla.py b/src/lightning_lite/strategies/launchers/xla.py index 6580fd4a01d0e..60342b344097c 100644 --- a/src/lightning_lite/strategies/launchers/xla.py +++ b/src/lightning_lite/strategies/launchers/xla.py @@ -86,8 +86,7 @@ def _wrapping_function( return_queue: SimpleQueue, global_states: Optional[_GlobalStateSnapshot] = None, ) -> None: - # TODO(lite): Update worker setup once TPUSpawn strategy is in Lite - self._strategy._worker_setup(process_idx) + self._strategy._local_rank = process_idx results = function(*args, **kwargs) if self._strategy.local_rank == 0: diff --git a/src/lightning_lite/strategies/xla.py b/src/lightning_lite/strategies/xla.py new file mode 100644 index 0000000000000..d11e05099b850 --- /dev/null +++ b/src/lightning_lite/strategies/xla.py @@ -0,0 +1,204 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import io +import os +from typing import Any, Dict, List, Mapping, Optional, Sequence, Union + +import torch +from torch import Tensor +from torch.nn import Module +from torch.utils.data import DataLoader + +from lightning_lite.accelerators import Accelerator +from lightning_lite.plugins.environments import XLAEnvironment +from lightning_lite.plugins.io.checkpoint_plugin import CheckpointIO +from lightning_lite.plugins.io.xla_plugin import XLACheckpointIO +from lightning_lite.plugins.precision import Precision +from lightning_lite.strategies.ddp_spawn import DDPSpawnStrategy +from lightning_lite.strategies.launchers.xla import _XLALauncher +from lightning_lite.strategies.strategy import TBroadcast +from lightning_lite.utilities import _TPU_AVAILABLE +from lightning_lite.utilities.apply_func import apply_to_collection +from lightning_lite.utilities.data import has_len +from lightning_lite.utilities.distributed import ReduceOp +from lightning_lite.utilities.rank_zero import rank_zero_only +from lightning_lite.utilities.types import _PATH + +if _TPU_AVAILABLE: + import torch_xla.core.xla_env_vars as xenv + import torch_xla.core.xla_model as xm + from torch_xla.core.xla_model import rendezvous + from torch_xla.distributed.parallel_loader import MpDeviceLoader +else: + xm, xmp, MpDeviceLoader, rendezvous = [None] * 4 + + +class XLAStrategy(DDPSpawnStrategy): + """Strategy for training multiple TPU devices using the :func:`torch_xla.distributed.xla_multiprocessing.spawn` + method.""" + + def __init__( + self, + accelerator: Optional[Accelerator] = None, + parallel_devices: Optional[List[torch.device]] = None, + checkpoint_io: Optional[CheckpointIO] = None, + precision_plugin: Optional[Precision] = None, + **_: Any, + ) -> None: + super().__init__( + accelerator=accelerator, + parallel_devices=parallel_devices, + cluster_environment=XLAEnvironment(), + checkpoint_io=checkpoint_io, + precision_plugin=precision_plugin, + start_method="fork", + ) + self._checkpoint_io: Optional[CheckpointIO] + self._launched = False + + @property + def root_device(self) -> torch.device: + if not self._launched: + raise RuntimeError("Accessing the XLA device before processes have spawned is not allowed.") + return xm.xla_device() + + @property + def checkpoint_io(self) -> CheckpointIO: + if self._checkpoint_io is None: + self._checkpoint_io = XLACheckpointIO() + return self._checkpoint_io + + @checkpoint_io.setter + def checkpoint_io(self, io: Optional[CheckpointIO]) -> None: + self._checkpoint_io = io + + @property + def distributed_sampler_kwargs(self) -> Dict[str, int]: + return dict(num_replicas=self.world_size, rank=self.global_rank) + + @property + def is_distributed(self) -> bool: + # HOST_WORLD_SIZE is not set outside the xmp.spawn process + return (xenv.HOST_WORLD_SIZE in os.environ) and self.world_size != 1 + + def _configure_launcher(self) -> None: + self._launcher = _XLALauncher(self) + + def setup_environment(self) -> None: + self._launched = True + self._set_world_ranks() + rank_zero_only.rank = self.global_rank + + def setup_module(self, module: Module) -> Module: + return module + + def module_to_device(self, module: Module) -> None: + module.to(self.root_device) + + def process_dataloader(self, dataloader: DataLoader) -> MpDeviceLoader: + XLAStrategy._validate_dataloader(dataloader) + dataloader = MpDeviceLoader(dataloader, self.root_device) + # Mimic interface to torch.utils.data.DataLoader + dataloader.dataset = dataloader._loader.dataset + return dataloader + + def reduce( + self, output: Union[Tensor, Any], group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None + ) -> Tensor: + if not isinstance(output, Tensor): + output = torch.tensor(output, device=self.root_device) + + invalid_reduce_op = isinstance(reduce_op, ReduceOp) and reduce_op != ReduceOp.SUM + invalid_reduce_op_str = isinstance(reduce_op, str) and reduce_op.lower() not in ("sum", "mean", "avg") + if invalid_reduce_op or invalid_reduce_op_str: + raise ValueError( + "Currently, the XLAStrategy only supports `sum`, `mean`, `avg` for the reduce operation, got:" + f" {reduce_op}" + ) + + output = xm.mesh_reduce("reduce", output, sum) + + if isinstance(reduce_op, str) and reduce_op.lower() in ("avg", "mean"): + output = output / self.world_size + + return output + + def barrier(self, name: Optional[str] = None, *args: Any, **kwargs: Any) -> None: + if self.is_distributed: + rendezvous(name) + + def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: + if not self.is_distributed: + return obj + buffer = io.BytesIO() + torch.save(obj, buffer) + data = bytearray(buffer.getbuffer()) + data_tensor = torch.tensor(data, device=self.root_device, dtype=torch.float) + data = xm.all_gather(data_tensor) + buffer = io.BytesIO(data.cpu().byte().numpy()) + obj = torch.load(buffer) + return obj + + def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor: + """ + Function to gather a tensor from several distributed processes + Args: + tensor: tensor of shape (batch, ...) + group: not available with TPUs + sync_grads: not available with TPUs + Return: + A tensor of shape (world_size, batch, ...) + """ + if isinstance(tensor, Tensor) and tensor.dim() == 0: + tensor = tensor.unsqueeze(0) + return xm.all_gather(tensor) + + def save_checkpoint( + self, checkpoint: Dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None + ) -> None: + """Save model/training states as a checkpoint file through state-dump and file-write. + + Args: + checkpoint: dict containing model and trainer state + filepath: write-target file's path + storage_options: parameter for how to save to storage, passed to ``CheckpointIO`` plugin + """ + # `xla_model.save` needs to be called on all ranks. It internally checks if the local rank is 0 + self.checkpoint_io.save_checkpoint(checkpoint, filepath, storage_options=storage_options) + + def remove_checkpoint(self, filepath: _PATH) -> None: + """Remove checkpoint filepath from the filesystem. + + Args: + filepath: Path to checkpoint + """ + if self.local_rank == 0: + self.checkpoint_io.remove_checkpoint(filepath) + + @classmethod + def register_strategies(cls, strategy_registry: Dict) -> None: + # TODO(lite): Deprecate the name "tpu_spawn" through the connector + strategy_registry.register("tpu_spawn", cls, description=cls.__class__.__name__) + strategy_registry.register("xla", cls, description=cls.__class__.__name__) + + @staticmethod + def _validate_dataloader(dataloaders: DataLoader) -> None: + def check_has_len(dataloader: DataLoader) -> None: + if not has_len(dataloader): + raise TypeError( + "TPUs do not currently support IterableDataset objects, the dataset must implement `__len__`." + " HINT: You can mock the length on your dataset to bypass this MisconfigurationException." + ) + + apply_to_collection(dataloaders, dtype=object, wrong_dtype=(Sequence, Mapping), function=check_has_len) diff --git a/src/lightning_lite/utilities/distributed.py b/src/lightning_lite/utilities/distributed.py index 166b28a5c948f..26fa3e1e230d0 100644 --- a/src/lightning_lite/utilities/distributed.py +++ b/src/lightning_lite/utilities/distributed.py @@ -3,12 +3,12 @@ from typing import Any, List, Optional, Tuple, Union import torch +from lightning_utilities.core.rank_zero import rank_zero_deprecation from torch import Tensor from torch.nn import functional as F from lightning_lite.plugins.environments.cluster_environment import ClusterEnvironment from lightning_lite.utilities.imports import _HPU_AVAILABLE, _TPU_AVAILABLE -from lightning_lite.utilities.rank_zero import rank_zero_deprecation from lightning_lite.utilities.rank_zero import rank_zero_info as new_rank_zero_info if _TPU_AVAILABLE: diff --git a/tests/tests_lite/strategies/test_registry.py b/tests/tests_lite/strategies/test_registry.py index 627837b4524b7..93c0071d9cd47 100644 --- a/tests/tests_lite/strategies/test_registry.py +++ b/tests/tests_lite/strategies/test_registry.py @@ -53,6 +53,16 @@ def test_available_strategies_in_registry(): "deepspeed_stage_3", "deepspeed_stage_3_offload", "deepspeed_stage_3_offload_nvme", - "dp", + "ddp_sharded_spawn_find_unused_parameters_false", + "ddp_sharded_spawn", + "ddp_spawn", + "ddp_fork", + "ddp_notebook", + "ddp_spawn_find_unused_parameters_false", + "ddp_fork_find_unused_parameters_false", + "ddp_notebook_find_unused_parameters_false", "single_tpu", + "tpu_spawn", + "xla", + "dp", }