-
Notifications
You must be signed in to change notification settings - Fork 130
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PyTorch: save/load models #1137
Conversation
This comment was marked as resolved.
This comment was marked as resolved.
This is important. We should avoid the GPU -> CPU and CPU -> GPU transfer when it is not needed. Maybe the interface for |
With the current config options ( Regarding the model: I don't think it is practical to change an existing |
Ok then on the filename postfix. Regarding the model, I was not saying at all that you should change an existing Although, I think I would also leave that up to the user. E.g. there could be some
There are other ways to implement it, this is not strictly necessary. This is why I questioned whether our
Yes, this is one example. In any case, we should keep it efficient, and avoid introducing inefficiencies at all cost. |
Regarding the proof of concept, maybe completely ignore such pretraining aspect for now, to keep it simple. So |
This would be the version without the |
@@ -52,6 +57,8 @@ def train(self): | |||
start_epoch, _ = self.get_train_start_epoch_batch(self.config) | |||
final_epoch = self.config_get_final_epoch(self.config) | |||
|
|||
self._load_model(epoch=start_epoch) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should not be here but in init_train_from_config
or init_network_from_config
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just like start_epoch
etc. Maybe I just merge this for now.
".pt" file extension seems standard.
I think calling
model.load_state_dict()
in each epoch is the right way to do it if we really wantget_model_func
to be specific to the epoch. Although this probably has to be extended if the set of model parameters really changes between epochs.The demo now carries over the model state between model epochs, saves model files, and can be interrupted and continued.