Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

preload_from_files for PT engine #1292

Merged
merged 10 commits into from
Apr 3, 2023
Merged

preload_from_files for PT engine #1292

merged 10 commits into from
Apr 3, 2023

Conversation

vieting
Copy link
Contributor

@vieting vieting commented Mar 28, 2023

As discussed in #1120, preload_from_files or something equivalent should be added for the PT engine as well. This is certainly not complete, but could be helpful as a starting point. It works as a proof-of-concept to load a wav2vec 2.0 checkpoint.

What do you think in general?

@albertz

This comment was marked as resolved.

@vieting

This comment was marked as resolved.

@albertz

This comment was marked as resolved.

@albertz

This comment was marked as resolved.

@vieting vieting marked this pull request as ready for review March 29, 2023 16:08
@vieting vieting requested review from a team and albertz as code owners March 29, 2023 16:08
returnn/torch/engine.py Outdated Show resolved Hide resolved
@albertz
Copy link
Member

albertz commented Mar 30, 2023

We also should make the order, or actually the preference for var loading consistent to TF. In TF, we first go through preload_from_files, and those entries sorted by key name, and then afterwards to the normal loading. However, every variable we will load, we will mark (see set_as_custom_init), and this makes sure that any subsequent loading will not load it anymore. So it means the values from where it is loaded first, they stay.

Now, you don't have any such logic in PT. This effectively means, some will get loaded multiple times, and also, the values from where it is loaded last, they stay. So the opposite order. I think then you need to do the normal loading first, and then you need to iterate over reversed(sorted(preload_from_files.items())).

I'm not sure if I'm maybe missing sth else here.

@vieting
Copy link
Contributor Author

vieting commented Mar 30, 2023

I reversed the order for the keys in preload_from_files. The normal loading from existing checkpoints for epoch > 1 should stay first, so it will be overwritten by preload_from_files in case that we are not in training, right?

Copy link
Member

@albertz albertz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks ok to me now, despite my last comment. I hope I did not miss anything. As said, I think consistency to the TF logic is important.

@patrick-wilken or @JackTemaki or someone should also review.

returnn/torch/engine.py Outdated Show resolved Hide resolved
@patrick-wilken
Copy link
Contributor

patrick-wilken commented Mar 30, 2023

By the way, what defines the parameter names in the frontend? That's what rf.Module.named_parameters() / rf.Module.named_parameters() seem to do, right? And it uses the attribute names of the modules. So to add the prefix we are talking about here you would rename the Module attribute? 😕

@albertz
Copy link
Member

albertz commented Mar 30, 2023

So to add the prefix we are talking about here you would rename the Module attribute?

No, you would simply put it into a submodule. Where you probably have the model anyway. E.g. if you have trained a LM, using the module TransformerLm, now you would create TransformerLm as a submodule into your main model, like:

class Model(nn.Module):
  def __init__(self):
    super().__init__()
    ...
    self.ext_lm = TransformerLm(...)

In that example, the prefix is simply ext_lm..

returnn/torch/engine.py Outdated Show resolved Hide resolved
@albertz albertz merged commit 1732a0b into master Apr 3, 2023
@albertz albertz deleted the torch_preload branch April 3, 2023 13:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants