Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[model loading] don't init weights for pretrained models #11463

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,7 +769,10 @@ def init_weights(self):
Initializes and prunes weights if needed.
"""
# Initialize weights
self.apply(self._init_weights)
if getattr(self.config, "use_pretrained_weights", False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • I'm not a huge fan of "attaching" a new parameter to the config which is not really understandable by the user.

  • Also, I think this could lead to problems -> lots of people initialize all weights except the final layer weights from a pre-trained BERT in, e.g. a BertForSequenceClassification. The logic would then not correctly initialize the final layer, but simply set everything to zero which would probably lead to a worse fine-tuning of BertForSequenceClassification.

=> I would propose the following:

  1. When using from_pretrained(...), we pass a new parameter to model = cls(config, *model_args, **model_kwargs) by setting model_kwargs["init_weights"] = False. This then sadly means that we have to replace all __init__(self, config) functions in the modeling files by __init__(self, config, init_weights=True), but I think we can use a regex for this. This is a huge change in terms of files that need to be changed, but I think it's cleaner then creating a new "use_pretrained_weights" config parameter that the user shouldn't have to learn about. Then, we also need to change self.init_weights() with
if init_weights:
   self.init_weights()
  1. Now, we also need to take care of cases where only parts of the model are initialized from pre-trained weights. The other part still needs to be initialized. Here we can't simply use self.init_weights() because it would necessarly run through all modules and initialized them. So I think we should leverage the missing_keys() list here to extract all nn.Modules(...) that still need to be initialized and then run self._init_weights(m) for m in uninitialized_modules

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think ? @stas00

Also keen to hear @LysandreJik's and @sgugger's opinion here

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The init_weights kwarg by itself will not work as it doesn't deal with 2. As I said in my comment, the only one to properly deal with this is to pass an uninitalized_weights kwargs (as done by the missing_keys) which would then be used:

if len(uninitalized_weights) > 0:
    self.init_weights(uninitalized_weights)

and of course init_weights then needs to use a function different than apply that only applies _init_weights to the unitialized_weights.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The init_weights kwarg on its own won't work, but it's necessary to prevent each model from calling self.init_weights().

The order of operations when doing BertModel.from_pretrained(...) is the following:

  1. Instantiate a random model: cls(config, *model_args, **model_kwargs) => this command already calls self.init_weights(...) (since in every model class we have a self.init_weights(...) in __init__(config):. So in order to prevent this we need to pass a flag to cls(config, *model_args, **model_kwargs)which I would do withmodel_kwargs["init_weights"] = False`.

  2. Only after the model is instantiated (and the weights already have values), we can know which weights were missing & thus need to be randomely initialized. Here we can retrieve uninitialized_weights, but it would be better to actually retrieve all nn.Modules that are randomely initialized since then we can reuse each model's _init_weights(...) function.

  3. Having retrieved uninitialized_modules we can run self._init_weights(...) on each module.

logger.info("detected pretrained model - skipping _init_weights")
else:
self.apply(self._init_weights)

# Prune heads if needed
if self.config.pruned_heads:
Expand Down Expand Up @@ -1116,8 +1119,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P

config.name_or_path = pretrained_model_name_or_path

# Instantiate model.
# weights are coming from state_dict so tell models not to init weights, since that
# randomization will be immediately overwritten by weights from state_dict
config.use_pretrained_weights = True

# Instantiate model.
if is_deepspeed_zero3_enabled():
import deepspeed

Expand Down