diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 7b1f477af5280f..906ebd99411407 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -769,7 +769,10 @@ def init_weights(self): Initializes and prunes weights if needed. """ # Initialize weights - self.apply(self._init_weights) + if getattr(self.config, "use_pretrained_weights", False): + logger.info("detected pretrained model - skipping _init_weights") + else: + self.apply(self._init_weights) # Prune heads if needed if self.config.pruned_heads: @@ -1116,8 +1119,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P config.name_or_path = pretrained_model_name_or_path - # Instantiate model. + # weights are coming from state_dict so tell models not to init weights, since that + # randomization will be immediately overwritten by weights from state_dict + config.use_pretrained_weights = True + # Instantiate model. if is_deepspeed_zero3_enabled(): import deepspeed