From e8720290a473d80d08e6d30b1751ca13762d23dd Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Sat, 24 Dec 2022 15:44:27 +0900 Subject: [PATCH] simplify torch.Tensor (#16190) --- tests/tests_lite/test_parity.py | 4 ++-- tests/tests_lite/utilities/test_apply_func.py | 3 ++- tests/tests_lite/utilities/test_data.py | 7 ++++--- tests/tests_lite/utilities/test_optimizer.py | 5 +++-- .../core/test_metric_result_integration.py | 3 ++- tests/tests_pytorch/helpers/datasets.py | 7 ++++--- .../helpers/deterministic_model.py | 4 ++-- .../loops/test_evaluation_loop_flow.py | 5 +++-- .../loops/test_training_loop_flow_scalar.py | 8 ++++---- tests/tests_pytorch/models/test_gpu.py | 2 +- tests/tests_pytorch/models/test_hooks.py | 3 ++- tests/tests_pytorch/models/test_horovod.py | 4 ++-- tests/tests_pytorch/models/test_restore.py | 3 ++- tests/tests_pytorch/plugins/test_amp_plugins.py | 3 ++- .../serve/test_servable_module_validator.py | 3 ++- tests/tests_pytorch/strategies/test_hivemind.py | 9 +++++---- .../strategies/test_sharded_strategy.py | 3 ++- .../trainer/logging_/test_eval_loop_logging.py | 13 +++++++------ .../trainer/logging_/test_train_loop_logging.py | 17 +++++++++-------- .../utilities/test_auto_restart.py | 3 ++- tests/tests_pytorch/utilities/test_fetching.py | 5 +++-- 21 files changed, 65 insertions(+), 49 deletions(-) diff --git a/tests/tests_lite/test_parity.py b/tests/tests_lite/test_parity.py index da18dfced70d7..fbb2c7bc4969c 100644 --- a/tests/tests_lite/test_parity.py +++ b/tests/tests_lite/test_parity.py @@ -25,7 +25,7 @@ from lightning_utilities.core.apply_func import apply_to_collection from tests_lite.helpers.models import RandomDataset from tests_lite.helpers.runif import RunIf -from torch import nn +from torch import nn, Tensor from torch.nn.parallel.distributed import DistributedDataParallel from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler @@ -131,7 +131,7 @@ def test_boring_lite_model_single_device(precision, strategy, devices, accelerat model.load_state_dict(state_dict) pure_state_dict = main(lite.to_device, model, train_dataloader, num_epochs=num_epochs) - state_dict = apply_to_collection(state_dict, torch.Tensor, lite.to_device) + state_dict = apply_to_collection(state_dict, Tensor, lite.to_device) for w_pure, w_lite in zip(state_dict.values(), lite_state_dict.values()): # TODO: This should be torch.equal, but MPS does not yet support this operation (torch 1.12) assert not torch.allclose(w_pure, w_lite) diff --git a/tests/tests_lite/utilities/test_apply_func.py b/tests/tests_lite/utilities/test_apply_func.py index b299783ae85d7..f4ec59adfb257 100644 --- a/tests/tests_lite/utilities/test_apply_func.py +++ b/tests/tests_lite/utilities/test_apply_func.py @@ -13,6 +13,7 @@ # limitations under the License. import pytest import torch +from torch import Tensor from lightning_lite.utilities.apply_func import move_data_to_device @@ -20,7 +21,7 @@ @pytest.mark.parametrize("should_return", [False, True]) def test_wrongly_implemented_transferable_data_type(should_return): class TensorObject: - def __init__(self, tensor: torch.Tensor, should_return: bool = True): + def __init__(self, tensor: Tensor, should_return: bool = True): self.tensor = tensor self.should_return = should_return diff --git a/tests/tests_lite/utilities/test_data.py b/tests/tests_lite/utilities/test_data.py index 23a84901e40e3..5facaba88c6fe 100644 --- a/tests/tests_lite/utilities/test_data.py +++ b/tests/tests_lite/utilities/test_data.py @@ -4,6 +4,7 @@ import pytest import torch from tests_lite.helpers.models import RandomDataset, RandomIterableDataset +from torch import Tensor from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler from lightning_lite.utilities.data import ( @@ -87,7 +88,7 @@ def __init__(self, attribute2, *args, **kwargs): class MyDataLoader(MyBaseDataLoader): - def __init__(self, data: torch.Tensor, *args, **kwargs): + def __init__(self, data: Tensor, *args, **kwargs): self.data = data super().__init__(range(data.size(0)), *args, **kwargs) @@ -209,7 +210,7 @@ def test_replace_dunder_methods_dataloader(cls, args, kwargs, arg_names, dataset for key, value in checked_values.items(): dataloader_value = getattr(dataloader, key) - if isinstance(dataloader_value, torch.Tensor): + if isinstance(dataloader_value, Tensor): assert dataloader_value is value else: assert dataloader_value == value @@ -227,7 +228,7 @@ def test_replace_dunder_methods_dataloader(cls, args, kwargs, arg_names, dataset for key, value in checked_values.items(): dataloader_value = getattr(dataloader, key) - if isinstance(dataloader_value, torch.Tensor): + if isinstance(dataloader_value, Tensor): assert dataloader_value is value else: assert dataloader_value == value diff --git a/tests/tests_lite/utilities/test_optimizer.py b/tests/tests_lite/utilities/test_optimizer.py index e14b8bc8cbe29..995b814ee2a24 100644 --- a/tests/tests_lite/utilities/test_optimizer.py +++ b/tests/tests_lite/utilities/test_optimizer.py @@ -1,6 +1,7 @@ import collections import torch +from torch import Tensor from lightning_lite.utilities.optimizer import _optimizer_to_device @@ -22,9 +23,9 @@ def __init__(self, *args, **kwargs): def assert_opt_parameters_on_device(opt, device: str): for param in opt.state.values(): # Not sure there are any global tensors in the state dict - if isinstance(param, torch.Tensor): + if isinstance(param, Tensor): assert param.data.device.type == device elif isinstance(param, collections.Mapping): for subparam in param.values(): - if isinstance(subparam, torch.Tensor): + if isinstance(subparam, Tensor): assert param.data.device.type == device diff --git a/tests/tests_pytorch/core/test_metric_result_integration.py b/tests/tests_pytorch/core/test_metric_result_integration.py index 57dab166a313c..e39d7909cf843 100644 --- a/tests/tests_pytorch/core/test_metric_result_integration.py +++ b/tests/tests_pytorch/core/test_metric_result_integration.py @@ -21,6 +21,7 @@ import torch import torchmetrics from lightning_utilities.test.warning import no_warning_call +from torch import Tensor from torch.nn import ModuleDict, ModuleList from torchmetrics import Metric, MetricCollection @@ -662,7 +663,7 @@ def test_logger_sync_dist(distributed_env, log_val): # self.log('bar', 0.5, ..., sync_dist=False) meta = _Metadata("foo", "bar") meta.sync = _Sync(_should=False) - is_tensor = isinstance(log_val, torch.Tensor) + is_tensor = isinstance(log_val, Tensor) if not is_tensor: log_val.update(torch.tensor([0, 1]), torch.tensor([0, 0], dtype=torch.long)) diff --git a/tests/tests_pytorch/helpers/datasets.py b/tests/tests_pytorch/helpers/datasets.py index c9d185313e85e..8ef629421ed44 100644 --- a/tests/tests_pytorch/helpers/datasets.py +++ b/tests/tests_pytorch/helpers/datasets.py @@ -19,6 +19,7 @@ from typing import Optional, Sequence, Tuple import torch +from torch import Tensor from torch.utils.data import Dataset @@ -69,7 +70,7 @@ def __init__( data_file = self.TRAIN_FILE_NAME if self.train else self.TEST_FILE_NAME self.data, self.targets = self._try_load(os.path.join(self.cached_folder_path, data_file)) - def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]: + def __getitem__(self, idx: int) -> Tuple[Tensor, int]: img = self.data[idx].float().unsqueeze(0) target = int(self.targets[idx]) @@ -125,7 +126,7 @@ def _try_load(path_data, trials: int = 30, delta: float = 1.0): return res @staticmethod - def normalize_tensor(tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0) -> torch.Tensor: + def normalize_tensor(tensor: Tensor, mean: float = 0.0, std: float = 1.0) -> Tensor: mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device) std = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device) return tensor.sub(mean).div(std) @@ -160,7 +161,7 @@ def __init__(self, root: str, num_samples: int = 100, digits: Optional[Sequence] super().__init__(root, normalize=(0.5, 1.0), **kwargs) @staticmethod - def _prepare_subset(full_data: torch.Tensor, full_targets: torch.Tensor, num_samples: int, digits: Sequence): + def _prepare_subset(full_data: Tensor, full_targets: Tensor, num_samples: int, digits: Sequence): classes = {d: 0 for d in digits} indexes = [] for idx, target in enumerate(full_targets): diff --git a/tests/tests_pytorch/helpers/deterministic_model.py b/tests/tests_pytorch/helpers/deterministic_model.py index 2fa165d7a55fa..25c6a3aa9afd2 100644 --- a/tests/tests_pytorch/helpers/deterministic_model.py +++ b/tests/tests_pytorch/helpers/deterministic_model.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import torch -from torch import nn +from torch import nn, Tensor from torch.utils.data import DataLoader, Dataset from pytorch_lightning.core.module import LightningModule @@ -56,7 +56,7 @@ def step(self, batch, batch_idx): def count_num_graphs(self, result, num_graphs=0): for k, v in result.items(): - if isinstance(v, torch.Tensor) and v.grad_fn is not None: + if isinstance(v, Tensor) and v.grad_fn is not None: num_graphs += 1 if isinstance(v, dict): num_graphs += self.count_num_graphs(v) diff --git a/tests/tests_pytorch/loops/test_evaluation_loop_flow.py b/tests/tests_pytorch/loops/test_evaluation_loop_flow.py index 59a511933b32b..fb3c73fd5601c 100644 --- a/tests/tests_pytorch/loops/test_evaluation_loop_flow.py +++ b/tests/tests_pytorch/loops/test_evaluation_loop_flow.py @@ -14,6 +14,7 @@ """Tests the evaluation loop.""" import torch +from torch import Tensor from pytorch_lightning import Trainer from pytorch_lightning.core.module import LightningModule @@ -68,7 +69,7 @@ def backward(self, loss, optimizer, optimizer_idx): assert len(train_step_out) == 1 train_step_out = train_step_out[0][0] - assert isinstance(train_step_out["loss"], torch.Tensor) + assert isinstance(train_step_out["loss"], Tensor) assert train_step_out["loss"].item() == 171 # make sure the optimizer closure returns the correct things @@ -129,7 +130,7 @@ def backward(self, loss, optimizer, optimizer_idx): assert len(train_step_out) == 1 train_step_out = train_step_out[0][0] - assert isinstance(train_step_out["loss"], torch.Tensor) + assert isinstance(train_step_out["loss"], Tensor) assert train_step_out["loss"].item() == 171 # make sure the optimizer closure returns the correct things diff --git a/tests/tests_pytorch/loops/test_training_loop_flow_scalar.py b/tests/tests_pytorch/loops/test_training_loop_flow_scalar.py index 01ff4ea9a70ee..087b9f953d809 100644 --- a/tests/tests_pytorch/loops/test_training_loop_flow_scalar.py +++ b/tests/tests_pytorch/loops/test_training_loop_flow_scalar.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import pytest -import torch from lightning_utilities.test.warning import no_warning_call +from torch import Tensor from torch.utils.data import DataLoader from torch.utils.data._utils.collate import default_collate @@ -151,7 +151,7 @@ def backward(self, loss, optimizer, optimizer_idx): assert len(train_step_out) == 1 train_step_out = train_step_out[0][0] - assert isinstance(train_step_out["loss"], torch.Tensor) + assert isinstance(train_step_out["loss"], Tensor) assert train_step_out["loss"].item() == 171 # make sure the optimizer closure returns the correct things @@ -172,7 +172,7 @@ def training_step(self, batch, batch_idx): return acc def training_step_end(self, tr_step_output): - assert isinstance(tr_step_output, torch.Tensor) + assert isinstance(tr_step_output, Tensor) assert self.count_num_graphs({"loss": tr_step_output}) == 1 self.training_step_end_called = True return tr_step_output @@ -221,7 +221,7 @@ def backward(self, loss, optimizer, optimizer_idx): assert len(train_step_out) == 1 train_step_out = train_step_out[0][0] - assert isinstance(train_step_out["loss"], torch.Tensor) + assert isinstance(train_step_out["loss"], Tensor) assert train_step_out["loss"].item() == 171 # make sure the optimizer closure returns the correct things diff --git a/tests/tests_pytorch/models/test_gpu.py b/tests/tests_pytorch/models/test_gpu.py index 8c4dd78ef28e3..4459ea5d7b75b 100644 --- a/tests/tests_pytorch/models/test_gpu.py +++ b/tests/tests_pytorch/models/test_gpu.py @@ -184,7 +184,7 @@ def to(self, *args, **kwargs): @RunIf(min_cuda_gpus=1) def test_non_blocking(): - """Tests that non_blocking=True only gets passed on torch.Tensor.to, but not on other objects.""" + """Tests that non_blocking=True only gets passed on Tensor.to, but not on other objects.""" trainer = Trainer() batch = torch.zeros(2, 3) diff --git a/tests/tests_pytorch/models/test_hooks.py b/tests/tests_pytorch/models/test_hooks.py index 38a42f8b3d1fc..3702c8a922546 100644 --- a/tests/tests_pytorch/models/test_hooks.py +++ b/tests/tests_pytorch/models/test_hooks.py @@ -18,6 +18,7 @@ import pytest import torch +from torch import Tensor from torch.utils.data import DataLoader from pytorch_lightning import __version__, Callback, LightningDataModule, LightningModule, Trainer @@ -827,7 +828,7 @@ def test_hooks_with_different_argument_names(tmpdir): class CustomBoringModel(BoringModel): def assert_args(self, x, batch_nb): - assert isinstance(x, torch.Tensor) + assert isinstance(x, Tensor) assert x.size() == (1, 32) assert isinstance(batch_nb, int) diff --git a/tests/tests_pytorch/models/test_horovod.py b/tests/tests_pytorch/models/test_horovod.py index fc71dc42cba40..60e45925d8254 100644 --- a/tests/tests_pytorch/models/test_horovod.py +++ b/tests/tests_pytorch/models/test_horovod.py @@ -21,7 +21,7 @@ import numpy as np import pytest import torch -from torch import optim +from torch import optim, Tensor from torchmetrics.classification.accuracy import Accuracy import tests_pytorch.helpers.pipelines as tpipes @@ -387,7 +387,7 @@ def _compute_batch(): # check on all batches on all ranks result = metric.compute() - assert isinstance(result, torch.Tensor) + assert isinstance(result, Tensor) total_preds = torch.stack([preds[i] for i in range(num_batches)]) total_target = torch.stack([target[i] for i in range(num_batches)]) diff --git a/tests/tests_pytorch/models/test_restore.py b/tests/tests_pytorch/models/test_restore.py index 7eb475985ea22..54cadf80a495b 100644 --- a/tests/tests_pytorch/models/test_restore.py +++ b/tests/tests_pytorch/models/test_restore.py @@ -24,6 +24,7 @@ import torch import torch.nn.functional as F from lightning_utilities.test.warning import no_warning_call +from torch import Tensor import tests_pytorch.helpers.pipelines as tpipes import tests_pytorch.helpers.utils as tutils @@ -160,7 +161,7 @@ def configure_optimizers(self): class CustomClassifModel(CustomClassifModel): def _is_equal(self, a, b): - if isinstance(a, torch.Tensor): + if isinstance(a, Tensor): return torch.all(torch.eq(a, b)) if isinstance(a, Mapping): diff --git a/tests/tests_pytorch/plugins/test_amp_plugins.py b/tests/tests_pytorch/plugins/test_amp_plugins.py index 1c9dd53f2da10..8d91376da1283 100644 --- a/tests/tests_pytorch/plugins/test_amp_plugins.py +++ b/tests/tests_pytorch/plugins/test_amp_plugins.py @@ -17,6 +17,7 @@ import pytest import torch +from torch import Tensor from pytorch_lightning import Trainer from pytorch_lightning.demos.boring_classes import BoringModel @@ -176,7 +177,7 @@ def __init__(self): self.layer1 = torch.nn.Linear(32, 32) self.layer2 = torch.nn.Linear(32, 2) - def forward(self, x: torch.Tensor): + def forward(self, x: Tensor): x = self.layer1(x) x = self.layer2(x) return x diff --git a/tests/tests_pytorch/serve/test_servable_module_validator.py b/tests/tests_pytorch/serve/test_servable_module_validator.py index dd578c5907556..127dd97cfd5af 100644 --- a/tests/tests_pytorch/serve/test_servable_module_validator.py +++ b/tests/tests_pytorch/serve/test_servable_module_validator.py @@ -2,6 +2,7 @@ import pytest import torch +from torch import Tensor from pytorch_lightning import Trainer from pytorch_lightning.demos.boring_classes import BoringModel @@ -21,7 +22,7 @@ def serialize(x): return {"x": deserialize}, {"output": serialize} - def serve_step(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: + def serve_step(self, x: Tensor) -> Dict[str, Tensor]: assert torch.equal(x, torch.arange(32, dtype=torch.float)) return {"output": torch.tensor([0, 1])} diff --git a/tests/tests_pytorch/strategies/test_hivemind.py b/tests/tests_pytorch/strategies/test_hivemind.py index 11c514fa841f9..a75d13676f3c4 100644 --- a/tests/tests_pytorch/strategies/test_hivemind.py +++ b/tests/tests_pytorch/strategies/test_hivemind.py @@ -7,6 +7,7 @@ import pytest import torch +from torch import Tensor from torch.optim import Optimizer import pytorch_lightning as pl @@ -41,7 +42,7 @@ def test_strategy(mock_dht): @mock.patch.dict(os.environ, {"HIVEMIND_MEMORY_SHARING_STRATEGY": "file_descriptor"}, clear=True) def test_optimizer_wrapped(): class TestModel(BoringModel): - def on_before_backward(self, loss: torch.Tensor) -> None: + def on_before_backward(self, loss: Tensor) -> None: optimizer = self.trainer.optimizers[0] assert isinstance(optimizer, hivemind.Optimizer) @@ -54,7 +55,7 @@ def on_before_backward(self, loss: torch.Tensor) -> None: @mock.patch.dict(os.environ, {"HIVEMIND_MEMORY_SHARING_STRATEGY": "file_descriptor"}, clear=True) def test_scheduler_wrapped(): class TestModel(BoringModel): - def on_before_backward(self, loss: torch.Tensor) -> None: + def on_before_backward(self, loss: Tensor) -> None: scheduler = self.trainer.lr_scheduler_configs[0].scheduler assert isinstance(scheduler, HiveMindScheduler) @@ -92,7 +93,7 @@ def test_reuse_grad_buffers_warning(): """Test to ensure we warn when a user overrides `optimizer_zero_grad` and `reuse_grad_buffers` is True.""" class TestModel(BoringModel): - def on_before_backward(self, loss: torch.Tensor) -> None: + def on_before_backward(self, loss: Tensor) -> None: optimizer = self.trainer.optimizers[0] assert isinstance(optimizer, hivemind.Optimizer) @@ -169,7 +170,7 @@ def test_args_passed_to_optimizer(mock_peers): with mock.patch("hivemind.Optimizer", wraps=hivemind.Optimizer) as mock_optimizer: class TestModel(BoringModel): - def on_before_backward(self, loss: torch.Tensor) -> None: + def on_before_backward(self, loss: Tensor) -> None: args, kwargs = mock_optimizer.call_args mock_optimizer.assert_called() arguments = dict( diff --git a/tests/tests_pytorch/strategies/test_sharded_strategy.py b/tests/tests_pytorch/strategies/test_sharded_strategy.py index 7200d4a866397..dd8e8127c4c92 100644 --- a/tests/tests_pytorch/strategies/test_sharded_strategy.py +++ b/tests/tests_pytorch/strategies/test_sharded_strategy.py @@ -6,6 +6,7 @@ import pytest import torch +from torch import Tensor from lightning_lite.strategies.fairscale import _FAIRSCALE_AVAILABLE from pytorch_lightning import LightningModule, Trainer @@ -42,7 +43,7 @@ def on_train_start(self): self._is_equal(optimizer_state, state) def _is_equal(self, a, b): - if isinstance(a, torch.Tensor): + if isinstance(a, Tensor): return torch.allclose(a, b) if isinstance(a, Mapping): diff --git a/tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py b/tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py index 1f234166be5aa..8b3069635b553 100644 --- a/tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py +++ b/tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py @@ -23,6 +23,7 @@ import numpy as np import pytest import torch +from torch import Tensor from pytorch_lightning import callbacks, Trainer from pytorch_lightning.callbacks.progress.rich_progress import _RICH_AVAILABLE @@ -72,8 +73,8 @@ def validation_step(self, batch, batch_idx): # we don't want to enable val metrics during steps because it is not something that users should do # on purpose DO NOT allow b_step... it's silly to monitor val step metrics assert set(trainer.callback_metrics) == {"a", "a2", "b", "a_epoch", "b_epoch", "a_step"} - assert all(isinstance(v, torch.Tensor) for v in trainer.callback_metrics.values()) - assert all(isinstance(v, torch.Tensor) for v in trainer.logged_metrics.values()) + assert all(isinstance(v, Tensor) for v in trainer.callback_metrics.values()) + assert all(isinstance(v, Tensor) for v in trainer.logged_metrics.values()) assert all(isinstance(v, float) for v in trainer.progress_bar_metrics.values()) @@ -115,8 +116,8 @@ def validation_epoch_end(self, outputs): # we don't want to enable val metrics during steps because it is not something that users should do assert set(trainer.callback_metrics) == {"a", "b", "b_epoch", "c", "d", "d_epoch", "g", "b_step"} - assert all(isinstance(v, torch.Tensor) for v in trainer.callback_metrics.values()) - assert all(isinstance(v, torch.Tensor) for v in trainer.logged_metrics.values()) + assert all(isinstance(v, Tensor) for v in trainer.callback_metrics.values()) + assert all(isinstance(v, Tensor) for v in trainer.logged_metrics.values()) assert all(isinstance(v, float) for v in trainer.progress_bar_metrics.values()) @@ -149,8 +150,8 @@ def validation_epoch_end(self, outputs): # make sure all the metrics are available for callbacks callback_metrics = set(trainer.callback_metrics) assert callback_metrics == (logged_metrics | pbar_metrics) - assert all(isinstance(v, torch.Tensor) for v in trainer.callback_metrics.values()) - assert all(isinstance(v, torch.Tensor) for v in trainer.logged_metrics.values()) + assert all(isinstance(v, Tensor) for v in trainer.callback_metrics.values()) + assert all(isinstance(v, Tensor) for v in trainer.logged_metrics.values()) assert all(isinstance(v, float) for v in trainer.progress_bar_metrics.values()) diff --git a/tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py b/tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py index 0352c9eda8d32..a2be3d0dcc3c4 100644 --- a/tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py +++ b/tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py @@ -23,6 +23,7 @@ import pytest import torch from lightning_utilities.test.warning import no_warning_call +from torch import Tensor from torch.utils.data import DataLoader from torchmetrics import Accuracy @@ -96,8 +97,8 @@ def training_step(self, batch, batch_idx): assert pbar_metrics == {"p_e", "p_s", "p_se_step", "p_se_epoch"} assert set(trainer.callback_metrics) == (logged_metrics | pbar_metrics | {"p_se", "l_se"}) - assert all(isinstance(v, torch.Tensor) for v in trainer.callback_metrics.values()) - assert all(isinstance(v, torch.Tensor) for v in trainer.logged_metrics.values()) + assert all(isinstance(v, Tensor) for v in trainer.callback_metrics.values()) + assert all(isinstance(v, Tensor) for v in trainer.logged_metrics.values()) assert all(isinstance(v, float) for v in trainer.progress_bar_metrics.values()) @@ -136,8 +137,8 @@ def training_epoch_end(self, outputs): assert pbar_metrics == {"b"} assert set(trainer.callback_metrics) == (logged_metrics | pbar_metrics | {"a"}) - assert all(isinstance(v, torch.Tensor) for v in trainer.callback_metrics.values()) - assert all(isinstance(v, torch.Tensor) for v in trainer.logged_metrics.values()) + assert all(isinstance(v, Tensor) for v in trainer.callback_metrics.values()) + assert all(isinstance(v, Tensor) for v in trainer.logged_metrics.values()) assert all(isinstance(v, float) for v in trainer.progress_bar_metrics.values()) @@ -180,8 +181,8 @@ def training_epoch_end(self, outputs): assert pbar_metrics == {"c", "b_epoch", "b_step"} assert set(trainer.callback_metrics) == (logged_metrics | pbar_metrics | {"a", "b"}) - assert all(isinstance(v, torch.Tensor) for v in trainer.callback_metrics.values()) - assert all(isinstance(v, torch.Tensor) for v in trainer.logged_metrics.values()) + assert all(isinstance(v, Tensor) for v in trainer.callback_metrics.values()) + assert all(isinstance(v, Tensor) for v in trainer.logged_metrics.values()) assert all(isinstance(v, float) for v in trainer.progress_bar_metrics.values()) @@ -724,7 +725,7 @@ def training_step(self, batch, batch_idx): @RunIf(min_cuda_gpus=1) def test_move_metrics_to_cpu(tmpdir): class TestModel(BoringModel): - def on_before_backward(self, loss: torch.Tensor) -> None: + def on_before_backward(self, loss: Tensor) -> None: assert loss.device.type == "cuda" trainer = Trainer( @@ -846,5 +847,5 @@ def test_unsqueezed_tensor_logging(): trainer.state.stage = RunningStage.TRAINING model._current_fx_name = "training_step" model.trainer = trainer - model.log("foo", torch.Tensor([1.2])) + model.log("foo", Tensor([1.2])) assert trainer.callback_metrics["foo"].ndim == 0 diff --git a/tests/tests_pytorch/utilities/test_auto_restart.py b/tests/tests_pytorch/utilities/test_auto_restart.py index c51391bfcb6dc..9b8d201cd620a 100644 --- a/tests/tests_pytorch/utilities/test_auto_restart.py +++ b/tests/tests_pytorch/utilities/test_auto_restart.py @@ -27,6 +27,7 @@ import numpy as np import pytest import torch +from torch import Tensor from torch.utils.data import BatchSampler, DistributedSampler, RandomSampler, SequentialSampler from torch.utils.data._utils.worker import _generate_state, get_worker_info from torch.utils.data.dataloader import DataLoader, default_collate @@ -620,7 +621,7 @@ def training_step(self, batch, batch_idx): } def validation_step(self, batch, batch_idx): - assert isinstance(batch, torch.Tensor) + assert isinstance(batch, Tensor) validation_epoch_end = None diff --git a/tests/tests_pytorch/utilities/test_fetching.py b/tests/tests_pytorch/utilities/test_fetching.py index f1ecabdbdd55a..0f6c70cfa8289 100644 --- a/tests/tests_pytorch/utilities/test_fetching.py +++ b/tests/tests_pytorch/utilities/test_fetching.py @@ -18,6 +18,7 @@ import pytest import torch +from torch import Tensor from torch.utils.data import DataLoader, Dataset, IterableDataset from pytorch_lightning import Callback, LightningDataModule, Trainer @@ -189,7 +190,7 @@ def __init__(self): self.local_embedding = torch.nn.Embedding(EMB_SZ, EMB_DIM) self.CYCLES_PER_MS = int(get_cycles_per_ms()) - def forward(self, indices: torch.Tensor): + def forward(self, indices: Tensor): result = self.local_embedding(indices) return result @@ -280,7 +281,7 @@ def training_step(self, dataloader_iter, batch_idx): self.batches.append(next(dataloader_iter)) batch = self.batches.pop(0) - assert isinstance(batch, torch.Tensor) or batch is None + assert isinstance(batch, Tensor) or batch is None self.count += 2 if self.automatic_optimization: loss = super().training_step(batch, 0)