Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: simplify Tensor import #15959

Merged
merged 1 commit into from
Dec 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 11 additions & 13 deletions src/lightning_lite/plugins/collectives/collective.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from abc import ABC, abstractmethod
from typing import Any, List, Optional

import torch
from torch import Tensor
from typing_extensions import Self

from lightning_lite.utilities.types import CollectibleGroup
Expand Down Expand Up @@ -38,45 +38,43 @@ def group(self) -> CollectibleGroup:
return self._group

@abstractmethod
def broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor:
def broadcast(self, tensor: Tensor, src: int) -> Tensor:
...

@abstractmethod
def all_reduce(self, tensor: torch.Tensor, op: str) -> torch.Tensor:
def all_reduce(self, tensor: Tensor, op: str) -> Tensor:
...

@abstractmethod
def reduce(self, tensor: torch.Tensor, dst: int, op: str) -> torch.Tensor:
def reduce(self, tensor: Tensor, dst: int, op: str) -> Tensor:
...

@abstractmethod
def all_gather(self, tensor_list: List[torch.Tensor], tensor: torch.Tensor) -> List[torch.Tensor]:
def all_gather(self, tensor_list: List[Tensor], tensor: Tensor) -> List[Tensor]:
...

@abstractmethod
def gather(self, tensor: torch.Tensor, gather_list: List[torch.Tensor], dst: int = 0) -> List[torch.Tensor]:
def gather(self, tensor: Tensor, gather_list: List[Tensor], dst: int = 0) -> List[Tensor]:
...

@abstractmethod
def scatter(self, tensor: torch.Tensor, scatter_list: List[torch.Tensor], src: int = 0) -> torch.Tensor:
def scatter(self, tensor: Tensor, scatter_list: List[Tensor], src: int = 0) -> Tensor:
...

@abstractmethod
def reduce_scatter(self, output: torch.Tensor, input_list: List[torch.Tensor], op: str) -> torch.Tensor:
def reduce_scatter(self, output: Tensor, input_list: List[Tensor], op: str) -> Tensor:
...

@abstractmethod
def all_to_all(
self, output_tensor_list: List[torch.Tensor], input_tensor_list: List[torch.Tensor]
) -> List[torch.Tensor]:
def all_to_all(self, output_tensor_list: List[Tensor], input_tensor_list: List[Tensor]) -> List[Tensor]:
...

@abstractmethod
def send(self, tensor: torch.Tensor, dst: int, tag: Optional[int] = 0) -> None:
def send(self, tensor: Tensor, dst: int, tag: Optional[int] = 0) -> None:
...

@abstractmethod
def recv(self, tensor: torch.Tensor, src: Optional[int] = None, tag: Optional[int] = 0) -> torch.Tensor:
def recv(self, tensor: Tensor, src: Optional[int] = None, tag: Optional[int] = 0) -> Tensor:
...

@abstractmethod
Expand Down
26 changes: 13 additions & 13 deletions src/lightning_lite/plugins/collectives/single_device.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, List

import torch
from torch import Tensor

from lightning_lite.plugins.collectives.collective import Collective
from lightning_lite.utilities.types import CollectibleGroup
Expand All @@ -15,42 +15,42 @@ def rank(self) -> int:
def world_size(self) -> int:
return 1

def broadcast(self, tensor: torch.Tensor, *_: Any, **__: Any) -> torch.Tensor:
def broadcast(self, tensor: Tensor, *_: Any, **__: Any) -> Tensor:
return tensor

def all_reduce(self, tensor: torch.Tensor, *_: Any, **__: Any) -> torch.Tensor:
def all_reduce(self, tensor: Tensor, *_: Any, **__: Any) -> Tensor:
return tensor

def reduce(self, tensor: torch.Tensor, *_: Any, **__: Any) -> torch.Tensor:
def reduce(self, tensor: Tensor, *_: Any, **__: Any) -> Tensor:
return tensor

def all_gather(self, tensor_list: List[torch.Tensor], tensor: torch.Tensor, **__: Any) -> List[torch.Tensor]:
def all_gather(self, tensor_list: List[Tensor], tensor: Tensor, **__: Any) -> List[Tensor]:
return [tensor]

