Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Data Concatenation: how to avoid sample contamination during training #25

Closed
xiabingquan opened this issue Sep 11, 2024 · 5 comments
Closed

Comments

@xiabingquan
Copy link

Hello VITA team.

Thanks for this great work. After reading through the code and the preprint paper. I got a question of data concatenation.

The paper mentions that you use "Data Concanetation" technique to concatenate different samples into a sequence. In this case, to avoid different samples attend to each other, the causal mask should be modified.

For example, if a sequence with length 3 contains two samples with lengths 1 and 2.
The original causal mask is:

[
  [1, 0, 0],
  [1, 1, 0],
  [1, 1, 1],
]

And the modified mask should be:

[
  [1, 0, 0],
  [0, 1, 0],
  [0, 1, 1],
]

I noticed that this codebase supports three attention implementations: eager、SDPA and Flash Attention.
However, customized attention mask is not supported in Flash Attention (see issue).

My question is:
What attention implementation do you use for training? Do you consider the sample contamination problem when using data concatenation? Thanks!

@xiabingquan xiabingquan changed the title How to avoid sample contamination during training Data Concatenation: how to avoid sample contamination during training Sep 11, 2024
@xiabingquan
Copy link
Author

I noticed that Flash Attention is used when training mixtral-8x7b

VITA/vita/train/train.py

Lines 255 to 263 in b546cd0

if model_args.model_type == "mixtral-8x7b":
torch_dtype = torch.float16 if training_args.fp16 else torch.bfloat16
model = VITAMixtralForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
torch_dtype=torch_dtype,
attn_implementation="flash_attention_2",
**bnb_model_from_pretrained_args,
)

Is this setting also used in training smaller models?

Currently, the model checkpoints are not immediately available to me. Any replies will help, thanks!

@linhaojia13
Copy link
Collaborator

Thanks for your attention. It is a good idea to modify the attention mask to avoid different samples attend to each other. However, we have not implemented such a mechanism because we found that the negative impact of simple concatenation is relatively small in our experiments.

We still believe that modifying the attention mask is necessary. If you come up with a solution in Flash Attention, welcome a pull request!

@xiabingquan
Copy link
Author

Thanks for your attention. It is a good idea to modify the attention mask to avoid different samples attend to each other. However, we have not implemented such a mechanism because we found that the negative impact of simple concatenation is relatively small in our experiments.

We still believe that modifying the attention mask is necessary. If you come up with a solution in Flash Attention, welcome a pull request!

Got it. Thanks for your quick reply. A workaround is using customized flash attention kernels based on Trion, such as flashattention2-custom-mask, but its correctntess is not tested.

@xiabingquan
Copy link
Author

@linhaojia13 The function flash_attn_varlen_func in flash attention could meet our needs (with some modifcations to cu_lens).

I've verified its correctness. Hope it helps! 😀

@linhaojia13
Copy link
Collaborator

Thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants