-
Notifications
You must be signed in to change notification settings - Fork 15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support MuParametrization and MuTransfer #64
base: main
Are you sure you want to change the base?
Changes from 1 commit
b588300
e843c24
6877a53
4ff7462
0648a17
3f84774
0ef23e6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
from typing import Optional, Tuple, Union | ||
import math | ||
import torch | ||
from torch import nn | ||
from torch.nn import CrossEntropyLoss | ||
|
@@ -41,6 +42,9 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None): | |
f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" | ||
f" {self.num_heads})." | ||
) | ||
|
||
# muP | ||
self.use_mup = config.use_mup | ||
|
||
self.scale_attn_weights = config.scale_attn_weights | ||
self.is_cross_attention = is_cross_attention | ||
|
@@ -53,15 +57,41 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None): | |
if self.is_cross_attention: | ||
self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim) | ||
self.q_attn = Conv1D(self.embed_dim, self.embed_dim) | ||
|
||
#muP -- q_attn | ||
if self.use_mup: | ||
self.q_attn.weight.data.normal_(mean=0.0, std=math.sqrt((config.initializer_range ** 2) / config.mup_width_scale)) | ||
self.q_attn.bias.zero_() | ||
|
||
else: | ||
self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim) | ||
self.c_proj = Conv1D(self.embed_dim, self.embed_dim) | ||
|
||
self.attn_dropout = nn.Dropout(config.attn_pdrop) | ||
self.resid_dropout = nn.Dropout(config.resid_pdrop) | ||
#muP -- c_attn specific, see https://github.com/cofe-ai/Mu-scaling/blob/a3d13c3eaa02e28cc6ef95116358b128424f9fe4/modeling/modeling_gpt2_mup.py#L487 | ||
if self.use_mup: | ||
if config.query_zero_init: | ||
_, fanout = self.c_attn.weight.shape | ||
self.c_attn.weight.data[:, :fanout//3] = 0 | ||
self.c_attn.bias.zero_() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think lines 70-74 should be under the More generally, my understanding is that In addition to moving lines 70-74 within the
We may also need to initialize (the rest of) Let me know what you think of this reasoning! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is extremely helpful, thank you. Updating accordingly. |
||
|
||
self.c_proj = Conv1D(self.embed_dim, self.embed_dim) | ||
|
||
#muP -- c_proj specific, see https://github.com/cofe-ai/Mu-scaling/blob/a3d13c3eaa02e28cc6ef95116358b128424f9fe4/modeling/modeling_gpt2_mup.py#L494 | ||
if self.use_mup: | ||
depth_std = config.initializer_range / math.sqrt(2 * config.n_layer) | ||
self.c_proj.weight.data.normal_(mean=0.0, std=math.sqrt(config.depth_std ** 2 / config.mup_width_scale)) | ||
self.c_proj.bias.zero_() | ||
|
||
if self.use_mup: | ||
self.attn_dropout = nn.Identity() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we consider asserting that the dropout probabilities are set to 0 in this case (in configs)? |
||
self.resid_dropout = nn.Identity() | ||
else: | ||
self.attn_dropout = nn.Dropout(config.attn_pdrop) | ||
self.resid_dropout = nn.Dropout(config.resid_pdrop) | ||
|
||
self.pruned_heads = set() | ||
|
||
|
||
|
||
self.rot_emb=None | ||
if self.position_embedding == "rope": | ||
self.rot_emb=RotaryEmbedding(dim=self.head_dim) | ||
|
@@ -76,8 +106,13 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None): | |
|
||
def _attn(self, query, key, value, attention_mask=None, head_mask=None,alibi_bias=None): | ||
attn_weights = torch.matmul(query, key.transpose(-1, -2)) | ||
|
||
if self.scale_attn_weights: | ||
|
||
#muP | ||
if self.use_mup: | ||
attn_weights = attn_weights / torch.full( | ||
[], value.size(-1), dtype=attn_weights.dtype, device=attn_weights.device | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just to check, I'm assuming this is where you're planning on using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think so, yes, though from a configuration point of view I'm unsure about how to best handle this (and the same applies to some of the other MuP-specific configuration options. In theory I wanted an high level configuration option that enables (or disables) MuP, but I also want it to be configurable for experimentation. At the same time, I'm conflicted about this more finegrained configuration since MuP simply doesn't work (that is, hparams do not transfer) unless MuPs requirements, including attention scaling, are respected. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For now, I'll require both config options to be set. |
||
elif self.scale_attn_weights: | ||
attn_weights = attn_weights / torch.full( | ||
[], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device | ||
) | ||
|
@@ -251,10 +286,31 @@ class APTMLP(nn.Module): | |
def __init__(self, intermediate_size, config): | ||
super().__init__() | ||
embed_dim = config.hidden_size | ||
|
||
#muP | ||
use_mup = config.use_mup | ||
|
||
self.c_fc = Conv1D(intermediate_size, embed_dim) | ||
|
||
#muP -- matrix-like | ||
if use_mup: | ||
self.c_fc.weight.data.normal_(mean=0.0, std=math.sqrt((config.initializer_range ** 2) / config.mup_width_scale)) | ||
self.c_fc.bias.zero_() | ||
|
||
self.c_proj = Conv1D(embed_dim, intermediate_size) | ||
|
||
#muP -- matrix-like, c_proj-specific, see https://github.com/cofe-ai/Mu-scaling/blob/a3d13c3eaa02e28cc6ef95116358b128424f9fe4/modeling/modeling_gpt2_mup.py#L494 | ||
if use_mup: | ||
depth_std = config.initializer_range / math.sqrt(2 * config.n_layer) | ||
self.c_proj.weight.data.normal_(mean=0.0, std=math.sqrt(depth_std ** 2 / config.mup_width_scale)) | ||
self.c_proj.bias.zero_() | ||
|
||
self.act = ACT2FN[config.activation_function] | ||
self.dropout = nn.Dropout(config.resid_pdrop) | ||
|
||
if use_mup: | ||
self.dropout = nn.Identity() | ||
else: | ||
self.dropout = nn.Dropout(config.resid_pdrop) | ||
|
||
def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor: | ||
hidden_states = self.c_fc(hidden_states) | ||
|
@@ -270,14 +326,34 @@ def __init__(self, config, layer_idx=None): | |
hidden_size = config.hidden_size | ||
inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size | ||
|
||
#muP | ||
use_mup = config.use_mup | ||
|
||
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) | ||
|
||
#muP -- vector-like | ||
if self.use_mup: | ||
self.ln_1.weight.data.fill_(1.0) | ||
self.ln_1.bias.data.zero_() | ||
|
||
self.attn = APTAttention(config, layer_idx=layer_idx) | ||
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) | ||
|
||
#muP -- vector-like | ||
if use_mup: | ||
self.ln_2.weight.data.fill_(1.0) | ||
self.ln_2.bias.data.zero_() | ||
|
||
if config.add_cross_attention: | ||
#muP TO DO: check proper behavior in case of crossattention | ||
self.crossattention = APTAttention(config, is_cross_attention=True, layer_idx=layer_idx) | ||
self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) | ||
|
||
#muP -- vector-like | ||
if use_mup: | ||
self.ln_cross_attn.weight.data.fill_(1.0) | ||
self.ln_cross_attn.bias.data.zero_() | ||
|
||
self.mlp = APTMLP(inner_dim, config) | ||
|
||
def forward( | ||
|
@@ -353,26 +429,60 @@ class APTModel(GPT2PreTrainedModel): | |
def __init__(self, config): | ||
super().__init__(config) | ||
|
||
self.embed_dim = config.hidden_size | ||
self.embed_dim = config.hidden_sizeù | ||
NZ99 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
use_mup = config.use_mup | ||
|
||
self.wte = nn.Embedding(config.vocab_size, self.embed_dim) | ||
|
||
#muP -- vector-like, zero if zero init or mantained regardless of width | ||
if use_mup: | ||
if config.wte_zero_init: | ||
self.wte.weight.data.zero_() | ||
else: | ||
self.wte.weight.data.normal_(mean=0.0, std=config.initializer_range) | ||
|
||
if self.wte.padding_idx is not None: | ||
self.wte.weight.data[self.wte.padding_idx].zero_() | ||
|
||
self.position_embedding = config.position_embedding if hasattr(config, "position_embedding") else "learned" | ||
|
||
if self.position_embedding=="learned" or self.position_embedding == 'rope' or self.position_embedding == 'rerope' or self.position_embedding=="linear_rope_scaling" or self.position_embedding =="dynamic_rope_scaling": | ||
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) | ||
|
||
#muP -- vector-like, constant regardless of width | ||
#muP TO DO: check proper behavior in rope & rerope case | ||
if self.use_mup: | ||
self.wpe.weight.data.normal_(0.0, std=config.initializer_range) | ||
|
||
if self.wpe.padding_idx is not None: | ||
self.wpe.weight.data[self.wte.padding_idx].zero_() | ||
|
||
self.alibi = None | ||
elif self.position_embedding=="alibi": | ||
#muP TO DO: check proper behavior in alibi case | ||
|
||
maxpos = config.n_positions | ||
attn_heads = config.n_head | ||
alibi = create_alibi_tensor(attn_heads,maxpos) | ||
self.register_buffer('alibi',alibi) | ||
else: | ||
raise Exception(f'position_embedding {self.position_embedding} not supported. Please select one of learned, rope, rerope, linear rope, dynamic rope or alibi') | ||
|
||
self.drop = nn.Dropout(config.embd_pdrop) | ||
#muP | ||
if use_mup: | ||
self.drop = nn.Identity() | ||
else: | ||
self.drop = nn.Dropout(config.embd_pdrop) | ||
|
||
self.h = nn.ModuleList([APTBlock(config, layer_idx=i) for i in range(config.num_hidden_layers)]) | ||
|
||
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) | ||
|
||
#muP -- vector-like | ||
if use_mup: | ||
self.ln_f.weight.data.fill_(1.0) | ||
self.ln_f.bias.data.zero_() | ||
|
||
# Model parallel | ||
self.model_parallel = False | ||
self.device_map = None | ||
|
@@ -474,6 +584,7 @@ def forward( | |
|
||
if self.position_embedding=="learned" or self.position_embedding == 'rope' or self.position_embedding == 'rerope' or self.position_embedding=="linear_rope_scaling" or self.position_embedding =="dynamic_rope_scaling": | ||
position_embeds = self.wpe(position_ids) | ||
position_embeds.mul_(self.mup_rp_embedding_mult) | ||
hidden_states = inputs_embeds + position_embeds | ||
else: | ||
hidden_states = inputs_embeds | ||
|
@@ -593,6 +704,10 @@ class APTLMHeadModel(GPT2PreTrainedModel): | |
def __init__(self, config): | ||
super().__init__(config) | ||
self.transformer = APTModel(config) | ||
|
||
#muP TO DO: check proper behavior for LM head, nothing should be done (?) | ||
#see https://github.com/cofe-ai/Mu-scaling/blob/a3d13c3eaa02e28cc6ef95116358b128424f9fe4/modeling/modeling_gpt2_mup.py#L472 | ||
#see also table 8's caption in https://arxiv.org/pdf/2203.03466.pdf | ||
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We may want to replace this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK, I've looked into it and agree that we need to switch from the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thinking about it, we should probably just add a config option about weight tying, and support both cases accordingly. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually, the MuScaling repo adds a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry, I'm not quite sure I'm understanding -- are you suggesting that we write our own MuReadout and MuSharedReadout modules ourselves? |
||
|
||
# Model parallel | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just checking, some of these config args seem unused in the modeling code as far as I can tell. I assume this is because the PR is still a draft, and they will actually be used in the final PR?
I also believe one is missing:
wte_zero_init
, which used here:protein-lm-scaling/protein_lm/modeling/models/apt/model_pytorch.py
Line 439 in b588300
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this might be the result of starting to work on this implementation based on the one from Marco in GPT-NeoX and later on switching to base it on the MuP-scaling one. I will recheck them all and update accordingly.