Skip to content

Commit

Permalink
Fix saving hyperparameters in a composition where parent is not a LM …
Browse files Browse the repository at this point in the history
…or LDM (#14151)



Co-authored-by: Rohit Gupta <[email protected]>
  • Loading branch information
awaelchli and rohitgr7 authored Aug 11, 2022
1 parent 98ded45 commit 3b18da3
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 5 deletions.
4 changes: 4 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed epoch-end logging results not being reset after the end of the epoch ([#14061](https://github.com/Lightning-AI/lightning/pull/14061))


- Fixed saving hyperparameters in a composition where the parent class is not a `LightningModule` or `LightningDataModule` ([#14151](https://github.com/Lightning-AI/lightning/pull/14151))



## [1.7.1] - 2022-08-09

### Fixed
Expand Down
17 changes: 12 additions & 5 deletions src/pytorch_lightning/utilities/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,14 +162,18 @@ def get_init_args(frame: types.FrameType) -> Dict[str, Any]:


def collect_init_args(
frame: types.FrameType, path_args: List[Dict[str, Any]], inside: bool = False
frame: types.FrameType,
path_args: List[Dict[str, Any]],
inside: bool = False,
classes: Tuple[Type, ...] = (),
) -> List[Dict[str, Any]]:
"""Recursively collects the arguments passed to the child constructors in the inheritance tree.
Args:
frame: the current stack frame
path_args: a list of dictionaries containing the constructor args in all parent classes
inside: track if we are inside inheritance path, avoid terminating too soon
classes: the classes in which to inspect the frames
Return:
A list of dictionaries where each dictionary contains the arguments passed to the
Expand All @@ -181,13 +185,13 @@ def collect_init_args(
if not isinstance(frame.f_back, types.FrameType):
return path_args

if "__class__" in local_vars:
if "__class__" in local_vars and (not classes or issubclass(local_vars["__class__"], classes)):
local_args = get_init_args(frame)
# recursive update
path_args.append(local_args)
return collect_init_args(frame.f_back, path_args, inside=True)
return collect_init_args(frame.f_back, path_args, inside=True, classes=classes)
if not inside:
return collect_init_args(frame.f_back, path_args, inside)
return collect_init_args(frame.f_back, path_args, inside, classes=classes)
return path_args


Expand Down Expand Up @@ -225,7 +229,10 @@ def save_hyperparameters(
init_args = {f.name: getattr(obj, f.name) for f in fields(obj)}
else:
init_args = {}
for local_args in collect_init_args(frame, []):

from pytorch_lightning.core.mixins import HyperparametersMixin

for local_args in collect_init_args(frame, [], classes=(HyperparametersMixin,)):
init_args.update(local_args)

if ignore is None:
Expand Down
19 changes: 19 additions & 0 deletions tests/tests_pytorch/models/test_hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.core.datamodule import LightningDataModule
from pytorch_lightning.core.mixins import HyperparametersMixin
from pytorch_lightning.core.saving import load_hparams_from_yaml, save_hparams_to_yaml
from pytorch_lightning.demos.boring_classes import BoringDataModule, BoringModel, RandomDataset
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, AttributeDict, is_picklable
Expand Down Expand Up @@ -399,6 +400,24 @@ def _raw_checkpoint_path(trainer) -> str:
return raw_checkpoint_path


@pytest.mark.parametrize("base_class", (HyperparametersMixin, LightningModule, LightningDataModule))
def test_save_hyperparameters_under_composition(base_class):
"""Test that in a composition where the parent is not a Lightning-like module, the parent's arguments don't get
collected."""

class ChildInComposition(base_class):
def __init__(self, same_arg):
super().__init__()
self.save_hyperparameters()

class NotPLSubclass: # intentionally not subclassing LightningModule/LightningDataModule
def __init__(self, same_arg="parent_default", other_arg="other"):
self.child = ChildInComposition(same_arg="cocofruit")

parent = NotPLSubclass()
assert parent.child.hparams == dict(same_arg="cocofruit")


class LocalVariableModelSuperLast(BoringModel):
"""This model has the super().__init__() call at the end."""

Expand Down

0 comments on commit 3b18da3

Please sign in to comment.