From a502b5913c19f13cbcc05e81f8a0271de4234868 Mon Sep 17 00:00:00 2001 From: Jirka Date: Thu, 8 Dec 2022 13:27:38 +0100 Subject: [PATCH] refactor: simplify Tensor import --- .../plugins/collectives/collective.py | 24 ++++++++--------- .../plugins/collectives/single_device.py | 26 +++++++++---------- .../plugins/collectives/torch_collective.py | 25 +++++++++--------- src/lightning_lite/plugins/precision/utils.py | 3 ++- .../callbacks/stochastic_weight_avg.py | 2 +- src/pytorch_lightning/core/module.py | 6 ++--- src/pytorch_lightning/loggers/mlflow.py | 4 +-- .../serve/servable_module.py | 3 ++- src/pytorch_lightning/strategies/utils.py | 3 ++- src/pytorch_lightning/trainer/supporters.py | 19 +++++++------- 10 files changed, 58 insertions(+), 57 deletions(-) diff --git a/src/lightning_lite/plugins/collectives/collective.py b/src/lightning_lite/plugins/collectives/collective.py index f2e7f896b3547..617b13700ea0b 100644 --- a/src/lightning_lite/plugins/collectives/collective.py +++ b/src/lightning_lite/plugins/collectives/collective.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from typing import Any, List, Optional -import torch +from torch import Tensor from typing_extensions import Self from lightning_lite.utilities.types import CollectibleGroup @@ -38,45 +38,43 @@ def group(self) -> CollectibleGroup: return self._group @abstractmethod - def broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor: + def broadcast(self, tensor: Tensor, src: int) -> Tensor: ... @abstractmethod - def all_reduce(self, tensor: torch.Tensor, op: str) -> torch.Tensor: + def all_reduce(self, tensor: Tensor, op: str) -> Tensor: ... @abstractmethod - def reduce(self, tensor: torch.Tensor, dst: int, op: str) -> torch.Tensor: + def reduce(self, tensor: Tensor, dst: int, op: str) -> Tensor: ... @abstractmethod - def all_gather(self, tensor_list: List[torch.Tensor], tensor: torch.Tensor) -> List[torch.Tensor]: + def all_gather(self, tensor_list: List[Tensor], tensor: Tensor) -> List[Tensor]: ... @abstractmethod - def gather(self, tensor: torch.Tensor, gather_list: List[torch.Tensor], dst: int = 0) -> List[torch.Tensor]: + def gather(self, tensor: Tensor, gather_list: List[Tensor], dst: int = 0) -> List[Tensor]: ... @abstractmethod - def scatter(self, tensor: torch.Tensor, scatter_list: List[torch.Tensor], src: int = 0) -> torch.Tensor: + def scatter(self, tensor: Tensor, scatter_list: List[Tensor], src: int = 0) -> Tensor: ... @abstractmethod - def reduce_scatter(self, output: torch.Tensor, input_list: List[torch.Tensor], op: str) -> torch.Tensor: + def reduce_scatter(self, output: Tensor, input_list: List[Tensor], op: str) -> Tensor: ... @abstractmethod - def all_to_all( - self, output_tensor_list: List[torch.Tensor], input_tensor_list: List[torch.Tensor] - ) -> List[torch.Tensor]: + def all_to_all(self, output_tensor_list: List[Tensor], input_tensor_list: List[Tensor]) -> List[Tensor]: ... @abstractmethod - def send(self, tensor: torch.Tensor, dst: int, tag: Optional[int] = 0) -> None: + def send(self, tensor: Tensor, dst: int, tag: Optional[int] = 0) -> None: ... @abstractmethod - def recv(self, tensor: torch.Tensor, src: Optional[int] = None, tag: Optional[int] = 0) -> torch.Tensor: + def recv(self, tensor: Tensor, src: Optional[int] = None, tag: Optional[int] = 0) -> Tensor: ... @abstractmethod diff --git a/src/lightning_lite/plugins/collectives/single_device.py b/src/lightning_lite/plugins/collectives/single_device.py index 1bc524192cf34..e24127528c6d7 100644 --- a/src/lightning_lite/plugins/collectives/single_device.py +++ b/src/lightning_lite/plugins/collectives/single_device.py @@ -1,6 +1,6 @@ from typing import Any, List -import torch +from torch import Tensor from lightning_lite.plugins.collectives.collective import Collective from lightning_lite.utilities.types import CollectibleGroup @@ -15,42 +15,42 @@ def rank(self) -> int: def world_size(self) -> int: return 1 - def broadcast(self, tensor: torch.Tensor, *_: Any, **__: Any) -> torch.Tensor: + def broadcast(self, tensor: Tensor, *_: Any, **__: Any) -> Tensor: return tensor - def all_reduce(self, tensor: torch.Tensor, *_: Any, **__: Any) -> torch.Tensor: + def all_reduce(self, tensor: Tensor, *_: Any, **__: Any) -> Tensor: return tensor - def reduce(self, tensor: torch.Tensor, *_: Any, **__: Any) -> torch.Tensor: + def reduce(self, tensor: Tensor, *_: Any, **__: Any) -> Tensor: return tensor - def all_gather(self, tensor_list: List[torch.Tensor], tensor: torch.Tensor, **__: Any) -> List[torch.Tensor]: + def all_gather(self, tensor_list: List[Tensor], tensor: Tensor, **__: Any) -> List[Tensor]: return [tensor] - def gather(self, tensor: torch.Tensor, *_: Any, **__: Any) -> List[torch.Tensor]: + def gather(self, tensor: Tensor, *_: Any, **__: Any) -> List[Tensor]: return [tensor] def scatter( self, - tensor: torch.Tensor, - scatter_list: List[torch.Tensor], + tensor: Tensor, + scatter_list: List[Tensor], *_: Any, **__: Any, - ) -> torch.Tensor: + ) -> Tensor: return scatter_list[0] - def reduce_scatter(self, output: torch.Tensor, input_list: List[torch.Tensor], *_: Any, **__: Any) -> torch.Tensor: + def reduce_scatter(self, output: Tensor, input_list: List[Tensor], *_: Any, **__: Any) -> Tensor: return input_list[0] def all_to_all( - self, output_tensor_list: List[torch.Tensor], input_tensor_list: List[torch.Tensor], *_: Any, **__: Any - ) -> List[torch.Tensor]: + self, output_tensor_list: List[Tensor], input_tensor_list: List[Tensor], *_: Any, **__: Any + ) -> List[Tensor]: return input_tensor_list def send(self, *_: Any, **__: Any) -> None: pass - def recv(self, tensor: torch.Tensor, *_: Any, **__: Any) -> torch.Tensor: + def recv(self, tensor: Tensor, *_: Any, **__: Any) -> Tensor: return tensor def barrier(self, *_: Any, **__: Any) -> None: diff --git a/src/lightning_lite/plugins/collectives/torch_collective.py b/src/lightning_lite/plugins/collectives/torch_collective.py index fc4282a28245b..c9ade52e1bac7 100644 --- a/src/lightning_lite/plugins/collectives/torch_collective.py +++ b/src/lightning_lite/plugins/collectives/torch_collective.py @@ -4,6 +4,7 @@ import torch import torch.distributed as dist +from torch import Tensor from typing_extensions import Self from lightning_lite.plugins.collectives.collective import Collective @@ -33,49 +34,47 @@ def rank(self) -> int: def world_size(self) -> int: return dist.get_world_size(self.group) - def broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor: + def broadcast(self, tensor: Tensor, src: int) -> Tensor: dist.broadcast(tensor, src, group=self.group) return tensor - def all_reduce(self, tensor: torch.Tensor, op: Union[str, ReduceOp, RedOpType] = "sum") -> torch.Tensor: + def all_reduce(self, tensor: Tensor, op: Union[str, ReduceOp, RedOpType] = "sum") -> Tensor: op = self._convert_to_native_op(op) dist.all_reduce(tensor, op=op, group=self.group) return tensor - def reduce(self, tensor: torch.Tensor, dst: int, op: Union[str, ReduceOp, RedOpType] = "sum") -> torch.Tensor: + def reduce(self, tensor: Tensor, dst: int, op: Union[str, ReduceOp, RedOpType] = "sum") -> Tensor: op = self._convert_to_native_op(op) dist.reduce(tensor, dst, op=op, group=self.group) return tensor - def all_gather(self, tensor_list: List[torch.Tensor], tensor: torch.Tensor) -> List[torch.Tensor]: + def all_gather(self, tensor_list: List[Tensor], tensor: Tensor) -> List[Tensor]: dist.all_gather(tensor_list, tensor, group=self.group) return tensor_list - def gather(self, tensor: torch.Tensor, gather_list: List[torch.Tensor], dst: int = 0) -> List[torch.Tensor]: + def gather(self, tensor: Tensor, gather_list: List[Tensor], dst: int = 0) -> List[Tensor]: dist.gather(tensor, gather_list, dst, group=self.group) return gather_list - def scatter(self, tensor: torch.Tensor, scatter_list: List[torch.Tensor], src: int = 0) -> torch.Tensor: + def scatter(self, tensor: Tensor, scatter_list: List[Tensor], src: int = 0) -> Tensor: dist.scatter(tensor, scatter_list, src, group=self.group) return tensor def reduce_scatter( - self, output: torch.Tensor, input_list: List[torch.Tensor], op: Union[str, ReduceOp, RedOpType] = "sum" - ) -> torch.Tensor: + self, output: Tensor, input_list: List[Tensor], op: Union[str, ReduceOp, RedOpType] = "sum" + ) -> Tensor: op = self._convert_to_native_op(op) dist.reduce_scatter(output, input_list, op=op, group=self.group) return output - def all_to_all( - self, output_tensor_list: List[torch.Tensor], input_tensor_list: List[torch.Tensor] - ) -> List[torch.Tensor]: + def all_to_all(self, output_tensor_list: List[Tensor], input_tensor_list: List[Tensor]) -> List[Tensor]: dist.all_to_all(output_tensor_list, input_tensor_list, group=self.group) return output_tensor_list - def send(self, tensor: torch.Tensor, dst: int, tag: Optional[int] = 0) -> None: + def send(self, tensor: Tensor, dst: int, tag: Optional[int] = 0) -> None: dist.send(tensor, dst, tag=tag, group=self.group) - def recv(self, tensor: torch.Tensor, src: Optional[int] = None, tag: Optional[int] = 0) -> torch.Tensor: + def recv(self, tensor: Tensor, src: Optional[int] = None, tag: Optional[int] = 0) -> Tensor: dist.recv(tensor, src, tag=tag, group=self.group) return tensor diff --git a/src/lightning_lite/plugins/precision/utils.py b/src/lightning_lite/plugins/precision/utils.py index dc41a5202d817..5ef6d5f858ea8 100644 --- a/src/lightning_lite/plugins/precision/utils.py +++ b/src/lightning_lite/plugins/precision/utils.py @@ -14,7 +14,8 @@ from typing import Union import torch +from torch import Tensor -def _convert_fp_tensor(tensor: torch.Tensor, dst_type: Union[str, torch.dtype]) -> torch.Tensor: +def _convert_fp_tensor(tensor: Tensor, dst_type: Union[str, torch.dtype]) -> Tensor: return tensor.to(dst_type) if torch.is_floating_point(tensor) else tensor diff --git a/src/pytorch_lightning/callbacks/stochastic_weight_avg.py b/src/pytorch_lightning/callbacks/stochastic_weight_avg.py index ccf5051d5fd39..53111868f3224 100644 --- a/src/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/src/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -115,7 +115,7 @@ def __init__( if device is not None and not isinstance(device, (torch.device, str)): raise MisconfigurationException(f"device is expected to be a torch.device or a str. Found {device}") - self.n_averaged: Optional[torch.Tensor] = None + self.n_averaged: Optional[Tensor] = None self._swa_epoch_start = swa_epoch_start self._swa_lrs = swa_lrs self._annealing_epochs = annealing_epochs diff --git a/src/pytorch_lightning/core/module.py b/src/pytorch_lightning/core/module.py index 9515abe8cfe6e..4f5ce15ea3d74 100644 --- a/src/pytorch_lightning/core/module.py +++ b/src/pytorch_lightning/core/module.py @@ -403,7 +403,7 @@ def log( " but it should not contain information about `dataloader_idx`" ) - value = apply_to_collection(value, (torch.Tensor, numbers.Number), self.__to_tensor, name) + value = apply_to_collection(value, (Tensor, numbers.Number), self.__to_tensor, name) if self.trainer._logger_connector.should_reset_tensors(self._current_fx_name): # if we started a new epoch (running its first batch) the hook name has changed @@ -545,10 +545,10 @@ def __check_not_nested(value: dict, name: str) -> None: def __check_allowed(v: Any, name: str, value: Any) -> None: raise ValueError(f"`self.log({name}, {value})` was called, but `{type(v).__name__}` values cannot be logged") - def __to_tensor(self, value: Union[torch.Tensor, numbers.Number], name: str) -> Tensor: + def __to_tensor(self, value: Union[Tensor, numbers.Number], name: str) -> Tensor: value = ( value.clone().detach().to(self.device) - if isinstance(value, torch.Tensor) + if isinstance(value, Tensor) else torch.tensor(value, device=self.device) ) if not torch.numel(value) == 1: diff --git a/src/pytorch_lightning/loggers/mlflow.py b/src/pytorch_lightning/loggers/mlflow.py index 35f9b8396dd0f..a9a00130d88b0 100644 --- a/src/pytorch_lightning/loggers/mlflow.py +++ b/src/pytorch_lightning/loggers/mlflow.py @@ -24,9 +24,9 @@ from time import time from typing import Any, Dict, Mapping, Optional, Union -import torch import yaml from lightning_utilities.core.imports import module_available +from torch import Tensor from typing_extensions import Literal from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint @@ -332,7 +332,7 @@ def _scan_and_log_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> Non for t, p, s, tag in checkpoints: metadata = { # Ensure .item() is called to store Tensor contents - "score": s.item() if isinstance(s, torch.Tensor) else s, + "score": s.item() if isinstance(s, Tensor) else s, "original_filename": Path(p).name, "Checkpoint": { k: getattr(checkpoint_callback, k) diff --git a/src/pytorch_lightning/serve/servable_module.py b/src/pytorch_lightning/serve/servable_module.py index ef95187c63245..1ceb42777eb1d 100644 --- a/src/pytorch_lightning/serve/servable_module.py +++ b/src/pytorch_lightning/serve/servable_module.py @@ -1,6 +1,7 @@ from typing import Any, Callable, Dict, Tuple import torch +from torch import Tensor class ServableModule(torch.nn.Module): @@ -70,7 +71,7 @@ def configure_serialization(self) -> Tuple[Dict[str, Callable], Dict[str, Callab """ ... - def serve_step(self, *args: torch.Tensor, **kwargs: torch.Tensor) -> Dict[str, torch.Tensor]: + def serve_step(self, *args: Tensor, **kwargs: Tensor) -> Dict[str, Tensor]: r""" Returns the predictions of your model as a dictionary. diff --git a/src/pytorch_lightning/strategies/utils.py b/src/pytorch_lightning/strategies/utils.py index be602a3929665..a7e2bcb6468dd 100644 --- a/src/pytorch_lightning/strategies/utils.py +++ b/src/pytorch_lightning/strategies/utils.py @@ -16,6 +16,7 @@ from inspect import getmembers, isclass import torch +from torch import Tensor from lightning_lite.plugins.precision.utils import _convert_fp_tensor from lightning_lite.strategies import _StrategyRegistry @@ -40,7 +41,7 @@ def _call_register_strategies(registry: _StrategyRegistry, base_module: str) -> mod.register_strategies(registry) -def _fp_to_half(tensor: torch.Tensor, precision: PrecisionType) -> torch.Tensor: +def _fp_to_half(tensor: Tensor, precision: PrecisionType) -> Tensor: if precision == PrecisionType.HALF: return _convert_fp_tensor(tensor, torch.half) if precision == PrecisionType.BFLOAT: diff --git a/src/pytorch_lightning/trainer/supporters.py b/src/pytorch_lightning/trainer/supporters.py index 59856f12ef304..2a548726dfd63 100644 --- a/src/pytorch_lightning/trainer/supporters.py +++ b/src/pytorch_lightning/trainer/supporters.py @@ -18,6 +18,7 @@ import torch from lightning_utilities.core.apply_func import apply_to_collection, apply_to_collections +from torch import Tensor from torch.utils.data import Dataset from torch.utils.data.dataloader import _BaseDataLoaderIter, _MultiProcessingDataLoaderIter, DataLoader from torch.utils.data.dataset import IterableDataset @@ -59,18 +60,18 @@ def reset(self, window_length: Optional[int] = None) -> None: """Empty the accumulator.""" if window_length is not None: self.window_length = window_length - self.memory: Optional[torch.Tensor] = None + self.memory: Optional[Tensor] = None self.current_idx: int = 0 self.last_idx: Optional[int] = None self.rotated: bool = False - def last(self) -> Optional[torch.Tensor]: + def last(self) -> Optional[Tensor]: """Get the last added element.""" if self.last_idx is not None: - assert isinstance(self.memory, torch.Tensor) + assert isinstance(self.memory, Tensor) return self.memory[self.last_idx].float() - def append(self, x: torch.Tensor) -> None: + def append(self, x: Tensor) -> None: """Add an element to the accumulator.""" if self.memory is None: # tradeoff memory for speed by keeping the memory on device @@ -89,21 +90,21 @@ def append(self, x: torch.Tensor) -> None: if self.current_idx == 0: self.rotated = True - def mean(self) -> Optional[torch.Tensor]: + def mean(self) -> Optional[Tensor]: """Get mean value from stored elements.""" return self._agg_memory("mean") - def max(self) -> Optional[torch.Tensor]: + def max(self) -> Optional[Tensor]: """Get maximal value from stored elements.""" return self._agg_memory("max") - def min(self) -> Optional[torch.Tensor]: + def min(self) -> Optional[Tensor]: """Get minimal value from stored elements.""" return self._agg_memory("min") - def _agg_memory(self, how: str) -> Optional[torch.Tensor]: + def _agg_memory(self, how: str) -> Optional[Tensor]: if self.last_idx is not None: - assert isinstance(self.memory, torch.Tensor) + assert isinstance(self.memory, Tensor) if self.rotated: return getattr(self.memory.float(), how)() return getattr(self.memory[: self.current_idx].float(), how)()