From 42169a23a06197b68460d2629b650ed8228fcd18 Mon Sep 17 00:00:00 2001 From: John Henning Date: Tue, 29 Mar 2022 00:14:14 -0400 Subject: [PATCH] Add typing to `LightningModule.trainer` (#12345) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Rohit Gupta Co-authored-by: Carlos MocholĂ­ --- pytorch_lightning/core/lightning.py | 2 +- pytorch_lightning/core/optimizer.py | 1 + pytorch_lightning/overrides/base.py | 53 +++++++++---------- pytorch_lightning/overrides/data_parallel.py | 18 ++++--- pytorch_lightning/overrides/distributed.py | 23 ++------ .../plugins/precision/apex_amp.py | 1 + .../plugins/precision/deepspeed.py | 8 ++- .../plugins/precision/precision_plugin.py | 2 + pytorch_lightning/strategies/bagua.py | 10 +++- pytorch_lightning/strategies/deepspeed.py | 6 ++- pytorch_lightning/strategies/ipu.py | 6 ++- pytorch_lightning/trainer/trainer.py | 6 +-- pytorch_lightning/utilities/data.py | 4 +- pytorch_lightning/utilities/model_summary.py | 7 +-- 14 files changed, 71 insertions(+), 76 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 1ab0dc7f3880f..b5a748295a7ea 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -95,7 +95,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: torch._C._log_api_usage_once(f"lightning.module.{self.__class__.__name__}") # pointer to the trainer object - self.trainer = None + self.trainer: Optional["pl.Trainer"] = None self._use_amp: bool = False diff --git a/pytorch_lightning/core/optimizer.py b/pytorch_lightning/core/optimizer.py index 7aa9baf794c5f..51b156510c1b1 100644 --- a/pytorch_lightning/core/optimizer.py +++ b/pytorch_lightning/core/optimizer.py @@ -176,6 +176,7 @@ def _init_optimizers_and_lr_schedulers( model: "pl.LightningModule", ) -> Tuple[List[Optimizer], List[LRSchedulerConfig], List[int]]: """Calls `LightningModule.configure_optimizers` and parses and validates the output.""" + assert model.trainer is not None optim_conf = model.trainer._call_lightning_module_hook("configure_optimizers", pl_module=model) if optim_conf is None: diff --git a/pytorch_lightning/overrides/base.py b/pytorch_lightning/overrides/base.py index ff4f6dd7fc096..727da4737107a 100644 --- a/pytorch_lightning/overrides/base.py +++ b/pytorch_lightning/overrides/base.py @@ -57,13 +57,10 @@ def on_post_move_to_device(self) -> None: class _LightningModuleWrapperBase(DeviceDtypeModuleMixin, torch.nn.Module): - def __init__(self, pl_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]): - """ - Wraps the user's LightningModule and redirects the forward call to the appropriate - method, either ``training_step``, ``validation_step`` or ``test_step``. - If the LightningModule is in none of the states `training`, `testing` or `validation`, - the inputs will be redirected to the - :meth:`~pytorch_lightning.core.lightning.LightningModule.predict` method. + def __init__(self, pl_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]) -> None: + """Wraps the user's LightningModule and redirects the forward call to the appropriate method, either + ``training_step``, ``validation_step``, ``test_step``, or ``predict_step``. + Inheriting classes may also modify the inputs or outputs of forward. Args: @@ -77,28 +74,26 @@ def __init__(self, pl_module: Union["pl.LightningModule", _LightningPrecisionMod self._ddp_params_and_buffers_to_ignore = [f"module.{p}" for p in _ddp_params_and_buffers_to_ignore] def forward(self, *inputs: Any, **kwargs: Any) -> Any: - lightning_module = unwrap_lightning_module(self.module) - trainer = lightning_module.trainer - - if trainer and trainer.training: - output = self.module.training_step(*inputs, **kwargs) - - # In manual_optimization, we need to prevent DDP reducer as - # it is done manually in `LightningModule.manual_backward` - # `require_backward_grad_sync` will be reset in the - # ddp_strategy `post_training_step` hook - if not lightning_module.automatic_optimization: - trainer.model.require_backward_grad_sync = False - elif trainer and trainer.testing: - output = self.module.test_step(*inputs, **kwargs) - elif trainer and (trainer.sanity_checking or trainer.validating): - output = self.module.validation_step(*inputs, **kwargs) - elif trainer and trainer.predicting: - output = self.module.predict_step(*inputs, **kwargs) - else: - output = self.module(*inputs, **kwargs) - - return output + pl_module = unwrap_lightning_module(self.module) + trainer = pl_module.trainer + + if trainer is not None: + if trainer.training: + output = self.module.training_step(*inputs, **kwargs) + # In manual_optimization, we need to prevent DDP reducer as + # it is done manually in `LightningModule.manual_backward` + # `require_backward_grad_sync` will be reset in the + # ddp_strategy `post_training_step` hook + if not pl_module.automatic_optimization: + trainer.model.require_backward_grad_sync = False # type: ignore[assignment] + return output + if trainer.testing: + return self.module.test_step(*inputs, **kwargs) + if trainer.sanity_checking or trainer.validating: + return self.module.validation_step(*inputs, **kwargs) + if trainer.predicting: + return self.module.predict_step(*inputs, **kwargs) + return self.module(*inputs, **kwargs) def on_post_move_to_device(self) -> None: pass diff --git a/pytorch_lightning/overrides/data_parallel.py b/pytorch_lightning/overrides/data_parallel.py index ea44bc0683648..2d9a7d9a3acca 100644 --- a/pytorch_lightning/overrides/data_parallel.py +++ b/pytorch_lightning/overrides/data_parallel.py @@ -13,12 +13,12 @@ # limitations under the License. import numbers import warnings -from typing import Any, Union +from typing import Any, cast, Union import torch import pytorch_lightning as pl -from pytorch_lightning.overrides.base import _LightningModuleWrapperBase +from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.rank_zero import rank_zero_warn @@ -36,10 +36,11 @@ def _ignore_scalar_return_in_dp() -> None: class LightningParallelModule(_LightningModuleWrapperBase): """Wraps the user's LightningModule and redirects the forward call to the appropriate method, either - ``training_step``, ``validation_step``, ``test_step`` or ``predict``. This class is used in combination with - :class:`~torch.nn.parallel.DataParallel` as shown in the example. It also takes care of converting Python - scalars to Tensors and un-squeezes 0-dimensional Tensors as it is required by - :class:`~torch.nn.parallel.DataParallel`. + ``training_step``, ``validation_step``, ``test_step``, or ``predict_step``. + + This class is used in combination with :class:`~torch.nn.parallel.DataParallel` as shown in the example. + It also takes care of converting Python scalars to Tensors and un-squeezes 0-dimensional Tensors as it is required + by :class:`~torch.nn.parallel.DataParallel`. Example: @@ -53,7 +54,7 @@ class LightningParallelModule(_LightningModuleWrapperBase): pl_module: the model to wrap """ - def __init__(self, pl_module: "pl.LightningModule") -> None: + def __init__(self, pl_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]) -> None: super().__init__(pl_module) _ignore_scalar_return_in_dp() @@ -63,7 +64,8 @@ def forward(self, *inputs: Any, **kwargs: Any) -> Any: output = super().forward(*inputs, **kwargs) def output_transform(data: Any) -> Any: - data = python_scalar_to_tensor(data, self.module.device) + device = cast(torch.device, self.module.device) + data = python_scalar_to_tensor(data, device) data = unsqueeze_scalar_tensor(data) return data diff --git a/pytorch_lightning/overrides/distributed.py b/pytorch_lightning/overrides/distributed.py index 519f30a2e8ebe..68d8113fdd18c 100644 --- a/pytorch_lightning/overrides/distributed.py +++ b/pytorch_lightning/overrides/distributed.py @@ -12,36 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. import itertools -from typing import Any, cast, Iterator, List, Sized, Union +from typing import Any, cast, Iterable, Iterator, List, Sized, Union import torch from torch import Tensor from torch.nn.parallel import DistributedDataParallel from torch.utils.data import BatchSampler, DistributedSampler, Sampler -import pytorch_lightning as pl from pytorch_lightning.overrides.base import _LightningModuleWrapperBase from pytorch_lightning.utilities import rank_zero_deprecation class LightningDistributedModule(_LightningModuleWrapperBase): - def __init__(self, pl_module: "pl.LightningModule") -> None: - """Wraps the user's LightningModule and redirects the forward call to the appropriate method, either - ``training_step``, ``validation_step``, ``test_step`` or ``predict``. This class is used in combination - with :class:`~torch.nn.parallel.DistributedDataParallel` as shown in the example. - - Example: - - ddp_model = torch.nn.parallel.DistributedDataParallel( - module=LightningDistributedModule(lightning_module), - device_ids=[local_rank], - ... - ) - - Args: - pl_module: the model to wrap - """ - super().__init__(pl_module) + ... def _find_tensors( @@ -164,5 +147,5 @@ def batch_size(self) -> int: return self._sampler.batch_size @property - def sampler(self) -> Sampler: + def sampler(self) -> Union[Sampler, Iterable]: return self._sampler.sampler diff --git a/pytorch_lightning/plugins/precision/apex_amp.py b/pytorch_lightning/plugins/precision/apex_amp.py index c329aedcf6f00..fd29efeb9f4fb 100644 --- a/pytorch_lightning/plugins/precision/apex_amp.py +++ b/pytorch_lightning/plugins/precision/apex_amp.py @@ -69,6 +69,7 @@ def backward( closure_loss: the loss value obtained from the closure optimizer: current optimizer being used. ``None`` if using manual optimization """ + assert model.trainer is not None opt = optimizer or model.trainer.optimizers with amp.scale_loss(closure_loss, opt) as closure_loss: super().backward(model, closure_loss, optimizer, *args, **kwargs) diff --git a/pytorch_lightning/plugins/precision/deepspeed.py b/pytorch_lightning/plugins/precision/deepspeed.py index b629c931d2fe9..3b70096dd5058 100644 --- a/pytorch_lightning/plugins/precision/deepspeed.py +++ b/pytorch_lightning/plugins/precision/deepspeed.py @@ -46,6 +46,7 @@ def backward(self, model: "pl.LightningModule", closure_loss: Tensor, *args: Any "You have overridden the `LightningModule.backward` hook but it will be ignored since DeepSpeed handles" " the backward logic internally." ) + assert model.trainer is not None deepspeed_engine: DeepSpeedEngine = model.trainer.model deepspeed_engine.backward(closure_loss, *args, **kwargs) @@ -75,7 +76,12 @@ def optimizer_step( "Skipping backward by returning `None` from your `training_step` is not supported by `DeepSpeed`" ) # DeepSpeed handles the optimizer step internally - deepspeed_engine = model.trainer.model if isinstance(model, pl.LightningModule) else model + deepspeed_engine: DeepSpeedEngine + if isinstance(model, pl.LightningModule): + assert model.trainer is not None + deepspeed_engine = model.trainer.model + else: + deepspeed_engine = model return deepspeed_engine.step(**kwargs) def clip_gradients( diff --git a/pytorch_lightning/plugins/precision/precision_plugin.py b/pytorch_lightning/plugins/precision/precision_plugin.py index 73c50b40250ed..bdd63bba17854 100644 --- a/pytorch_lightning/plugins/precision/precision_plugin.py +++ b/pytorch_lightning/plugins/precision/precision_plugin.py @@ -55,6 +55,7 @@ def pre_backward(self, model: "pl.LightningModule", closure_loss: Tensor) -> Ten model: the model to be optimized closure_loss: the loss value obtained from the closure """ + assert model.trainer is not None model.trainer._call_callback_hooks("on_before_backward", closure_loss) model.trainer._call_lightning_module_hook("on_before_backward", closure_loss) return closure_loss @@ -89,6 +90,7 @@ def post_backward(self, model: "pl.LightningModule", closure_loss: Tensor) -> Te """ # once backward has been applied, release graph closure_loss = closure_loss.detach() + assert model.trainer is not None model.trainer._call_callback_hooks("on_after_backward") model.trainer._call_lightning_module_hook("on_after_backward") return closure_loss diff --git a/pytorch_lightning/strategies/bagua.py b/pytorch_lightning/strategies/bagua.py index 17318331b840d..61485395f0aea 100644 --- a/pytorch_lightning/strategies/bagua.py +++ b/pytorch_lightning/strategies/bagua.py @@ -6,7 +6,11 @@ from torch.nn import Module import pytorch_lightning as pl -from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, unwrap_lightning_module +from pytorch_lightning.overrides.base import ( + _LightningModuleWrapperBase, + _LightningPrecisionModuleWrapperBase, + unwrap_lightning_module, +) from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin @@ -32,7 +36,7 @@ class LightningBaguaModule(_LightningModuleWrapperBase): - def __init__(self, pl_module: "pl.LightningModule") -> None: + def __init__(self, pl_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]) -> None: super().__init__(pl_module) # Bagua use `bagua_module_name` to distinguish different modules self._bagua_module_name = f"{pl_module.__class__.__name__}{id(pl_module)}" @@ -161,6 +165,7 @@ def configure_ddp(self) -> None: self._model = self._setup_model(model) # start the background communication for async algorithm + assert self.lightning_module.trainer is not None if self.lightning_module.trainer.training and self._bagua_algorithm == "async": self.model.bagua_algorithm.resume(self.model) # type: ignore @@ -188,6 +193,7 @@ def register_strategies(cls, strategy_registry: Dict) -> None: def teardown(self) -> None: # abort the background communication for async algorithm + assert self.lightning_module.trainer is not None if self.lightning_module.trainer.training and self._bagua_algorithm == "async": self.model.bagua_algorithm.abort(self.model) # type: ignore diff --git a/pytorch_lightning/strategies/deepspeed.py b/pytorch_lightning/strategies/deepspeed.py index bdec69c43b2f4..f3e3951e4ffd7 100644 --- a/pytorch_lightning/strategies/deepspeed.py +++ b/pytorch_lightning/strategies/deepspeed.py @@ -27,7 +27,7 @@ import pytorch_lightning as pl from pytorch_lightning.core.optimizer import _init_optimizers_and_lr_schedulers -from pytorch_lightning.overrides.base import _LightningModuleWrapperBase +from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.strategies.ddp import DDPStrategy @@ -67,7 +67,9 @@ def remove_module_hooks(model: torch.nn.Module) -> None: class LightningDeepSpeedModule(_LightningModuleWrapperBase): - def __init__(self, pl_module: "pl.LightningModule", precision: int) -> None: + def __init__( + self, pl_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase], precision: int + ) -> None: super().__init__(pl_module) self.precision = precision diff --git a/pytorch_lightning/strategies/ipu.py b/pytorch_lightning/strategies/ipu.py index cc72313a86e39..4603110c01536 100644 --- a/pytorch_lightning/strategies/ipu.py +++ b/pytorch_lightning/strategies/ipu.py @@ -19,7 +19,7 @@ from torch.utils.data import DataLoader import pytorch_lightning as pl -from pytorch_lightning.overrides.base import _LightningModuleWrapperBase +from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin @@ -40,7 +40,9 @@ class LightningIPUModule(_LightningModuleWrapperBase): - def __init__(self, pl_module: "pl.LightningModule", precision: Union[str, int]): + def __init__( + self, pl_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase], precision: Union[str, int] + ) -> None: super().__init__(pl_module) self.precision = precision diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 9cd94f4f5ad6f..53b16af117e34 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -598,10 +598,8 @@ def __init__( self._terminate_on_nan = terminate_on_nan self.gradient_clip_val: Union[int, float] = gradient_clip_val - self.gradient_clip_algorithm = ( - GradClipAlgorithmType(gradient_clip_algorithm.lower()) - if gradient_clip_algorithm is not None - else gradient_clip_algorithm + self.gradient_clip_algorithm: Optional[GradClipAlgorithmType] = ( + GradClipAlgorithmType(gradient_clip_algorithm.lower()) if gradient_clip_algorithm is not None else None ) self.track_grad_norm: float = float(track_grad_norm) diff --git a/pytorch_lightning/utilities/data.py b/pytorch_lightning/utilities/data.py index 3f59a8f017cc7..5577ed654ea8f 100644 --- a/pytorch_lightning/utilities/data.py +++ b/pytorch_lightning/utilities/data.py @@ -170,7 +170,9 @@ def get_len(dataloader: DataLoader) -> Union[int, float]: return float("inf") -def _update_dataloader(dataloader: DataLoader, sampler: Sampler, mode: Optional[RunningStage] = None) -> DataLoader: +def _update_dataloader( + dataloader: DataLoader, sampler: Union[Sampler, Iterable], mode: Optional[RunningStage] = None +) -> DataLoader: dl_kwargs = _get_dataloader_init_kwargs(dataloader, sampler, mode=mode) dl_cls = type(dataloader) try: diff --git a/pytorch_lightning/utilities/model_summary.py b/pytorch_lightning/utilities/model_summary.py index af7735da4f757..6c1d11781cfba 100644 --- a/pytorch_lightning/utilities/model_summary.py +++ b/pytorch_lightning/utilities/model_summary.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. """Utilities related to model weights summary.""" - import contextlib import logging from collections import OrderedDict @@ -264,11 +263,7 @@ def _forward_example_input(self) -> None: mode = model.training model.eval() - if trainer is not None: - forward_context = trainer.precision_plugin.forward_context() - else: - forward_context = contextlib.nullcontext() - + forward_context = contextlib.nullcontext() if trainer is None else trainer.precision_plugin.forward_context() with torch.no_grad(), forward_context: # let the model hooks collect the input- and output shapes if isinstance(input_, (list, tuple)):