Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[model loading] don't init weights for pretrained models #11463

Closed
wants to merge 2 commits into from

Conversation

stas00
Copy link
Contributor

@stas00 stas00 commented Apr 27, 2021

Skip _init_weights for pretrained models since they get immediately replaced by pretrained weights. This leads to a much faster startup for huge models.

Fixes: #9205

@sgugger, @patrickvonplaten

@stas00 stas00 changed the title [WIP] [model loading] don't init weights for pretrained models [model loading] don't init weights for pretrained models Apr 27, 2021
Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This won't work sadly, for two reasons.

  1. First, everything in the __dict__ of the config gets serialized when one uses config.self_pretrained() (which is called by model.from_pretrained) so any other model downloaded from the hub with a checkpoint saved after this is merged will get this attribute in the config. Then if a user instantiates a randomly-initialized model using the config, with the following code:
config = AutoConfig.from_pretrained("new_checkpoint_after_this_is_merge")
model = AutoModel.from_config(config)

then the model won't be randomly initalized (at least not with _init_weights) since the config will have this use_pretrained_weights.

  1. Then come the problem that pretrained model instantiated with from_pretrained does not necessarily have all weights initialized (if you discard the head to put another task-specific head) and this PR will break the way those weights are randomly initialized.

I sadly don't see a way around passing around a list of not-initialized weights from pretrained to the _init_weights function

@stas00
Copy link
Contributor Author

stas00 commented Apr 27, 2021

This won't work sadly, for two reasons.

  1. First, everything in the __dict__ of the config gets serialized when one uses config.self_pretrained() (which is called by model.from_pretrained) so any other model downloaded from the hub with a checkpoint saved after this is merged will get this attribute in the config. Then if a user instantiates a randomly-initialized model using the config, with the following code:
config = AutoConfig.from_pretrained("new_checkpoint_after_this_is_merge")
model = AutoModel.from_config(config)

then the model won't be randomly initalized (at least not with _init_weights) since the config will have this use_pretrained_weights.

So if I find another way to do it that doesn't taint the config then it's OK, right? (as far as config correctness goes)

e.g. what if I unset this config value as soon as model = cls() is done? So this is sort of a "context" operation then.

  1. Then come the problem that pretrained model instantiated with from_pretrained does not necessarily have all weights initialized (if you discard the head to put another task-specific head) and this PR will break the way those weights are randomly initialized.

I sadly don't see a way around passing around a list of not-initialized weights from pretrained to the _init_weights function

I appreciate that you could think of the edge cases.

Clearly, we don't have any tests that somehow verify that the init is done correctly. I was hoping that there would be some, but these would be hard to conjure.

If you feel this is a worthwhile effort, perhaps let's start coming up with examples, write tests if possible and solve those? You can throw the edge-cases at me and I will try to overcome those.

Or alternatively, we provide a very easy way for users to either force the init, or if it's safer to force no-init? e.g. the staple examples could all enforce no-init and explain how to change that if the user wants to modify the example to have the original behavior?

So what I'm suggesting is that instead of from_pretrained automatically forcing no init as I proposed in this PR, we instead have a way for a user to choose whether they want init_weights or not explicitly?

@@ -769,7 +769,10 @@ def init_weights(self):
Initializes and prunes weights if needed.
"""
# Initialize weights
self.apply(self._init_weights)
if getattr(self.config, "use_pretrained_weights", False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • I'm not a huge fan of "attaching" a new parameter to the config which is not really understandable by the user.

  • Also, I think this could lead to problems -> lots of people initialize all weights except the final layer weights from a pre-trained BERT in, e.g. a BertForSequenceClassification. The logic would then not correctly initialize the final layer, but simply set everything to zero which would probably lead to a worse fine-tuning of BertForSequenceClassification.

=> I would propose the following:

  1. When using from_pretrained(...), we pass a new parameter to model = cls(config, *model_args, **model_kwargs) by setting model_kwargs["init_weights"] = False. This then sadly means that we have to replace all __init__(self, config) functions in the modeling files by __init__(self, config, init_weights=True), but I think we can use a regex for this. This is a huge change in terms of files that need to be changed, but I think it's cleaner then creating a new "use_pretrained_weights" config parameter that the user shouldn't have to learn about. Then, we also need to change self.init_weights() with
if init_weights:
   self.init_weights()
  1. Now, we also need to take care of cases where only parts of the model are initialized from pre-trained weights. The other part still needs to be initialized. Here we can't simply use self.init_weights() because it would necessarly run through all modules and initialized them. So I think we should leverage the missing_keys() list here to extract all nn.Modules(...) that still need to be initialized and then run self._init_weights(m) for m in uninitialized_modules

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think ? @stas00

Also keen to hear @LysandreJik's and @sgugger's opinion here

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The init_weights kwarg by itself will not work as it doesn't deal with 2. As I said in my comment, the only one to properly deal with this is to pass an uninitalized_weights kwargs (as done by the missing_keys) which would then be used:

if len(uninitalized_weights) > 0:
    self.init_weights(uninitalized_weights)

and of course init_weights then needs to use a function different than apply that only applies _init_weights to the unitialized_weights.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The init_weights kwarg on its own won't work, but it's necessary to prevent each model from calling self.init_weights().

The order of operations when doing BertModel.from_pretrained(...) is the following:

  1. Instantiate a random model: cls(config, *model_args, **model_kwargs) => this command already calls self.init_weights(...) (since in every model class we have a self.init_weights(...) in __init__(config):. So in order to prevent this we need to pass a flag to cls(config, *model_args, **model_kwargs)which I would do withmodel_kwargs["init_weights"] = False`.

  2. Only after the model is instantiated (and the weights already have values), we can know which weights were missing & thus need to be randomely initialized. Here we can retrieve uninitialized_weights, but it would be better to actually retrieve all nn.Modules that are randomely initialized since then we can reuse each model's _init_weights(...) function.

  3. Having retrieved uninitialized_modules we can run self._init_weights(...) on each module.

@sgugger
Copy link
Collaborator

sgugger commented Apr 27, 2021

I appreciate that you could think of the edge cases.

That is not the edge case but the overwhelming majority ;-) You are mostly working with seq2seq models that don't throw away any weights when doing transfer learning, but all the basic examples fine-tuning BERT on a classification task encounter this :-)

Testing the init is done properly is very difficult as those are all random weights. Testing those weights follow this distribution instead of that one is not something easily achievable.

I don't think the no_init option is the right one: it will only work for a certain class of problems and not others, so it's not general enough. We shouldn't go for it just before it's easier to implement than the other solutions on the table.

@patrickvonplaten
Copy link
Contributor

@stas00 @sgugger that's how I would approach the problem: #11471

@stas00
Copy link
Contributor Author

stas00 commented Apr 28, 2021

OK, let's move the effort to Patrick's PR #11471

[...] You are mostly working with seq2seq models [...]

Guilty as charged. I'm glad you guys have a much wider view than I. Thank you!

@stas00 stas00 closed this Apr 28, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[model_utils] very slow model instantiation
3 participants