Warning during save_hyperparameter() gives misleading advice? #13615
-
I try to understand / rectify a warning about saving my hyper parameters and would need some assistance please. I build a model this way: from pytorch_lightning.core.mixins import HyperparametersMixin
class MyModel(nn.Module, HyperparametersMixin):
def __init__(...):
super().__init__()
self.save_hyperparameters() # Logs to self.hparams only, not to the logger (since there isn't any yet)
class MyModule(pl.LightningModule):
def __init__(model: nn.Module):
super().__init__()
self.save_hyperparameters("model", logger=False)
self.save_hyperparameters()
self.save_hyperparameters(model.hparams)
model = MyModel(...)
module = MyModule(model) This works, and I can load a checkpoint with But I also get this warning during the initialization of So I change the corresponding code in self.save_hyperparameters(ignore=["model"])
self.save_hyperparameters(model.hparams) The created checkpoint is marginally reduced by ~3KB, the checkpoint size is ~1MB. But when I want to load the checkpoint I get this error: File "/Users/stephan/Library/Caches/pypoetry/virtualenvs/molgen-6oMP0hTK-py3.9/lib/python3.9/site-packages/pytorch_lightning/core/saving.py", line 161, in load_from_checkpoint
model = cls._load_model_state(checkpoint, strict=strict, **kwargs)
File "/Users/stephan/Library/Caches/pypoetry/virtualenvs/molgen-6oMP0hTK-py3.9/lib/python3.9/site-packages/pytorch_lightning/core/saving.py", line 203, in _load_model_state
model = cls(**_cls_kwargs)
TypeError: __init__() missing 1 required keyword-only argument: 'model' Which seems to indicate that I need to save What am I missing? What is the correct way to save the hyper parameters / model in the |
Beta Was this translation helpful? Give feedback.
Replies: 8 comments 6 replies
-
the attributes that are not saved as hparams need to be passed explicitly. Considering you are using If you include it in the hparams, your checkpoints will be unnecessarily big and can create issues if you have large models. By
it means the model weights are already saved in the checkpoint and are loaded using PyTorch API, not as hparams. |
Beta Was this translation helpful? Give feedback.
-
Hi @rohitgr7, Thank you very much for your swift response. Understood, I think. Would you mind checking my understanding? I have 2 concerns (1) I changed the code to load the checkpoint as follows, please refer to the assumption in the comment at the end of the code snippet. # Load checkpoint and move it to CPU
checkpoint = torch.load(model_path_to_load_from, map_location=torch.device("cpu"))
# Leverage the logged hyper params to build an (untrained) model
# The function build_model returns an instance of nn.Module
hparams = checkpoint["hyper_parameters"]
model_config = {
"vocab_size": hparams["vocab_size"],
"embed_dim": hparams["embed_dim"],
...
}
model = build_model(
hparams["architecture"], hparams["variant"], **model_config
)
# ***** I think I do not need to manually load the state_dict, True? *****
# If I need to, the issue is that the state_dict contains keys such as "model.linear.bias",
# but load_state_dict() expects "linear.bias"
# model.load_state_dict(checkpoint["state_dict"])
# Finally, load the pl module
pl_model = MoleculeGenerator.load_from_checkpoint(model_path_to_load_from, model=model) (2) Besides the model, I pass a loss function to the pl Module, which is also an nn.Module. Do I need to / do you recommend treating it as the model, i.e., load / configure it manually from the checkpoint? I guess the question is, should I exclude nn.Module just as a precaution if it gets too large (and the loss function parameters probably don't) or is it sort of “forbidden” to log nn.Modules via save_hyperparameters()? |
Beta Was this translation helpful? Give feedback.
-
I'd not recommend this way. Since you are loading the actual model using the hparams, you should load it from within the LightningModule's init class MyModule(LightningModule):
def __init__(...):
super().__init__()
self.save_hyperparameters(...)
model = build_model(self.hparams...) just wondering, how did you load the model without the checkpoint if you need hparams to load them.
it's actually hard for us to determine which one is a loss module or which one is a model. So we check if an object is an instance of |
Beta Was this translation helpful? Give feedback.
-
I hope you still have some time/energy to stay with me, I really would like to do it properly. That's getting trickier than I thought. For context: I am doing experiments and need to combine different datasets, models, encodings, … to assess the overall/combined performance. So, I am trying to make a generic wrapper that I can plug those components into. Which is the background for my comment:
I think I understand your reasoning, but since I have different model architectures, I have a different set of hparams for each architecture. So, I would need to pass the superset of those hparams to the
Maybe not correctly until now? Until now, I saved the As a summary of my understanding, I have the following options:
If you can think of a fourth "best of all worlds" I am happy to hear it of course. Thanks for your insights. |
Beta Was this translation helpful? Give feedback.
-
personally I'd recommend this.
if you are concerned with the huge number of arguments, you can use namespaces. because in all other cases, you might have to reload the checkpoint manually to initialize the |
Beta Was this translation helpful? Give feedback.
-
Hello,
In this way, each checkpoint saves the hyperparameters needed for each model and I only instantiate the model inside the LightningModule. |
Beta Was this translation helpful? Give feedback.
-
I also encountered the same issue..The argument of
This is not consistent with the recommendation from the warning message and causes confusion. |
Beta Was this translation helpful? Give feedback.
-
I also use Hydra for configuration. My model configurations may contain nested In my opinion the solution by @BrunoBelucci is the cleanest, but unfortunately it doesn't work with nested I tried lots of different solutions. A solution similar to the one described by @hogru was promising - a main Finally I found a really simple solution. Let's consider an example where the model configuration includes a backbone. @dataclass
class MyBackboneConfig:
_target_: str = "backbones.MyBackbone"
depth: int = 8
@dataclass
class MyModelConfig:
_target_: str = "models.MyModel"
backbone: MyBackboneConfig = MyBackboneConfig() When I instantiate the model, I don't instantiate the arguments recursively (see the Hydra option from hydra.utils import instantiate
model = instantiate(model_cfg, _recursive_=False) The model constructor instantiates any OmegaConf arguments recursively. class MyModel(LightningModule):
def __init__(self, backbone):
super().__init__()
self.save_hyperparameters()
self.backbone = instantiate(backbone) When loading a checkpoint, it's possible to use a Hydra configuration and not load hyperparameters from the checkpoint like this: from hydra.utils import call
model_cfg._target_ += ".load_from_checkpoint"
model = call(model_cfg, checkpoint_path, _recursive_=False) It's also possible to use the hyperparameters from the checkpoint and not override anything like this: from hydra._internal.utils import _locate
model_class = _locate(model_cfg._target_)
model = model_class.load_from_checkpoint(checkpoint_path) One should just be aware that it's not possible to load the hyperparameters from a checkpoint and override some of the nested hyperparameters (e.g. override only the |
Beta Was this translation helpful? Give feedback.
the attributes that are not saved as hparams need to be passed explicitly. Considering you are using
load_from_checkpoint
API, you can usemodel = MyModule.load_from_checkpoint(ckpt_path, model=model)
.If you include it in the hparams, your checkpoints will be unnecessarily big and can create issues if you have large models.
By
it means the model weights are already saved in the checkpoint and are loaded using PyTorch API, not as hparams.