diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index e537e307b6190c..59621572790b1c 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -786,16 +786,16 @@ def init_weights(self): """ Maybe initializes and prunes weights if needed. """ + # Prune heads if needed + if self.config.pruned_heads: + self.prune_heads(self.config.pruned_heads) + if not _init_weights: return # Initialize weights self.apply(self._init_weights) - # Prune heads if needed - if self.config.pruned_heads: - self.prune_heads(self.config.pruned_heads) - # Tie weights if needed self.tie_weights() @@ -1163,9 +1163,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P try: from .modeling_tf_pytorch_utils import load_tf2_checkpoint_in_pytorch_model - model, missing_keys = load_tf2_checkpoint_in_pytorch_model( - model, resolved_archive_file, allow_missing_keys=True - ) + model = load_tf2_checkpoint_in_pytorch_model(model, resolved_archive_file, allow_missing_keys=True) except ImportError: logger.error( "Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see " @@ -1241,9 +1239,11 @@ def _load_state_dict_into_model(cls, model, state_dict, pretrained_model_name_or expects_prefix_module = any(s.startswith(prefix) for s in expected_keys) if not has_prefix_module and expects_prefix_module: - expected_keys = [s.split(prefix)[-1] for s in expected_keys if s.startswith(prefix)] + expected_keys = [".".join(s.split(".")[1:]) if s.startswith(prefix) else s for s in expected_keys] elif has_prefix_module and not expects_prefix_module: - expected_keys = [".".join([prefix, s]) for s in expected_keys if ".".join([prefix, s]) in set(loaded_keys)] + expected_keys = [ + ".".join([prefix, s]) if ".".join([prefix, s]) in set(loaded_keys) else s for s in expected_keys + ] missing_keys = list(set(expected_keys) - set(loaded_keys)) unexpected_keys = list(set(loaded_keys) - set(expected_keys)) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index d193a9e7a47862..68057a69a04901 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -176,6 +176,49 @@ def test_save_load__keys_to_ignore_on_save(self): for k in _keys_to_ignore_on_save: self.assertNotIn(k, state_dict_saved) + def test_save_load_fast_init_from_base(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + base_class = MODEL_MAPPING[config.__class__] + for model_class in self.all_model_classes: + if model_class == base_class: + continue + + model = base_class(config) + + # check that certain keys didn't get saved with the model + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + + torch.manual_seed(0) + model_fast_init = model_class.from_pretrained(tmpdirname) + model_slow_init = model_class.from_pretrained(tmpdirname, _no_fast_init=True) + + for key in model_fast_init.state_dict().keys(): + max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item() + self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") + + def test_save_load_fast_init_to_base(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + base_class = MODEL_MAPPING[config.__class__] + + for model_class in self.all_model_classes: + if model_class == base_class: + continue + + model = model_class(config) + + # check that certain keys didn't get saved with the model + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + + torch.manual_seed(0) + model_fast_init = base_class.from_pretrained(tmpdirname, _no_fast_init=True) + model_slow_init = base_class.from_pretrained(tmpdirname, _no_fast_init=True) + + for key in model_fast_init.state_dict().keys(): + max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item() + self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") + def test_initialization(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -635,7 +678,7 @@ def test_head_pruning_integration(self): if not self.test_pruning: return - for model_class in self.all_model_classes: + for model_class in self.all_model_classes[:1]: ( config, inputs_dict,