Skip to content

Commit

Permalink
[Lazy init] Force fall back to slow init for composite models (#11705)
Browse files Browse the repository at this point in the history
* fix encoder-decoder & RAG

* finalize

* Update src/transformers/models/encoder_decoder/modeling_encoder_decoder.py

Co-authored-by: Lysandre Debut <[email protected]>

* Update src/transformers/models/rag/modeling_rag.py

Co-authored-by: Lysandre Debut <[email protected]>

Co-authored-by: Patrick von Platen <[email protected]>
Co-authored-by: Lysandre Debut <[email protected]>
  • Loading branch information
3 people authored May 12, 2021
1 parent 5c1cda9 commit fd6204b
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 8 deletions.
24 changes: 16 additions & 8 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,12 @@ def get_output_embeddings(self) -> nn.Module:
"""
return None # Overwrite for models with output embeddings

def _init_weights(self, module):
"""
Initialize the weights. This method should be overridden by derived class.
"""
raise NotImplementedError(f"Make sure `_init_weigths` is implemented for {self.__class__}")

def tie_weights(self):
"""
Tie the weights between the input embeddings and the output embeddings.
Expand Down Expand Up @@ -1205,7 +1211,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
)

model, missing_keys, unexpected_keys, error_msgs = cls._load_state_dict_into_model(
model, state_dict, pretrained_model_name_or_path
model, state_dict, pretrained_model_name_or_path, _fast_init=_fast_init
)

# make sure token embedding weights are still tied if needed
Expand All @@ -1225,7 +1231,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
return model

@classmethod
def _load_state_dict_into_model(cls, model, state_dict, pretrained_model_name_or_path):
def _load_state_dict_into_model(cls, model, state_dict, pretrained_model_name_or_path, _fast_init=True):

# Convert old format to new format if needed from a PyTorch state_dict
old_keys = []
Expand Down Expand Up @@ -1273,12 +1279,14 @@ def _load_state_dict_into_model(cls, model, state_dict, pretrained_model_name_or
for pat in cls._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]

# tie unintialized modules
unintialized_modules = model.retrieve_modules_from_names(
missing_keys, add_prefix=add_prefix, remove_prefix=remove_prefix
)
for module in unintialized_modules:
model._init_weights(module)
if _fast_init:
# retrieve unintialized modules and initialize
unintialized_modules = model.retrieve_modules_from_names(
missing_keys, add_prefix=add_prefix, remove_prefix=remove_prefix
)
for module in unintialized_modules:
model._init_weights(module)

# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, "_metadata", None)
state_dict = state_dict.copy()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,13 @@ def get_output_embeddings(self):
def set_output_embeddings(self, new_embeddings):
return self.decoder.set_output_embeddings(new_embeddings)

@classmethod
def from_pretrained(cls, *args, **kwargs):
# At the moment fast initialization is not supported
# for composite models
kwargs["_fast_init"] = False
return super().from_pretrained(*args, **kwargs)

@classmethod
def from_encoder_decoder_pretrained(
cls,
Expand Down
7 changes: 7 additions & 0 deletions src/transformers/models/rag/modeling_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,13 @@ class RagPreTrainedModel(PreTrainedModel):
base_model_prefix = "rag"
_keys_to_ignore_on_load_missing = [r"position_ids"]

@classmethod
def from_pretrained(cls, *args, **kwargs):
# At the moment fast initialization is not supported
# for composite models
kwargs["_fast_init"] = False
return super().from_pretrained(*args, **kwargs)

@classmethod
def from_pretrained_question_encoder_generator(
cls,
Expand Down

0 comments on commit fd6204b

Please sign in to comment.