Skip to content

Commit

Permalink
Merge branch 'master' into refactor/promote-cli
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca committed Jul 22, 2022
2 parents c16d7cf + 9f51c07 commit c8a13de
Show file tree
Hide file tree
Showing 30 changed files with 125 additions and 107 deletions.
3 changes: 3 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for calling unknown methods with `DummyLogger` ([#13224](https://github.com/PyTorchLightning/pytorch-lightning/pull/13224)


- Added support for recursively setting the `Trainer` reference for ensembles of `LightningModule`s ([#13638](https://github.com/PyTorchLightning/pytorch-lightning/pull/13638)


- Added Apple Silicon Support via `MPSAccelerator` ([#13123](https://github.com/PyTorchLightning/pytorch-lightning/pull/13123))


Expand Down
3 changes: 1 addition & 2 deletions src/pytorch_lightning/callbacks/progress/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,7 @@ def total_val_batches(self) -> Union[int, float]:
Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the predict dataloader
is of infinite size.
"""
assert self._trainer is not None
return sum(self.trainer.num_val_batches) if self._trainer.fit_loop.epoch_loop._should_check_val_epoch() else 0
return sum(self.trainer.num_val_batches) if self.trainer.fit_loop.epoch_loop._should_check_val_epoch() else 0

@property
def total_batches_current_epoch(self) -> Union[int, float]:
Expand Down
36 changes: 26 additions & 10 deletions src/pytorch_lightning/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ class LightningModule(
"automatic_optimization",
"truncated_bptt_steps",
"use_amp",
"trainer",
]
+ DeviceDtypeModuleMixin.__jit_unused_properties__
+ HyperparametersMixin.__jit_unused_properties__
Expand All @@ -93,7 +94,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
torch._C._log_api_usage_once(f"lightning.module.{self.__class__.__name__}")

# pointer to the trainer object
self.trainer: Optional["pl.Trainer"] = None
self._trainer: Optional["pl.Trainer"] = None

self._use_amp: bool = False

Expand Down Expand Up @@ -172,6 +173,21 @@ def lr_schedulers(self) -> Optional[Union[LRSchedulerTypeUnion, List[LRScheduler
# multiple schedulers
return lr_schedulers

@property
def trainer(self) -> "pl.Trainer":
if not self._running_torchscript and self._trainer is None:
raise RuntimeError(f"{self.__class__.__qualname__} is not attached to a `Trainer`.")
return self._trainer

@trainer.setter
def trainer(self, trainer: Optional["pl.Trainer"]) -> None:
for v in self.children():
if isinstance(v, LightningModule):
v.trainer = trainer
if trainer is not None and not isinstance(trainer, weakref.ProxyTypes):
trainer = weakref.proxy(trainer)
self._trainer = trainer

@property
def example_input_array(self) -> Any:
"""The example input array is a specification of what the module can consume in the :meth:`forward` method.
Expand All @@ -193,25 +209,25 @@ def example_input_array(self, example: Any) -> None:
@property
def current_epoch(self) -> int:
"""The current epoch in the ``Trainer``, or 0 if not attached."""
return self.trainer.current_epoch if self.trainer else 0
return self.trainer.current_epoch if self._trainer else 0

@property
def global_step(self) -> int:
"""Total training batches seen across all epochs.
If no Trainer is attached, this propery is 0.
"""
return self.trainer.global_step if self.trainer else 0
return self.trainer.global_step if self._trainer else 0

@property
def global_rank(self) -> int:
"""The index of the current process across all nodes and devices."""
return self.trainer.global_rank if self.trainer else 0
return self.trainer.global_rank if self._trainer else 0

@property
def local_rank(self) -> int:
"""The index of the current process within a single node."""
return self.trainer.local_rank if self.trainer else 0
return self.trainer.local_rank if self._trainer else 0

@property
def on_gpu(self):
Expand Down Expand Up @@ -249,7 +265,7 @@ def logger(self) -> Optional[Logger]:
"""Reference to the logger object in the Trainer."""
# this should match the implementation of `trainer.logger`
# we don't reuse it so we can properly set the deprecation stacklevel
if self.trainer is None:
if self._trainer is None:
return
loggers = self.trainer.loggers
if len(loggers) == 0:
Expand All @@ -271,14 +287,14 @@ def logger(self) -> Optional[Logger]:
@property
def loggers(self) -> List[Logger]:
"""Reference to the list of loggers in the Trainer."""
return self.trainer.loggers if self.trainer else []
return self.trainer.loggers if self._trainer else []

def _apply_batch_transfer_handler(
self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: int = 0
) -> Any:
device = device or self.device
datahook_selector = (
_DataHookSelector(self, None) if self.trainer is None else self.trainer._data_connector._datahook_selector
_DataHookSelector(self, None) if self._trainer is None else self.trainer._data_connector._datahook_selector
)

hook = datahook_selector.get_hook("on_before_batch_transfer")
Expand Down Expand Up @@ -365,7 +381,7 @@ def log(
value, object, self.__check_allowed, name, value, wrong_dtype=(numbers.Number, Metric, Tensor, dict)
)

if self.trainer is None:
if self._trainer is None:
# not an error to support testing the `*_step` methods without a `Trainer` reference
rank_zero_warn(
"You are trying to `self.log()` but the `self.trainer` reference is not registered on the model yet."
Expand Down Expand Up @@ -1964,7 +1980,7 @@ def _prevent_trainer_and_dataloaders_deepcopy(self) -> None:
def __getstate__(self) -> Dict[str, Any]:
state = dict(self.__dict__)
if self._should_prevent_trainer_and_dataloaders_deepcopy:
state["trainer"] = None
state["_trainer"] = None
state.pop("train_dataloader", None)
state.pop("val_dataloader", None)
state.pop("test_dataloader", None)
Expand Down
1 change: 0 additions & 1 deletion src/pytorch_lightning/core/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,6 @@ def _init_optimizers_and_lr_schedulers(
model: "pl.LightningModule",
) -> Tuple[List[Optimizer], List[LRSchedulerConfig], List[int]]:
"""Calls `LightningModule.configure_optimizers` and parses and validates the output."""
assert model.trainer is not None
optim_conf = model.trainer._call_lightning_module_hook("configure_optimizers", pl_module=model)

if optim_conf is None:
Expand Down
4 changes: 0 additions & 4 deletions src/pytorch_lightning/loggers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from pytorch_lightning.loggers.mlflow import _MLFLOW_AVAILABLE, MLFlowLogger # noqa: F401
from pytorch_lightning.loggers.neptune import _NEPTUNE_AVAILABLE, NeptuneLogger # noqa: F401
from pytorch_lightning.loggers.wandb import WandbLogger # noqa: F401
from pytorch_lightning.utilities.imports import _WANDB_AVAILABLE

if _COMET_AVAILABLE:
__all__.append("CometLogger")
Expand All @@ -38,6 +37,3 @@

if _NEPTUNE_AVAILABLE:
__all__.append("NeptuneLogger")

if _WANDB_AVAILABLE:
__all__.append("WandbLogger")
6 changes: 5 additions & 1 deletion src/pytorch_lightning/loggers/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from pytorch_lightning.callbacks import Checkpoint
from pytorch_lightning.loggers.logger import Logger, rank_zero_experiment
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _WANDB_GREATER_EQUAL_0_10_22, _WANDB_GREATER_EQUAL_0_12_10
from pytorch_lightning.utilities.imports import _RequirementAvailable
from pytorch_lightning.utilities.logger import _add_prefix, _convert_params, _flatten_dict, _sanitize_callable_params
from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn

Expand All @@ -38,6 +38,10 @@
# needed for test mocks, these tests shall be updated
wandb, Run, RunDisabled = None, None, None # type: ignore

_WANDB_AVAILABLE = _RequirementAvailable("wandb")
_WANDB_GREATER_EQUAL_0_10_22 = _RequirementAvailable("wandb>=0.10.22")
_WANDB_GREATER_EQUAL_0_12_10 = _RequirementAvailable("wandb>=0.12.10")


class WandbLogger(Logger):
r"""
Expand Down
5 changes: 3 additions & 2 deletions src/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,9 +182,10 @@ def teardown(self) -> None:
def on_save_checkpoint(self) -> Dict:
state_dict = super().on_save_checkpoint()

trainer = self._trainer
if (
self.trainer is not None
and self.trainer.state._fault_tolerant_mode.is_enabled
trainer is not None
and trainer.state._fault_tolerant_mode.is_enabled
and self._data_fetcher is not None
and not self._num_completed_batches_reached() # did not finish
and self.batch_progress.current.ready # did start
Expand Down
9 changes: 5 additions & 4 deletions src/pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,15 +269,16 @@ def on_save_checkpoint(self) -> Dict:
state_dict = super().on_save_checkpoint()
state_dict["_batches_that_stepped"] = self._batches_that_stepped

trainer = self._trainer
if (
self.trainer is not None
and self.trainer.state._fault_tolerant_mode.is_enabled
and self.trainer.train_dataloader is not None
trainer is not None
and trainer.state._fault_tolerant_mode.is_enabled
and trainer.train_dataloader is not None
and not self._num_completed_batches_reached() # did not finish
# TODO: fault-tolerance requires a minimum number of batches so probably should be > 0
and self.batch_progress.current.ready # did start
):
loader: CombinedLoader = self.trainer.train_dataloader
loader: CombinedLoader = trainer.train_dataloader
state = loader.state_dict(has_completed=self._has_completed())
if state:
state_dict["dataloader_state_dict"] = _collect_states_on_rank_zero_over_collection(state)
Expand Down
11 changes: 2 additions & 9 deletions src/pytorch_lightning/loops/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,6 @@ def trainer(self) -> "pl.Trainer":
@trainer.setter
def trainer(self, trainer: "pl.Trainer") -> None:
"""Connects this loop's trainer and its children."""
if not isinstance(trainer, pl.Trainer):
raise MisconfigurationException(
f"Loop {self.__class__.__name__} should be connected to a `Trainer`, found: {trainer}."
)
self._trainer = trainer
for v in self.__dict__.values():
if isinstance(v, Loop):
Expand Down Expand Up @@ -318,6 +314,7 @@ def load_state_dict(
self.restarting = True

def _load_from_state_dict(self, state_dict: Dict, prefix: str, metrics: Optional[Dict[str, Metric]] = None) -> None:
trainer = self._trainer
for k, v in self.__dict__.items():
key = prefix + k
if key not in state_dict:
Expand All @@ -326,11 +323,7 @@ def _load_from_state_dict(self, state_dict: Dict, prefix: str, metrics: Optional

if isinstance(v, BaseProgress):
v.load_state_dict(state_dict[key])
elif (
isinstance(v, _ResultCollection)
and self.trainer is not None
and self.trainer.lightning_module is not None
):
elif isinstance(v, _ResultCollection) and trainer is not None and trainer.lightning_module is not None:
metric_attributes = {
name: module
for name, module in self.trainer.lightning_module.named_modules()
Expand Down
2 changes: 0 additions & 2 deletions src/pytorch_lightning/loops/optimization/manual_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,6 @@ def advance(self, kwargs: OrderedDict) -> None: # type: ignore[override]
Args:
kwargs: The kwargs passed down to the hooks.
"""
assert self.trainer is not None

kwargs = self._build_kwargs(kwargs, self._hiddens)

# manually capture logged metrics
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/overrides/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def __init__(self, pl_module: Union["pl.LightningModule", _LightningPrecisionMod

def forward(self, *inputs: Any, **kwargs: Any) -> Any:
pl_module = unwrap_lightning_module(self.module)
trainer = pl_module.trainer
trainer = pl_module._trainer

if trainer is not None:
if trainer.training:
Expand Down
1 change: 0 additions & 1 deletion src/pytorch_lightning/plugins/precision/apex_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ def backward(
closure_loss: the loss value obtained from the closure
optimizer: current optimizer being used. ``None`` if using manual optimization
"""
assert model.trainer is not None
opt = optimizer or model.trainer.optimizers
with amp.scale_loss(closure_loss, opt) as closure_loss:
super().backward(model, closure_loss, optimizer, *args, **kwargs)
Expand Down
2 changes: 0 additions & 2 deletions src/pytorch_lightning/plugins/precision/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ def backward(self, model: "pl.LightningModule", closure_loss: Tensor, *args: Any
"You have overridden the `LightningModule.backward` hook but it will be ignored since DeepSpeed handles"
" the backward logic internally."
)
assert model.trainer is not None
deepspeed_engine: "deepspeed.DeepSpeedEngine" = model.trainer.model
deepspeed_engine.backward(closure_loss, *args, **kwargs)

Expand Down Expand Up @@ -110,7 +109,6 @@ def optimizer_step(
# DeepSpeed handles the optimizer step internally
deepspeed_engine: "deepspeed.DeepSpeedEngine"
if isinstance(model, pl.LightningModule):
assert model.trainer is not None
deepspeed_engine = model.trainer.model
else:
deepspeed_engine = model
Expand Down
3 changes: 0 additions & 3 deletions src/pytorch_lightning/plugins/precision/precision_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ def pre_backward(self, model: "pl.LightningModule", closure_loss: Tensor) -> Ten
model: the model to be optimized
closure_loss: the loss value obtained from the closure
"""
assert model.trainer is not None
model.trainer._call_callback_hooks("on_before_backward", closure_loss)
model.trainer._call_lightning_module_hook("on_before_backward", closure_loss)
return closure_loss
Expand Down Expand Up @@ -90,7 +89,6 @@ def post_backward(self, model: "pl.LightningModule", closure_loss: Tensor) -> Te
"""
# once backward has been applied, release graph
closure_loss = closure_loss.detach()
assert model.trainer is not None
model.trainer._call_callback_hooks("on_after_backward")
model.trainer._call_lightning_module_hook("on_after_backward")
return closure_loss
Expand All @@ -110,7 +108,6 @@ def _after_closure(
# none of this applies to Lite
return
trainer = model.trainer
assert trainer is not None
trainer._call_callback_hooks("on_before_optimizer_step", optimizer, optimizer_idx)
trainer._call_lightning_module_hook("on_before_optimizer_step", optimizer, optimizer_idx)
# TODO: this is done for the entire model but should be changed to per-optimizer
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/serve/servable_module_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def on_train_start(self, trainer: "pl.Trainer", servable_module: "pl.LightningMo

# Note: The Trainer needs to be detached from the pl_module before starting the process.
# This would fail during the deepcopy with DDP.
servable_module.trainer = None
servable_module.trainer = None # type: ignore[assignment]

process = Process(target=self._start_server, args=(servable_module, self.host, self.port, self.optimization))
process.start()
Expand Down
1 change: 0 additions & 1 deletion src/pytorch_lightning/strategies/bagua.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,6 @@ def register_strategies(cls, strategy_registry: Dict) -> None:
def teardown(self) -> None:
# abort the background communication for async algorithm
assert self.lightning_module is not None
assert self.lightning_module.trainer is not None
if self.lightning_module.trainer.training and self._bagua_algorithm == "async":
self.model.bagua_algorithm.abort(self.model) # type: ignore

Expand Down
12 changes: 7 additions & 5 deletions src/pytorch_lightning/strategies/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,7 @@ def reconciliate_processes(self, trace: str) -> None:
def teardown(self) -> None:
log.detail(f"{self.__class__.__name__}: tearing down strategy")

pl_module = self.lightning_module
if isinstance(self.model, DistributedDataParallel):
if (
_TORCH_GREATER_EQUAL_1_11
Expand All @@ -464,15 +465,16 @@ def teardown(self) -> None:
f" pass `Trainer(..., strategy={self.__class__.__name__}(static_graph=True))` to enable them."
)
# unwrap model
self.model = self.lightning_module
self.model = pl_module

if (
self.lightning_module.trainer is not None
and self.lightning_module.trainer.state.fn == TrainerFn.FITTING
pl_module is not None
# `self.lightning_module._trainer` can be None if teardown gets called on an exception before
# the trainer gets set on the LightningModule
and pl_module._trainer is not None
and pl_module._trainer.state.fn == TrainerFn.FITTING
and self._layer_sync
):
# `self.lightning_module.trainer` can be None if teardown gets called on an exception before
# the trainer gets set on the LightningModule
self.model = self._layer_sync.revert(self.model)

super().teardown()
12 changes: 7 additions & 5 deletions src/pytorch_lightning/strategies/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,7 @@ def register_strategies(cls, strategy_registry: Dict) -> None:
def teardown(self) -> None:
log.detail(f"{self.__class__.__name__}: tearing down strategy")

pl_module = self.lightning_module
if isinstance(self.model, DistributedDataParallel):
if (
_TORCH_GREATER_EQUAL_1_11
Expand All @@ -312,14 +313,15 @@ def teardown(self) -> None:
f" pass `Trainer(..., strategy={self.__class__.__name__}(static_graph=True))` to enable them."
)
# unwrap model
self.model = self.lightning_module
self.model = pl_module

if (
self.lightning_module.trainer is not None
and self.lightning_module.trainer.state.fn == TrainerFn.FITTING
pl_module is not None
# `self.lightning_module._trainer` can be None if teardown gets called on an exception before
# the trainer gets set on the LightningModule
and pl_module._trainer is not None
and pl_module._trainer.state.fn == TrainerFn.FITTING
and self._layer_sync
):
# `self.lightning_module.trainer` can be None if teardown gets called on an exception before
# the trainer gets set on the LightningModule
self.model = self._layer_sync.revert(self.model)
super().teardown()
Loading

0 comments on commit c8a13de

Please sign in to comment.