def gather(self, tensor: torch.Tensor, *_: Any, **__: Any) -> List[torch.Tensor]:
def gather(self, tensor: Tensor, *_: Any, **__: Any) -> List[Tensor]:
return [tensor]

def scatter(
self,
tensor: torch.Tensor,
scatter_list: List[torch.Tensor],
tensor: Tensor,
scatter_list: List[Tensor],
*_: Any,
**__: Any,
) -> torch.Tensor:
) -> Tensor:
return scatter_list[0]

def reduce_scatter(self, output: torch.Tensor, input_list: List[torch.Tensor], *_: Any, **__: Any) -> torch.Tensor:
def reduce_scatter(self, output: Tensor, input_list: List[Tensor], *_: Any, **__: Any) -> Tensor:
return input_list[0]

def all_to_all(
self, output_tensor_list: List[torch.Tensor], input_tensor_list: List[torch.Tensor], *_: Any, **__: Any
) -> List[torch.Tensor]:
self, output_tensor_list: List[Tensor], input_tensor_list: List[Tensor], *_: Any, **__: Any
) -> List[Tensor]:
return input_tensor_list

def send(self, *_: Any, **__: Any) -> None:
pass

def recv(self, tensor: torch.Tensor, *_: Any, **__: Any) -> torch.Tensor:
def recv(self, tensor: Tensor, *_: Any, **__: Any) -> Tensor:
return tensor

def barrier(self, *_: Any, **__: Any) -> None:
Expand Down
25 changes: 12 additions & 13 deletions src/lightning_lite/plugins/collectives/torch_collective.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import torch
import torch.distributed as dist
from torch import Tensor
from typing_extensions import Self

from lightning_lite.plugins.collectives.collective import Collective
Expand Down Expand Up @@ -33,49 +34,47 @@ def rank(self) -> int:
def world_size(self) -> int:
return dist.get_world_size(self.group)

def broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor:
def broadcast(self, tensor: Tensor, src: int) -> Tensor:
dist.broadcast(tensor, src, group=self.group)
return tensor

def all_reduce(self, tensor: torch.Tensor, op: Union[str, ReduceOp, RedOpType] = "sum") -> torch.Tensor:
def all_reduce(self, tensor: Tensor, op: Union[str, ReduceOp, RedOpType] = "sum") -> Tensor:
op = self._convert_to_native_op(op)
dist.all_reduce(tensor, op=op, group=self.group)
return tensor

def reduce(self, tensor: torch.Tensor, dst: int, op: Union[str, ReduceOp, RedOpType] = "sum") -> torch.Tensor:
def reduce(self, tensor: Tensor, dst: int, op: Union[str, ReduceOp, RedOpType] = "sum") -> Tensor:
op = self._convert_to_native_op(op)
dist.reduce(tensor, dst, op=op, group=self.group)
return tensor

def all_gather(self, tensor_list: List[torch.Tensor], tensor: torch.Tensor) -> List[torch.Tensor]:
def all_gather(self, tensor_list: List[Tensor], tensor: Tensor) -> List[Tensor]:
dist.all_gather(tensor_list, tensor, group=self.group)
return tensor_list

def gather(self, tensor: torch.Tensor, gather_list: List[torch.Tensor], dst: int = 0) -> List[torch.Tensor]:
def gather(self, tensor: Tensor, gather_list: List[Tensor], dst: int = 0) -> List[Tensor]:
dist.gather(tensor, gather_list, dst, group=self.group)
return gather_list

def scatter(self, tensor: torch.Tensor, scatter_list: List[torch.Tensor], src: int = 0) -> torch.Tensor:
def scatter(self, tensor: Tensor, scatter_list: List[Tensor], src: int = 0) -> Tensor:
dist.scatter(tensor, scatter_list, src, group=self.group)
return tensor

