Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Lazy init] Force fall back to slow init for composite models #11705

Merged
merged 5 commits into from
May 12, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,12 @@ def get_output_embeddings(self) -> nn.Module:
"""
return None # Overwrite for models with output embeddings

def _init_weights(self, module):
"""
Initialize the weights. This method should be overridden by derived class.
"""
raise NotImplementedError(f"Make sure `_init_weigths` is implemented for {self.__class__}")

def tie_weights(self):
"""
Tie the weights between the input embeddings and the output embeddings.
Expand Down Expand Up @@ -1205,7 +1211,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
)

model, missing_keys, unexpected_keys, error_msgs = cls._load_state_dict_into_model(
model, state_dict, pretrained_model_name_or_path
model, state_dict, pretrained_model_name_or_path, _fast_init=_fast_init
)

# make sure token embedding weights are still tied if needed
Expand All @@ -1225,7 +1231,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
return model

@classmethod
def _load_state_dict_into_model(cls, model, state_dict, pretrained_model_name_or_path):
def _load_state_dict_into_model(cls, model, state_dict, pretrained_model_name_or_path, _fast_init=True):

# Convert old format to new format if needed from a PyTorch state_dict
old_keys = []
Expand Down Expand Up @@ -1273,12 +1279,14 @@ def _load_state_dict_into_model(cls, model, state_dict, pretrained_model_name_or
for pat in cls._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]

# tie unintialized modules
unintialized_modules = model.retrieve_modules_from_names(
missing_keys, add_prefix=add_prefix, remove_prefix=remove_prefix
)
for module in unintialized_modules:
model._init_weights(module)
if _fast_init:
# retrieve unintialized modules and initialize
unintialized_modules = model.retrieve_modules_from_names(
missing_keys, add_prefix=add_prefix, remove_prefix=remove_prefix
)
for module in unintialized_modules:
model._init_weights(module)

# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, "_metadata", None)
state_dict = state_dict.copy()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,13 @@ def get_output_embeddings(self):
def set_output_embeddings(self, new_embeddings):
return self.decoder.set_output_embeddings(new_embeddings)

@classmethod
def from_pretrained(cls, *args, **kwargs):
# At the moment fast initialization is not supported
# for composite models
kwargs["_fast_init"] = False
return super(EncoderDecoderModel, cls).from_pretrained(*args, **kwargs)
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved

@classmethod
def from_encoder_decoder_pretrained(
cls,
Expand Down
7 changes: 7 additions & 0 deletions src/transformers/models/rag/modeling_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,13 @@ class RagPreTrainedModel(PreTrainedModel):
base_model_prefix = "rag"
_keys_to_ignore_on_load_missing = [r"position_ids"]

@classmethod
def from_pretrained(cls, *args, **kwargs):
# At the moment fast initialization is not supported
# for composite models
kwargs["_fast_init"] = False
return super(RagPreTrainedModel, cls).from_pretrained(*args, **kwargs)
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved

@classmethod
def from_pretrained_question_encoder_generator(
cls,
Expand Down