Skip to content

Commit

Permalink
simplify torch.Tensor (#16190)
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda authored Dec 24, 2022
1 parent 87818e3 commit e872029
Show file tree
Hide file tree
Showing 21 changed files with 65 additions and 49 deletions.
4 changes: 2 additions & 2 deletions tests/tests_lite/test_parity.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from lightning_utilities.core.apply_func import apply_to_collection
from tests_lite.helpers.models import RandomDataset
from tests_lite.helpers.runif import RunIf
from torch import nn
from torch import nn, Tensor
from torch.nn.parallel.distributed import DistributedDataParallel
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
Expand Down Expand Up @@ -131,7 +131,7 @@ def test_boring_lite_model_single_device(precision, strategy, devices, accelerat
model.load_state_dict(state_dict)
pure_state_dict = main(lite.to_device, model, train_dataloader, num_epochs=num_epochs)

state_dict = apply_to_collection(state_dict, torch.Tensor, lite.to_device)
state_dict = apply_to_collection(state_dict, Tensor, lite.to_device)
for w_pure, w_lite in zip(state_dict.values(), lite_state_dict.values()):
# TODO: This should be torch.equal, but MPS does not yet support this operation (torch 1.12)
assert not torch.allclose(w_pure, w_lite)
Expand Down
3 changes: 2 additions & 1 deletion tests/tests_lite/utilities/test_apply_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@
# limitations under the License.
import pytest
import torch
from torch import Tensor

from lightning_lite.utilities.apply_func import move_data_to_device


@pytest.mark.parametrize("should_return", [False, True])
def test_wrongly_implemented_transferable_data_type(should_return):
class TensorObject:
def __init__(self, tensor: torch.Tensor, should_return: bool = True):
def __init__(self, tensor: Tensor, should_return: bool = True):
self.tensor = tensor
self.should_return = should_return

Expand Down
7 changes: 4 additions & 3 deletions tests/tests_lite/utilities/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest
import torch
from tests_lite.helpers.models import RandomDataset, RandomIterableDataset
from torch import Tensor
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler

from lightning_lite.utilities.data import (
Expand Down Expand Up @@ -87,7 +88,7 @@ def __init__(self, attribute2, *args, **kwargs):


class MyDataLoader(MyBaseDataLoader):
def __init__(self, data: torch.Tensor, *args, **kwargs):
def __init__(self, data: Tensor, *args, **kwargs):
self.data = data
super().__init__(range(data.size(0)), *args, **kwargs)

Expand Down Expand Up @@ -209,7 +210,7 @@ def test_replace_dunder_methods_dataloader(cls, args, kwargs, arg_names, dataset

for key, value in checked_values.items():
dataloader_value = getattr(dataloader, key)
if isinstance(dataloader_value, torch.Tensor):
if isinstance(dataloader_value, Tensor):
assert dataloader_value is value
else:
assert dataloader_value == value
Expand All @@ -227,7 +228,7 @@ def test_replace_dunder_methods_dataloader(cls, args, kwargs, arg_names, dataset

for key, value in checked_values.items():
dataloader_value = getattr(dataloader, key)
if isinstance(dataloader_value, torch.Tensor):
if isinstance(dataloader_value, Tensor):
assert dataloader_value is value
else:
assert dataloader_value == value
Expand Down
5 changes: 3 additions & 2 deletions tests/tests_lite/utilities/test_optimizer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import collections

import torch
from torch import Tensor

from lightning_lite.utilities.optimizer import _optimizer_to_device

Expand All @@ -22,9 +23,9 @@ def __init__(self, *args, **kwargs):
def assert_opt_parameters_on_device(opt, device: str):
for param in opt.state.values():
# Not sure there are any global tensors in the state dict
if isinstance(param, torch.Tensor):
if isinstance(param, Tensor):
assert param.data.device.type == device
elif isinstance(param, collections.Mapping):
for subparam in param.values():
if isinstance(subparam, torch.Tensor):
if isinstance(subparam, Tensor):
assert param.data.device.type == device
3 changes: 2 additions & 1 deletion tests/tests_pytorch/core/test_metric_result_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import torch
import torchmetrics
from lightning_utilities.test.warning import no_warning_call
from torch import Tensor
from torch.nn import ModuleDict, ModuleList
from torchmetrics import Metric, MetricCollection

Expand Down Expand Up @@ -662,7 +663,7 @@ def test_logger_sync_dist(distributed_env, log_val):
# self.log('bar', 0.5, ..., sync_dist=False)
meta = _Metadata("foo", "bar")
meta.sync = _Sync(_should=False)
is_tensor = isinstance(log_val, torch.Tensor)
is_tensor = isinstance(log_val, Tensor)

if not is_tensor:
log_val.update(torch.tensor([0, 1]), torch.tensor([0, 0], dtype=torch.long))
Expand Down
7 changes: 4 additions & 3 deletions tests/tests_pytorch/helpers/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from typing import Optional, Sequence, Tuple

import torch
from torch import Tensor
from torch.utils.data import Dataset


Expand Down Expand Up @@ -69,7 +70,7 @@ def __init__(
data_file = self.TRAIN_FILE_NAME if self.train else self.TEST_FILE_NAME
self.data, self.targets = self._try_load(os.path.join(self.cached_folder_path, data_file))

def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
def __getitem__(self, idx: int) -> Tuple[Tensor, int]:
img = self.data[idx].float().unsqueeze(0)
target = int(self.targets[idx])

Expand Down Expand Up @@ -125,7 +126,7 @@ def _try_load(path_data, trials: int = 30, delta: float = 1.0):
return res

@staticmethod
def normalize_tensor(tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0) -> torch.Tensor:
def normalize_tensor(tensor: Tensor, mean: float = 0.0, std: float = 1.0) -> Tensor:
mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device)
std = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device)
return tensor.sub(mean).div(std)
Expand Down Expand Up @@ -160,7 +161,7 @@ def __init__(self, root: str, num_samples: int = 100, digits: Optional[Sequence]
super().__init__(root, normalize=(0.5, 1.0), **kwargs)

@staticmethod
def _prepare_subset(full_data: torch.Tensor, full_targets: torch.Tensor, num_samples: int, digits: Sequence):
def _prepare_subset(full_data: Tensor, full_targets: Tensor, num_samples: int, digits: Sequence):
classes = {d: 0 for d in digits}
indexes = []
for idx, target in enumerate(full_targets):
Expand Down
4 changes: 2 additions & 2 deletions tests/tests_pytorch/helpers/deterministic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import nn
from torch import nn, Tensor
from torch.utils.data import DataLoader, Dataset

from pytorch_lightning.core.module import LightningModule
Expand Down Expand Up @@ -56,7 +56,7 @@ def step(self, batch, batch_idx):

def count_num_graphs(self, result, num_graphs=0):
for k, v in result.items():
if isinstance(v, torch.Tensor) and v.grad_fn is not None:
if isinstance(v, Tensor) and v.grad_fn is not None:
num_graphs += 1
if isinstance(v, dict):
num_graphs += self.count_num_graphs(v)
Expand Down
5 changes: 3 additions & 2 deletions tests/tests_pytorch/loops/test_evaluation_loop_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""Tests the evaluation loop."""

import torch
from torch import Tensor

from pytorch_lightning import Trainer
from pytorch_lightning.core.module import LightningModule
Expand Down Expand Up @@ -68,7 +69,7 @@ def backward(self, loss, optimizer, optimizer_idx):

assert len(train_step_out) == 1
train_step_out = train_step_out[0][0]
assert isinstance(train_step_out["loss"], torch.Tensor)
assert isinstance(train_step_out["loss"], Tensor)
assert train_step_out["loss"].item() == 171

# make sure the optimizer closure returns the correct things
Expand Down Expand Up @@ -129,7 +130,7 @@ def backward(self, loss, optimizer, optimizer_idx):

assert len(train_step_out) == 1
train_step_out = train_step_out[0][0]
assert isinstance(train_step_out["loss"], torch.Tensor)
assert isinstance(train_step_out["loss"], Tensor)
assert train_step_out["loss"].item() == 171

# make sure the optimizer closure returns the correct things
Expand Down
8 changes: 4 additions & 4 deletions tests/tests_pytorch/loops/test_training_loop_flow_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch
from lightning_utilities.test.warning import no_warning_call
from torch import Tensor
from torch.utils.data import DataLoader
from torch.utils.data._utils.collate import default_collate

Expand Down Expand Up @@ -151,7 +151,7 @@ def backward(self, loss, optimizer, optimizer_idx):

assert len(train_step_out) == 1
train_step_out = train_step_out[0][0]
assert isinstance(train_step_out["loss"], torch.Tensor)
assert isinstance(train_step_out["loss"], Tensor)
assert train_step_out["loss"].item() == 171

# make sure the optimizer closure returns the correct things
Expand All @@ -172,7 +172,7 @@ def training_step(self, batch, batch_idx):
return acc

def training_step_end(self, tr_step_output):
assert isinstance(tr_step_output, torch.Tensor)
assert isinstance(tr_step_output, Tensor)
assert self.count_num_graphs({"loss": tr_step_output}) == 1
self.training_step_end_called = True
return tr_step_output
Expand Down Expand Up @@ -221,7 +221,7 @@ def backward(self, loss, optimizer, optimizer_idx):

assert len(train_step_out) == 1
train_step_out = train_step_out[0][0]
assert isinstance(train_step_out["loss"], torch.Tensor)
assert isinstance(train_step_out["loss"], Tensor)
assert train_step_out["loss"].item() == 171

# make sure the optimizer closure returns the correct things
Expand Down
2 changes: 1 addition & 1 deletion tests/tests_pytorch/models/test_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def to(self, *args, **kwargs):

@RunIf(min_cuda_gpus=1)
def test_non_blocking():
"""Tests that non_blocking=True only gets passed on torch.Tensor.to, but not on other objects."""
"""Tests that non_blocking=True only gets passed on Tensor.to, but not on other objects."""
trainer = Trainer()

batch = torch.zeros(2, 3)
Expand Down
3 changes: 2 additions & 1 deletion tests/tests_pytorch/models/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import pytest
import torch
from torch import Tensor
from torch.utils.data import DataLoader

from pytorch_lightning import __version__, Callback, LightningDataModule, LightningModule, Trainer
Expand Down Expand Up @@ -827,7 +828,7 @@ def test_hooks_with_different_argument_names(tmpdir):

class CustomBoringModel(BoringModel):
def assert_args(self, x, batch_nb):
assert isinstance(x, torch.Tensor)
assert isinstance(x, Tensor)
assert x.size() == (1, 32)
assert isinstance(batch_nb, int)

Expand Down
4 changes: 2 additions & 2 deletions tests/tests_pytorch/models/test_horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import numpy as np
import pytest
import torch
from torch import optim
from torch import optim, Tensor
from torchmetrics.classification.accuracy import Accuracy

import tests_pytorch.helpers.pipelines as tpipes
Expand Down Expand Up @@ -387,7 +387,7 @@ def _compute_batch():

# check on all batches on all ranks
result = metric.compute()
assert isinstance(result, torch.Tensor)
assert isinstance(result, Tensor)

total_preds = torch.stack([preds[i] for i in range(num_batches)])
total_target = torch.stack([target[i] for i in range(num_batches)])
Expand Down
3 changes: 2 additions & 1 deletion tests/tests_pytorch/models/test_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import torch
import torch.nn.functional as F
from lightning_utilities.test.warning import no_warning_call
from torch import Tensor

import tests_pytorch.helpers.pipelines as tpipes
import tests_pytorch.helpers.utils as tutils
Expand Down Expand Up @@ -160,7 +161,7 @@ def configure_optimizers(self):

class CustomClassifModel(CustomClassifModel):
def _is_equal(self, a, b):
if isinstance(a, torch.Tensor):
if isinstance(a, Tensor):
return torch.all(torch.eq(a, b))

if isinstance(a, Mapping):
Expand Down
3 changes: 2 additions & 1 deletion tests/tests_pytorch/plugins/test_amp_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import pytest
import torch
from torch import Tensor

from pytorch_lightning import Trainer
from pytorch_lightning.demos.boring_classes import BoringModel
Expand Down Expand Up @@ -176,7 +177,7 @@ def __init__(self):
self.layer1 = torch.nn.Linear(32, 32)
self.layer2 = torch.nn.Linear(32, 2)

def forward(self, x: torch.Tensor):
def forward(self, x: Tensor):
x = self.layer1(x)
x = self.layer2(x)
return x
Expand Down
3 changes: 2 additions & 1 deletion tests/tests_pytorch/serve/test_servable_module_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import pytest
import torch
from torch import Tensor

from pytorch_lightning import Trainer
from pytorch_lightning.demos.boring_classes import BoringModel
Expand All @@ -21,7 +22,7 @@ def serialize(x):

return {"x": deserialize}, {"output": serialize}

def serve_step(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
def serve_step(self, x: Tensor) -> Dict[str, Tensor]:
assert torch.equal(x, torch.arange(32, dtype=torch.float))
return {"output": torch.tensor([0, 1])}

Expand Down
9 changes: 5 additions & 4 deletions tests/tests_pytorch/strategies/test_hivemind.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import pytest
import torch
from torch import Tensor
from torch.optim import Optimizer

import pytorch_lightning as pl
Expand Down Expand Up @@ -41,7 +42,7 @@ def test_strategy(mock_dht):
@mock.patch.dict(os.environ, {"HIVEMIND_MEMORY_SHARING_STRATEGY": "file_descriptor"}, clear=True)
def test_optimizer_wrapped():
class TestModel(BoringModel):
def on_before_backward(self, loss: torch.Tensor) -> None:
def on_before_backward(self, loss: Tensor) -> None:
optimizer = self.trainer.optimizers[0]
assert isinstance(optimizer, hivemind.Optimizer)

Expand All @@ -54,7 +55,7 @@ def on_before_backward(self, loss: torch.Tensor) -> None:
@mock.patch.dict(os.environ, {"HIVEMIND_MEMORY_SHARING_STRATEGY": "file_descriptor"}, clear=True)
def test_scheduler_wrapped():
class TestModel(BoringModel):
def on_before_backward(self, loss: torch.Tensor) -> None:
def on_before_backward(self, loss: Tensor) -> None:
scheduler = self.trainer.lr_scheduler_configs[0].scheduler
assert isinstance(scheduler, HiveMindScheduler)

Expand Down Expand Up @@ -92,7 +93,7 @@ def test_reuse_grad_buffers_warning():
"""Test to ensure we warn when a user overrides `optimizer_zero_grad` and `reuse_grad_buffers` is True."""

class TestModel(BoringModel):
def on_before_backward(self, loss: torch.Tensor) -> None:
def on_before_backward(self, loss: Tensor) -> None:
optimizer = self.trainer.optimizers[0]
assert isinstance(optimizer, hivemind.Optimizer)

Expand Down Expand Up @@ -169,7 +170,7 @@ def test_args_passed_to_optimizer(mock_peers):
with mock.patch("hivemind.Optimizer", wraps=hivemind.Optimizer) as mock_optimizer:

class TestModel(BoringModel):
def on_before_backward(self, loss: torch.Tensor) -> None:
def on_before_backward(self, loss: Tensor) -> None:
args, kwargs = mock_optimizer.call_args
mock_optimizer.assert_called()
arguments = dict(
Expand Down
3 changes: 2 additions & 1 deletion tests/tests_pytorch/strategies/test_sharded_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import pytest
import torch
from torch import Tensor

from lightning_lite.strategies.fairscale import _FAIRSCALE_AVAILABLE
from pytorch_lightning import LightningModule, Trainer
Expand Down Expand Up @@ -42,7 +43,7 @@ def on_train_start(self):
self._is_equal(optimizer_state, state)

def _is_equal(self, a, b):
if isinstance(a, torch.Tensor):
if isinstance(a, Tensor):
return torch.allclose(a, b)

if isinstance(a, Mapping):
Expand Down
Loading

0 comments on commit e872029

Please sign in to comment.