Skip to content

Commit

Permalink
Use torch.testing.assert_close everywhere (#15031)
Browse files Browse the repository at this point in the history
remove unnecessary version check
  • Loading branch information
otaj authored Oct 7, 2022
1 parent 8008055 commit 7e518ca
Show file tree
Hide file tree
Showing 7 changed files with 13 additions and 53 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,8 @@
from pytorch_lightning.core.module import LightningModule
from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12
from tests_pytorch.helpers.runif import RunIf

if _TORCH_GREATER_EQUAL_1_12:
torch_test_assert_close = torch.testing.assert_close
else:
torch_test_assert_close = torch.testing.assert_allclose


class MockTqdm(Tqdm):
def __init__(self, *args, **kwargs):
Expand Down Expand Up @@ -421,7 +415,7 @@ def training_step(self, batch, batch_idx):
)
trainer.fit(TestModel())

torch_test_assert_close(trainer.progress_bar_metrics["a"], 0.123)
torch.testing.assert_close(trainer.progress_bar_metrics["a"], 0.123)
assert trainer.progress_bar_metrics["b"] == {"b1": 1.0}
assert trainer.progress_bar_metrics["c"] == {"c1": 2.0}
pbar = trainer.progress_bar_callback.main_progress_bar
Expand Down
10 changes: 2 additions & 8 deletions tests/tests_pytorch/plugins/test_amp_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,9 @@
from pytorch_lightning.demos.boring_classes import BoringModel
from pytorch_lightning.plugins import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12
from tests_pytorch.conftest import mock_cuda_count
from tests_pytorch.helpers.runif import RunIf

if _TORCH_GREATER_EQUAL_1_12:
torch_test_assert_close = torch.testing.assert_close
else:
torch_test_assert_close = torch.testing.assert_allclose


class MyNativeAMP(NativeMixedPrecisionPlugin):
pass
Expand Down Expand Up @@ -104,13 +98,13 @@ def check_grads_unscaled(self, optimizer=None):
grads = [p.grad for p in self.parameters()]
assert len(grads) == len(self.original_grads)
for actual, expected in zip(grads, self.original_grads):
torch_test_assert_close(actual, expected, equal_nan=True)
torch.testing.assert_close(actual, expected, equal_nan=True)

def check_grads_clipped(self):
parameters = list(self.parameters())
assert len(parameters) == len(self.clipped_parameters)
for actual, expected in zip(parameters, self.clipped_parameters):
torch_test_assert_close(actual.grad, expected.grad, equal_nan=True)
torch.testing.assert_close(actual.grad, expected.grad, equal_nan=True)

def on_before_optimizer_step(self, optimizer, *_):
self.check_grads_unscaled(optimizer)
Expand Down
8 changes: 1 addition & 7 deletions tests/tests_pytorch/strategies/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,10 @@
from pytorch_lightning import Trainer
from pytorch_lightning.demos.boring_classes import BoringModel
from pytorch_lightning.strategies import DDPStrategy
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12
from tests_pytorch.helpers.datamodules import ClassifDataModule
from tests_pytorch.helpers.runif import RunIf
from tests_pytorch.strategies.test_dp import CustomClassificationModelDP

if _TORCH_GREATER_EQUAL_1_12:
torch_test_assert_close = torch.testing.assert_close
else:
torch_test_assert_close = torch.testing.assert_allclose


@pytest.mark.parametrize(
"trainer_kwargs",
Expand Down Expand Up @@ -58,7 +52,7 @@ def test_evaluate(tmpdir, trainer_kwargs):

# make sure weights didn't change
new_weights = model.layer_0.weight.clone().detach().cpu()
torch_test_assert_close(old_weights, new_weights)
torch.testing.assert_close(old_weights, new_weights)


def test_model_parallel_setup_called(tmpdir):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,8 @@
from pytorch_lightning.demos.boring_classes import BoringModel
from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin
from pytorch_lightning.strategies import Strategy
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12
from tests_pytorch.helpers.runif import RunIf

if _TORCH_GREATER_EQUAL_1_12:
torch_test_assert_close = torch.testing.assert_close
else:
torch_test_assert_close = torch.testing.assert_allclose


class ManualOptModel(BoringModel):
def __init__(self):
Expand Down Expand Up @@ -461,7 +455,7 @@ def check_grads_unscaled(self, optimizer=None):
grads = [p.grad for p in self.parameters()]
assert len(grads) == len(self.original_grads)
for actual, expected in zip(grads, self.original_grads):
torch_test_assert_close(actual, expected)
torch.testing.assert_close(actual, expected)

def on_before_optimizer_step(self, optimizer, *_):
self.check_grads_unscaled(optimizer)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,6 @@

import pytorch_lightning as pl
from pytorch_lightning.demos.boring_classes import BoringModel
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12

if _TORCH_GREATER_EQUAL_1_12:
torch_test_assert_close = torch.testing.assert_close
else:
torch_test_assert_close = torch.testing.assert_allclose


class MultiOptModel(BoringModel):
Expand Down Expand Up @@ -58,7 +52,7 @@ def training_step(self, batch, batch_idx, optimizer_idx):
for k, v in model.actual.items():
assert torch.equal(trainer.callback_metrics[f"loss_{k}_step"], v[-1])
# test loss is properly reduced
torch_test_assert_close(trainer.callback_metrics[f"loss_{k}_epoch"], torch.tensor(v).mean())
torch.testing.assert_close(trainer.callback_metrics[f"loss_{k}_epoch"], torch.tensor(v).mean())


def test_multiple_optimizers(tmpdir):
Expand Down
11 changes: 3 additions & 8 deletions tests/tests_pytorch/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
)
from pytorch_lightning.trainer.states import RunningStage, TrainerFn
from pytorch_lightning.utilities.exceptions import DeadlockDetectedException, MisconfigurationException
from pytorch_lightning.utilities.imports import _OMEGACONF_AVAILABLE, _TORCH_GREATER_EQUAL_1_12
from pytorch_lightning.utilities.imports import _OMEGACONF_AVAILABLE
from tests_pytorch.conftest import mock_cuda_count, mock_mps_count
from tests_pytorch.helpers.datamodules import ClassifDataModule
from tests_pytorch.helpers.runif import RunIf
Expand All @@ -70,11 +70,6 @@
if _OMEGACONF_AVAILABLE:
from omegaconf import OmegaConf

