-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[model loading] don't init weights for pretrained models #11463
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This won't work sadly, for two reasons.
- First, everything in the
__dict__
of the config gets serialized when one usesconfig.self_pretrained()
(which is called bymodel.from_pretrained
) so any other model downloaded from the hub with a checkpoint saved after this is merged will get this attribute in the config. Then if a user instantiates a randomly-initialized model using the config, with the following code:
config = AutoConfig.from_pretrained("new_checkpoint_after_this_is_merge")
model = AutoModel.from_config(config)
then the model won't be randomly initalized (at least not with _init_weights
) since the config will have this use_pretrained_weights
.
- Then come the problem that pretrained model instantiated with
from_pretrained
does not necessarily have all weights initialized (if you discard the head to put another task-specific head) and this PR will break the way those weights are randomly initialized.
I sadly don't see a way around passing around a list of not-initialized weights from pretrained to the _init_weights
function
So if I find another way to do it that doesn't taint the config then it's OK, right? (as far as config correctness goes) e.g. what if I unset this config value as soon as
I appreciate that you could think of the edge cases. Clearly, we don't have any tests that somehow verify that the init is done correctly. I was hoping that there would be some, but these would be hard to conjure. If you feel this is a worthwhile effort, perhaps let's start coming up with examples, write tests if possible and solve those? You can throw the edge-cases at me and I will try to overcome those. Or alternatively, we provide a very easy way for users to either force the init, or if it's safer to force no-init? e.g. the staple examples could all enforce no-init and explain how to change that if the user wants to modify the example to have the original behavior? So what I'm suggesting is that instead of |
@@ -769,7 +769,10 @@ def init_weights(self): | |||
Initializes and prunes weights if needed. | |||
""" | |||
# Initialize weights | |||
self.apply(self._init_weights) | |||
if getattr(self.config, "use_pretrained_weights", False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
-
I'm not a huge fan of "attaching" a new parameter to the
config
which is not really understandable by the user. -
Also, I think this could lead to problems -> lots of people initialize all weights except the final layer weights from a pre-trained BERT in, e.g. a
BertForSequenceClassification
. The logic would then not correctly initialize the final layer, but simply set everything to zero which would probably lead to a worse fine-tuning ofBertForSequenceClassification
.
=> I would propose the following:
- When using
from_pretrained(...)
, we pass a new parameter tomodel = cls(config, *model_args, **model_kwargs)
by settingmodel_kwargs["init_weights"] = False
. This then sadly means that we have to replace all__init__(self, config)
functions in the modeling files by__init__(self, config, init_weights=True)
, but I think we can use a regex for this. This is a huge change in terms of files that need to be changed, but I think it's cleaner then creating a new"use_pretrained_weights"
config parameter that the user shouldn't have to learn about. Then, we also need to changeself.init_weights()
with
if init_weights:
self.init_weights()
- Now, we also need to take care of cases where only parts of the model are initialized from pre-trained weights. The other part still needs to be initialized. Here we can't simply use
self.init_weights()
because it would necessarly run through all modules and initialized them. So I think we should leverage themissing_keys()
list here to extract allnn.Modules(...)
that still need to be initialized and then runself._init_weights(m) for m in uninitialized_modules
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you think ? @stas00
Also keen to hear @LysandreJik's and @sgugger's opinion here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The init_weights
kwarg by itself will not work as it doesn't deal with 2. As I said in my comment, the only one to properly deal with this is to pass an uninitalized_weights
kwargs (as done by the missing_keys
) which would then be used:
if len(uninitalized_weights) > 0:
self.init_weights(uninitalized_weights)
and of course init_weights
then needs to use a function different than apply
that only applies _init_weights
to the unitialized_weights
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The init_weights
kwarg on its own won't work, but it's necessary to prevent each model from calling self.init_weights()
.
The order of operations when doing BertModel.from_pretrained(...)
is the following:
-
Instantiate a random model:
cls(config, *model_args, **model_kwargs)
=> this command already callsself.init_weights(...)
(since in every model class we have aself.init_weights(...)
in__init__(config):. So in order to prevent this we need to pass a flag to
cls(config, *model_args, **model_kwargs)which I would do with
model_kwargs["init_weights"] = False`. -
Only after the model is instantiated (and the weights already have values), we can know which weights were missing & thus need to be randomely initialized. Here we can retrieve
uninitialized_weights
, but it would be better to actually retrieve allnn.Modules
that are randomely initialized since then we can reuse each model's_init_weights(...)
function. -
Having retrieved
uninitialized_modules
we can runself._init_weights(...)
on each module.
That is not the edge case but the overwhelming majority ;-) You are mostly working with seq2seq models that don't throw away any weights when doing transfer learning, but all the basic examples fine-tuning BERT on a classification task encounter this :-) Testing the init is done properly is very difficult as those are all random weights. Testing those weights follow this distribution instead of that one is not something easily achievable. I don't think the |
OK, let's move the effort to Patrick's PR #11471
Guilty as charged. I'm glad you guys have a much wider view than I. Thank you! |
Skip
_init_weights
for pretrained models since they get immediately replaced by pretrained weights. This leads to a much faster startup for huge models.Fixes: #9205
@sgugger, @patrickvonplaten