Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support MuParametrization and MuTransfer #64

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions protein_lm/modeling/models/apt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,32 @@ def __init__(
position_embedding="learned",
tokenizer=None,
max_sequence_length = 1024,
use_mup = False,
query_zero_init = True,
n_layer = None,
initializer_range = 0.02,
mup_init_scale = 1.0,
mup_output_temp = 1.0,
mup_attn_mult = 1.0,
mup_embedding_mult = 1.0,
Copy link
Collaborator

@othertea othertea Dec 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just checking, some of these config args seem unused in the modeling code as far as I can tell. I assume this is because the PR is still a draft, and they will actually be used in the final PR?

I also believe one is missing: wte_zero_init, which used here:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this might be the result of starting to work on this implementation based on the one from Marco in GPT-NeoX and later on switching to base it on the MuP-scaling one. I will recheck them all and update accordingly.

mup_rp_embedding_mult = 1.0,
mup_width_scale = 2.0,
**kwargs
):
super().__init__(**kwargs)
self.nn_model_type = "APT"
self.position_embedding = position_embedding
self.tokenizer = tokenizer
self.max_sequence_length = max_sequence_length

self.use_mup = use_mup
self.query_zero_init = query_zero_init,
self.n_layer = n_layer
self.initializer_range = initializer_range
self.mup_init_scale = mup_init_scale
self.mup_output_temp = mup_output_temp
self.mup_attn_mult = mup_attn_mult
self.mup_embedding_mult = mup_embedding_mult
self.mup_rp_embedding_mult = mup_rp_embedding_mult
self.mup_width_scale = mup_width_scale

131 changes: 123 additions & 8 deletions protein_lm/modeling/models/apt/model_pytorch.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Optional, Tuple, Union
import math
import torch
from torch import nn
from torch.nn import CrossEntropyLoss
Expand Down Expand Up @@ -41,6 +42,9 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None):
f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads})."
)

# muP
self.use_mup = config.use_mup

self.scale_attn_weights = config.scale_attn_weights
self.is_cross_attention = is_cross_attention
Expand All @@ -53,15 +57,41 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None):
if self.is_cross_attention:
self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim)
self.q_attn = Conv1D(self.embed_dim, self.embed_dim)

#muP -- q_attn
if self.use_mup:
self.q_attn.weight.data.normal_(mean=0.0, std=math.sqrt((config.initializer_range ** 2) / config.mup_width_scale))
self.q_attn.bias.zero_()

else:
self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)
self.c_proj = Conv1D(self.embed_dim, self.embed_dim)

self.attn_dropout = nn.Dropout(config.attn_pdrop)
self.resid_dropout = nn.Dropout(config.resid_pdrop)
#muP -- c_attn specific, see https://github.com/cofe-ai/Mu-scaling/blob/a3d13c3eaa02e28cc6ef95116358b128424f9fe4/modeling/modeling_gpt2_mup.py#L487
if self.use_mup:
if config.query_zero_init:
_, fanout = self.c_attn.weight.shape
self.c_attn.weight.data[:, :fanout//3] = 0
self.c_attn.bias.zero_()
Copy link
Collaborator

@othertea othertea Dec 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think lines 70-74 should be under the else (line 66) block?

More generally, my understanding is that (self.c_attn, self.q_attn) in the case of cross-attention (self.is_cross_attention==True) has a similar function as self.c_attn in the no cross-attention (self.is_cross_attention==False) case. In particular, self.q_attn's parameters in the cross-attention case has the same function as the first last third of the parameters of the self.c_attn (the ones that are set in line 73) in the no cross-attention case. Therefore, the way you initialize the weights of self.q_attn should be the same as the way you initialize the weights of the first last third of the weights of self.c_attn in the no cross-attention case.

In addition to moving lines 70-74 within the else block mentioned about, this includes

  • adding zero-initialization of self.q_attn to the cross-attention case when config.query_zero_init==True, and
  • adding muP initialization (lines 62-64) of the first last third of the parameters of self.c_attn in the no cross-attention case.

We may also need to initialize (the rest of) self.c_attn with muP init as well.

Let me know what you think of this reasoning!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is extremely helpful, thank you. Updating accordingly.


self.c_proj = Conv1D(self.embed_dim, self.embed_dim)

#muP -- c_proj specific, see https://github.com/cofe-ai/Mu-scaling/blob/a3d13c3eaa02e28cc6ef95116358b128424f9fe4/modeling/modeling_gpt2_mup.py#L494
if self.use_mup:
depth_std = config.initializer_range / math.sqrt(2 * config.n_layer)
self.c_proj.weight.data.normal_(mean=0.0, std=math.sqrt(config.depth_std ** 2 / config.mup_width_scale))
self.c_proj.bias.zero_()

if self.use_mup:
self.attn_dropout = nn.Identity()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we consider asserting that the dropout probabilities are set to 0 in this case (in configs)?

self.resid_dropout = nn.Identity()
else:
self.attn_dropout = nn.Dropout(config.attn_pdrop)
self.resid_dropout = nn.Dropout(config.resid_pdrop)

self.pruned_heads = set()



self.rot_emb=None
if self.position_embedding == "rope":
self.rot_emb=RotaryEmbedding(dim=self.head_dim)
Expand All @@ -76,8 +106,13 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None):

