Skip to content

Commit

Permalink
fix _resize_token_embeddings will set lm head size to 0 when enabled …
Browse files Browse the repository at this point in the history
…deepspeed zero3 (#26024)
  • Loading branch information
kai01ai authored and LysandreJik committed Sep 15, 2023
1 parent 8160e42 commit 2ba46c1
Showing 1 changed file with 11 additions and 1 deletion.
12 changes: 11 additions & 1 deletion src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1451,10 +1451,20 @@ def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None):
add_hook_to_module(new_embeddings, hook)
self.set_input_embeddings(new_embeddings)

# Update new_num_tokens with the actual size of new_embeddings
if pad_to_multiple_of is not None:
if is_deepspeed_zero3_enabled():
import deepspeed

with deepspeed.zero.GatheredParameters(new_embeddings.weight, modifier_rank=None):
new_num_tokens = new_embeddings.weight.shape[0]
else:
new_num_tokens = new_embeddings.weight.shape[0]

# if word embeddings are not tied, make sure that lm head is resized as well
if self.get_output_embeddings() is not None and not self.config.tie_word_embeddings:
old_lm_head = self.get_output_embeddings()
new_lm_head = self._get_resized_lm_head(old_lm_head, new_embeddings.weight.shape[0])
new_lm_head = self._get_resized_lm_head(old_lm_head, new_num_tokens)
if hasattr(old_lm_head, "_hf_hook"):
hook = old_lm_head._hf_hook
add_hook_to_module(new_lm_head, hook)
Expand Down

0 comments on commit 2ba46c1

Please sign in to comment.