Skip to content

Commit

Permalink
Integrate Lite Precision into PL (#14798)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholí <[email protected]>
Co-authored-by: Jirka Borovec <[email protected]>
  • Loading branch information
3 people authored Sep 22, 2022
1 parent 6df6dea commit dd2a1c5
Show file tree
Hide file tree
Showing 17 changed files with 112 additions and 204 deletions.
1 change: 0 additions & 1 deletion docs/source-pytorch/api_references.rst
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,6 @@ precision
FullyShardedNativeNativeMixedPrecisionPlugin
HPUPrecisionPlugin
IPUPrecisionPlugin
MixedPrecisionPlugin
NativeMixedPrecisionPlugin
PrecisionPlugin
ShardedNativeMixedPrecisionPlugin
Expand Down
1 change: 0 additions & 1 deletion docs/source-pytorch/extensions/plugins.rst
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ The full list of built-in precision plugins is listed below.
FullyShardedNativeNativeMixedPrecisionPlugin
HPUPrecisionPlugin
IPUPrecisionPlugin
MixedPrecisionPlugin
NativeMixedPrecisionPlugin
PrecisionPlugin
ShardedNativeMixedPrecisionPlugin
Expand Down
2 changes: 1 addition & 1 deletion src/lightning_lite/plugins/precision/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from lightning_lite.plugins.precision.precision import Precision # isort:skip
from lightning_lite.plugins.precision.deepspeed import DeepSpeedPrecision
from lightning_lite.plugins.precision.double import DoublePrecision
from lightning_lite.plugins.precision.native_amp import NativeMixedPrecision
from lightning_lite.plugins.precision.precision import Precision
from lightning_lite.plugins.precision.tpu import TPUPrecision
from lightning_lite.plugins.precision.tpu_bf16 import TPUBf16Precision

Expand Down
2 changes: 1 addition & 1 deletion src/lightning_lite/plugins/precision/double.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import torch

from lightning_lite.plugins.precision import Precision
from lightning_lite.plugins.precision.precision import Precision


class DoublePrecision(Precision):
Expand Down
2 changes: 1 addition & 1 deletion src/lightning_lite/plugins/precision/native_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from torch.nn import Module
from torch.optim import LBFGS

from lightning_lite.plugins.precision import Precision
from lightning_lite.plugins.precision.precision import Precision
from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_10
from lightning_lite.utilities.types import Steppable

Expand Down
6 changes: 3 additions & 3 deletions src/lightning_lite/plugins/precision/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def forward_context(self) -> Generator[None, None, None]:
"""A contextmanager for managing model forward/training_step/evaluation_step/predict_step."""
yield

def pre_backward(self, tensor: Tensor, module: Optional[Module]) -> None:
def pre_backward(self, tensor: Tensor, module: Optional[Module]) -> Any:
"""Runs before precision plugin executes backward.
Args:
Expand All @@ -51,7 +51,7 @@ def backward(self, tensor: Tensor, model: Optional[Module], *args: Any, **kwargs
"""
tensor.backward(*args, **kwargs)

def post_backward(self, tensor: Tensor, module: Optional[Module]) -> None:
def post_backward(self, tensor: Tensor, module: Optional[Module]) -> Any:
"""Runs after precision plugin executes backward.
Args:
Expand All @@ -67,7 +67,7 @@ def optimizer_step(
"""Hook to run the optimizer step."""
return optimizer.step(**kwargs)

def get_main_params(self, optimizer: Optimizer) -> _PARAMETERS:
def main_params(self, optimizer: Optimizer) -> _PARAMETERS:
"""The main params of the model.
Returns the plain model params here. Maybe different in other precision plugins.
Expand Down
7 changes: 7 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,16 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- In Lightning Lite, state-dict access to the module wrapper now gets passed through to the original module reference ([#14629](https://github.com/Lightning-AI/lightning/pull/14629))


- Removed fall-back to `LightningEnvironment` when number of SLURM tasks does not correspond to number of processes in Trainer ([#14300](https://github.com/Lightning-AI/lightning/pull/14300))


- Integrated the Lite Precision plugins into the PL Precision plugins - the base class in PL now extends the `lightning_lite.precision.Precision` base class ([#14798](https://github.com/Lightning-AI/lightning/pull/14798))
* The `PrecisionPlugin.backward` signature changed: The `closure_loss` argument was renamed to `tensor`
* The `PrecisionPlugin.{pre_,post_}backward` signature changed: The `closure_loss` argument was renamed to `tensor` and moved as the first argument
* The `PrecisionPlugin.optimizer_step` signature changed: The `model`, `optimizer_idx` and `closure` arguments need to be passed as keyword arguments now


- Trainer queries the CUDA devices through NVML if available to avoid initializing CUDA before forking, which eliminates the need for the `PL_DISABLE_FORK` environment variable introduced in v1.7.4 ([#14631](https://github.com/Lightning-AI/lightning/issues/14631))


Expand Down
3 changes: 2 additions & 1 deletion src/pytorch_lightning/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from lightning_lite.utilities.cloud_io import get_filesystem
from lightning_lite.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
from lightning_lite.utilities.distributed import distributed_available, sync_ddp
from lightning_lite.utilities.types import Steppable
from pytorch_lightning.callbacks.callback import Callback
from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks, ModelHooks
from pytorch_lightning.core.mixins import HyperparametersMixin
Expand Down Expand Up @@ -1398,7 +1399,7 @@ def training_step(...):
self.trainer.strategy.backward(loss, None, None, *args, **kwargs)

def backward(
self, loss: Tensor, optimizer: Optional[Optimizer], optimizer_idx: Optional[int], *args: Any, **kwargs: Any
self, loss: Tensor, optimizer: Optional[Steppable], optimizer_idx: Optional[int], *args: Any, **kwargs: Any
) -> None:
"""Called to perform backward on the loss returned in :meth:`training_step`. Override this hook with your
own implementation if you need to.
Expand Down
2 changes: 0 additions & 2 deletions src/pytorch_lightning/plugins/precision/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from pytorch_lightning.plugins.precision.fully_sharded_native_amp import FullyShardedNativeMixedPrecisionPlugin
from pytorch_lightning.plugins.precision.hpu import HPUPrecisionPlugin
from pytorch_lightning.plugins.precision.ipu import IPUPrecisionPlugin
from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin
from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin
Expand All @@ -33,7 +32,6 @@
"FullyShardedNativeMixedPrecisionPlugin",
"HPUPrecisionPlugin",
"IPUPrecisionPlugin",
"MixedPrecisionPlugin",
"NativeMixedPrecisionPlugin",
"PrecisionPlugin",
"ShardedNativeMixedPrecisionPlugin",
Expand Down
36 changes: 18 additions & 18 deletions src/pytorch_lightning/plugins/precision/apex_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Dict, Optional, Union
from typing import Any, Callable, Dict, Optional

from torch import Tensor
from torch.nn import Module
from torch.optim import LBFGS, Optimizer

import pytorch_lightning as pl
from lightning_lite.utilities.types import _PARAMETERS
from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin
from lightning_lite.utilities.types import _PARAMETERS, Steppable
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
from pytorch_lightning.utilities import _APEX_AVAILABLE, AMPType
from pytorch_lightning.utilities.exceptions import MisconfigurationException

if _APEX_AVAILABLE:
from apex import amp


class ApexMixedPrecisionPlugin(MixedPrecisionPlugin):
class ApexMixedPrecisionPlugin(PrecisionPlugin):
"""Mixed Precision Plugin based on Nvidia/Apex (https://github.com/NVIDIA/apex)"""

backend = AMPType.APEX
Expand Down Expand Up @@ -55,31 +54,32 @@ def dispatch(self, trainer: "pl.Trainer") -> None:
self._connected = True
return super().dispatch(trainer)

def backward(
def backward( # type: ignore[override]
self,
tensor: Tensor,
model: "pl.LightningModule",
closure_loss: Tensor,
optimizer: Optional[Optimizer],
optimizer_idx: Optional[int],
optimizer: Optional[Steppable],
*args: Any,
**kwargs: Any,
) -> None:
"""Run before precision plugin executes backward.
r"""Run before precision plugin executes backward.
Args:
tensor: the loss value obtained from the closure
model: the model to be optimized
closure_loss: the loss value obtained from the closure
optimizer: current optimizer being used. ``None`` if using manual optimization
optimizer_idx: the index of the current optimizer. ``None`` if using manual optimization
\*args: Positional arguments intended for the actual function that performs the backward, like
:meth:`~torch.Tensor.backward`.
\**kwargs: Keyword arguments for the same purpose as ``*args``.
"""
opt = optimizer or model.trainer.optimizers
with amp.scale_loss(closure_loss, opt) as closure_loss:
super().backward(model, closure_loss, optimizer, optimizer_idx, *args, **kwargs)
with amp.scale_loss(tensor, opt) as tensor:
super().backward(tensor, model, optimizer, *args, **kwargs)

def optimizer_step(
def optimizer_step( # type: ignore[override]
self,
model: Optional[Union["pl.LightningModule", Module]],
optimizer: Optimizer,
optimizer: Steppable,
model: "pl.LightningModule",
optimizer_idx: int,
closure: Callable[[], Any],
**kwargs: Any,
Expand All @@ -97,7 +97,7 @@ def optimizer_step(
self._after_closure(model, optimizer, optimizer_idx)
skipped_backward = closure_result is None
# in manual optimization, the closure does not return a value
if not isinstance(model, pl.LightningModule) or not model.automatic_optimization or not skipped_backward:
if not model.automatic_optimization or not skipped_backward:
return optimizer.step(**kwargs)
return closure_result

Expand Down
33 changes: 11 additions & 22 deletions src/pytorch_lightning/plugins/precision/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@
from lightning_utilities.core.imports import RequirementCache
from lightning_utilities.core.rank_zero import WarningCache
from torch import Tensor
from torch.nn import Module
from torch.optim import LBFGS, Optimizer

import pytorch_lightning as pl
from lightning_lite.utilities.enums import AMPType, PrecisionType
from lightning_lite.utilities.types import Steppable
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
from pytorch_lightning.utilities import GradClipAlgorithmType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -73,20 +73,20 @@ def __init__(self, precision: Union[str, int], amp_type: str, amp_level: Optiona
self.amp_type = amp_type
self.amp_level = amp_level

def backward(
def backward( # type: ignore[override]
self,
tensor: Tensor,
model: "pl.LightningModule",
closure_loss: Tensor,
optimizer: Optional[Optimizer],
optimizer: Optional[Steppable],
optimizer_idx: Optional[int],
*args: Any,
**kwargs: Any,
) -> None:
r"""Performs back-propagation using DeepSpeed's engine.
Args:
tensor: the loss tensor
model: the model to be optimized
closure_loss: the loss tensor
optimizer: ignored for DeepSpeed
optimizer_idx: ignored for DeepSpeed
\*args: additional positional arguments for the :meth:`deepspeed.DeepSpeedEngine.backward` call
Expand All @@ -98,19 +98,12 @@ def backward(
" the backward logic internally."
)
deepspeed_engine: "deepspeed.DeepSpeedEngine" = model.trainer.model
deepspeed_engine.backward(closure_loss, *args, **kwargs)
deepspeed_engine.backward(tensor, *args, **kwargs)

def _run_backward(
self, tensor: Tensor, model: Optional["deepspeed.DeepSpeedEngine"], *args: Any, **kwargs: Any
) -> None:
if model is None:
raise ValueError("Please provide the model as input to `backward`.")
model.backward(tensor, *args, **kwargs)

def optimizer_step(
def optimizer_step( # type: ignore[override]
self,
model: Optional[Union["pl.LightningModule", Module]],
optimizer: Optimizer,
optimizer: Steppable,
model: "pl.LightningModule",
optimizer_idx: int,
closure: Callable[[], Any],
**kwargs: Any,
Expand All @@ -123,16 +116,12 @@ def optimizer_step(
self._after_closure(model, optimizer, optimizer_idx)
skipped_backward = closure_result is None
# in manual optimization, the closure does not return a value
if isinstance(model, pl.LightningModule) and model.automatic_optimization and skipped_backward:
if model.automatic_optimization and skipped_backward:
raise MisconfigurationException(
"Skipping backward by returning `None` from your `training_step` is not supported by `DeepSpeed`"
)
# DeepSpeed handles the optimizer step internally
deepspeed_engine: "deepspeed.DeepSpeedEngine"
if isinstance(model, pl.LightningModule):
deepspeed_engine = model.trainer.model
else:
deepspeed_engine = model
deepspeed_engine: "deepspeed.DeepSpeedEngine" = model.trainer.model
return deepspeed_engine.step(**kwargs)

def clip_gradients(
Expand Down
21 changes: 14 additions & 7 deletions src/pytorch_lightning/plugins/precision/ipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Union

from lightning_utilities.core.rank_zero import WarningCache
from torch.nn import Module
from torch import Tensor
from torch.optim import LBFGS, Optimizer

import pytorch_lightning as pl
from lightning_lite.utilities.enums import PrecisionType
from lightning_lite.utilities.types import Steppable
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
from pytorch_lightning.utilities import GradClipAlgorithmType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand All @@ -45,17 +46,23 @@ def __init__(self, precision: int) -> None:
super().__init__()
self.precision = precision

def backward(self, model: "pl.LightningModule", *_: Any, **__: Any) -> None:
def backward( # type: ignore[override]
self,
tensor: Tensor,
model: "pl.LightningModule",
*args: Any,
**kwargs: Any,
) -> None:
if is_overridden("backward", model):
warning_cache.warn(
"You have overridden the `LightningModule.backward` hook but it will be ignored since IPUs handle"
" the backward logic internally."
)

def optimizer_step(
def optimizer_step( # type: ignore[override]
self,
model: Optional[Union["pl.LightningModule", Module]],
optimizer: Optimizer,
optimizer: Steppable,
model: "pl.LightningModule",
optimizer_idx: int,
closure: Callable[[], Any],
**kwargs: Any,
Expand All @@ -69,7 +76,7 @@ def optimizer_step(
self._after_closure(model, optimizer, optimizer_idx)
skipped_backward = closure_result is None
# in manual optimization, the closure does not return a value
if isinstance(model, pl.LightningModule) and model.automatic_optimization and skipped_backward:
if model.automatic_optimization and skipped_backward:
# we lack coverage here and IPUs are (currently) limited - something to explore if there's demand
raise MisconfigurationException(
"Skipping backward by returning `None` from your `training_step` is not implemented for IPUs."
Expand Down
26 changes: 0 additions & 26 deletions src/pytorch_lightning/plugins/precision/mixed.py

This file was deleted.

Loading

0 comments on commit dd2a1c5

Please sign in to comment.