From 5faae27ce744f2ebde92c637563280b70435ff10 Mon Sep 17 00:00:00 2001 From: Bryn Lloyd Date: Thu, 3 Nov 2022 13:42:10 +0100 Subject: [PATCH] revert removing _load_from_checkpoint to preserve doc --- src/pytorch_lightning/core/datamodule.py | 7 +++--- src/pytorch_lightning/core/saving.py | 18 +++++++------- .../test_lightning_module.py | 24 ------------------- 3 files changed, 12 insertions(+), 37 deletions(-) delete mode 100644 tests/tests_type_checking/test_lightning_module.py diff --git a/src/pytorch_lightning/core/datamodule.py b/src/pytorch_lightning/core/datamodule.py index 5cd1713e2ffba..21ee283d31d7d 100644 --- a/src/pytorch_lightning/core/datamodule.py +++ b/src/pytorch_lightning/core/datamodule.py @@ -14,9 +14,10 @@ """LightningDataModule for loading DataLoaders with ease.""" import inspect from argparse import ArgumentParser, Namespace -from typing import Any, Dict, IO, List, Mapping, Optional, Sequence, Tuple, Union +from typing import Any, Dict, IO, List, Mapping, Optional, Sequence, Tuple, Type, Union from torch.utils.data import DataLoader, Dataset, IterableDataset +from typing_extensions import Self import pytorch_lightning as pl from lightning_lite.utilities.types import _PATH @@ -214,11 +215,11 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: @classmethod def load_from_checkpoint( - cls, + cls: Type[Self], # type: ignore [valid-type] checkpoint_path: Union[_PATH, IO], hparams_file: Optional[_PATH] = None, **kwargs: Any, - ) -> Union["pl.LightningModule", "pl.LightningDataModule"]: + ) -> Self: # type: ignore [valid-type] r""" Primary way of loading a datamodule from a checkpoint. When Lightning saves a checkpoint it stores the arguments passed to ``__init__`` in the checkpoint under ``"datamodule_hyper_parameters"``. diff --git a/src/pytorch_lightning/core/saving.py b/src/pytorch_lightning/core/saving.py index 4070f154c4d6d..2bd6312eff546 100644 --- a/src/pytorch_lightning/core/saving.py +++ b/src/pytorch_lightning/core/saving.py @@ -21,11 +21,12 @@ from copy import deepcopy from enum import Enum from pathlib import Path -from typing import Any, Callable, cast, Dict, IO, MutableMapping, Optional, Type, TypeVar, Union +from typing import Any, Callable, cast, Dict, IO, MutableMapping, Optional, Type, Union from warnings import warn import yaml from lightning_utilities.core.apply_func import apply_to_collection +from typing_extensions import Self import pytorch_lightning as pl from lightning_lite.utilities.cloud_io import _load as pl_load @@ -49,9 +50,6 @@ # the older shall be on the top CHECKPOINT_PAST_HPARAMS_KEYS = ("hparams", "module_arguments") # used in 0.7.6 -LM = TypeVar("LM", bound="pl.LightningModule") -LDM = TypeVar("LDM", bound="pl.LightningDataModule") - class ModelIO: CHECKPOINT_HYPER_PARAMS_KEY = "hyper_parameters" @@ -60,13 +58,13 @@ class ModelIO: @classmethod def load_from_checkpoint( - cls: Union[Type["ModelIO"], Type[LM], Type[LDM]], + cls: Type[Self], # type: ignore [valid-type] checkpoint_path: Union[str, IO], map_location: _MAP_LOCATION_TYPE = None, hparams_file: Optional[str] = None, strict: bool = True, **kwargs: Any, - ) -> Union[LM, LDM]: + ) -> Self: # type: ignore [valid-type] r""" Primary way of loading a model from a checkpoint. When Lightning saves a checkpoint it stores the arguments passed to ``__init__`` in the checkpoint under ``"hyper_parameters"``. @@ -149,13 +147,13 @@ def load_from_checkpoint( def _load_from_checkpoint( - cls: Union[Type["ModelIO"], Type[LM], Type[LDM]], + cls: Union[Type["pl.LightningModule"], Type["pl.LightningDataModule"]], checkpoint_path: Union[_PATH, IO], map_location: _MAP_LOCATION_TYPE = None, hparams_file: Optional[_PATH] = None, strict: Optional[bool] = None, **kwargs: Any, -) -> Union[LM, LDM]: +) -> Union["pl.LightningModule", "pl.LightningDataModule"]: if map_location is None: map_location = cast(_MAP_LOCATION_TYPE, lambda storage, loc: storage) with pl_legacy_patch(): @@ -185,9 +183,9 @@ def _load_from_checkpoint( checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY].update(kwargs) if issubclass(cls, pl.LightningDataModule): - return cast(LDM, _load_state(cls, checkpoint, **kwargs)) + return _load_state(cls, checkpoint, **kwargs) if issubclass(cls, pl.LightningModule): - return cast(LM, _load_state(cls, checkpoint, strict=strict, **kwargs)) + return _load_state(cls, checkpoint, strict=strict, **kwargs) raise NotImplementedError(f"Unsupported {cls}") diff --git a/tests/tests_type_checking/test_lightning_module.py b/tests/tests_type_checking/test_lightning_module.py deleted file mode 100644 index 1591e64889db2..0000000000000 --- a/tests/tests_type_checking/test_lightning_module.py +++ /dev/null @@ -1,24 +0,0 @@ -from pathlib import Path - -from pytorch_lightning import Trainer -from pytorch_lightning.demos.boring_classes import BoringModel - - -def test_load_from_checkpoint_type(tmp_path: Path) -> None: - class MyModule(BoringModel): - def __init__(self, some_parameter: int): - super().__init__() - self.save_hyperparameters() - - @property - def parameter(self) -> int: - return self.hparams.some_parameter - - net = MyModule(some_parameter=42) - trainer = Trainer(default_root_dir=str(tmp_path), fast_dev_run=True) - trainer.fit(net) - checkpoint_path = str(tmp_path / "model.pt") - trainer.save_checkpoint(checkpoint_path) - - net_loaded = MyModule.load_from_checkpoint(checkpoint_path) # type: ignore - assert net_loaded.parameter == 42