def _attn(self, query, key, value, attention_mask=None, head_mask=None,alibi_bias=None):
attn_weights = torch.matmul(query, key.transpose(-1, -2))

if self.scale_attn_weights:

#muP
if self.use_mup:
attn_weights = attn_weights / torch.full(
[], value.size(-1), dtype=attn_weights.dtype, device=attn_weights.device
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to check, I'm assuming this is where you're planning on using config.mup_attn_mult?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so, yes, though from a configuration point of view I'm unsure about how to best handle this (and the same applies to some of the other MuP-specific configuration options. In theory I wanted an high level configuration option that enables (or disables) MuP, but I also want it to be configurable for experimentation. At the same time, I'm conflicted about this more finegrained configuration since MuP simply doesn't work (that is, hparams do not transfer) unless MuPs requirements, including attention scaling, are respected.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now, I'll require both config options to be set.

elif self.scale_attn_weights:
attn_weights = attn_weights / torch.full(
[], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
)
Expand Down Expand Up @@ -251,10 +286,31 @@ class APTMLP(nn.Module):
def __init__(self, intermediate_size, config):
super().__init__()
embed_dim = config.hidden_size

#muP
use_mup = config.use_mup

self.c_fc = Conv1D(intermediate_size, embed_dim)

#muP -- matrix-like
if use_mup:
self.c_fc.weight.data.normal_(mean=0.0, std=math.sqrt((config.initializer_range ** 2) / config.mup_width_scale))
self.c_fc.bias.zero_()

self.c_proj = Conv1D(embed_dim, intermediate_size)

#muP -- matrix-like, c_proj-specific, see https://github.com/cofe-ai/Mu-scaling/blob/a3d13c3eaa02e28cc6ef95116358b128424f9fe4/modeling/modeling_gpt2_mup.py#L494
if use_mup:
depth_std = config.initializer_range / math.sqrt(2 * config.n_layer)
self.c_proj.weight.data.normal_(mean=0.0, std=math.sqrt(depth_std ** 2 / config.mup_width_scale))
self.c_proj.bias.zero_()

self.act = ACT2FN[config.activation_function]
self.dropout = nn.Dropout(config.resid_pdrop)

if use_mup:
self.dropout = nn.Identity()
else:
self.dropout = nn.Dropout(config.resid_pdrop)

def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
hidden_states = self.c_fc(hidden_states)
Expand All @@ -270,14 +326,34 @@ def __init__(self, config, layer_idx=None):
hidden_size = config.hidden_size
inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size

#muP
use_mup = config.use_mup

self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)

#muP -- vector-like
if self.use_mup:
self.ln_1.weight.data.fill_(1.0)
self.ln_1.bias.data.zero_()

self.attn = APTAttention(config, layer_idx=layer_idx)
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)

#muP -- vector-like
if use_mup:
self.ln_2.weight.data.fill_(1.0)
self.ln_2.bias.data.zero_()

if config.add_cross_attention:
#muP TO DO: check proper behavior in case of crossattention
self.crossattention = APTAttention(config, is_cross_attention=True, layer_idx=layer_idx)
self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)

#muP -- vector-like
if use_mup:
self.ln_cross_attn.weight.data.fill_(1.0)
self.ln_cross_attn.bias.data.zero_()

self.mlp = APTMLP(inner_dim, config)

