Skip to content

Commit

Permalink
Add typing to LightningModule.trainer (#12345)
Browse files Browse the repository at this point in the history
Co-authored-by: Rohit Gupta <[email protected]>
Co-authored-by: Carlos Mocholí <[email protected]>
  • Loading branch information
3 people authored Mar 29, 2022
1 parent 2de6a9b commit 42169a2
Show file tree
Hide file tree
Showing 14 changed files with 71 additions and 76 deletions.
2 changes: 1 addition & 1 deletion pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
torch._C._log_api_usage_once(f"lightning.module.{self.__class__.__name__}")

# pointer to the trainer object
self.trainer = None
self.trainer: Optional["pl.Trainer"] = None

self._use_amp: bool = False

Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/core/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ def _init_optimizers_and_lr_schedulers(
model: "pl.LightningModule",
) -> Tuple[List[Optimizer], List[LRSchedulerConfig], List[int]]:
"""Calls `LightningModule.configure_optimizers` and parses and validates the output."""
assert model.trainer is not None
optim_conf = model.trainer._call_lightning_module_hook("configure_optimizers", pl_module=model)

if optim_conf is None:
Expand Down
53 changes: 24 additions & 29 deletions pytorch_lightning/overrides/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,10 @@ def on_post_move_to_device(self) -> None:


class _LightningModuleWrapperBase(DeviceDtypeModuleMixin, torch.nn.Module):
def __init__(self, pl_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]):
"""
Wraps the user's LightningModule and redirects the forward call to the appropriate
method, either ``training_step``, ``validation_step`` or ``test_step``.
If the LightningModule is in none of the states `training`, `testing` or `validation`,
the inputs will be redirected to the
:meth:`~pytorch_lightning.core.lightning.LightningModule.predict` method.
def __init__(self, pl_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]) -> None:
"""Wraps the user's LightningModule and redirects the forward call to the appropriate method, either
``training_step``, ``validation_step``, ``test_step``, or ``predict_step``.
Inheriting classes may also modify the inputs or outputs of forward.
Args:
Expand All @@ -77,28 +74,26 @@ def __init__(self, pl_module: Union["pl.LightningModule", _LightningPrecisionMod
self._ddp_params_and_buffers_to_ignore = [f"module.{p}" for p in _ddp_params_and_buffers_to_ignore]

def forward(self, *inputs: Any, **kwargs: Any) -> Any:
lightning_module = unwrap_lightning_module(self.module)
trainer = lightning_module.trainer

if trainer and trainer.training:
output = self.module.training_step(*inputs, **kwargs)

# In manual_optimization, we need to prevent DDP reducer as
# it is done manually in `LightningModule.manual_backward`
# `require_backward_grad_sync` will be reset in the
# ddp_strategy `post_training_step` hook
if not lightning_module.automatic_optimization:
trainer.model.require_backward_grad_sync = False
elif trainer and trainer.testing:
output = self.module.test_step(*inputs, **kwargs)
elif trainer and (trainer.sanity_checking or trainer.validating):
output = self.module.validation_step(*inputs, **kwargs)
elif trainer and trainer.predicting:
output = self.module.predict_step(*inputs, **kwargs)
else:
output = self.module(*inputs, **kwargs)

return output
pl_module = unwrap_lightning_module(self.module)
trainer = pl_module.trainer

if trainer is not None:
if trainer.training:
output = self.module.training_step(*inputs, **kwargs)
# In manual_optimization, we need to prevent DDP reducer as
# it is done manually in `LightningModule.manual_backward`
# `require_backward_grad_sync` will be reset in the
# ddp_strategy `post_training_step` hook
if not pl_module.automatic_optimization:
trainer.model.require_backward_grad_sync = False # type: ignore[assignment]
return output
if trainer.testing:
return self.module.test_step(*inputs, **kwargs)
if trainer.sanity_checking or trainer.validating:
return self.module.validation_step(*inputs, **kwargs)
if trainer.predicting:
return self.module.predict_step(*inputs, **kwargs)
return self.module(*inputs, **kwargs)

def on_post_move_to_device(self) -> None:
pass
Expand Down
18 changes: 10 additions & 8 deletions pytorch_lightning/overrides/data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
# limitations under the License.
import numbers
import warnings
from typing import Any, Union
from typing import Any, cast, Union

import torch

import pytorch_lightning as pl
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.rank_zero import rank_zero_warn

Expand All @@ -36,10 +36,11 @@ def _ignore_scalar_return_in_dp() -> None:

class LightningParallelModule(_LightningModuleWrapperBase):
"""Wraps the user's LightningModule and redirects the forward call to the appropriate method, either
``training_step``, ``validation_step``, ``test_step`` or ``predict``. This class is used in combination with
:class:`~torch.nn.parallel.DataParallel` as shown in the example. It also takes care of converting Python
scalars to Tensors and un-squeezes 0-dimensional Tensors as it is required by
:class:`~torch.nn.parallel.DataParallel`.
``training_step``, ``validation_step``, ``test_step``, or ``predict_step``.
This class is used in combination with :class:`~torch.nn.parallel.DataParallel` as shown in the example.
It also takes care of converting Python scalars to Tensors and un-squeezes 0-dimensional Tensors as it is required
by :class:`~torch.nn.parallel.DataParallel`.
Example:
Expand All @@ -53,7 +54,7 @@ class LightningParallelModule(_LightningModuleWrapperBase):
pl_module: the model to wrap
"""

def __init__(self, pl_module: "pl.LightningModule") -> None:
def __init__(self, pl_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]) -> None:
super().__init__(pl_module)
_ignore_scalar_return_in_dp()

Expand All @@ -63,7 +64,8 @@ def forward(self, *inputs: Any, **kwargs: Any) -> Any:
output = super().forward(*inputs, **kwargs)

def output_transform(data: Any) -> Any:
data = python_scalar_to_tensor(data, self.module.device)
device = cast(torch.device, self.module.device)
data = python_scalar_to_tensor(data, device)
data = unsqueeze_scalar_tensor(data)
return data

Expand Down
23 changes: 3 additions & 20 deletions pytorch_lightning/overrides/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,36 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
from typing import Any, cast, Iterator, List, Sized, Union
from typing import Any, cast, Iterable, Iterator, List, Sized, Union

import torch
from torch import Tensor
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import BatchSampler, DistributedSampler, Sampler

import pytorch_lightning as pl
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
from pytorch_lightning.utilities import rank_zero_deprecation


class LightningDistributedModule(_LightningModuleWrapperBase):
def __init__(self, pl_module: "pl.LightningModule") -> None:
"""Wraps the user's LightningModule and redirects the forward call to the appropriate method, either
``training_step``, ``validation_step``, ``test_step`` or ``predict``. This class is used in combination
with :class:`~torch.nn.parallel.DistributedDataParallel` as shown in the example.
Example:
ddp_model = torch.nn.parallel.DistributedDataParallel(
module=LightningDistributedModule(lightning_module),
device_ids=[local_rank],
...
)
Args:
pl_module: the model to wrap
"""
super().__init__(pl_module)
...


def _find_tensors(
Expand Down Expand Up @@ -164,5 +147,5 @@ def batch_size(self) -> int:
return self._sampler.batch_size

@property
def sampler(self) -> Sampler:
def sampler(self) -> Union[Sampler, Iterable]:
return self._sampler.sampler
1 change: 1 addition & 0 deletions pytorch_lightning/plugins/precision/apex_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def backward(
closure_loss: the loss value obtained from the closure
optimizer: current optimizer being used. ``None`` if using manual optimization
"""
assert model.trainer is not None
opt = optimizer or model.trainer.optimizers
with amp.scale_loss(closure_loss, opt) as closure_loss:
super().backward(model, closure_loss, optimizer, *args, **kwargs)
Expand Down
8 changes: 7 additions & 1 deletion pytorch_lightning/plugins/precision/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def backward(self, model: "pl.LightningModule", closure_loss: Tensor, *args: Any
"You have overridden the `LightningModule.backward` hook but it will be ignored since DeepSpeed handles"
" the backward logic internally."
)
assert model.trainer is not None
deepspeed_engine: DeepSpeedEngine = model.trainer.model
deepspeed_engine.backward(closure_loss, *args, **kwargs)

Expand Down Expand Up @@ -75,7 +76,12 @@ def optimizer_step(
"Skipping backward by returning `None` from your `training_step` is not supported by `DeepSpeed`"
)
# DeepSpeed handles the optimizer step internally
deepspeed_engine = model.trainer.model if isinstance(model, pl.LightningModule) else model
deepspeed_engine: DeepSpeedEngine
if isinstance(model, pl.LightningModule):
assert model.trainer is not None
deepspeed_engine = model.trainer.model
else:
deepspeed_engine = model
return deepspeed_engine.step(**kwargs)

def clip_gradients(
Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/plugins/precision/precision_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def pre_backward(self, model: "pl.LightningModule", closure_loss: Tensor) -> Ten
model: the model to be optimized
closure_loss: the loss value obtained from the closure
"""
assert model.trainer is not None
model.trainer._call_callback_hooks("on_before_backward", closure_loss)
model.trainer._call_lightning_module_hook("on_before_backward", closure_loss)
return closure_loss
Expand Down Expand Up @@ -89,6 +90,7 @@ def post_backward(self, model: "pl.LightningModule", closure_loss: Tensor) -> Te
"""
# once backward has been applied, release graph
closure_loss = closure_loss.detach()
assert model.trainer is not None
model.trainer._call_callback_hooks("on_after_backward")
model.trainer._call_lightning_module_hook("on_after_backward")
return closure_loss
Expand Down
10 changes: 8 additions & 2 deletions pytorch_lightning/strategies/bagua.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@
from torch.nn import Module

import pytorch_lightning as pl
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, unwrap_lightning_module
from pytorch_lightning.overrides.base import (
_LightningModuleWrapperBase,
_LightningPrecisionModuleWrapperBase,
unwrap_lightning_module,
)
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.precision import PrecisionPlugin
Expand All @@ -32,7 +36,7 @@


class LightningBaguaModule(_LightningModuleWrapperBase):
def __init__(self, pl_module: "pl.LightningModule") -> None:
def __init__(self, pl_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]) -> None:
super().__init__(pl_module)
# Bagua use `bagua_module_name` to distinguish different modules
self._bagua_module_name = f"{pl_module.__class__.__name__}{id(pl_module)}"
Expand Down Expand Up @@ -161,6 +165,7 @@ def configure_ddp(self) -> None:
self._model = self._setup_model(model)

# start the background communication for async algorithm
assert self.lightning_module.trainer is not None
if self.lightning_module.trainer.training and self._bagua_algorithm == "async":
self.model.bagua_algorithm.resume(self.model) # type: ignore

Expand Down Expand Up @@ -188,6 +193,7 @@ def register_strategies(cls, strategy_registry: Dict) -> None:

def teardown(self) -> None:
# abort the background communication for async algorithm
assert self.lightning_module.trainer is not None
if self.lightning_module.trainer.training and self._bagua_algorithm == "async":
self.model.bagua_algorithm.abort(self.model) # type: ignore

Expand Down
6 changes: 4 additions & 2 deletions pytorch_lightning/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

import pytorch_lightning as pl
from pytorch_lightning.core.optimizer import _init_optimizers_and_lr_schedulers
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.strategies.ddp import DDPStrategy
Expand Down Expand Up @@ -67,7 +67,9 @@ def remove_module_hooks(model: torch.nn.Module) -> None:


class LightningDeepSpeedModule(_LightningModuleWrapperBase):
def __init__(self, pl_module: "pl.LightningModule", precision: int) -> None:
def __init__(
self, pl_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase], precision: int
) -> None:
super().__init__(pl_module)
self.precision = precision

Expand Down
6 changes: 4 additions & 2 deletions pytorch_lightning/strategies/ipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from torch.utils.data import DataLoader

import pytorch_lightning as pl
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.precision import PrecisionPlugin
Expand All @@ -40,7 +40,9 @@


class LightningIPUModule(_LightningModuleWrapperBase):
def __init__(self, pl_module: "pl.LightningModule", precision: Union[str, int]):
def __init__(
self, pl_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase], precision: Union[str, int]
) -> None:
super().__init__(pl_module)
self.precision = precision

Expand Down
6 changes: 2 additions & 4 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,10 +598,8 @@ def __init__(

self._terminate_on_nan = terminate_on_nan
self.gradient_clip_val: Union[int, float] = gradient_clip_val
self.gradient_clip_algorithm = (
GradClipAlgorithmType(gradient_clip_algorithm.lower())
if gradient_clip_algorithm is not None
else gradient_clip_algorithm
self.gradient_clip_algorithm: Optional[GradClipAlgorithmType] = (
GradClipAlgorithmType(gradient_clip_algorithm.lower()) if gradient_clip_algorithm is not None else None
)
self.track_grad_norm: float = float(track_grad_norm)

Expand Down
4 changes: 3 additions & 1 deletion pytorch_lightning/utilities/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,9 @@ def get_len(dataloader: DataLoader) -> Union[int, float]:
return float("inf")


def _update_dataloader(dataloader: DataLoader, sampler: Sampler, mode: Optional[RunningStage] = None) -> DataLoader:
def _update_dataloader(
dataloader: DataLoader, sampler: Union[Sampler, Iterable], mode: Optional[RunningStage] = None
) -> DataLoader:
dl_kwargs = _get_dataloader_init_kwargs(dataloader, sampler, mode=mode)
dl_cls = type(dataloader)
try:
Expand Down
7 changes: 1 addition & 6 deletions pytorch_lightning/utilities/model_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities related to model weights summary."""

import contextlib
import logging
from collections import OrderedDict
Expand Down Expand Up @@ -264,11 +263,7 @@ def _forward_example_input(self) -> None:
mode = model.training
model.eval()

if trainer is not None:
forward_context = trainer.precision_plugin.forward_context()
else:
forward_context = contextlib.nullcontext()

forward_context = contextlib.nullcontext() if trainer is None else trainer.precision_plugin.forward_context()
with torch.no_grad(), forward_context:
# let the model hooks collect the input- and output shapes
if isinstance(input_, (list, tuple)):
Expand Down

0 comments on commit 42169a2

Please sign in to comment.