def reduce_scatter(
self, output: torch.Tensor, input_list: List[torch.Tensor], op: Union[str, ReduceOp, RedOpType] = "sum"
) -> torch.Tensor:
self, output: Tensor, input_list: List[Tensor], op: Union[str, ReduceOp, RedOpType] = "sum"
) -> Tensor:
op = self._convert_to_native_op(op)
dist.reduce_scatter(output, input_list, op=op, group=self.group)
return output

def all_to_all(
self, output_tensor_list: List[torch.Tensor], input_tensor_list: List[torch.Tensor]
) -> List[torch.Tensor]:
def all_to_all(self, output_tensor_list: List[Tensor], input_tensor_list: List[Tensor]) -> List[Tensor]:
dist.all_to_all(output_tensor_list, input_tensor_list, group=self.group)
return output_tensor_list

def send(self, tensor: torch.Tensor, dst: int, tag: Optional[int] = 0) -> None:
def send(self, tensor: Tensor, dst: int, tag: Optional[int] = 0) -> None:
dist.send(tensor, dst, tag=tag, group=self.group)

def recv(self, tensor: torch.Tensor, src: Optional[int] = None, tag: Optional[int] = 0) -> torch.Tensor:
def recv(self, tensor: Tensor, src: Optional[int] = None, tag: Optional[int] = 0) -> Tensor:
dist.recv(tensor, src, tag=tag, group=self.group)
return tensor

Expand Down
3 changes: 2 additions & 1 deletion src/lightning_lite/plugins/precision/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
from typing import Union

import torch
from torch import Tensor


