diff --git a/src/lightning_lite/strategies/xla.py b/src/lightning_lite/strategies/xla.py index 51bac19afbc24..ecd751e4d26d5 100644 --- a/src/lightning_lite/strategies/xla.py +++ b/src/lightning_lite/strategies/xla.py @@ -156,20 +156,22 @@ def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: return obj def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor: - """ - Function to gather a tensor from several distributed processes + """Function to gather a tensor from several distributed processes. + Args: tensor: tensor of shape (batch, ...) group: not available with TPUs - sync_grads: not available with TPUs + sync_grads: flag that allows users to synchronize gradients for the all_gather operation Return: A tensor of shape (world_size, batch, ...) """ if isinstance(tensor, Tensor) and tensor.dim() == 0: tensor = tensor.unsqueeze(0) + + import torch_xla.core.functions as xf import torch_xla.core.xla_model as xm - return xm.all_gather(tensor) + return xf.all_gather(tensor) if sync_grads else xm.all_gather(tensor) def save_checkpoint( self, checkpoint: Dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None diff --git a/src/pytorch_lightning/strategies/tpu_spawn.py b/src/pytorch_lightning/strategies/tpu_spawn.py index fd6b0c9693003..05545e7e00f27 100644 --- a/src/pytorch_lightning/strategies/tpu_spawn.py +++ b/src/pytorch_lightning/strategies/tpu_spawn.py @@ -289,20 +289,22 @@ def remove_checkpoint(self, filepath: _PATH) -> None: self.checkpoint_io.remove_checkpoint(filepath) def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor: - """ - Function to gather a tensor from several distributed processes + """Function to gather a tensor from several distributed processes. + Args: tensor: tensor of shape (batch, ...) group: not available with TPUs - sync_grads: not available with TPUs + sync_grads: flag that allows users to synchronize gradients for the all_gather operation Return: A tensor of shape (world_size, batch, ...) """ if isinstance(tensor, Tensor) and tensor.dim() == 0: tensor = tensor.unsqueeze(0) + + import torch_xla.core.functions as xf import torch_xla.core.xla_model as xm - return xm.all_gather(tensor) + return xf.all_gather(tensor) if sync_grads else xm.all_gather(tensor) def teardown(self) -> None: super().teardown() diff --git a/tests/tests_lite/strategies/test_xla.py b/tests/tests_lite/strategies/test_xla.py index 11d41cb72cd16..14d3f2f713422 100644 --- a/tests/tests_lite/strategies/test_xla.py +++ b/tests/tests_lite/strategies/test_xla.py @@ -17,6 +17,7 @@ from unittest.mock import Mock import pytest +import torch from tests_lite.helpers.dataloaders import CustomNotImplementedErrorDataloader from tests_lite.helpers.models import RandomDataset, RandomIterableDataset from tests_lite.helpers.runif import RunIf @@ -113,3 +114,24 @@ def test_xla_validate_unsupported_iterable_dataloaders(_, dataloader, monkeypatc with pytest.raises(TypeError, match="TPUs do not currently support"): XLAStrategy().process_dataloader(dataloader) + + +def tpu_all_gather_fn(strategy): + for sync_grads in [True, False]: + tensor = torch.tensor(1.0, device=strategy.root_device, requires_grad=True) + result = strategy.all_gather(tensor, sync_grads=sync_grads) + summed = result.sum() + assert torch.equal(summed, torch.tensor(8.0)) + summed.backward() + if sync_grads: + assert torch.equal(tensor.grad, torch.tensor(1.0)) + else: + # As gradients are not synced, the original tensor will not have gradients. + assert tensor.grad is None + + +@RunIf(tpu=True) +@mock.patch.dict(os.environ, os.environ.copy(), clear=True) +def test_tpu_all_gather(): + """Test the all_gather operation on TPU.""" + xla_launch(tpu_all_gather_fn)