Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix saving hyperparameters in a composition where parent is not a LM or LDM #14151

Merged
merged 6 commits into from
Aug 11, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed epoch-end logging results not being reset after the end of the epoch ([#14061](https://github.com/Lightning-AI/lightning/pull/14061))


- Fix saving hyperparameters in a composition where the parent class is not a `LightningModule` or `LightningDataModule` ([#14151](https://github.com/Lightning-AI/lightning/pull/14151))
awaelchli marked this conversation as resolved.
Show resolved Hide resolved



## [1.7.1] - 2022-08-09

### Fixed
Expand Down
17 changes: 12 additions & 5 deletions src/pytorch_lightning/utilities/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,14 +162,18 @@ def get_init_args(frame: types.FrameType) -> Dict[str, Any]:


def collect_init_args(
frame: types.FrameType, path_args: List[Dict[str, Any]], inside: bool = False
frame: types.FrameType,
path_args: List[Dict[str, Any]],
inside: bool = False,
classes: Tuple[Type, ...] = (object,),
carmocca marked this conversation as resolved.
Show resolved Hide resolved
) -> List[Dict[str, Any]]:
"""Recursively collects the arguments passed to the child constructors in the inheritance tree.

Args:
frame: the current stack frame
path_args: a list of dictionaries containing the constructor args in all parent classes
inside: track if we are inside inheritance path, avoid terminating too soon
classes: the classes in which to inspect the frames

Return:
A list of dictionaries where each dictionary contains the arguments passed to the
Expand All @@ -181,13 +185,13 @@ def collect_init_args(
if not isinstance(frame.f_back, types.FrameType):
return path_args

if "__class__" in local_vars:
if "__class__" in local_vars and issubclass(local_vars["__class__"], classes):
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
local_args = get_init_args(frame)
# recursive update
path_args.append(local_args)
return collect_init_args(frame.f_back, path_args, inside=True)
return collect_init_args(frame.f_back, path_args, inside=True, classes=classes)
if not inside:
return collect_init_args(frame.f_back, path_args, inside)
return collect_init_args(frame.f_back, path_args, inside, classes=classes)
return path_args


Expand Down Expand Up @@ -225,7 +229,10 @@ def save_hyperparameters(
init_args = {f.name: getattr(obj, f.name) for f in fields(obj)}
else:
init_args = {}
for local_args in collect_init_args(frame, []):

from pytorch_lightning.core.mixins import HyperparametersMixin

for local_args in collect_init_args(frame, [], classes=(HyperparametersMixin,)):
init_args.update(local_args)

if ignore is None:
Expand Down
19 changes: 19 additions & 0 deletions tests/tests_pytorch/models/test_hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.core.datamodule import LightningDataModule
from pytorch_lightning.core.mixins import HyperparametersMixin
from pytorch_lightning.core.saving import load_hparams_from_yaml, save_hparams_to_yaml
from pytorch_lightning.demos.boring_classes import BoringDataModule, BoringModel, RandomDataset
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, AttributeDict, is_picklable
Expand Down Expand Up @@ -399,6 +400,24 @@ def _raw_checkpoint_path(trainer) -> str:
return raw_checkpoint_path


@pytest.mark.parametrize("base_class", (HyperparametersMixin, LightningModule, LightningDataModule))
def test_save_hyperparameters_under_composition(base_class):
"""Test that in a composition where the parent is not a Lightning-like module, the parent's arguments don't get
collected."""

class ChildInComposition(base_class):
def __init__(self, same_arg):
super().__init__()
self.save_hyperparameters()

class NotPLSubclass: # intentionally not subclassing LightningModule/LightningDataModule
def __init__(self, same_arg="parent_default", other_arg="other"):
self.child = ChildInComposition(same_arg="cocofruit")

parent = NotPLSubclass()
assert parent.child.hparams == dict(same_arg="cocofruit")


class LocalVariableModelSuperLast(BoringModel):
"""This model has the super().__init__() call at the end."""

Expand Down