Skip to content

Commit

Permalink
revert removing _load_from_checkpoint to preserve doc
Browse files Browse the repository at this point in the history
  • Loading branch information
dyollb committed Nov 3, 2022
1 parent 9a1e2ee commit 5faae27
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 37 deletions.
7 changes: 4 additions & 3 deletions src/pytorch_lightning/core/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
"""LightningDataModule for loading DataLoaders with ease."""
import inspect
from argparse import ArgumentParser, Namespace
from typing import Any, Dict, IO, List, Mapping, Optional, Sequence, Tuple, Union
from typing import Any, Dict, IO, List, Mapping, Optional, Sequence, Tuple, Type, Union

from torch.utils.data import DataLoader, Dataset, IterableDataset
from typing_extensions import Self

import pytorch_lightning as pl
from lightning_lite.utilities.types import _PATH
Expand Down Expand Up @@ -214,11 +215,11 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:

@classmethod
def load_from_checkpoint(
cls,
cls: Type[Self], # type: ignore [valid-type]
checkpoint_path: Union[_PATH, IO],
hparams_file: Optional[_PATH] = None,
**kwargs: Any,
) -> Union["pl.LightningModule", "pl.LightningDataModule"]:
) -> Self: # type: ignore [valid-type]
r"""
Primary way of loading a datamodule from a checkpoint. When Lightning saves a checkpoint
it stores the arguments passed to ``__init__`` in the checkpoint under ``"datamodule_hyper_parameters"``.
Expand Down
18 changes: 8 additions & 10 deletions src/pytorch_lightning/core/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,12 @@
from copy import deepcopy
from enum import Enum
from pathlib import Path
from typing import Any, Callable, cast, Dict, IO, MutableMapping, Optional, Type, TypeVar, Union
from typing import Any, Callable, cast, Dict, IO, MutableMapping, Optional, Type, Union
from warnings import warn

import yaml
from lightning_utilities.core.apply_func import apply_to_collection
from typing_extensions import Self

import pytorch_lightning as pl
from lightning_lite.utilities.cloud_io import _load as pl_load
Expand All @@ -49,9 +50,6 @@
# the older shall be on the top
CHECKPOINT_PAST_HPARAMS_KEYS = ("hparams", "module_arguments") # used in 0.7.6

LM = TypeVar("LM", bound="pl.LightningModule")
LDM = TypeVar("LDM", bound="pl.LightningDataModule")


class ModelIO:
CHECKPOINT_HYPER_PARAMS_KEY = "hyper_parameters"
Expand All @@ -60,13 +58,13 @@ class ModelIO:

@classmethod
def load_from_checkpoint(
cls: Union[Type["ModelIO"], Type[LM], Type[LDM]],
cls: Type[Self], # type: ignore [valid-type]
checkpoint_path: Union[str, IO],
map_location: _MAP_LOCATION_TYPE = None,
hparams_file: Optional[str] = None,
strict: bool = True,
**kwargs: Any,
) -> Union[LM, LDM]:
) -> Self: # type: ignore [valid-type]
r"""
Primary way of loading a model from a checkpoint. When Lightning saves a checkpoint
it stores the arguments passed to ``__init__`` in the checkpoint under ``"hyper_parameters"``.
Expand Down Expand Up @@ -149,13 +147,13 @@ def load_from_checkpoint(


def _load_from_checkpoint(
cls: Union[Type["ModelIO"], Type[LM], Type[LDM]],
cls: Union[Type["pl.LightningModule"], Type["pl.LightningDataModule"]],
checkpoint_path: Union[_PATH, IO],
map_location: _MAP_LOCATION_TYPE = None,
hparams_file: Optional[_PATH] = None,
strict: Optional[bool] = None,
**kwargs: Any,
) -> Union[LM, LDM]:
) -> Union["pl.LightningModule", "pl.LightningDataModule"]:
if map_location is None:
map_location = cast(_MAP_LOCATION_TYPE, lambda storage, loc: storage)
with pl_legacy_patch():
Expand Down Expand Up @@ -185,9 +183,9 @@ def _load_from_checkpoint(
checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY].update(kwargs)

if issubclass(cls, pl.LightningDataModule):
return cast(LDM, _load_state(cls, checkpoint, **kwargs))
return _load_state(cls, checkpoint, **kwargs)
if issubclass(cls, pl.LightningModule):
return cast(LM, _load_state(cls, checkpoint, strict=strict, **kwargs))
return _load_state(cls, checkpoint, strict=strict, **kwargs)
raise NotImplementedError(f"Unsupported {cls}")


Expand Down
24 changes: 0 additions & 24 deletions tests/tests_type_checking/test_lightning_module.py

This file was deleted.

0 comments on commit 5faae27

Please sign in to comment.