if _TORCH_GREATER_EQUAL_1_12:
torch_test_assert_close = torch.testing.assert_close
else:
torch_test_assert_close = torch.testing.assert_allclose


def test_trainer_error_when_input_not_lightning_module():
"""Test that a useful error gets raised when the Trainer methods receive something other than a
Expand Down Expand Up @@ -1125,7 +1120,7 @@ def configure_gradient_clipping(self, *args, **kwargs):
# test that gradient is clipped correctly
parameters = self.parameters()
grad_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2)
torch_test_assert_close(grad_norm, torch.tensor(0.05, device=self.device))
torch.testing.assert_close(grad_norm, torch.tensor(0.05, device=self.device))
self.assertion_called = True

model = TestModel()
Expand Down Expand Up @@ -1156,7 +1151,7 @@ def configure_gradient_clipping(self, *args, **kwargs):
parameters = self.parameters()
grad_max_list = [torch.max(p.grad.detach().abs()) for p in parameters]
grad_max = torch.max(torch.stack(grad_max_list))
torch_test_assert_close(grad_max.abs(), torch.tensor(1e-10, device=self.device))
torch.testing.assert_close(grad_max.abs(), torch.tensor(1e-10, device=self.device))
self.assertion_called = True

model = TestModel()
Expand Down
13 changes: 4 additions & 9 deletions tests/tests_pytorch/utilities/test_auto_restart.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,10 @@
from pytorch_lightning.utilities.enums import _FaultTolerantMode, AutoRestartBatchKeys
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.fetching import DataFetcher
from pytorch_lightning.utilities.imports import _fault_tolerant_training, _TORCH_GREATER_EQUAL_1_12
from pytorch_lightning.utilities.imports import _fault_tolerant_training
from tests_pytorch.core.test_results import spawn_launch
from tests_pytorch.helpers.runif import RunIf

if _TORCH_GREATER_EQUAL_1_12:
torch_test_assert_close = torch.testing.assert_close
else:
torch_test_assert_close = torch.testing.assert_allclose


def test_fast_forward_getattr():
dataset = range(15)
Expand Down Expand Up @@ -946,9 +941,9 @@ def run(should_fail, resume):
pre_fail_train_batches, pre_fail_val_batches = run(should_fail=True, resume=False)
post_fail_train_batches, post_fail_val_batches = run(should_fail=False, resume=True)

torch_test_assert_close(total_train_batches, pre_fail_train_batches + post_fail_train_batches)
torch.testing.assert_close(total_train_batches, pre_fail_train_batches + post_fail_train_batches)
for k in total_val_batches:
torch_test_assert_close(total_val_batches[k], pre_fail_val_batches[k] + post_fail_val_batches[k])
torch.testing.assert_close(total_val_batches[k], pre_fail_val_batches[k] + post_fail_val_batches[k])


class TestAutoRestartModelUnderSignal(BoringModel):
Expand Down Expand Up @@ -1482,6 +1477,6 @@ def configure_optimizers(self):
trainer.train_dataloader = None
restart_batches = model.batches

torch_test_assert_close(total_batches, failed_batches + restart_batches)
torch.testing.assert_close(total_batches, failed_batches + restart_batches)
assert not torch.equal(total_weight, failed_weight)
assert torch.equal(total_weight, model.layer.weight)

0 comments on commit 7e518ca

Please sign in to comment.