-
Notifications
You must be signed in to change notification settings - Fork 59
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Data Concatenation: how to avoid sample contamination during training #25
Comments
I noticed that Flash Attention is used when training mixtral-8x7b Lines 255 to 263 in b546cd0
Is this setting also used in training smaller models?
|
Thanks for your attention. It is a good idea to modify the attention mask to avoid different samples attend to each other. However, we have not implemented such a mechanism because we found that the negative impact of simple concatenation is relatively small in our experiments. We still believe that modifying the attention mask is necessary. If you come up with a solution in Flash Attention, welcome a pull request! |
Got it. Thanks for your quick reply. A workaround is using customized flash attention kernels based on Trion, such as flashattention2-custom-mask, but its correctntess is not tested. |
@linhaojia13 The function I've verified its correctness. Hope it helps! 😀 |
Thank you! |
Hello VITA team.
Thanks for this great work. After reading through the code and the preprint paper. I got a question of data concatenation.
The paper mentions that you use "Data Concanetation" technique to concatenate different samples into a sequence. In this case, to avoid different samples attend to each other, the causal mask should be modified.
For example, if a sequence with length 3 contains two samples with lengths 1 and 2.
The original causal mask is:
And the modified mask should be:
I noticed that this codebase supports three attention implementations: eager、SDPA and Flash Attention.
However, customized attention mask is not supported in Flash Attention (see issue).
My question is:
What attention implementation do you use for training? Do you consider the sample contamination problem when using data concatenation? Thanks!
The text was updated successfully, but these errors were encountered: