Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

load_from_checkpoint returns the expected type #15496

Merged
merged 9 commits into from
Nov 4, 2022
3 changes: 2 additions & 1 deletion src/pytorch_lightning/core/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import Any, Dict, IO, List, Mapping, Optional, Sequence, Tuple, Union

from torch.utils.data import DataLoader, Dataset, IterableDataset
from typing_extensions import Self

import pytorch_lightning as pl
from lightning_lite.utilities.types import _PATH
Expand Down Expand Up @@ -218,7 +219,7 @@ def load_from_checkpoint(
checkpoint_path: Union[_PATH, IO],
hparams_file: Optional[_PATH] = None,
**kwargs: Any,
) -> Union["pl.LightningModule", "pl.LightningDataModule"]:
) -> Self: # type: ignore [valid-type]
dyollb marked this conversation as resolved.
Show resolved Hide resolved
r"""
Primary way of loading a datamodule from a checkpoint. When Lightning saves a checkpoint
it stores the arguments passed to ``__init__`` in the checkpoint under ``"datamodule_hyper_parameters"``.
Expand Down
3 changes: 2 additions & 1 deletion src/pytorch_lightning/core/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

import yaml
from lightning_utilities.core.apply_func import apply_to_collection
from typing_extensions import Self

import pytorch_lightning as pl
from lightning_lite.utilities.cloud_io import _load as pl_load
Expand Down Expand Up @@ -63,7 +64,7 @@ def load_from_checkpoint(
hparams_file: Optional[str] = None,
strict: bool = True,
**kwargs: Any,
) -> Union["pl.LightningModule", "pl.LightningDataModule"]:
) -> Self: # type: ignore [valid-type]
dyollb marked this conversation as resolved.
Show resolved Hide resolved
r"""
Primary way of loading a model from a checkpoint. When Lightning saves a checkpoint
it stores the arguments passed to ``__init__`` in the checkpoint under ``"hyper_parameters"``.
Expand Down