def forward(
Expand Down Expand Up @@ -353,26 +429,60 @@ class APTModel(GPT2PreTrainedModel):
def __init__(self, config):
super().__init__(config)

self.embed_dim = config.hidden_size
self.embed_dim = config.hidden_sizeù
NZ99 marked this conversation as resolved.
Show resolved Hide resolved
use_mup = config.use_mup

self.wte = nn.Embedding(config.vocab_size, self.embed_dim)

#muP -- vector-like, zero if zero init or mantained regardless of width
if use_mup:
if config.wte_zero_init:
self.wte.weight.data.zero_()
else:
self.wte.weight.data.normal_(mean=0.0, std=config.initializer_range)

if self.wte.padding_idx is not None:
self.wte.weight.data[self.wte.padding_idx].zero_()

self.position_embedding = config.position_embedding if hasattr(config, "position_embedding") else "learned"

if self.position_embedding=="learned" or self.position_embedding == 'rope' or self.position_embedding == 'rerope' or self.position_embedding=="linear_rope_scaling" or self.position_embedding =="dynamic_rope_scaling":
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)

#muP -- vector-like, constant regardless of width
#muP TO DO: check proper behavior in rope & rerope case
if self.use_mup:
self.wpe.weight.data.normal_(0.0, std=config.initializer_range)

if self.wpe.padding_idx is not None:
self.wpe.weight.data[self.wte.padding_idx].zero_()

self.alibi = None
elif self.position_embedding=="alibi":
#muP TO DO: check proper behavior in alibi case

maxpos = config.n_positions
attn_heads = config.n_head
alibi = create_alibi_tensor(attn_heads,maxpos)
self.register_buffer('alibi',alibi)
else:
raise Exception(f'position_embedding {self.position_embedding} not supported. Please select one of learned, rope, rerope, linear rope, dynamic rope or alibi')

self.drop = nn.Dropout(config.embd_pdrop)
#muP
if use_mup:
self.drop = nn.Identity()
else:
self.drop = nn.Dropout(config.embd_pdrop)

self.h = nn.ModuleList([APTBlock(config, layer_idx=i) for i in range(config.num_hidden_layers)])

self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)

#muP -- vector-like
if use_mup:
self.ln_f.weight.data.fill_(1.0)
self.ln_f.bias.data.zero_()

# Model parallel
self.model_parallel = False
self.device_map = None
Expand Down Expand Up @@ -474,6 +584,7 @@ def forward(

if self.position_embedding=="learned" or self.position_embedding == 'rope' or self.position_embedding == 'rerope' or self.position_embedding=="linear_rope_scaling" or self.position_embedding =="dynamic_rope_scaling":
position_embeds = self.wpe(position_ids)
position_embeds.mul_(self.mup_rp_embedding_mult)
hidden_states = inputs_embeds + position_embeds
else:
hidden_states = inputs_embeds
Expand Down Expand Up @@ -593,6 +704,10 @@ class APTLMHeadModel(GPT2PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.transformer = APTModel(config)

#muP TO DO: check proper behavior for LM head, nothing should be done (?)
#see https://github.com/cofe-ai/Mu-scaling/blob/a3d13c3eaa02e28cc6ef95116358b128424f9fe4/modeling/modeling_gpt2_mup.py#L472
#see also table 8's caption in https://arxiv.org/pdf/2203.03466.pdf
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may want to replace this nn.Linear with MuReadout from the muP repository/package or do the equivalent initialization manually. I think cofe-ai's Mu-scaling repo also uses MuReadout; lm_head is initialized in a different file than the one your comment references: https://github.com/cofe-ai/Mu-scaling/blob/a3d13c3eaa02e28cc6ef95116358b128424f9fe4/modeling/lm_mup.py#L17-L25

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I've looked into it and agree that we need to switch from the nn.Linear layer to MuReadout. Do we plan to use weight tying? Because mup also has a MuSharedReadout (https://github.com/microsoft/mup/blob/19814971934ef91dd546f88e913fc963e096d11c/mup/layer.py#L59-L68) that is likely handy in this case. For now I'll add a note about weight tying and just switch to the normal MuReadout.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking about it, we should probably just add a config option about weight tying, and support both cases accordingly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, the MuScaling repo adds a width_mult config option to MuReadout because the original mup repo has it calculated and set on a layer by layer basis. I think we might want to follow the MuScaling approach of just integrating MuReadout and MuSharedReadout in our code accordingly.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I'm not quite sure I'm understanding -- are you suggesting that we write our own MuReadout and MuSharedReadout modules ourselves?


# Model parallel
Expand Down