Skip to content

Commit

Permalink
Improve warning for unpickable hyperparameter (#19581)
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli authored Mar 6, 2024
1 parent b871f7a commit f23b3b1
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 4 deletions.
5 changes: 4 additions & 1 deletion src/lightning/pytorch/utilities/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,10 @@ def clean_namespace(hparams: MutableMapping) -> None:
del_attrs = [k for k, v in hparams.items() if not is_picklable(v)]

for k in del_attrs:
rank_zero_warn(f"attribute '{k}' removed from hparams because it cannot be pickled")
rank_zero_warn(
f"Attribute '{k}' removed from hparams because it cannot be pickled. You can suppress this warning by"
f" setting `self.save_hyperparameters(ignore=['{k}'])`.",
)
del hparams[k]


Expand Down
13 changes: 10 additions & 3 deletions tests/tests_pytorch/models/test_hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,19 +528,26 @@ def test_hparams_pickle(tmpdir):
class UnpickleableArgsBoringModel(BoringModel):
"""A model that has an attribute that cannot be pickled."""

def __init__(self, foo="bar", pickle_me=(lambda x: x + 1), **kwargs):
def __init__(self, foo="bar", pickle_me=(lambda x: x + 1), ignore=False, **kwargs):
super().__init__(**kwargs)
assert not is_picklable(pickle_me)
self.save_hyperparameters()
if ignore:
self.save_hyperparameters(ignore=["pickle_me"])
else:
self.save_hyperparameters()


def test_hparams_pickle_warning(tmpdir):
model = UnpickleableArgsBoringModel()
trainer = Trainer(default_root_dir=tmpdir, max_steps=1)
with pytest.warns(UserWarning, match="attribute 'pickle_me' removed from hparams because it cannot be pickled"):
with pytest.warns(UserWarning, match="Attribute 'pickle_me' removed from hparams because it cannot be pickled"):
trainer.fit(model)
assert "pickle_me" not in model.hparams

model = UnpickleableArgsBoringModel(ignore=True)
with no_warning_call(UserWarning, match="Attribute 'pickle_me' removed from hparams because it cannot be pickled"):
trainer.fit(model)


def test_hparams_save_yaml(tmpdir):
class Options(str, Enum):
Expand Down

0 comments on commit f23b3b1

Please sign in to comment.