diff --git a/docs/source-pytorch/api_references.rst b/docs/source-pytorch/api_references.rst index 2d9cc4572f4d7..1a41f07b6a98a 100644 --- a/docs/source-pytorch/api_references.rst +++ b/docs/source-pytorch/api_references.rst @@ -189,7 +189,6 @@ precision FullyShardedNativeNativeMixedPrecisionPlugin HPUPrecisionPlugin IPUPrecisionPlugin - MixedPrecisionPlugin NativeMixedPrecisionPlugin PrecisionPlugin ShardedNativeMixedPrecisionPlugin diff --git a/docs/source-pytorch/extensions/plugins.rst b/docs/source-pytorch/extensions/plugins.rst index 27aff0c11fdcb..ac0fdfacbfe01 100644 --- a/docs/source-pytorch/extensions/plugins.rst +++ b/docs/source-pytorch/extensions/plugins.rst @@ -59,7 +59,6 @@ The full list of built-in precision plugins is listed below. FullyShardedNativeNativeMixedPrecisionPlugin HPUPrecisionPlugin IPUPrecisionPlugin - MixedPrecisionPlugin NativeMixedPrecisionPlugin PrecisionPlugin ShardedNativeMixedPrecisionPlugin diff --git a/src/lightning_lite/plugins/precision/__init__.py b/src/lightning_lite/plugins/precision/__init__.py index 95d2ccf1e9478..412ef9274822c 100644 --- a/src/lightning_lite/plugins/precision/__init__.py +++ b/src/lightning_lite/plugins/precision/__init__.py @@ -11,10 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from lightning_lite.plugins.precision.precision import Precision # isort:skip from lightning_lite.plugins.precision.deepspeed import DeepSpeedPrecision from lightning_lite.plugins.precision.double import DoublePrecision from lightning_lite.plugins.precision.native_amp import NativeMixedPrecision +from lightning_lite.plugins.precision.precision import Precision from lightning_lite.plugins.precision.tpu import TPUPrecision from lightning_lite.plugins.precision.tpu_bf16 import TPUBf16Precision diff --git a/src/lightning_lite/plugins/precision/double.py b/src/lightning_lite/plugins/precision/double.py index 13f5909deac9d..3de2b422f8fdd 100644 --- a/src/lightning_lite/plugins/precision/double.py +++ b/src/lightning_lite/plugins/precision/double.py @@ -16,7 +16,7 @@ import torch -from lightning_lite.plugins.precision import Precision +from lightning_lite.plugins.precision.precision import Precision class DoublePrecision(Precision): diff --git a/src/lightning_lite/plugins/precision/native_amp.py b/src/lightning_lite/plugins/precision/native_amp.py index db01b1c476fd5..58cdffeb99cbc 100644 --- a/src/lightning_lite/plugins/precision/native_amp.py +++ b/src/lightning_lite/plugins/precision/native_amp.py @@ -19,7 +19,7 @@ from torch.nn import Module from torch.optim import LBFGS -from lightning_lite.plugins.precision import Precision +from lightning_lite.plugins.precision.precision import Precision from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_10 from lightning_lite.utilities.types import Steppable diff --git a/src/lightning_lite/plugins/precision/precision.py b/src/lightning_lite/plugins/precision/precision.py index 49397015ff389..0fd1a4c4e1c1d 100644 --- a/src/lightning_lite/plugins/precision/precision.py +++ b/src/lightning_lite/plugins/precision/precision.py @@ -34,7 +34,7 @@ def forward_context(self) -> Generator[None, None, None]: """A contextmanager for managing model forward/training_step/evaluation_step/predict_step.""" yield - def pre_backward(self, tensor: Tensor, module: Optional[Module]) -> None: + def pre_backward(self, tensor: Tensor, module: Optional[Module]) -> Any: """Runs before precision plugin executes backward. Args: @@ -51,7 +51,7 @@ def backward(self, tensor: Tensor, model: Optional[Module], *args: Any, **kwargs """ tensor.backward(*args, **kwargs) - def post_backward(self, tensor: Tensor, module: Optional[Module]) -> None: + def post_backward(self, tensor: Tensor, module: Optional[Module]) -> Any: """Runs after precision plugin executes backward. Args: @@ -67,7 +67,7 @@ def optimizer_step( """Hook to run the optimizer step.""" return optimizer.step(**kwargs) - def get_main_params(self, optimizer: Optimizer) -> _PARAMETERS: + def main_params(self, optimizer: Optimizer) -> _PARAMETERS: """The main params of the model. Returns the plain model params here. Maybe different in other precision plugins. diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 9a32a9b366357..e82238f3bdda1 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -78,9 +78,16 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - In Lightning Lite, state-dict access to the module wrapper now gets passed through to the original module reference ([#14629](https://github.com/Lightning-AI/lightning/pull/14629)) + - Removed fall-back to `LightningEnvironment` when number of SLURM tasks does not correspond to number of processes in Trainer ([#14300](https://github.com/Lightning-AI/lightning/pull/14300)) +- Integrated the Lite Precision plugins into the PL Precision plugins - the base class in PL now extends the `lightning_lite.precision.Precision` base class ([#14798](https://github.com/Lightning-AI/lightning/pull/14798)) + * The `PrecisionPlugin.backward` signature changed: The `closure_loss` argument was renamed to `tensor` + * The `PrecisionPlugin.{pre_,post_}backward` signature changed: The `closure_loss` argument was renamed to `tensor` and moved as the first argument + * The `PrecisionPlugin.optimizer_step` signature changed: The `model`, `optimizer_idx` and `closure` arguments need to be passed as keyword arguments now + + - Trainer queries the CUDA devices through NVML if available to avoid initializing CUDA before forking, which eliminates the need for the `PL_DISABLE_FORK` environment variable introduced in v1.7.4 ([#14631](https://github.com/Lightning-AI/lightning/issues/14631)) diff --git a/src/pytorch_lightning/core/module.py b/src/pytorch_lightning/core/module.py index eb22bb8e13247..428ed58ea7eda 100644 --- a/src/pytorch_lightning/core/module.py +++ b/src/pytorch_lightning/core/module.py @@ -37,6 +37,7 @@ from lightning_lite.utilities.cloud_io import get_filesystem from lightning_lite.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin from lightning_lite.utilities.distributed import distributed_available, sync_ddp +from lightning_lite.utilities.types import Steppable from pytorch_lightning.callbacks.callback import Callback from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks, ModelHooks from pytorch_lightning.core.mixins import HyperparametersMixin @@ -1398,7 +1399,7 @@ def training_step(...): self.trainer.strategy.backward(loss, None, None, *args, **kwargs) def backward( - self, loss: Tensor, optimizer: Optional[Optimizer], optimizer_idx: Optional[int], *args: Any, **kwargs: Any + self, loss: Tensor, optimizer: Optional[Steppable], optimizer_idx: Optional[int], *args: Any, **kwargs: Any ) -> None: """Called to perform backward on the loss returned in :meth:`training_step`. Override this hook with your own implementation if you need to. diff --git a/src/pytorch_lightning/plugins/precision/__init__.py b/src/pytorch_lightning/plugins/precision/__init__.py index 5206aed62c497..e74bbbabf2a82 100644 --- a/src/pytorch_lightning/plugins/precision/__init__.py +++ b/src/pytorch_lightning/plugins/precision/__init__.py @@ -18,7 +18,6 @@ from pytorch_lightning.plugins.precision.fully_sharded_native_amp import FullyShardedNativeMixedPrecisionPlugin from pytorch_lightning.plugins.precision.hpu import HPUPrecisionPlugin from pytorch_lightning.plugins.precision.ipu import IPUPrecisionPlugin -from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin @@ -33,7 +32,6 @@ "FullyShardedNativeMixedPrecisionPlugin", "HPUPrecisionPlugin", "IPUPrecisionPlugin", - "MixedPrecisionPlugin", "NativeMixedPrecisionPlugin", "PrecisionPlugin", "ShardedNativeMixedPrecisionPlugin", diff --git a/src/pytorch_lightning/plugins/precision/apex_amp.py b/src/pytorch_lightning/plugins/precision/apex_amp.py index 0416e216f6834..d5f1562551325 100644 --- a/src/pytorch_lightning/plugins/precision/apex_amp.py +++ b/src/pytorch_lightning/plugins/precision/apex_amp.py @@ -11,15 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, Optional, Union +from typing import Any, Callable, Dict, Optional from torch import Tensor -from torch.nn import Module from torch.optim import LBFGS, Optimizer import pytorch_lightning as pl -from lightning_lite.utilities.types import _PARAMETERS -from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin +from lightning_lite.utilities.types import _PARAMETERS, Steppable +from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin from pytorch_lightning.utilities import _APEX_AVAILABLE, AMPType from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -27,7 +26,7 @@ from apex import amp -class ApexMixedPrecisionPlugin(MixedPrecisionPlugin): +class ApexMixedPrecisionPlugin(PrecisionPlugin): """Mixed Precision Plugin based on Nvidia/Apex (https://github.com/NVIDIA/apex)""" backend = AMPType.APEX @@ -55,31 +54,32 @@ def dispatch(self, trainer: "pl.Trainer") -> None: self._connected = True return super().dispatch(trainer) - def backward( + def backward( # type: ignore[override] self, + tensor: Tensor, model: "pl.LightningModule", - closure_loss: Tensor, - optimizer: Optional[Optimizer], - optimizer_idx: Optional[int], + optimizer: Optional[Steppable], *args: Any, **kwargs: Any, ) -> None: - """Run before precision plugin executes backward. + r"""Run before precision plugin executes backward. Args: + tensor: the loss value obtained from the closure model: the model to be optimized - closure_loss: the loss value obtained from the closure optimizer: current optimizer being used. ``None`` if using manual optimization - optimizer_idx: the index of the current optimizer. ``None`` if using manual optimization + \*args: Positional arguments intended for the actual function that performs the backward, like + :meth:`~torch.Tensor.backward`. + \**kwargs: Keyword arguments for the same purpose as ``*args``. """ opt = optimizer or model.trainer.optimizers - with amp.scale_loss(closure_loss, opt) as closure_loss: - super().backward(model, closure_loss, optimizer, optimizer_idx, *args, **kwargs) + with amp.scale_loss(tensor, opt) as tensor: + super().backward(tensor, model, optimizer, *args, **kwargs) - def optimizer_step( + def optimizer_step( # type: ignore[override] self, - model: Optional[Union["pl.LightningModule", Module]], - optimizer: Optimizer, + optimizer: Steppable, + model: "pl.LightningModule", optimizer_idx: int, closure: Callable[[], Any], **kwargs: Any, @@ -97,7 +97,7 @@ def optimizer_step( self._after_closure(model, optimizer, optimizer_idx) skipped_backward = closure_result is None # in manual optimization, the closure does not return a value - if not isinstance(model, pl.LightningModule) or not model.automatic_optimization or not skipped_backward: + if not model.automatic_optimization or not skipped_backward: return optimizer.step(**kwargs) return closure_result diff --git a/src/pytorch_lightning/plugins/precision/deepspeed.py b/src/pytorch_lightning/plugins/precision/deepspeed.py index 658e66cd1b7ad..1b6cbb6ba84dd 100644 --- a/src/pytorch_lightning/plugins/precision/deepspeed.py +++ b/src/pytorch_lightning/plugins/precision/deepspeed.py @@ -16,11 +16,11 @@ from lightning_utilities.core.imports import RequirementCache from lightning_utilities.core.rank_zero import WarningCache from torch import Tensor -from torch.nn import Module from torch.optim import LBFGS, Optimizer import pytorch_lightning as pl from lightning_lite.utilities.enums import AMPType, PrecisionType +from lightning_lite.utilities.types import Steppable from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin from pytorch_lightning.utilities import GradClipAlgorithmType from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -73,11 +73,11 @@ def __init__(self, precision: Union[str, int], amp_type: str, amp_level: Optiona self.amp_type = amp_type self.amp_level = amp_level - def backward( + def backward( # type: ignore[override] self, + tensor: Tensor, model: "pl.LightningModule", - closure_loss: Tensor, - optimizer: Optional[Optimizer], + optimizer: Optional[Steppable], optimizer_idx: Optional[int], *args: Any, **kwargs: Any, @@ -85,8 +85,8 @@ def backward( r"""Performs back-propagation using DeepSpeed's engine. Args: + tensor: the loss tensor model: the model to be optimized - closure_loss: the loss tensor optimizer: ignored for DeepSpeed optimizer_idx: ignored for DeepSpeed \*args: additional positional arguments for the :meth:`deepspeed.DeepSpeedEngine.backward` call @@ -98,19 +98,12 @@ def backward( " the backward logic internally." ) deepspeed_engine: "deepspeed.DeepSpeedEngine" = model.trainer.model - deepspeed_engine.backward(closure_loss, *args, **kwargs) + deepspeed_engine.backward(tensor, *args, **kwargs) - def _run_backward( - self, tensor: Tensor, model: Optional["deepspeed.DeepSpeedEngine"], *args: Any, **kwargs: Any - ) -> None: - if model is None: - raise ValueError("Please provide the model as input to `backward`.") - model.backward(tensor, *args, **kwargs) - - def optimizer_step( + def optimizer_step( # type: ignore[override] self, - model: Optional[Union["pl.LightningModule", Module]], - optimizer: Optimizer, + optimizer: Steppable, + model: "pl.LightningModule", optimizer_idx: int, closure: Callable[[], Any], **kwargs: Any, @@ -123,16 +116,12 @@ def optimizer_step( self._after_closure(model, optimizer, optimizer_idx) skipped_backward = closure_result is None # in manual optimization, the closure does not return a value - if isinstance(model, pl.LightningModule) and model.automatic_optimization and skipped_backward: + if model.automatic_optimization and skipped_backward: raise MisconfigurationException( "Skipping backward by returning `None` from your `training_step` is not supported by `DeepSpeed`" ) # DeepSpeed handles the optimizer step internally - deepspeed_engine: "deepspeed.DeepSpeedEngine" - if isinstance(model, pl.LightningModule): - deepspeed_engine = model.trainer.model - else: - deepspeed_engine = model + deepspeed_engine: "deepspeed.DeepSpeedEngine" = model.trainer.model return deepspeed_engine.step(**kwargs) def clip_gradients( diff --git a/src/pytorch_lightning/plugins/precision/ipu.py b/src/pytorch_lightning/plugins/precision/ipu.py index 2b01dd010fc5f..fb3978417ecac 100644 --- a/src/pytorch_lightning/plugins/precision/ipu.py +++ b/src/pytorch_lightning/plugins/precision/ipu.py @@ -11,14 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Union from lightning_utilities.core.rank_zero import WarningCache -from torch.nn import Module +from torch import Tensor from torch.optim import LBFGS, Optimizer import pytorch_lightning as pl from lightning_lite.utilities.enums import PrecisionType +from lightning_lite.utilities.types import Steppable from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin from pytorch_lightning.utilities import GradClipAlgorithmType from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -45,17 +46,23 @@ def __init__(self, precision: int) -> None: super().__init__() self.precision = precision - def backward(self, model: "pl.LightningModule", *_: Any, **__: Any) -> None: + def backward( # type: ignore[override] + self, + tensor: Tensor, + model: "pl.LightningModule", + *args: Any, + **kwargs: Any, + ) -> None: if is_overridden("backward", model): warning_cache.warn( "You have overridden the `LightningModule.backward` hook but it will be ignored since IPUs handle" " the backward logic internally." ) - def optimizer_step( + def optimizer_step( # type: ignore[override] self, - model: Optional[Union["pl.LightningModule", Module]], - optimizer: Optimizer, + optimizer: Steppable, + model: "pl.LightningModule", optimizer_idx: int, closure: Callable[[], Any], **kwargs: Any, @@ -69,7 +76,7 @@ def optimizer_step( self._after_closure(model, optimizer, optimizer_idx) skipped_backward = closure_result is None # in manual optimization, the closure does not return a value - if isinstance(model, pl.LightningModule) and model.automatic_optimization and skipped_backward: + if model.automatic_optimization and skipped_backward: # we lack coverage here and IPUs are (currently) limited - something to explore if there's demand raise MisconfigurationException( "Skipping backward by returning `None` from your `training_step` is not implemented for IPUs." diff --git a/src/pytorch_lightning/plugins/precision/mixed.py b/src/pytorch_lightning/plugins/precision/mixed.py deleted file mode 100644 index 52c8b96d42882..0000000000000 --- a/src/pytorch_lightning/plugins/precision/mixed.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import TYPE_CHECKING, Union - -from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin - -if TYPE_CHECKING: - from pytorch_lightning.utilities import AMPType - - -class MixedPrecisionPlugin(PrecisionPlugin): - """Base Class for mixed precision.""" - - backend: "AMPType" - precision: Union[str, int] = "mixed" diff --git a/src/pytorch_lightning/plugins/precision/native_amp.py b/src/pytorch_lightning/plugins/precision/native_amp.py index 4df1b166ca8dd..6127aaed9c7db 100644 --- a/src/pytorch_lightning/plugins/precision/native_amp.py +++ b/src/pytorch_lightning/plugins/precision/native_amp.py @@ -16,11 +16,11 @@ import torch from torch import Tensor -from torch.nn import Module -from torch.optim import LBFGS, Optimizer +from torch.optim import LBFGS import pytorch_lightning as pl -from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin +from lightning_lite.utilities.types import Steppable +from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_10, AMPType from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -30,7 +30,7 @@ from torch.cuda.amp import autocast as old_autocast -class NativeMixedPrecisionPlugin(MixedPrecisionPlugin): +class NativeMixedPrecisionPlugin(PrecisionPlugin): """Plugin for Native Mixed Precision (AMP) training with ``torch.autocast``. Args: @@ -57,27 +57,24 @@ def __init__( self.device = device self.scaler = scaler - def pre_backward(self, model: "pl.LightningModule", closure_loss: Tensor) -> Tensor: - if self.scaler is not None: - closure_loss = self.scaler.scale(closure_loss) - return super().pre_backward(model, closure_loss) - - def _run_backward(self, tensor: Tensor, model: Optional[Module], *args: Any, **kwargs: Any) -> None: + def pre_backward(self, tensor: Tensor, module: "pl.LightningModule") -> Tensor: # type: ignore[override] if self.scaler is not None: tensor = self.scaler.scale(tensor) - super()._run_backward(tensor, model, *args, **kwargs) + return super().pre_backward(tensor, module) - def optimizer_step( + def optimizer_step( # type: ignore[override] self, - model: Optional[Union["pl.LightningModule", Module]], - optimizer: Optimizer, + optimizer: Steppable, + model: "pl.LightningModule", optimizer_idx: int, closure: Callable[[], Any], **kwargs: Any, ) -> Any: if self.scaler is None: # skip scaler logic, as bfloat16 does not require scaler - return super().optimizer_step(model, optimizer, optimizer_idx, closure, **kwargs) + return super().optimizer_step( + optimizer, model=model, optimizer_idx=optimizer_idx, closure=closure, **kwargs + ) if isinstance(optimizer, LBFGS): raise MisconfigurationException( f"Native AMP and the LBFGS optimizer are not compatible (optimizer {optimizer_idx})." @@ -88,7 +85,7 @@ def optimizer_step( self._after_closure(model, optimizer, optimizer_idx) skipped_backward = closure_result is None # in manual optimization, the closure does not return a value - if not isinstance(model, pl.LightningModule) or not model.automatic_optimization or not skipped_backward: + if not model.automatic_optimization or not skipped_backward: # note: the scaler will skip the `optimizer.step` if nonfinite gradients are found step_output = self.scaler.step(optimizer, **kwargs) self.scaler.update() diff --git a/src/pytorch_lightning/plugins/precision/precision_plugin.py b/src/pytorch_lightning/plugins/precision/precision_plugin.py index 790ab99707403..c53f091d3fa29 100644 --- a/src/pytorch_lightning/plugins/precision/precision_plugin.py +++ b/src/pytorch_lightning/plugins/precision/precision_plugin.py @@ -13,7 +13,7 @@ # limitations under the License. import contextlib from functools import partial -from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union +from typing import Any, Callable, Generator, List, Optional, Tuple, Union import torch from torch import Tensor @@ -21,12 +21,13 @@ from torch.optim import Optimizer import pytorch_lightning as pl -from lightning_lite.utilities.types import _PARAMETERS +from lightning_lite.plugins import Precision as LitePrecision +from lightning_lite.utilities.types import Steppable from pytorch_lightning.core.hooks import CheckpointHooks from pytorch_lightning.utilities import grad_norm, GradClipAlgorithmType -class PrecisionPlugin(CheckpointHooks): +class PrecisionPlugin(LitePrecision, CheckpointHooks): """Base class for all plugins handling the precision-specific parts of the training. The class attribute precision must be overwritten in child classes. The default value reflects fp32 training. @@ -34,36 +35,22 @@ class PrecisionPlugin(CheckpointHooks): precision: Union[str, int] = 32 - def main_params(self, optimizer: Optimizer) -> _PARAMETERS: - """The main params of the model. - - Returns the plain model params here. Maybe different in other precision plugins. - """ - for group in optimizer.param_groups: - yield from group["params"] - def connect( self, model: Module, optimizers: List[Optimizer], lr_schedulers: List[Any] ) -> Tuple[Module, List[Optimizer], List[Any]]: """Connects this plugin to the accelerator and the training process.""" return model, optimizers, lr_schedulers - def pre_backward(self, model: "pl.LightningModule", closure_loss: Tensor) -> Tensor: - """Run before precision plugin executes backward. - - Args: - model: the model to be optimized - closure_loss: the loss value obtained from the closure - """ - model.trainer._call_callback_hooks("on_before_backward", closure_loss) - model.trainer._call_lightning_module_hook("on_before_backward", closure_loss) - return closure_loss + def pre_backward(self, tensor: Tensor, module: "pl.LightningModule") -> Tensor: # type: ignore[override] + module.trainer._call_callback_hooks("on_before_backward", tensor) + module.trainer._call_lightning_module_hook("on_before_backward", tensor) + return tensor - def backward( + def backward( # type: ignore[override] self, + tensor: Tensor, model: "pl.LightningModule", - closure_loss: Tensor, - optimizer: Optional[Optimizer], + optimizer: Optional[Steppable], optimizer_idx: Optional[int], *args: Any, **kwargs: Any, @@ -71,47 +58,25 @@ def backward( r"""Performs the actual backpropagation. Args: + tensor: the loss value obtained from the closure model: the model to be optimized - closure_loss: the loss value obtained from the closure optimizer: current optimizer being used. ``None`` if using manual optimization optimizer_idx: the index of the current optimizer. ``None`` if using manual optimization \*args: Positional arguments intended for the actual function that performs the backward, like :meth:`~torch.Tensor.backward`. \**kwargs: Keyword arguments for the same purpose as ``*args``. """ - # do backward pass - if model is not None and isinstance(model, pl.LightningModule): - model.backward(closure_loss, optimizer, optimizer_idx, *args, **kwargs) - else: - self._run_backward(closure_loss, *args, **kwargs) + model.backward(tensor, optimizer, optimizer_idx, *args, **kwargs) - def post_backward(self, model: "pl.LightningModule", closure_loss: Tensor) -> Tensor: - """Run after precision plugin executes backward. - - Args: - model: the model to be optimized - closure_loss: the loss value obtained from the closure - """ + def post_backward(self, tensor: Tensor, module: "pl.LightningModule") -> Tensor: # type: ignore[override] # once backward has been applied, release graph - closure_loss = closure_loss.detach() - model.trainer._call_callback_hooks("on_after_backward") - model.trainer._call_lightning_module_hook("on_after_backward") + closure_loss = tensor.detach() + module.trainer._call_callback_hooks("on_after_backward") + module.trainer._call_lightning_module_hook("on_after_backward") return closure_loss - def _run_backward(self, tensor: Tensor, model: Optional[Module], *args: Any, **kwargs: Any) -> None: - """Lightning-independent backward logic. - - Currently only used by Lightning Lite. Subject to further refactors. - """ - tensor.backward(*args, **kwargs) - - def _after_closure( - self, model: Optional[Union["pl.LightningModule", Module]], optimizer: Optimizer, optimizer_idx: int - ) -> None: + def _after_closure(self, model: "pl.LightningModule", optimizer: Steppable, optimizer_idx: int) -> None: """Utility to share some code after the closure has been run.""" - if not isinstance(model, pl.LightningModule): - # none of this applies to Lite - return trainer = model.trainer trainer._call_callback_hooks("on_before_optimizer_step", optimizer, optimizer_idx) trainer._call_lightning_module_hook("on_before_optimizer_step", optimizer, optimizer_idx) @@ -143,17 +108,16 @@ def _wrap_closure( self._after_closure(model, optimizer, optimizer_idx) return closure_result - def optimizer_step( + def optimizer_step( # type: ignore[override] self, - model: Optional[Union["pl.LightningModule", Module]], - optimizer: Optimizer, + optimizer: Steppable, + model: "pl.LightningModule", optimizer_idx: int, closure: Callable[[], Any], **kwargs: Any, ) -> Any: """Hook to run the optimizer step.""" - if isinstance(model, pl.LightningModule): - closure = partial(self._wrap_closure, model, optimizer, optimizer_idx, closure) + closure = partial(self._wrap_closure, model, optimizer, optimizer_idx, closure) return optimizer.step(closure=closure, **kwargs) def _track_grad_norm(self, trainer: "pl.Trainer") -> None: @@ -174,7 +138,7 @@ def _track_grad_norm(self, trainer: "pl.Trainer") -> None: def _clip_gradients( self, model: Union["pl.LightningModule", Module], - optimizer: Optimizer, + optimizer: Steppable, optimizer_idx: int, clip_val: Optional[Union[int, float]] = None, gradient_clip_algorithm: Optional[GradClipAlgorithmType] = None, @@ -218,11 +182,6 @@ def clip_grad_by_norm(self, optimizer: Optimizer, clip_val: Union[int, float]) - def dispatch(self, trainer: "pl.Trainer") -> None: """Hook to do something when ``Strategy.dispatch()`` gets called.""" - @contextlib.contextmanager - def forward_context(self) -> Generator[None, None, None]: - """A contextmanager for managing model forward/training_step/evaluation_step/predict_step.""" - yield - @contextlib.contextmanager def train_step_context(self) -> Generator[None, None, None]: """A contextmanager for the training step.""" @@ -246,26 +205,3 @@ def predict_step_context(self) -> Generator[None, None, None]: """A contextmanager for the predict step.""" with self.forward_context(): yield - - def teardown(self) -> None: - """This method is called to teardown the training process. - - It is the right place to release memory and free other resources. - """ - - def state_dict(self) -> Dict[str, Any]: - """Called when saving a checkpoint, implement to generate precision plugin state_dict. - - Returns: - A dictionary containing precision plugin state. - """ - return {} - - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: - """Called when loading a checkpoint, implement to reload precision plugin state given precision plugin - state_dict. - - Args: - state_dict: the precision plugin state returned by ``state_dict``. - """ - pass diff --git a/src/pytorch_lightning/plugins/precision/tpu.py b/src/pytorch_lightning/plugins/precision/tpu.py index b393492a168bb..3af98a7b26ce8 100644 --- a/src/pytorch_lightning/plugins/precision/tpu.py +++ b/src/pytorch_lightning/plugins/precision/tpu.py @@ -12,12 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. from functools import partial -from typing import Any, Callable, Optional, Union - -from torch.nn import Module -from torch.optim import Optimizer +from typing import Any, Callable import pytorch_lightning as pl +from lightning_lite.utilities.types import Steppable from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin from pytorch_lightning.utilities import _XLA_AVAILABLE from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -29,20 +27,19 @@ class TPUPrecisionPlugin(PrecisionPlugin): """Precision plugin for TPU integration.""" - def optimizer_step( + def optimizer_step( # type: ignore[override] self, - model: Optional[Union["pl.LightningModule", Module]], - optimizer: Optimizer, + optimizer: Steppable, + model: "pl.LightningModule", optimizer_idx: int, closure: Callable[[], Any], - **kwargs: Any + **kwargs: Any, ) -> Any: - if isinstance(model, pl.LightningModule): - closure = partial(self._wrap_closure, model, optimizer, optimizer_idx, closure) + closure = partial(self._wrap_closure, model, optimizer, optimizer_idx, closure) closure_result = xm.optimizer_step(optimizer, optimizer_args={"closure": closure, **kwargs}) skipped_backward = closure_result is None # in manual optimization, the closure does not return a value - if isinstance(model, pl.LightningModule) and model.automatic_optimization and skipped_backward: + if model.automatic_optimization and skipped_backward: # we lack coverage here so disable this - something to explore if there's demand raise MisconfigurationException( "Skipping backward by returning `None` from your `training_step` is not implemented for TPUs." diff --git a/src/pytorch_lightning/strategies/strategy.py b/src/pytorch_lightning/strategies/strategy.py index 3a38134bb3365..8e0ca583ff682 100644 --- a/src/pytorch_lightning/strategies/strategy.py +++ b/src/pytorch_lightning/strategies/strategy.py @@ -202,11 +202,11 @@ def backward( """ self.pre_backward(closure_loss) assert self.lightning_module is not None - closure_loss = self.precision_plugin.pre_backward(self.lightning_module, closure_loss) + closure_loss = self.precision_plugin.pre_backward(closure_loss, self.lightning_module) - self.precision_plugin.backward(self.lightning_module, closure_loss, optimizer, optimizer_idx, *args, **kwargs) + self.precision_plugin.backward(closure_loss, self.lightning_module, optimizer, optimizer_idx, *args, **kwargs) - closure_loss = self.precision_plugin.post_backward(self.lightning_module, closure_loss) + closure_loss = self.precision_plugin.post_backward(closure_loss, self.lightning_module) self.post_backward(closure_loss) return closure_loss @@ -219,17 +219,21 @@ def optimizer_step( model: Optional[Union["pl.LightningModule", Module]] = None, **kwargs: Any, ) -> Any: - """Performs the actual optimizer step. + r"""Performs the actual optimizer step. Args: optimizer: the optimizer performing the step opt_idx: index of the current optimizer closure: closure calculating the loss value model: reference to the model, optionally defining optimizer step related hooks - **kwargs: Any extra arguments to ``optimizer.step`` + \**kwargs: Keyword arguments to to ``optimizer.step`` """ model = model or self.lightning_module - return self.precision_plugin.optimizer_step(model, optimizer, opt_idx, closure, **kwargs) + # TODO(lite): remove assertion once strategy's optimizer_step typing is fixed + assert isinstance(model, pl.LightningModule) + return self.precision_plugin.optimizer_step( + optimizer, model=model, optimizer_idx=opt_idx, closure=closure, **kwargs + ) def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer]) -> Tuple[Module, List[Optimizer]]: """Setup a model and multiple optimizers together.