From 2d3f8f564992639981eb8250a741c02074382186 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 12 May 2021 15:52:54 +0100 Subject: [PATCH] [Lazy init] Force fall back to slow init for composite models (#11705) * fix encoder-decoder & RAG * finalize * Update src/transformers/models/encoder_decoder/modeling_encoder_decoder.py Co-authored-by: Lysandre Debut * Update src/transformers/models/rag/modeling_rag.py Co-authored-by: Lysandre Debut Co-authored-by: Patrick von Platen Co-authored-by: Lysandre Debut --- src/transformers/modeling_utils.py | 24 ++++++++++++------- .../modeling_encoder_decoder.py | 7 ++++++ src/transformers/models/rag/modeling_rag.py | 7 ++++++ 3 files changed, 30 insertions(+), 8 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index ca8ae2267109d7..9ab8824067c54e 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -510,6 +510,12 @@ def get_output_embeddings(self) -> nn.Module: """ return None # Overwrite for models with output embeddings + def _init_weights(self, module): + """ + Initialize the weights. This method should be overridden by derived class. + """ + raise NotImplementedError(f"Make sure `_init_weigths` is implemented for {self.__class__}") + def tie_weights(self): """ Tie the weights between the input embeddings and the output embeddings. @@ -1205,7 +1211,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P ) model, missing_keys, unexpected_keys, error_msgs = cls._load_state_dict_into_model( - model, state_dict, pretrained_model_name_or_path + model, state_dict, pretrained_model_name_or_path, _fast_init=_fast_init ) # make sure token embedding weights are still tied if needed @@ -1225,7 +1231,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P return model @classmethod - def _load_state_dict_into_model(cls, model, state_dict, pretrained_model_name_or_path): + def _load_state_dict_into_model(cls, model, state_dict, pretrained_model_name_or_path, _fast_init=True): # Convert old format to new format if needed from a PyTorch state_dict old_keys = [] @@ -1273,12 +1279,14 @@ def _load_state_dict_into_model(cls, model, state_dict, pretrained_model_name_or for pat in cls._keys_to_ignore_on_load_unexpected: unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] - # tie unintialized modules - unintialized_modules = model.retrieve_modules_from_names( - missing_keys, add_prefix=add_prefix, remove_prefix=remove_prefix - ) - for module in unintialized_modules: - model._init_weights(module) + if _fast_init: + # retrieve unintialized modules and initialize + unintialized_modules = model.retrieve_modules_from_names( + missing_keys, add_prefix=add_prefix, remove_prefix=remove_prefix + ) + for module in unintialized_modules: + model._init_weights(module) + # copy state_dict so _load_from_state_dict can modify it metadata = getattr(state_dict, "_metadata", None) state_dict = state_dict.copy() diff --git a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py index 3696cf9167b18d..b3bb1eb6036597 100644 --- a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py +++ b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py @@ -221,6 +221,13 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): return self.decoder.set_output_embeddings(new_embeddings) + @classmethod + def from_pretrained(cls, *args, **kwargs): + # At the moment fast initialization is not supported + # for composite models + kwargs["_fast_init"] = False + return super().from_pretrained(*args, **kwargs) + @classmethod def from_encoder_decoder_pretrained( cls, diff --git a/src/transformers/models/rag/modeling_rag.py b/src/transformers/models/rag/modeling_rag.py index 42c2e16d6ca795..8bbc754d14e825 100644 --- a/src/transformers/models/rag/modeling_rag.py +++ b/src/transformers/models/rag/modeling_rag.py @@ -232,6 +232,13 @@ class RagPreTrainedModel(PreTrainedModel): base_model_prefix = "rag" _keys_to_ignore_on_load_missing = [r"position_ids"] + @classmethod + def from_pretrained(cls, *args, **kwargs): + # At the moment fast initialization is not supported + # for composite models + kwargs["_fast_init"] = False + return super().from_pretrained(*args, **kwargs) + @classmethod def from_pretrained_question_encoder_generator( cls,