-
Notifications
You must be signed in to change notification settings - Fork 15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support MuParametrization and MuTransfer #64
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @NZ99 thanks a ton for this MuParametrization PR!! I'm very excited about getting this working 🙂
I have a few high-level comments:
- Coordinate checking sounds like a great idea! I'm fully in support of doing this before merging a final PR.
- On the other hand, I wonder if we should consider saving the integration into Optuna for a subsequent PR. This first MuP PR already has a lot of changes, and I would be happy if we just got a version that was verified by coordinate checking. But I do see that an advantage with integrating Optuna in this PR might be that we can try a hyperparam search with it and verify that hyperparams do transfer successfully. So I see the pros and cons of each, and of course, it's your call!
- Regarding the organization of the mup code: I noticed that in both the repo that you refer to in this PR as well as the muTransformers repo, they have an initialization function that is called on the modules, but in this PR, you chose to initialize them within the
__init__
function of each module. I'm curious what are the pros and cons of this choice? Again, I leave it up to you whether or not you want a more global init function. (Though personally I think that the LayerNorm init happens often enough that you may want to factor that out either way.) - Are we planning on using the MuP repo? It can be useful if we want to use their
MuReadout
module (a potential location for using this is in an in-line comment below). And I believe we can't use the default PyTorch optimizers, so we either need to use the MuP optimizers or implement our own with the correct scaling.
And of course, @pascalnotin please feel free to chime in on any thoughts regarding the above comments or anything else!
I've also left some (not-100%-comprehensive) in-line comments. Let me know what you think!
mup_init_scale = 1.0, | ||
mup_output_temp = 1.0, | ||
mup_attn_mult = 1.0, | ||
mup_embedding_mult = 1.0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just checking, some of these config args seem unused in the modeling code as far as I can tell. I assume this is because the PR is still a draft, and they will actually be used in the final PR?
I also believe one is missing: wte_zero_init
, which used here:
if config.wte_zero_init: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this might be the result of starting to work on this implementation based on the one from Marco in GPT-NeoX and later on switching to base it on the MuP-scaling one. I will recheck them all and update accordingly.
if self.is_cross_attention: | ||
self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim) | ||
self.q_attn = Conv1D(self.embed_dim, self.embed_dim) | ||
|
||
#muP -- q_attn | ||
if self.use_mup: | ||
self.q_attn.weight.data.normal_(mean=0.0, std=math.sqrt((config.initializer_range ** 2) / config.mup_width_scale)) | ||
self.q_attn.bias.zero_() | ||
|
||
else: | ||
self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim) | ||
self.c_proj = Conv1D(self.embed_dim, self.embed_dim) | ||
|
||
self.attn_dropout = nn.Dropout(config.attn_pdrop) | ||
self.resid_dropout = nn.Dropout(config.resid_pdrop) | ||
#muP -- c_attn specific, see https://github.com/cofe-ai/Mu-scaling/blob/a3d13c3eaa02e28cc6ef95116358b128424f9fe4/modeling/modeling_gpt2_mup.py#L487 | ||
if self.use_mup: | ||
if config.query_zero_init: | ||
_, fanout = self.c_attn.weight.shape | ||
self.c_attn.weight.data[:, :fanout//3] = 0 | ||
self.c_attn.bias.zero_() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think lines 70-74 should be under the else
(line 66) block?
More generally, my understanding is that (self.c_attn, self.q_attn)
in the case of cross-attention (self.is_cross_attention==True
) has a similar function as self.c_attn
in the no cross-attention (self.is_cross_attention==False
) case. In particular, self.q_attn
's parameters in the cross-attention case has the same function as the first last third of the parameters of the self.c_attn
(the ones that are set in line 73) in the no cross-attention case. Therefore, the way you initialize the weights of self.q_attn
should be the same as the way you initialize the weights of the first last third of the weights of self.c_attn
in the no cross-attention case.
In addition to moving lines 70-74 within the else
block mentioned about, this includes
- adding zero-initialization of
self.q_attn
to the cross-attention case whenconfig.query_zero_init==True
, and - adding muP initialization (lines 62-64) of the first
lastthird of the parameters ofself.c_attn
in the no cross-attention case.
We may also need to initialize (the rest of) self.c_attn
with muP init as well.
Let me know what you think of this reasoning!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is extremely helpful, thank you. Updating accordingly.
if self.use_mup: | ||
attn_weights = attn_weights / torch.full( | ||
[], value.size(-1), dtype=attn_weights.dtype, device=attn_weights.device | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to check, I'm assuming this is where you're planning on using config.mup_attn_mult
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think so, yes, though from a configuration point of view I'm unsure about how to best handle this (and the same applies to some of the other MuP-specific configuration options. In theory I wanted an high level configuration option that enables (or disables) MuP, but I also want it to be configurable for experimentation. At the same time, I'm conflicted about this more finegrained configuration since MuP simply doesn't work (that is, hparams do not transfer) unless MuPs requirements, including attention scaling, are respected.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For now, I'll require both config options to be set.
|
||
#muP TO DO: check proper behavior for LM head, nothing should be done (?) | ||
#see https://github.com/cofe-ai/Mu-scaling/blob/a3d13c3eaa02e28cc6ef95116358b128424f9fe4/modeling/modeling_gpt2_mup.py#L472 | ||
#see also table 8's caption in https://arxiv.org/pdf/2203.03466.pdf | ||
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We may want to replace this nn.Linear
with MuReadout
from the muP repository/package or do the equivalent initialization manually. I think cofe-ai's Mu-scaling repo also uses MuReadout
; lm_head
is initialized in a different file than the one your comment references: https://github.com/cofe-ai/Mu-scaling/blob/a3d13c3eaa02e28cc6ef95116358b128424f9fe4/modeling/lm_mup.py#L17-L25
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, I've looked into it and agree that we need to switch from the nn.Linear
layer to MuReadout
. Do we plan to use weight tying? Because mup
also has a MuSharedReadout
(https://github.com/microsoft/mup/blob/19814971934ef91dd546f88e913fc963e096d11c/mup/layer.py#L59-L68) that is likely handy in this case. For now I'll add a note about weight tying and just switch to the normal MuReadout
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thinking about it, we should probably just add a config option about weight tying, and support both cases accordingly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, the MuScaling repo adds a width_mult
config option to MuReadout
because the original mup
repo has it calculated and set on a layer by layer basis. I think we might want to follow the MuScaling approach of just integrating MuReadout and MuSharedReadout in our code accordingly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I'm not quite sure I'm understanding -- are you suggesting that we write our own MuReadout and MuSharedReadout modules ourselves?
Hey @othertea, thank you very much for these comments, they are really appreciated! Also sorry for the late reply, being sick with COVID over the holidays I have only got
Also, thank you very much for the detailed inline comments! I will take a look soon. |
Made a mistake and pushed from an older account. Anyway, I tried |
Added the mup refactor and the coordinate checking integration discussed previously on Discord. You can find a colab showcasing mup coordinate checking results on the model here. |
Given that MuP now passes coordinate checking in my tests, maybe we can review (cc @othertea? would love to have a second pair of eyes given that I really don't trust myself much) and consider merging, with follow up work going to a to-be-opened second PR for optuna. Let me know what you think :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for the updates @NZ99 ! I can confirm I can reproduce your coord check plots and they look good, and that if you do not apply _init_weights
, the plots look bad, which is great!
I added some more in-line comments down below. I'll take another look, but for now, the only thing I really notice is that we may be missing a attn_mult
parameter.
And to prepare for a merge, could you resolve the merge conflicts (e.g., by rebasing with main or by merging with main)? I think we should be able to merge this very soon 🙂
@@ -0,0 +1 @@ | |||
from models import * |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from models import * | |
from .models import * |
I think you need these to be relative imports? They didn't work as is for me.
Alternatively, instead of changing all of these to relative imports we can remove these lines and import them by specifying the full module paths in test_coord_check.py
@@ -0,0 +1 @@ | |||
from apt import * |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from apt import * | |
from .apt import * |
self.attn_dropout = nn.Dropout(config.attn_pdrop) | ||
self.resid_dropout = nn.Dropout(config.resid_pdrop) | ||
if self.use_mup: | ||
self.attn_dropout = nn.Identity() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we consider asserting that the dropout probabilities are set to 0 in this case (in configs)?
[], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device | ||
) | ||
if self.use_mup: | ||
attn_weights = attn_weights / torch.full( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we be multiplying by some attn_mult
here that we add as config option? (as in Mu-Scaling or mutransformers )
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch yes, thank you! Will update accordingly
prepend_bos=True, | ||
append_eos=True, | ||
eos_idx=2) | ||
# mup implementation does not currently support this |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to the dropout case, should we consider adding an assertion that we are not using mup with this in the configs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I think this is a good idea!
# note that this has to be run after mup.set_base_shape for it to work | ||
# see https://github.com/microsoft/mup#basic-usage | ||
# not sure if this is required here | ||
self.apply(self._init_weights) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So it seems to me like we shouldn't call this here? As in your coordinate check example, you will have to call it again anyway (and only if you're using mup?)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, I think this might have been the result of some earlier testing and of forgetting to remove. Indeed this shouldn't have an effect so no reason to keep. Thanks!
|
||
if __name__ == "__main__": | ||
delta_model = APTLMHeadModel(config=APTConfig(n_embd=200, n_layer=8, num_attention_heads=10, n_inner=200, use_mup=True)) | ||
delta_model.apply(delta_model._init_weights) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
delta_model.apply(delta_model._init_weights) |
I think only the actual model needs to have _init_weights
applied? I checked, but you should double check too!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had a look into the mup example for transformer and it seems they do not init_weights when mup is active.
see https://github.com/microsoft/mup/blob/main/examples/Transformer/main.py#L189 (line189 and Line 310 following) They only call a weights initialization at the end of the definition of their transformer model.
https://github.com/microsoft/mup/blob/main/examples/Transformer/model.py#L105C9-L105C28 (Line 105).
We also call initialization
# Initialize weights and apply final processing
self.post_init()
Only that our models function is based on the inherited post_init from transformers PreTrainedModel (Line 1243)
https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L1243C5-L1243C8
If I understood this right, you are right yes! So, we have to skipp the _init_weights when using MuP.
delta_model.apply(delta_model._init_weights) | ||
|
||
base_model = APTLMHeadModel(config=APTConfig(n_embd=1, n_layer=8, num_attention_heads=1, n_inner=1, use_mup=True)) | ||
base_model.apply(base_model._init_weights) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
base_model.apply(base_model._init_weights) |
Co-authored-by: othertea <[email protected]>
Co-authored-by: othertea <[email protected]>
This PR adds initial modeling-related changes to support MuP. It is incomplete, as the interaction between some modeling-related techniques is not yet clear to me (e.g. do we have to include any modifications as to make MuP compatible with AliBi?) but is shared for now as to facilitate discussion and collaboration, as agreed during community project meetings.
I'm uploading my code as is after leaving for NeurIPS -- I'll be checking whether anything important is missing or needs fixing in the code so far over the holidays. Apologies in case there are indeed issues with the code as pushed.
Important TODOs: