Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make gradients available for all_gather on TPU #15003

Merged
merged 19 commits into from
Dec 8, 2022
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions src/lightning_lite/strategies/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,20 +156,22 @@ def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
return obj

def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor:
"""
Function to gather a tensor from several distributed processes
"""Function to gather a tensor from several distributed processes.

Args:
tensor: tensor of shape (batch, ...)
group: not available with TPUs
sync_grads: not available with TPUs
sync_grads: flag that allows users to synchronize gradients for the all_gather operation
Return:
A tensor of shape (world_size, batch, ...)
"""
if isinstance(tensor, Tensor) and tensor.dim() == 0:
tensor = tensor.unsqueeze(0)

import torch_xla.core.functions as xf
import torch_xla.core.xla_model as xm

return xm.all_gather(tensor)
return xf.all_gather(tensor) if sync_grads else xm.all_gather(tensor)

def save_checkpoint(
self, checkpoint: Dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None
Expand Down
10 changes: 6 additions & 4 deletions src/pytorch_lightning/strategies/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,20 +289,22 @@ def remove_checkpoint(self, filepath: _PATH) -> None:
self.checkpoint_io.remove_checkpoint(filepath)

def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor:
"""
Function to gather a tensor from several distributed processes
"""Function to gather a tensor from several distributed processes.

Args:
tensor: tensor of shape (batch, ...)
group: not available with TPUs
sync_grads: not available with TPUs
sync_grads: flag that allows users to synchronize gradients for the all_gather operation
Return:
A tensor of shape (world_size, batch, ...)
"""
if isinstance(tensor, Tensor) and tensor.dim() == 0:
tensor = tensor.unsqueeze(0)

import torch_xla.core.functions as xf
import torch_xla.core.xla_model as xm

return xm.all_gather(tensor)
return xf.all_gather(tensor) if sync_grads else xm.all_gather(tensor)

def teardown(self) -> None:
super().teardown()
Expand Down
19 changes: 19 additions & 0 deletions tests/tests_lite/strategies/test_xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from unittest.mock import Mock

import pytest
import torch
from tests_lite.helpers.dataloaders import CustomNotImplementedErrorDataloader
from tests_lite.helpers.models import RandomDataset, RandomIterableDataset
from tests_lite.helpers.runif import RunIf
Expand Down Expand Up @@ -113,3 +114,21 @@ def test_xla_validate_unsupported_iterable_dataloaders(_, dataloader, monkeypatc

with pytest.raises(TypeError, match="TPUs do not currently support"):
XLAStrategy().process_dataloader(dataloader)


def tpu_all_gather_fn(strategy):
tensor = torch.tensor(1, device=strategy.root_device)
result = strategy.all_gather(tensor, sync_grads=False)
assert result.sum() == 8

tensor = torch.tensor(1.0, device=strategy.root_device, requires_grad=True)
result = strategy.all_gather(tensor, sync_grads=False)
result.sum().backward()
assert torch.equal(tensor.grad, torch.tensor(1.0))


@RunIf(tpu=True)
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
def test_tpu_all_gather():
"""Test the all_gather operation on TPU."""
xla_launch(tpu_all_gather_fn)
36 changes: 36 additions & 0 deletions tests/tests_pytorch/accelerators/test_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,3 +323,39 @@ def test_trainer_config_device_ids(devices, expected_device_ids):
trainer = Trainer(accelerator="tpu", devices=devices)
assert trainer.device_ids == expected_device_ids
assert trainer.num_devices == len(expected_device_ids)


@RunIf(tpu=True)
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
def test_all_gather(tmpdir):
carmocca marked this conversation as resolved.
Show resolved Hide resolved
nb_devices = 8
values_per_dim = 2
expected_tensor_dims = [nb_devices, values_per_dim, values_per_dim]

class TestModel(BoringModel):

training_step_called = False

def training_step(self, batch, batch_idx):
self.training_step_called = True
tensor = torch.rand(values_per_dim, values_per_dim, requires_grad=True, device=self.device)
tensor_with_grads = self.all_gather(tensor, sync_grads=True)
tensor_wo_grads = self.all_gather(tensor, sync_grads=False)
assert tensor_with_grads.shape == torch.Size(expected_tensor_dims)
assert tensor_wo_grads.shape == torch.Size(expected_tensor_dims)

loss = tensor_with_grads.sum() + tensor_wo_grads.sum()

return loss

model = TestModel()
trainer = Trainer(
default_root_dir=tmpdir,
fast_dev_run=True,
accelerator="tpu",
devices=nb_devices,
enable_progress_bar=False,
enable_model_summary=False,
)
trainer.fit(model)
assert model.training_step_called