def _convert_fp_tensor(tensor: torch.Tensor, dst_type: Union[str, torch.dtype]) -> torch.Tensor:
def _convert_fp_tensor(tensor: Tensor, dst_type: Union[str, torch.dtype]) -> Tensor:
return tensor.to(dst_type) if torch.is_floating_point(tensor) else tensor
2 changes: 1 addition & 1 deletion src/pytorch_lightning/callbacks/stochastic_weight_avg.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def __init__(
if device is not None and not isinstance(device, (torch.device, str)):
raise MisconfigurationException(f"device is expected to be a torch.device or a str. Found {device}")

self.n_averaged: Optional[torch.Tensor] = None
self.n_averaged: Optional[Tensor] = None
self._swa_epoch_start = swa_epoch_start
self._swa_lrs = swa_lrs
self._annealing_epochs = annealing_epochs
Expand Down
6 changes: 3 additions & 3 deletions src/pytorch_lightning/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ def log(
" but it should not contain information about `dataloader_idx`"
)

value = apply_to_collection(value, (torch.Tensor, numbers.Number), self.__to_tensor, name)
value = apply_to_collection(value, (Tensor, numbers.Number), self.__to_tensor, name)

if self.trainer._logger_connector.should_reset_tensors(self._current_fx_name):
# if we started a new epoch (running its first batch) the hook name has changed
Expand Down Expand Up @@ -545,10 +545,10 @@ def __check_not_nested(value: dict, name: str) -> None:
def __check_allowed(v: Any, name: str, value: Any) -> None:
raise ValueError(f"`self.log({name}, {value})` was called, but `{type(v).__name__}` values cannot be logged")

def __to_tensor(self, value: Union[torch.Tensor, numbers.Number], name: str) -> Tensor:
def __to_tensor(self, value: Union[Tensor, numbers.Number], name: str) -> Tensor:
value = (
value.clone().detach().to(self.device)
if isinstance(value, torch.Tensor)
if isinstance(value, Tensor)
else torch.tensor(value, device=self.device)
)
if not torch.numel(value) == 1:
Expand Down
4 changes: 2 additions & 2 deletions src/pytorch_lightning/loggers/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@
from time import time
from typing import Any, Dict, Mapping, Optional, Union

import torch
import yaml
from lightning_utilities.core.imports import module_available
from torch import Tensor
from typing_extensions import Literal

from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
Expand Down Expand Up @@ -332,7 +332,7 @@ def _scan_and_log_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> Non
for t, p, s, tag in checkpoints:
metadata = {
# Ensure .item() is called to store Tensor contents
"score": s.item() if isinstance(s, torch.Tensor) else s,
"score": s.item() if isinstance(s, Tensor) else s,
"original_filename": Path(p).name,
"Checkpoint": {
k: getattr(checkpoint_callback, k)
Expand Down
3 changes: 2 additions & 1 deletion src/pytorch_lightning/serve/servable_module.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any, Callable, Dict, Tuple

import torch
from torch import Tensor


class ServableModule(torch.nn.Module):
Expand Down Expand Up @@ -70,7 +71,7 @@ def configure_serialization(self) -> Tuple[Dict[str, Callable], Dict[str, Callab
"""
...

def serve_step(self, *args: torch.Tensor, **kwargs: torch.Tensor) -> Dict[str, torch.Tensor]:
def serve_step(self, *args: Tensor, **kwargs: Tensor) -> Dict[str, Tensor]:
r"""
Returns the predictions of your model as a dictionary.

Expand Down
3 changes: 2 additions & 1 deletion src/pytorch_lightning/strategies/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from inspect import getmembers, isclass

import torch
from torch import Tensor

from lightning_lite.plugins.precision.utils import _convert_fp_tensor
from lightning_lite.strategies import _StrategyRegistry
Expand All @@ -40,7 +41,7 @@ def _call_register_strategies(registry: _StrategyRegistry, base_module: str) ->
mod.register_strategies(registry)


def _fp_to_half(tensor: torch.Tensor, precision: PrecisionType) -> torch.Tensor:
def _fp_to_half(tensor: Tensor, precision: PrecisionType) -> Tensor:
if precision == PrecisionType.HALF:
return _convert_fp_tensor(tensor, torch.half)
if precision == PrecisionType.BFLOAT:
Expand Down
19 changes: 10 additions & 9 deletions src/pytorch_lightning/trainer/supporters.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import torch
from lightning_utilities.core.apply_func import apply_to_collection, apply_to_collections
from torch import Tensor
from torch.utils.data import Dataset
from torch.utils.data.dataloader import _BaseDataLoaderIter, _MultiProcessingDataLoaderIter, DataLoader
from torch.utils.data.dataset import IterableDataset
Expand Down Expand Up @@ -59,18 +60,18 @@ def reset(self, window_length: Optional[int] = None) -> None:
"""Empty the accumulator."""
if window_length is not None:
self.window_length = window_length
self.memory: Optional[torch.Tensor] = None
self.memory: Optional[Tensor] = None
self.current_idx: int = 0
self.last_idx: Optional[int] = None
self.rotated: bool = False

def last(self) -> Optional[torch.Tensor]:
def last(self) -> Optional[Tensor]:
"""Get the last added element."""
if self.last_idx is not None:
assert isinstance(self.memory, torch.Tensor)
assert isinstance(self.memory, Tensor)
return self.memory[self.last_idx].float()

def append(self, x: torch.Tensor) -> None:
def append(self, x: Tensor) -> None:
"""Add an element to the accumulator."""
if self.memory is None:
# tradeoff memory for speed by keeping the memory on device
Expand All @@ -89,21 +90,21 @@ def append(self, x: torch.Tensor) -> None:
if self.current_idx == 0:
self.rotated = True

def mean(self) -> Optional[torch.Tensor]:
def mean(self) -> Optional[Tensor]:
"""Get mean value from stored elements."""
return self._agg_memory("mean")

def max(self) -> Optional[torch.Tensor]:
def max(self) -> Optional[Tensor]:
"""Get maximal value from stored elements."""
return self._agg_memory("max")

def min(self) -> Optional[torch.Tensor]:
def min(self) -> Optional[Tensor]:
"""Get minimal value from stored elements."""
return self._agg_memory("min")

def _agg_memory(self, how: str) -> Optional[torch.Tensor]:
def _agg_memory(self, how: str) -> Optional[Tensor]:
if self.last_idx is not None:
assert isinstance(self.memory, torch.Tensor)
assert isinstance(self.memory, Tensor)
if self.rotated:
return getattr(self.memory.float(), how)()
return getattr(self.memory[: self.current_idx].float(), how)()
Expand Down