-
Notifications
You must be signed in to change notification settings - Fork 110
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Explain About Packing Inputs Without Cross-Contamination Attention #265
Comments
@qibao77 in our implementation, we changed 2 things:
|
Thank you for your reply! I want to add this feature to my pretraining code, like llama3, but I found that there is no change in the loss compared to naive packing, is there any advise? |
What do you mean by no change in the loss ? |
Yes,in my experiment, loss(naive_packing(a, b)) == loss (packing_without_cross_contamination(a, b)), and I have checked that the "_get_unpad_data" function was replaced correctly. |
@qibao77 Can you share your experimental code showing loss(naive_packing(a, b)) == loss (packing_without_cross_contamination(a, b)) ? |
@qibao77 were you pre-training or fine-tuning? curious, was the loss exactly matching step by step, or was that much later? |
For loss (packing_without_cross_contamination(a, b)) , the code is shown as follows:
...
For the definetion of generate_attention_mask:
For loss(naive_packing(a, b)): |
pre-training, matching step by step |
@qibao77 it's unclear how you could be matching step by step if the attention masks are different. |
@qibao77 actually you can run this script to see that the Naive packing will give a different loss compared with Packing without cross-contamination. In this script, assume that there are 2 data point:
The result is:
You see that Naive packing is problematic, right ?
|
Thanks for your good job! Why this operation (" overwriting the function: _get_unpad_data with a monkey-patched function") can implement the feature of packing without cross-contamination attention? Can you explain more details or give some reference to me? Thank you very much!
The text was updated successfully, but these errors were encountered: