From e693ace8bebc1be0af3de0b0785fb4b03aac30c8 Mon Sep 17 00:00:00 2001 From: cw-tan Date: Tue, 17 Sep 2024 16:20:01 -0400 Subject: [PATCH] add unittest for dpp gather autograd compatibility --- tests/unittests/bases/test_ddp.py | 56 +++++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/tests/unittests/bases/test_ddp.py b/tests/unittests/bases/test_ddp.py index c057d0cbdf8..60a9b96948c 100644 --- a/tests/unittests/bases/test_ddp.py +++ b/tests/unittests/bases/test_ddp.py @@ -77,6 +77,60 @@ def _test_ddp_gather_uneven_tensors_multidim(rank: int, worldsize: int = NUM_PRO assert (val == torch.ones_like(val)).all() +@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_2_1, reason="test only works on newer torch versions") +def _test_ddp_gather_autograd_same_shape(rank: int, worldsize: int = NUM_PROCESSES) -> None: + """Test that ddp gather preserves local rank's autograd graph for same-shaped tensors across ranks. + + This function tests that ``torchmetrics.utilities.distributed.gather_all_tensors`` works as intended in + preserving the local rank's autograd graph upon the gather. The function compares derivative values obtained + with the local rank results from the ``gather_all_tensors`` output and the original local rank tensor. + This test only considers tensors of the same shape across different ranks. + + Note that this test only works for torch>=2.0. + """ + tensor = torch.ones(50, requires_grad=True) + result = gather_all_tensors(tensor) + assert len(result) == worldsize + scalar1 = 0 + scalar2 = 0 + for idx in range(worldsize): + if idx == rank: + scalar1 = scalar1 + torch.sum(tensor * torch.ones_like(tensor)) + else: + scalar1 = scalar1 + torch.sum(result[idx] * torch.ones_like(result[idx])) + scalar2 = scalar2 + torch.sum(result[idx] * torch.ones_like(result[idx])) + gradient1 = torch.autograd.grad(scalar1, [tensor], retain_graph=True)[0] + gradient2 = torch.autograd.grad(scalar2, [tensor])[0] + assert torch.allclose(gradient1, gradient2) + + +@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_2_1, reason="test only works on newer torch versions") +def _test_ddp_gather_autograd_different_shape(rank: int, worldsize: int = NUM_PROCESSES) -> None: + """Test that ddp gather preserves local rank's autograd graph for differently-shaped tensors across ranks. + + This function tests that ``torchmetrics.utilities.distributed.gather_all_tensors`` works as intended in + preserving the local rank's autograd graph upon the gather. The function compares derivative values obtained + with the local rank results from the ``gather_all_tensors`` output and the original local rank tensor. + This test considers tensors of different shapes across different ranks. + + Note that this test only works for torch>=2.0. + """ + tensor = torch.ones(rank + 1, 2 - rank, requires_grad=True) + result = gather_all_tensors(tensor) + assert len(result) == worldsize + scalar1 = 0 + scalar2 = 0 + for idx in range(worldsize): + if idx == rank: + scalar1 = scalar1 + torch.sum(tensor * torch.ones_like(tensor)) + else: + scalar1 = scalar1 + torch.sum(result[idx] * torch.ones_like(result[idx])) + scalar2 = scalar2 + torch.sum(result[idx] * torch.ones_like(result[idx])) + gradient1 = torch.autograd.grad(scalar1, [tensor], retain_graph=True)[0] + gradient2 = torch.autograd.grad(scalar2, [tensor])[0] + assert torch.allclose(gradient1, gradient2) + + def _test_ddp_compositional_tensor(rank: int, worldsize: int = NUM_PROCESSES) -> None: dummy = DummyMetricSum() dummy._reductions = {"x": torch.sum} @@ -97,6 +151,8 @@ def _test_ddp_compositional_tensor(rank: int, worldsize: int = NUM_PROCESSES) -> _test_ddp_sum_cat, _test_ddp_gather_uneven_tensors, _test_ddp_gather_uneven_tensors_multidim, + _test_ddp_gather_autograd_same_shape, + _test_ddp_gather_autograd_different_shape, _test_ddp_compositional_tensor, ], )