From cb0b12e6bd98b67ec472fdaeb8e8a9048a364c51 Mon Sep 17 00:00:00 2001 From: Steffen Kirres Date: Wed, 5 Oct 2022 16:13:42 +0200 Subject: [PATCH 1/9] Make gradients available for all_gather on TPU --- src/lightning_lite/strategies/xla.py | 16 +++-- src/pytorch_lightning/strategies/tpu_spawn.py | 16 +++-- tests/tests_pytorch/accelerators/test_tpu.py | 60 +++++++++++++++++++ 3 files changed, 82 insertions(+), 10 deletions(-) diff --git a/src/lightning_lite/strategies/xla.py b/src/lightning_lite/strategies/xla.py index 51bac19afbc24..e8895d09bcaf5 100644 --- a/src/lightning_lite/strategies/xla.py +++ b/src/lightning_lite/strategies/xla.py @@ -156,20 +156,26 @@ def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: return obj def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor: - """ - Function to gather a tensor from several distributed processes + """Function to gather a tensor from several distributed processes. + Args: tensor: tensor of shape (batch, ...) group: not available with TPUs - sync_grads: not available with TPUs + sync_grads: flag that allows users to synchronize gradients for the all_gather operation Return: A tensor of shape (world_size, batch, ...) """ if isinstance(tensor, Tensor) and tensor.dim() == 0: tensor = tensor.unsqueeze(0) - import torch_xla.core.xla_model as xm - return xm.all_gather(tensor) + if sync_grads: + import torch_xla.core.functions as xf + + return xf.all_gather(tensor) + else: + import torch_xla.core.xla_model as xm + + return xm.all_gather(tensor) def save_checkpoint( self, checkpoint: Dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None diff --git a/src/pytorch_lightning/strategies/tpu_spawn.py b/src/pytorch_lightning/strategies/tpu_spawn.py index 88757f804f09e..aa14eca998376 100644 --- a/src/pytorch_lightning/strategies/tpu_spawn.py +++ b/src/pytorch_lightning/strategies/tpu_spawn.py @@ -289,20 +289,26 @@ def remove_checkpoint(self, filepath: _PATH) -> None: self.checkpoint_io.remove_checkpoint(filepath) def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor: - """ - Function to gather a tensor from several distributed processes + """Function to gather a tensor from several distributed processes. + Args: tensor: tensor of shape (batch, ...) group: not available with TPUs - sync_grads: not available with TPUs + sync_grads: flag that allows users to synchronize gradients for the all_gather operation Return: A tensor of shape (world_size, batch, ...) """ if isinstance(tensor, Tensor) and tensor.dim() == 0: tensor = tensor.unsqueeze(0) - import torch_xla.core.xla_model as xm - return xm.all_gather(tensor) + if sync_grads: + import torch_xla.core.functions as xf + + return xf.all_gather(tensor) + else: + import torch_xla.core.xla_model as xm + + return xm.all_gather(tensor) def teardown(self) -> None: super().teardown() diff --git a/tests/tests_pytorch/accelerators/test_tpu.py b/tests/tests_pytorch/accelerators/test_tpu.py index 85ce3cac3a31c..eec5a0f5268e0 100644 --- a/tests/tests_pytorch/accelerators/test_tpu.py +++ b/tests/tests_pytorch/accelerators/test_tpu.py @@ -323,3 +323,63 @@ def test_trainer_config_device_ids(devices, expected_device_ids): trainer = Trainer(accelerator="tpu", devices=devices) assert trainer.device_ids == expected_device_ids assert trainer.num_devices == len(expected_device_ids) + + +@RunIf(tpu=True, standalone=True) +@mock.patch.dict(os.environ, os.environ.copy(), clear=True) +def test_all_gather_sync_grads(tmpdir): + class TestModel(BoringModel): + + training_step_called = False + + def training_step(self, batch, batch_idx): + self.training_step_called = True + tensor = torch.rand(2, 2, requires_grad=True, device=self.device) + gathered_tensor = self.all_gather(tensor, sync_grads=True) + assert gathered_tensor.shape == torch.Size([8, 2, 2]) + + loss = gathered_tensor.sum() + + return loss + + model = TestModel() + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=True, + accelerator="tpu", + devices=8, + enable_progress_bar=False, + enable_model_summary=False, + ) + trainer.fit(model) + assert model.training_step_called + + +@RunIf(tpu=True, standalone=True) +@mock.patch.dict(os.environ, os.environ.copy(), clear=True) +def test_all_gather_no_grads(tmpdir): + class TestModel(BoringModel): + + training_step_called = False + + def training_step(self, batch, batch_idx): + self.training_step_called = True + tensor = torch.rand(2, 2, requires_grad=True, device=self.device) + gathered_tensor = self.all_gather(tensor, sync_grads=False) + assert gathered_tensor.shape == torch.Size([8, 2, 2]) + + loss = gathered_tensor.sum() + + return loss + + model = TestModel() + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=True, + accelerator="tpu", + devices=8, + enable_progress_bar=False, + enable_model_summary=False, + ) + trainer.fit(model) + assert model.training_step_called From 8fd5eb83f1025d8bc41e51a31e65b1644f245b57 Mon Sep 17 00:00:00 2001 From: Steffen Kirres Date: Wed, 19 Oct 2022 15:47:01 +0200 Subject: [PATCH 2/9] Modify switch and tests --- src/lightning_lite/strategies/xla.py | 11 ++--- src/pytorch_lightning/strategies/tpu_spawn.py | 11 ++--- tests/tests_lite/strategies/test_xla.py | 15 ++++++ tests/tests_pytorch/accelerators/test_tpu.py | 46 +++++-------------- 4 files changed, 32 insertions(+), 51 deletions(-) diff --git a/src/lightning_lite/strategies/xla.py b/src/lightning_lite/strategies/xla.py index e8895d09bcaf5..f55641c5e7c59 100644 --- a/src/lightning_lite/strategies/xla.py +++ b/src/lightning_lite/strategies/xla.py @@ -168,14 +168,9 @@ def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bo if isinstance(tensor, Tensor) and tensor.dim() == 0: tensor = tensor.unsqueeze(0) - if sync_grads: - import torch_xla.core.functions as xf - - return xf.all_gather(tensor) - else: - import torch_xla.core.xla_model as xm - - return xm.all_gather(tensor) + import torch_xla.core.functions as xf + import torch_xla.core.xla_model as xm + return xf.all_gather(tensor) if sync_grads else xm.all_gather(tensor) def save_checkpoint( self, checkpoint: Dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None diff --git a/src/pytorch_lightning/strategies/tpu_spawn.py b/src/pytorch_lightning/strategies/tpu_spawn.py index aa14eca998376..bf88b5d9719ea 100644 --- a/src/pytorch_lightning/strategies/tpu_spawn.py +++ b/src/pytorch_lightning/strategies/tpu_spawn.py @@ -301,14 +301,9 @@ def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bo if isinstance(tensor, Tensor) and tensor.dim() == 0: tensor = tensor.unsqueeze(0) - if sync_grads: - import torch_xla.core.functions as xf - - return xf.all_gather(tensor) - else: - import torch_xla.core.xla_model as xm - - return xm.all_gather(tensor) + import torch_xla.core.functions as xf + import torch_xla.core.xla_model as xm + return xf.all_gather(tensor) if sync_grads else xm.all_gather(tensor) def teardown(self) -> None: super().teardown() diff --git a/tests/tests_lite/strategies/test_xla.py b/tests/tests_lite/strategies/test_xla.py index 11d41cb72cd16..524c35bee816a 100644 --- a/tests/tests_lite/strategies/test_xla.py +++ b/tests/tests_lite/strategies/test_xla.py @@ -17,6 +17,7 @@ from unittest.mock import Mock import pytest +import torch from tests_lite.helpers.dataloaders import CustomNotImplementedErrorDataloader from tests_lite.helpers.models import RandomDataset, RandomIterableDataset from tests_lite.helpers.runif import RunIf @@ -113,3 +114,17 @@ def test_xla_validate_unsupported_iterable_dataloaders(_, dataloader, monkeypatc with pytest.raises(TypeError, match="TPUs do not currently support"): XLAStrategy().process_dataloader(dataloader) + + +def tpu_all_gather_fn(strategy): + for sync_grads in [True, False]: + tensor = torch.tensor(1, device=strategy.root_device) + result = strategy.all_gather(tensor, sync_grads=sync_grads) + assert result.sum() == 8 + + +@RunIf(tpu=True) +@mock.patch.dict(os.environ, os.environ.copy(), clear=True) +def test_tpu_all_gather(): + """Test the all_gather operation on TPU.""" + xla_launch(tpu_all_gather_fn) diff --git a/tests/tests_pytorch/accelerators/test_tpu.py b/tests/tests_pytorch/accelerators/test_tpu.py index eec5a0f5268e0..32cc62fea6f65 100644 --- a/tests/tests_pytorch/accelerators/test_tpu.py +++ b/tests/tests_pytorch/accelerators/test_tpu.py @@ -327,48 +327,24 @@ def test_trainer_config_device_ids(devices, expected_device_ids): @RunIf(tpu=True, standalone=True) @mock.patch.dict(os.environ, os.environ.copy(), clear=True) -def test_all_gather_sync_grads(tmpdir): - class TestModel(BoringModel): - - training_step_called = False - - def training_step(self, batch, batch_idx): - self.training_step_called = True - tensor = torch.rand(2, 2, requires_grad=True, device=self.device) - gathered_tensor = self.all_gather(tensor, sync_grads=True) - assert gathered_tensor.shape == torch.Size([8, 2, 2]) - - loss = gathered_tensor.sum() +def test_all_gather(tmpdir): + nb_devices = 8 + values_per_dim = 2 + expected_tensor_dims = [nb_devices, values_per_dim, values_per_dim] - return loss - - model = TestModel() - trainer = Trainer( - default_root_dir=tmpdir, - fast_dev_run=True, - accelerator="tpu", - devices=8, - enable_progress_bar=False, - enable_model_summary=False, - ) - trainer.fit(model) - assert model.training_step_called - - -@RunIf(tpu=True, standalone=True) -@mock.patch.dict(os.environ, os.environ.copy(), clear=True) -def test_all_gather_no_grads(tmpdir): class TestModel(BoringModel): training_step_called = False def training_step(self, batch, batch_idx): self.training_step_called = True - tensor = torch.rand(2, 2, requires_grad=True, device=self.device) - gathered_tensor = self.all_gather(tensor, sync_grads=False) - assert gathered_tensor.shape == torch.Size([8, 2, 2]) + tensor = torch.rand(values_per_dim, values_per_dim, requires_grad=True, device=self.device) + tensor_with_grads = self.all_gather(tensor, sync_grads=True) + tensor_wo_grads = self.all_gather(tensor, sync_grads=False) + assert tensor_with_grads.shape == torch.Size(expected_tensor_dims) + assert tensor_wo_grads.shape == torch.Size(expected_tensor_dims) - loss = gathered_tensor.sum() + loss = tensor_with_grads.sum() + tensor_wo_grads.sum() return loss @@ -377,7 +353,7 @@ def training_step(self, batch, batch_idx): default_root_dir=tmpdir, fast_dev_run=True, accelerator="tpu", - devices=8, + devices=nb_devices, enable_progress_bar=False, enable_model_summary=False, ) From 8cf0877e5de4d50fd7a005381903154cb90e78aa Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 19 Oct 2022 13:50:03 +0000 Subject: [PATCH 3/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning_lite/strategies/xla.py | 1 + src/pytorch_lightning/strategies/tpu_spawn.py | 1 + 2 files changed, 2 insertions(+) diff --git a/src/lightning_lite/strategies/xla.py b/src/lightning_lite/strategies/xla.py index f55641c5e7c59..ecd751e4d26d5 100644 --- a/src/lightning_lite/strategies/xla.py +++ b/src/lightning_lite/strategies/xla.py @@ -170,6 +170,7 @@ def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bo import torch_xla.core.functions as xf import torch_xla.core.xla_model as xm + return xf.all_gather(tensor) if sync_grads else xm.all_gather(tensor) def save_checkpoint( diff --git a/src/pytorch_lightning/strategies/tpu_spawn.py b/src/pytorch_lightning/strategies/tpu_spawn.py index bf88b5d9719ea..61d397c0223d8 100644 --- a/src/pytorch_lightning/strategies/tpu_spawn.py +++ b/src/pytorch_lightning/strategies/tpu_spawn.py @@ -303,6 +303,7 @@ def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bo import torch_xla.core.functions as xf import torch_xla.core.xla_model as xm + return xf.all_gather(tensor) if sync_grads else xm.all_gather(tensor) def teardown(self) -> None: From 2c339cf41b7457216fc642106909658d2e9ad2be Mon Sep 17 00:00:00 2001 From: stekiri Date: Mon, 31 Oct 2022 09:17:20 +0100 Subject: [PATCH 4/9] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- tests/tests_lite/strategies/test_xla.py | 12 ++++++++---- tests/tests_pytorch/accelerators/test_tpu.py | 2 +- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/tests/tests_lite/strategies/test_xla.py b/tests/tests_lite/strategies/test_xla.py index 524c35bee816a..84ce76879c296 100644 --- a/tests/tests_lite/strategies/test_xla.py +++ b/tests/tests_lite/strategies/test_xla.py @@ -117,10 +117,14 @@ def test_xla_validate_unsupported_iterable_dataloaders(_, dataloader, monkeypatc def tpu_all_gather_fn(strategy): - for sync_grads in [True, False]: - tensor = torch.tensor(1, device=strategy.root_device) - result = strategy.all_gather(tensor, sync_grads=sync_grads) - assert result.sum() == 8 + tensor = torch.tensor(1, device=strategy.root_device) + result = strategy.all_gather(tensor, sync_grads=False) + assert result.sum() == 8 + + tensor = torch.tensor(1.0, device=strategy.root_device , requires_grad=True) + result = strategy.all_gather(tensor, sync_grads=False) + result.sum().backward() + assert torch.equal(tensor.grad, torch.tensor(1.0)) @RunIf(tpu=True) diff --git a/tests/tests_pytorch/accelerators/test_tpu.py b/tests/tests_pytorch/accelerators/test_tpu.py index 32cc62fea6f65..75fe6f01267e8 100644 --- a/tests/tests_pytorch/accelerators/test_tpu.py +++ b/tests/tests_pytorch/accelerators/test_tpu.py @@ -325,7 +325,7 @@ def test_trainer_config_device_ids(devices, expected_device_ids): assert trainer.num_devices == len(expected_device_ids) -@RunIf(tpu=True, standalone=True) +@RunIf(tpu=True) @mock.patch.dict(os.environ, os.environ.copy(), clear=True) def test_all_gather(tmpdir): nb_devices = 8 From b623b0d5abb3d073a1396bea5c9f57322eb3d85c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 31 Oct 2022 08:18:47 +0000 Subject: [PATCH 5/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/tests_lite/strategies/test_xla.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/tests_lite/strategies/test_xla.py b/tests/tests_lite/strategies/test_xla.py index 84ce76879c296..93b91bfc99c42 100644 --- a/tests/tests_lite/strategies/test_xla.py +++ b/tests/tests_lite/strategies/test_xla.py @@ -120,8 +120,8 @@ def tpu_all_gather_fn(strategy): tensor = torch.tensor(1, device=strategy.root_device) result = strategy.all_gather(tensor, sync_grads=False) assert result.sum() == 8 - - tensor = torch.tensor(1.0, device=strategy.root_device , requires_grad=True) + + tensor = torch.tensor(1.0, device=strategy.root_device, requires_grad=True) result = strategy.all_gather(tensor, sync_grads=False) result.sum().backward() assert torch.equal(tensor.grad, torch.tensor(1.0)) From 371dd9a294e7e8627725a73de06e1b2fb80b3747 Mon Sep 17 00:00:00 2001 From: Steffen Kirres Date: Tue, 1 Nov 2022 16:23:28 +0100 Subject: [PATCH 6/9] Modify tests --- tests/tests_lite/strategies/test_xla.py | 19 ++++++++------ tests/tests_pytorch/accelerators/test_tpu.py | 26 ++++++++++---------- 2 files changed, 24 insertions(+), 21 deletions(-) diff --git a/tests/tests_lite/strategies/test_xla.py b/tests/tests_lite/strategies/test_xla.py index 93b91bfc99c42..14d3f2f713422 100644 --- a/tests/tests_lite/strategies/test_xla.py +++ b/tests/tests_lite/strategies/test_xla.py @@ -117,14 +117,17 @@ def test_xla_validate_unsupported_iterable_dataloaders(_, dataloader, monkeypatc def tpu_all_gather_fn(strategy): - tensor = torch.tensor(1, device=strategy.root_device) - result = strategy.all_gather(tensor, sync_grads=False) - assert result.sum() == 8 - - tensor = torch.tensor(1.0, device=strategy.root_device, requires_grad=True) - result = strategy.all_gather(tensor, sync_grads=False) - result.sum().backward() - assert torch.equal(tensor.grad, torch.tensor(1.0)) + for sync_grads in [True, False]: + tensor = torch.tensor(1.0, device=strategy.root_device, requires_grad=True) + result = strategy.all_gather(tensor, sync_grads=sync_grads) + summed = result.sum() + assert torch.equal(summed, torch.tensor(8.0)) + summed.backward() + if sync_grads: + assert torch.equal(tensor.grad, torch.tensor(1.0)) + else: + # As gradients are not synced, the original tensor will not have gradients. + assert tensor.grad is None @RunIf(tpu=True) diff --git a/tests/tests_pytorch/accelerators/test_tpu.py b/tests/tests_pytorch/accelerators/test_tpu.py index 75fe6f01267e8..96bd55af093c4 100644 --- a/tests/tests_pytorch/accelerators/test_tpu.py +++ b/tests/tests_pytorch/accelerators/test_tpu.py @@ -328,32 +328,32 @@ def test_trainer_config_device_ids(devices, expected_device_ids): @RunIf(tpu=True) @mock.patch.dict(os.environ, os.environ.copy(), clear=True) def test_all_gather(tmpdir): - nb_devices = 8 - values_per_dim = 2 - expected_tensor_dims = [nb_devices, values_per_dim, values_per_dim] - class TestModel(BoringModel): - training_step_called = False def training_step(self, batch, batch_idx): self.training_step_called = True - tensor = torch.rand(values_per_dim, values_per_dim, requires_grad=True, device=self.device) - tensor_with_grads = self.all_gather(tensor, sync_grads=True) - tensor_wo_grads = self.all_gather(tensor, sync_grads=False) - assert tensor_with_grads.shape == torch.Size(expected_tensor_dims) - assert tensor_wo_grads.shape == torch.Size(expected_tensor_dims) - loss = tensor_with_grads.sum() + tensor_wo_grads.sum() + for sync_grads in [True, False]: + tensor = torch.tensor(1.0, device=self.device, requires_grad=True) + result = self.all_gather(tensor, sync_grads=sync_grads) + summed = result.sum() + assert torch.equal(summed, torch.tensor(8.0)) + summed.backward() + if sync_grads: + assert torch.equal(tensor.grad, torch.tensor(1.0)) + else: + # As gradients are not synced, the original tensor will not have gradients. + assert tensor.grad is None - return loss + return torch.rand(1, requires_grad=True, device=self.device).sum() model = TestModel() trainer = Trainer( default_root_dir=tmpdir, fast_dev_run=True, accelerator="tpu", - devices=nb_devices, + devices=8, enable_progress_bar=False, enable_model_summary=False, ) From 410a6bea6eff07c4a8768d996ac7daf2b7010e13 Mon Sep 17 00:00:00 2001 From: Steffen Kirres Date: Wed, 2 Nov 2022 09:58:00 +0100 Subject: [PATCH 7/9] Fix test --- tests/tests_pytorch/accelerators/test_tpu.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/tests/tests_pytorch/accelerators/test_tpu.py b/tests/tests_pytorch/accelerators/test_tpu.py index 96bd55af093c4..e0f2cc2cca846 100644 --- a/tests/tests_pytorch/accelerators/test_tpu.py +++ b/tests/tests_pytorch/accelerators/test_tpu.py @@ -328,6 +328,8 @@ def test_trainer_config_device_ids(devices, expected_device_ids): @RunIf(tpu=True) @mock.patch.dict(os.environ, os.environ.copy(), clear=True) def test_all_gather(tmpdir): + nb_devices = 8 + class TestModel(BoringModel): training_step_called = False @@ -338,7 +340,7 @@ def training_step(self, batch, batch_idx): tensor = torch.tensor(1.0, device=self.device, requires_grad=True) result = self.all_gather(tensor, sync_grads=sync_grads) summed = result.sum() - assert torch.equal(summed, torch.tensor(8.0)) + assert torch.equal(summed, torch.tensor(float(nb_devices))) summed.backward() if sync_grads: assert torch.equal(tensor.grad, torch.tensor(1.0)) @@ -346,16 +348,19 @@ def training_step(self, batch, batch_idx): # As gradients are not synced, the original tensor will not have gradients. assert tensor.grad is None - return torch.rand(1, requires_grad=True, device=self.device).sum() + dummy_loss = torch.rand(1, requires_grad=True, device=self.device).sum() + return dummy_loss + + def on_train_end(self) -> None: + assert self.training_step_called model = TestModel() trainer = Trainer( default_root_dir=tmpdir, fast_dev_run=True, accelerator="tpu", - devices=8, + devices=nb_devices, enable_progress_bar=False, enable_model_summary=False, ) trainer.fit(model) - assert model.training_step_called From b22a288d71e525c1e15d67d30411fc4a31b2add8 Mon Sep 17 00:00:00 2001 From: Steffen Kirres Date: Mon, 14 Nov 2022 17:39:56 +0100 Subject: [PATCH 8/9] Drop test --- tests/tests_pytorch/accelerators/test_tpu.py | 43 +------------------- 1 file changed, 1 insertion(+), 42 deletions(-) diff --git a/tests/tests_pytorch/accelerators/test_tpu.py b/tests/tests_pytorch/accelerators/test_tpu.py index e0f2cc2cca846..86f812e7c9308 100644 --- a/tests/tests_pytorch/accelerators/test_tpu.py +++ b/tests/tests_pytorch/accelerators/test_tpu.py @@ -322,45 +322,4 @@ def test_warning_if_tpus_not_used(tpu_available): def test_trainer_config_device_ids(devices, expected_device_ids): trainer = Trainer(accelerator="tpu", devices=devices) assert trainer.device_ids == expected_device_ids - assert trainer.num_devices == len(expected_device_ids) - - -@RunIf(tpu=True) -@mock.patch.dict(os.environ, os.environ.copy(), clear=True) -def test_all_gather(tmpdir): - nb_devices = 8 - - class TestModel(BoringModel): - training_step_called = False - - def training_step(self, batch, batch_idx): - self.training_step_called = True - - for sync_grads in [True, False]: - tensor = torch.tensor(1.0, device=self.device, requires_grad=True) - result = self.all_gather(tensor, sync_grads=sync_grads) - summed = result.sum() - assert torch.equal(summed, torch.tensor(float(nb_devices))) - summed.backward() - if sync_grads: - assert torch.equal(tensor.grad, torch.tensor(1.0)) - else: - # As gradients are not synced, the original tensor will not have gradients. - assert tensor.grad is None - - dummy_loss = torch.rand(1, requires_grad=True, device=self.device).sum() - return dummy_loss - - def on_train_end(self) -> None: - assert self.training_step_called - - model = TestModel() - trainer = Trainer( - default_root_dir=tmpdir, - fast_dev_run=True, - accelerator="tpu", - devices=nb_devices, - enable_progress_bar=False, - enable_model_summary=False, - ) - trainer.fit(model) + assert trainer.num_devices == len(expected_device_ids) \ No newline at end of file From 6262ead23e0832394b7b976043d056c6e33c90f8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 14 Nov 2022 16:41:37 +0000 Subject: [PATCH 9/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/tests_pytorch/accelerators/test_tpu.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_pytorch/accelerators/test_tpu.py b/tests/tests_pytorch/accelerators/test_tpu.py index 86f812e7c9308..85ce3cac3a31c 100644 --- a/tests/tests_pytorch/accelerators/test_tpu.py +++ b/tests/tests_pytorch/accelerators/test_tpu.py @@ -322,4 +322,4 @@ def test_warning_if_tpus_not_used(tpu_available): def test_trainer_config_device_ids(devices, expected_device_ids): trainer = Trainer(accelerator="tpu", devices=devices) assert trainer.device_ids == expected_device_ids - assert trainer.num_devices == len(expected_device_ids) \ No newline at end of file + assert trainer.num_devices == len(expected_device_ids)