Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Explain About Packing Inputs Without Cross-Contamination Attention #265

Open
qibao77 opened this issue Sep 4, 2024 · 10 comments
Open

Explain About Packing Inputs Without Cross-Contamination Attention #265

qibao77 opened this issue Sep 4, 2024 · 10 comments

Comments

@qibao77
Copy link

qibao77 commented Sep 4, 2024

Thanks for your good job! Why this operation (" overwriting the function: _get_unpad_data with a monkey-patched function") can implement the feature of packing without cross-contamination attention? Can you explain more details or give some reference to me? Thank you very much!

@khai-meetkai
Copy link
Collaborator

@qibao77 in our implementation, we changed 2 things:

  • First we extend the format of attention_mask to represent the packing, marking the start and end of each packed input. Assuming that the max_input_length is 10 and we have 2 data points:
    • input_ids1 = [1, 2, 3, 0, 0, 0, 0, 0, 0, 0]; attention_mask = [1,1,1, 0, 0, 0, 0, 0, 0, 0]
    • input_ids2 = [4,5,6,7,8, 0, 0, 0, 0, 0, 0]; attention_mask = [1,1,1,1, 1, 0, 0, 0, 0, 0, 0]
      When we pack 2 data points into 1 data point:
    • input_ids = [1,2,3,4,5,6,7,8, 0, 0]; attention_mask=[1, 1, 1, 2, 2, 2, 2, 2, 0, 0]. Here the attention_mask is used to mark the boundary of individual data points, 1 for data point 1 and 2 for data point 2 and 0 for padding (the same as without packing)
      Here, assume that padding_token_id=0. Without packing, we have 2 data points:
  • With the extended attention_mask, the current code (of function: _get_unpad_data) doesn't work as it was implemented to only accept 0 and 1, so we overwrite function: _get_unpad_data to accept the extended attention_mask

@qibao77
Copy link
Author

qibao77 commented Sep 5, 2024

@qibao77 in our implementation, we changed 2 things:

  • First we extend the format of attention_mask to represent the packing, marking the start and end of each packed input. Assuming that the max_input_length is 10 and we have 2 data points:

    • input_ids1 = [1, 2, 3, 0, 0, 0, 0, 0, 0, 0]; attention_mask = [1,1,1, 0, 0, 0, 0, 0, 0, 0]
    • input_ids2 = [4,5,6,7,8, 0, 0, 0, 0, 0, 0]; attention_mask = [1,1,1,1, 1, 0, 0, 0, 0, 0, 0]
      When we pack 2 data points into 1 data point:
    • input_ids = [1,2,3,4,5,6,7,8, 0, 0]; attention_mask=[1, 1, 1, 2, 2, 2, 2, 2, 0, 0]. Here the attention_mask is used to mark the boundary of individual data points, 1 for data point 1 and 2 for data point 2 and 0 for padding (the same as without packing)
      Here, assume that padding_token_id=0. Without packing, we have 2 data points:
  • With the extended attention_mask, the current code (of function: _get_unpad_data) doesn't work as it was implemented to only accept 0 and 1, so we overwrite function: _get_unpad_data to accept the extended attention_mask

Thank you for your reply! I want to add this feature to my pretraining code, like llama3, but I found that there is no change in the loss compared to naive packing, is there any advise?

@khai-meetkai
Copy link
Collaborator

What do you mean by no change in the loss ?
you mean: loss(naive_packing(a, b)) == loss (packing_without_cross_contamination(a, b)) ?

@qibao77
Copy link
Author

qibao77 commented Sep 6, 2024

Yes,in my experiment, loss(naive_packing(a, b)) == loss (packing_without_cross_contamination(a, b)), and I have checked that the "_get_unpad_data" function was replaced correctly.

@khai-meetkai
Copy link
Collaborator

@qibao77 Can you share your experimental code showing loss(naive_packing(a, b)) == loss (packing_without_cross_contamination(a, b)) ?

@vgoklani
Copy link

vgoklani commented Sep 6, 2024

@qibao77 were you pre-training or fine-tuning?

curious, was the loss exactly matching step by step, or was that much later?

@qibao77
Copy link
Author

qibao77 commented Sep 11, 2024

@qibao77 Can you share your experimental code showing loss(naive_packing(a, b)) == loss (packing_without_cross_contamination(a, b)) ?

For loss (packing_without_cross_contamination(a, b)) , the code is shown as follows:
...

monkey_patch_packing_for_model(self.local_dir)
self.gpt = LlamaForCausalLM.from_pretrained(
                            self.local_dir, config=self.hf_config, trust_remote_code=True, revision='main', offload_state_dict=True,attn_implementation="flash_attention_2"
                        )

...

attention_mask = generate_attention_mask(input_ids,special_token_end=self.tokenizer.eos_token_id,pad_token_id=self.tokenizer.pad_token_id)
model_out = self.gpt(
            input_ids=input_ids, attention_mask=attention_mask, labels=labels
        )

For the definetion of generate_attention_mask:

def generate_attention_mask(input_ids, special_token_end=3, pad_token_id=0):
    batch_size, seq_len = input_ids.shape
    mask = torch.zeros_like(input_ids)
    for i in range(batch_size):
        current_label = 1
        for j in range(seq_len):
            if input_ids[i, j] == special_token_end:
                mask[i, j] = current_label
                current_label += 1
            elif input_ids[i, j] == pad_token_id: break 
            else:
                mask[i, j] = current_label
    
    return mask

For loss(naive_packing(a, b)):
the function generate_attention_mask is not used, and the value of attention_mask is 1 ,excepted padding part.

@qibao77
Copy link
Author

qibao77 commented Sep 11, 2024

@qibao77 were you pre-training or fine-tuning?

curious, was the loss exactly matching step by step, or was that much later?

pre-training, matching step by step

@vgoklani
Copy link

@qibao77 it's unclear how you could be matching step by step if the attention masks are different.

@khai-meetkai
Copy link
Collaborator

khai-meetkai commented Sep 13, 2024

@qibao77 actually you can run this script to see that the Naive packing will give a different loss compared with Packing without cross-contamination. In this script, assume that there are 2 data point:
a = [1,2,3]
b = [4, 5, 6, 7, 8]
I compare the loss of:

  1. loss(a) + loss(b)
  2. loss(naive_pack(a, b))
  3. loss(packing_without_cross_contamination(a, b))

The result is:

  1. loss(a) + loss(b) = 44.141
  2. loss(naive_pack(a, b)) = 37.55
  3. loss(packing_without_cross_contamination(a, b)) = 44.17

You see that Naive packing is problematic, right ?

from transformers import AutoModelForCausalLM, AutoTokenizer
import monkey_patch_packing 
import torch

def main():
    # pad_token = 0
    # max_length = 10
    pretrained_path = "meta-llama/Meta-Llama-3.1-8B"
    input_ids1 = [1, 2, 3] + [0 for _ in range(7)]
    labels1 = [1, 2, 3] + [-100 for _ in range(7)]
    attention1 = [1, 1, 1] + [0 for _ in range(7)]
    
    input_ids2 = [4, 5, 6, 7, 8] + [0 for _ in range(5)]
    labels2 = [4, 5, 6, 7, 8] + [-100 for _ in range(5)]
    attention2 = [1, 1, 1, 1, 1] + [0 for _ in range(5)]
    # packing
    packed_inputs = [1,2,3,4,5,6,7,8, 0, 0]
    # note here that 4 is the first token so will not be included for computing loss
    packed_labels = [1,2,3,-100,5,6,7,8, -100, -100]
    
    naive_attention = [1, 1, 1, 1, 1, 1, 1, 1, 0, 0]
    correct_packed_attention = [1, 1, 1, 2, 2, 2, 2, 2, 0, 0]
    
    assert len(input_ids1) == len(input_ids2) == len(attention1) == len(attention2) == len(naive_attention) == len(correct_packed_attention) == len(packed_inputs)
    
    # Load model without monkey-patching
    model = AutoModelForCausalLM.from_pretrained(
        pretrained_path,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        attn_implementation="flash_attention_2",
        trust_remote_code=True
    )
    # loss without using packing
    loss1, num_tok1 = compute_loss(model, input_ids1, attention1, labels1)
    loss2, num_tok2 = compute_loss(model, input_ids2, attention2, labels2)
    total_original_loss = loss1 + loss2 
    total_original_num_tokens = num_tok1 + num_tok2
    print(f"total original loss: {total_original_loss}; total_original_num_tokens={total_original_num_tokens}")    
    # loss with native packing
    naive_loss, naive_num_tok = compute_loss(model, packed_inputs, naive_attention, packed_labels)
    print(f"naive loss: {naive_loss}; num_token: {naive_num_tok}")
    
    # loss with packing without cross-contamination
    # need to reload using monkey-patched code
    monkey_patch_packing.monkey_patch_packing_for_model(pretrained_path)
    model = AutoModelForCausalLM.from_pretrained(
        pretrained_path,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        attn_implementation="flash_attention_2",
        trust_remote_code=True
    )
    
    correct_loss, correcte_num_tok = compute_loss(model, packed_inputs, correct_packed_attention, packed_labels)
    print(f"correct_loss: {correct_loss}; num_token={correcte_num_tok}")

def compute_loss(model, input_ids, attention, labels):    
    inputs = {
        "input_ids": torch.tensor([input_ids]).to(model.device),
        "labels": torch.tensor([labels]).to(model.device),
        "attention_mask": torch.tensor([attention]).to(model.device)
    }
    total_num_loss_tokens = 0
    total_loss = 0
    with torch.no_grad():
        avg_loss = model.forward(**inputs).loss.item()
        # compute number of tokens used for computing loss
        labels = inputs["labels"]
        shift_labels = labels[..., 1:].contiguous()
        shift_labels = shift_labels.view(-1)
        ignore_count = (shift_labels == -100).sum()
        num_tokens = shift_labels.size(0) - ignore_count

        total_num_loss_tokens += num_tokens.item()
        total_loss += avg_loss * num_tokens.item()
    return total_loss, total_num_loss_tokens

if __name__ == "__main__":
    main()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants