From b0d1639153bb8b926cea96411d6c7b44ad879eac Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Mon, 26 Apr 2021 17:59:47 -0700 Subject: [PATCH 1/2] don't init weights for pretrained models --- src/transformers/modeling_utils.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 7b1f477af5280f..67ee6c6368e259 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -769,7 +769,10 @@ def init_weights(self): Initializes and prunes weights if needed. """ # Initialize weights - self.apply(self._init_weights) + if not self.config.use_pretrained_weights: + self.apply(self._init_weights) + else: + logger.info("detected pretrained model - skipping _init_weights") # Prune heads if needed if self.config.pruned_heads: @@ -1116,8 +1119,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P config.name_or_path = pretrained_model_name_or_path - # Instantiate model. + # weights are coming from state_dict so tell models not to init weights, since that + # randomization will be immediately overwritten by weights from state_dict + config.use_pretrained_weights = True + # Instantiate model. if is_deepspeed_zero3_enabled(): import deepspeed From 2cf051dc2a5978c528efb2ca10ecae61ed123d44 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Mon, 26 Apr 2021 18:08:19 -0700 Subject: [PATCH 2/2] fix --- src/transformers/modeling_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 67ee6c6368e259..906ebd99411407 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -769,10 +769,10 @@ def init_weights(self): Initializes and prunes weights if needed. """ # Initialize weights - if not self.config.use_pretrained_weights: - self.apply(self._init_weights) - else: + if getattr(self.config, "use_pretrained_weights", False): logger.info("detected pretrained model - skipping _init_weights") + else: + self.apply(self._init_weights) # Prune heads if needed if self.config.pruned_heads: