Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyTorch: save/load models #1137

Merged
merged 2 commits into from
Oct 18, 2022
Merged

PyTorch: save/load models #1137

merged 2 commits into from
Oct 18, 2022

Conversation

patrick-wilken
Copy link
Contributor

".pt" file extension seems standard.
I think calling model.load_state_dict() in each epoch is the right way to do it if we really want get_model_func to be specific to the epoch. Although this probably has to be extended if the set of model parameters really changes between epochs.

The demo now carries over the model state between model epochs, saves model files, and can be interrupted and continued.

@patrick-wilken patrick-wilken requested review from a team and albertz as code owners September 29, 2022 17:05
returnn/torch/engine.py Outdated Show resolved Hide resolved
returnn/torch/engine.py Outdated Show resolved Hide resolved
@albertz

This comment was marked as resolved.

@albertz
Copy link
Member

albertz commented Sep 29, 2022

Although this probably has to be extended if the set of model parameters really changes between epochs.

This is important. We should avoid the GPU -> CPU and CPU -> GPU transfer when it is not needed.

Maybe the interface for get_model_func is bad? If this is difficult to do in practice, due to the interface of get_model_func, we should change that interface.

@patrick-wilken
Copy link
Contributor Author

patrick-wilken commented Sep 30, 2022

".pt" file extension seems standard.

Why does that matter? There is no point in hardcoding this. Just leave it up to the user.

With the current config options (model) the user cannot specify a file extension because the epoch (and possibly "pretrain") is added to the path. (i.e. model.pt.005 instead of model.005.pt). If we want to reuse the existing code to create and find the paths something like what I implemented is needed. Or we could say we don't need the file extension...

Regarding the model: I don't think it is practical to change an existing Module instance (add parameters, replace the forward function) between epochs. So if we want to support to have a new network each epoch this means we get a new Module instance, as it is currently implemented. (Maybe get_model_func could return None to indicate same as previous epoch? Then we could keep the instance.)
Of course, if we don't need that functionality we would create the Module instance once and load the parameters directly into it.

@albertz
Copy link
Member

albertz commented Oct 7, 2022

Ok then on the filename postfix.

Regarding the model, I was not saying at all that you should change an existing Module instance. Where did you read that?

Although, I think I would also leave that up to the user. E.g. there could be some on_new_epoch(model) callback, and the user could maybe just change some other hyper params, like dropout or do whatever.

So if we want to support to have a new network each epoch this means we get a new Module instance, as it is currently implemented.

There are other ways to implement it, this is not strictly necessary. This is why I questioned whether our get_model_func interface should maybe be changed.

(Maybe get_model_func could return None to indicate same as previous epoch? Then we could keep the instance.)

Yes, this is one example.

In any case, we should keep it efficient, and avoid introducing inefficiencies at all cost.

@albertz
Copy link
Member

albertz commented Oct 7, 2022

Regarding the proof of concept, maybe completely ignore such pretraining aspect for now, to keep it simple. So get_model_func would not get an epoch parameter.

@patrick-wilken
Copy link
Contributor Author

This would be the version without the epoch parameter, so calling get_model once and then keep the Module as attribute.

@@ -52,6 +57,8 @@ def train(self):
start_epoch, _ = self.get_train_start_epoch_batch(self.config)
final_epoch = self.config_get_final_epoch(self.config)

self._load_model(epoch=start_epoch)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should not be here but in init_train_from_config or init_network_from_config.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just like start_epoch etc. Maybe I just merge this for now.

@albertz albertz merged commit e691019 into master Oct 18, 2022
@albertz albertz deleted the pwilken_pytorch branch October 18, 2022 11:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants