Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Missing modules in prompt-based PEFT when re-loading model #2043

Closed
3 of 4 tasks
martin-wey opened this issue Aug 29, 2024 · 22 comments
Closed
3 of 4 tasks

Missing modules in prompt-based PEFT when re-loading model #2043

martin-wey opened this issue Aug 29, 2024 · 22 comments

Comments

@martin-wey
Copy link

System Info

python 3.10.10, transformers 4.44.2, peft 0.12.0

Who can help?

@BenjaminBossan @sayak

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder
  • My own task or dataset (give details below)

Reproduction

The following is a simplified script to reproduce the bug. I have experienced the same issue using transformers.Trainer. The fine-tuning using PEFT and p-tuning/prompt tuning works perfectly. However, when reloading the model from a saved PEFT checkpoint for generation, some modules are missing. In turn, the model does not generate expected content.

Basic script to save the adapter:

from transformers import AutoModelForCausalLM
from peft import PromptEncoderConfig, get_peft_model, AutoPeftModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Meta-Llama-3.1-8B-Instruct",
    torch_dtype=torch.bfloat16,
    trust_remote_code=True,
)
peft_config = PromptEncoderConfig(task_type="CAUSAL_LM", num_virtual_tokens=20, encoder_hidden_size=128)
model = get_peft_model(model, peft_config)
model.save_pretrained("meta-llama-ptuning")
model.print_trainable_parameters()
print(model)

Output:

trainable params: 1,151,232 || all params: 8,031,412,480 || trainable%: 0.0143
PeftModelForCausalLM(
  (base_model): LlamaForCausalLM(
    (model): LlamaModel(
      (embed_tokens): Embedding(128256, 4096)
      (layers): ModuleList(
        (0-31): 32 x LlamaDecoderLayer(
          (self_attn): LlamaSdpaAttention(
            (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
            (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
            (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
            (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
            (rotary_emb): LlamaRotaryEmbedding()
          )
          (mlp): LlamaMLP(
            (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
            (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
            (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
            (act_fn): SiLU()
          )
          (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
          (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        )
      )
      (norm): LlamaRMSNorm((4096,), eps=1e-05)
      (rotary_emb): LlamaRotaryEmbedding()
    )
    (lm_head): Linear(in_features=4096, out_features=128256, bias=False)
  )
  (prompt_encoder): ModuleDict(
    (default): PromptEncoder(
      (embedding): Embedding(20, 4096)
      (mlp_head): Sequential(
        (0): Linear(in_features=4096, out_features=128, bias=True)
        (1): ReLU()
        (2): Linear(in_features=128, out_features=128, bias=True)
        (3): ReLU()
        (4): Linear(in_features=128, out_features=4096, bias=True)
      )
    )
  )
  (word_embeddings): Embedding(128256, 4096)
)

Load the PEFT model from checkpoint (method 1):

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Meta-Llama-3.1-8B-Instruct",
    torch_dtype=torch.bfloat16,
    trust_remote_code=True,
)
peft_config = PromptEncoderConfig.from_pretrained("meta-llama-ptuning")
model = get_peft_model(model, peft_config)
print(model)

Load the PEFT model from checkpoint (method 2):

model = AutoPeftModelForCausalLM.from_pretrained("meta-llama-ptuning")
print(model)

Output: The MLP part of the prompt encoder is missing

PeftModelForCausalLM(
  (base_model): LlamaForCausalLM(
    (model): LlamaModel(
      (embed_tokens): Embedding(128256, 4096)
      (layers): ModuleList(
        (0-31): 32 x LlamaDecoderLayer(
          (self_attn): LlamaSdpaAttention(
            (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
            (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
            (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
            (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
            (rotary_emb): LlamaRotaryEmbedding()
          )
          (mlp): LlamaMLP(
            (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
            (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
            (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
            (act_fn): SiLU()
          )
          (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
          (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        )
      )
      (norm): LlamaRMSNorm((4096,), eps=1e-05)
      (rotary_emb): LlamaRotaryEmbedding()
    )
    (lm_head): Linear(in_features=4096, out_features=128256, bias=False)
  )
  (prompt_encoder): ModuleDict(
    (default): PromptEncoder(
      (embedding): Embedding(20, 4096)
    )
  )
  (word_embeddings): Embedding(128256, 4096)
)

Perhaps something's wrong here: https://github.com/huggingface/peft/blob/v0.12.0/src/peft/tuners/p_tuning/model.py#L82

Expected behavior

I believe the entire prompt encoder should be reloaded from the PEFT checkpoint.
When I reload the model using the aforementioned methods, it generates content that is close to the base model, meaning the prompt encoder is not loaded properly.

@martin-wey martin-wey changed the title Missing modules in prompt-based PEFT when re-loading Missing modules in prompt-based PEFT when re-loading model Aug 29, 2024
@BenjaminBossan
Copy link
Member

BenjaminBossan commented Aug 29, 2024

Thanks for reporting this issue.

First of all, please don't load a model using get_peft_model, this is for creating new models. Always use from_pretrained to load models. Second, please ensure that the same dtype is used when loading the model. Initially, you use bfloat16 but when you called model = AutoPeftModelForCausalLM.from_pretrained("meta-llama-ptuning"), you forgot to specify the dtype.

That said, even with these adjustments, as well as ensuring that the model is in eval mode, I could reproduce the error, i.e. there is a small discrepancy after loading (in my tests, mean abs diff of logits was ~0.02, depending on model). However, this discrepancy is not due to the missing module being loaded. When you check these lines:

elif config.is_prompt_learning:
to_return = {}
if config.peft_type == PeftType.MULTITASK_PROMPT_TUNING:
to_return["prefix_task_cols"] = model.prompt_encoder[adapter_name].prefix_task_cols
to_return["prefix_task_rows"] = model.prompt_encoder[adapter_name].prefix_task_rows
prompt_embeddings = model.prompt_encoder[adapter_name].embedding.weight
else:
if config.inference_mode:
prompt_embeddings = model.prompt_encoder[adapter_name].embedding.weight
else:
prompt_embeddings = model.get_prompt_embedding_to_save(adapter_name)
to_return["prompt_embeddings"] = prompt_embeddings

You can see that the prompt embedding is saved as part of the state_dict. This is an optimization, because for pure inference, since parameters are fixed, this output does not change anyway, so there is no need to load the mlp_head.

Still, I'm not sure where the difference comes from and will investigate further. Just wanted to share some insights I had so far.

PS: Interestingly, for facebook/opt-125m, I did not find any discrepancy. However, when I checked meta-llama/Meta-Llama-3-8B, bigscience/bloomz-560m, Qwen/Qwen2-1.5B, and microsoft/Phi-3.5-mini-instruct, they all had a small difference.

@martin-wey
Copy link
Author

martin-wey commented Aug 29, 2024

First of all, please don't load a model using get_peft_model

I may have pasted a wrong snippet. I'm having the same issue when first loading the base model then use PeftModel to load the model+adapter.

You can see that the prompt embedding is saved as part of the state_dict. This is an optimization, because for pure inference, since parameters are fixed, this output does not change anyway, so there is no need to load the mlp_head.

Got it, thanks!

Second, please ensure that the same dtype is used when loading the model. Initially, you use bfloat16

Even so, the PEFT-tuned model's responses to input prompts are very similar to the base model. Besides, I'm also having troubles with Phi-3.5-mini-128k-instruct and CodeQwen1.5-7B-Chat. However, LoRA-based PEFTs (including QLoRA and DoRA) work just fine.
The entire fine-tuning phase with prompt-based tunings seems normal, i.e., good validation loss almost matching LoRA-tuned models.

Thanks for your quick reply. I will also keep trying with other LLMs.

@martin-wey
Copy link
Author

martin-wey commented Aug 30, 2024

@BenjaminBossan here's some more details about the forward pass of the PEFT and base models. I hope it helps :)

I checked whether the input goes through all the modules of the network using hooks.

# ...
model = AutoPeftModelForCausalLM.from_pretrained(
    "../runs/codellama/CodeLlama-7b-Instruct-hf_conala_p-tuning_3e-3/checkpoint-198/", 
    torch_dtype="bfloat16", 
    device_map="auto"
)

def forward_hook(module, input, output):
    print(f"Module: {module.__class__.__name__}")
    print(f"Input: {input}")
    print(f"Output: {output}")
    print("-" * 50)

hook_word_embeddings = model.word_embeddings.register_forward_hook(forward_hook)
hook_prompt_embeddings = model.prompt_encoder['default'].embedding.register_forward_hook(forward_hook)

for sample in test_set:
    tokenized_sample = tokenizer.apply_chat_template(
        sample["messages"],
        return_tensors="pt",
    ).to(model.device)

    output = model(tokenized_sample)

Output:

Module: Embedding
Input: (tensor([[    1,   518, 25580, 29962,  3532, 14816, 29903,  6778,    13,  3492,
           526,   263,  8444, 20255, 29889,    13, 29966,   829, 14816, 29903,
          6778,    13,    13,  4563,   680,   278,  1819,   411,  1021,  6611,
           310,  1023,  8600,   421, 29881, 29896, 29952,   322,   421, 29881,
         29906, 29952,   518, 29914, 25580, 29962]], device='cuda:0'),)
Output: tensor([[[ 0.0069,  0.0031, -0.0013,  ...,  0.0003, -0.0031, -0.0026],
         [ 0.0170, -0.0242,  0.0211,  ..., -0.0048, -0.0002,  0.0183],
         [-0.0312, -0.0112, -0.0510,  ...,  0.0070,  0.0281, -0.0294],
         ...,
         [ 0.0142, -0.0182, -0.0065,  ..., -0.0210,  0.0099, -0.0132],
         [-0.0312, -0.0112, -0.0510,  ...,  0.0070,  0.0281, -0.0294],
         [-0.0237, -0.0297,  0.0014,  ...,  0.0427, -0.0008, -0.0029]]],
       device='cuda:0', dtype=torch.bfloat16)

The input never goes through model.prompt_encoder.default.embedding. I am not sure that's the expected behaviour.
Therefore, I tried the same thing with the base model by registering a hook for embed_tokens:

# ...
model = AutoModelForCausalLM.from_pretrained(
    "codellama/CodeLlama-7b-Instruct-hf", 
    torch_dtype="bfloat16", 
    device_map="auto"
)
hook_embeddings = model.model.embed_tokens.register_forward_hook(forward_hook)
# ...

Output (identical):

Module: Embedding
Input: (tensor([[    1,   518, 25580, 29962,  3532, 14816, 29903,  6778,    13,  3492,
           526,   263,  8444, 20255, 29889,    13, 29966,   829, 14816, 29903,
          6778,    13,    13,  4563,   680,   278,  1819,   411,  1021,  6611,
           310,  1023,  8600,   421, 29881, 29896, 29952,   322,   421, 29881,
         29906, 29952,   518, 29914, 25580, 29962]], device='cuda:0'),)
Output: tensor([[[ 0.0069,  0.0031, -0.0013,  ...,  0.0003, -0.0031, -0.0026],
         [ 0.0170, -0.0242,  0.0211,  ..., -0.0048, -0.0002,  0.0183],
         [-0.0312, -0.0112, -0.0510,  ...,  0.0070,  0.0281, -0.0294],
         ...,
         [ 0.0142, -0.0182, -0.0065,  ..., -0.0210,  0.0099, -0.0132],
         [-0.0312, -0.0112, -0.0510,  ...,  0.0070,  0.0281, -0.0294],
         [-0.0237, -0.0297,  0.0014,  ...,  0.0427, -0.0008, -0.0029]]],
       device='cuda:0', dtype=torch.bfloat16, grad_fn=<EmbeddingBackward0>)

Basically it seems like the prompt_encoder is bypassed during the forward pass. That's why I'm getting the same output when generating code using both PEFT and base models.

@BenjaminBossan
Copy link
Member

Thanks for digging deeper into this.

The input never goes through model.prompt_encoder.default.embedding. I am not sure that's the expected behaviour.

Yes, this is expected, as it relates to the optimization I mentioned above. The prompt embeddings to be prefixed are precomputed, therefore, the embedding's forward is never called.

That's why I'm getting the same output when generating code using both PEFT and base models.

I cannot observe this, for me the outputs of the loaded PEFT model are very close to the outputs of the original PEFT model, but different enough that generations start different at some point.

Investigating this further sent me down a rabbit hole and I think I have figured out the issue. To cut it short, the issue appears to be that the individual outputs of sending the same input repeated 10 times through the MLP does not equal the output of sending it through the MLP once. Let me illustrate:

import torch
torch.manual_seed(0);
device = 0
input_size = 128
hidden_size = 32
output_size = 64
layers = [
    torch.nn.Linear(input_size, hidden_size),
    torch.nn.ReLU(),
    torch.nn.Linear(hidden_size, hidden_size),
    torch.nn.ReLU(),
    torch.nn.Linear(hidden_size, output_size),
]
mlp_head = torch.nn.Sequential(*layers)
mlp_head.to(device).eval();
x = torch.randn(20, input_size).to(device)

# x repeats 10x along the batch dimension
x = x.repeat(10, 1, 1)
# output with all 10 identical samples
out0 = mlp_head(x)
# output with only 1 of the samples
out1 = mlp_head(x[:1])

for i in range(10):
    # this should be 0 but it is 2.0311928139449265e-08
    print((out0[i:i+1] - out1).abs().mean().float().item())

When the involved sizes are small enough, the difference is actually 0, which may explain why opt-125m showed no problems but the bigger LLMs did.

So how does this translate to p-tuning? Let's check these lines which are executed during training time:

peft/src/peft/peft_model.py

Lines 658 to 663 in 679bcd8

prompt_tokens = (
self.prompt_tokens[self.active_adapter]
.unsqueeze(0)
.expand(batch_size, -1)
.to(prompt_encoder.embedding.weight.device)
)

prompts = prompt_encoder(prompt_tokens)

You see that we repeat the same input batch_size times and then send it through the prompt encoder.

Now, when we load the model, we go through a slightly different code path:

prompts = prompt_encoder.embedding.weight.repeat(batch_size, 1, 1)

Here, we just take the output of a single sample and repeat it batch_size times. In theory, that should be the same thing but as I showed above, there are slight differences.

To address this, I created the following patch:

@@ -692,7 +692,13 @@ class PeftModel(PushToHubMixin, torch.nn.Module):
                 if peft_config.inference_mode:
                     prompts = prompt_encoder.embedding.weight.repeat(batch_size, 1, 1)
                 else:
-                    prompts = prompt_encoder(prompt_tokens)
+                    prompt_tokens = (
+                        self.prompt_tokens[self.active_adapter]
+                        .unsqueeze(0)
+                        .expand(1, -1)
+                        .to(prompt_encoder.embedding.weight.device)
+                    )
+                    prompts = prompt_encoder(prompt_tokens).repeat(batch_size, 1, 1)
             return prompts
 
     def get_nb_trainable_parameters(self) -> tuple[int, int]:

With this patch, the error vanishes for me. For completeness, here is the script I used to check it:

import torch
from transformers import AutoModelForCausalLM
from peft import PromptEncoderConfig, get_peft_model, AutoPeftModelForCausalLM, PeftModel

inputs = torch.arange(10).view(-1, 1).to(0)

model_id = "meta-llama/Meta-Llama-3-8B"
#model_id = "bigscience/bloomz-560m"
#model_id = "Qwen/Qwen2-1.5B"
#model_id = "microsoft/Phi-3.5-mini-instruct"
#model_id = "facebook/opt-125m"

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map=0,
)
peft_config = PromptEncoderConfig(task_type="CAUSAL_LM", num_virtual_tokens=20, encoder_hidden_size=128)
model = get_peft_model(model, peft_config)
model.eval();

torch.manual_seed(0)
with torch.inference_mode():
    output_peft = model(inputs).logits
    gen_peft = model.generate(inputs, min_new_tokens=10, max_new_tokens=10)

model.save_pretrained("/tmp/peft/2043")
del model
torch.cuda.empty_cache()

model = AutoPeftModelForCausalLM.from_pretrained("/tmp/peft/2043", device_map=0, torch_dtype=torch.bfloat16)
# using `model = AutoModelForCausalLM.from_pretrained(...); model = PeftModel.from_pretrained(...)` also works

torch.manual_seed(0)
with torch.inference_mode():
    output_loaded = model(inputs).logits
    gen_loaded = model.generate(inputs, min_new_tokens=10, max_new_tokens=10)

torch.testing.assert_close(output_loaded, output_peft)
torch.testing.assert_close(gen_loaded, gen_peft)

If you can confirm that this patch solves your original issue, I will create a PR to fix this in PEFT.

@martin-wey
Copy link
Author

martin-wey commented Aug 31, 2024

Thanks a lot for the detailed explanation.

From my end, I am still experiencing both models generating exactly the same content for a given prompt, even with this fix. It makes no sense to me as the training phase using p-tuning or prompt tuning works properly, with validation loss decreasing.

I use the following dataset: https://huggingface.co/datasets/neulab/docprompting-conala. The model learns to generate 1-2 lines of code for a given instruction.

I implemented a TrainerCallback after each epoch to generate code for a given test sample to rule out potential issues when reloading the model from a checkpoint.

class GenerateAfterEpochCallback(TrainerCallback):
    def __init__(self, test_example, tokenizer):
        self.test_example = test_example
        self.tokenizer = tokenizer

    def on_epoch_end(self, args, state, control, **kwargs):
        model = kwargs['model']
        tokenizer = self.tokenizer
        
        model.eval()
        
        inputs = tokenizer.apply_chat_template(
            self.test_example["messages"][:-1] # remove assistant's solution, 
            return_dict=True,
            return_tensors="pt",
            add_generation_prompt=True
        ).to(model.device)
    
        with torch.no_grad():
            output_ids = model.generate(
                input_ids=inputs["input_ids"],
                attention_mask=inputs["attention_mask"],
                max_length=128,
            )
        
        generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
        print(f"\nGenerated text after epoch {state.epoch}:\n{generated_text}\n")

        model.train()
        
        return control

Regardless of the PEFT hyperparameters configuration, the model generates the same content all the time, which is also identical to the base model. I tried a few things that all result in identical outputs (see below):

  • higher / lower learning rates -> 3e-4 vs 3e-3 vs 3e-2
  • num_virtual_tokens=20 vs num_virtual_tokens=200
  • encoder_hidden_size=128 vs encoder_hidden_size=512
  • padding left vs padding right
  • p-tuning vs prompt tuning

Output:
base_model
The output makes sense for the prompt, but is not what is expected from a fine-tuned model. It's like p-tuning and prompt tuning have zero impact.

Interestingly, when generating without attention_mask, the generated content and the input prompt change completely:
without_attn_mask

This is weird. I used prompt-based tunings with PEFT in may/june 2023 for another project (with a similar setup), and I never had such issue.

@BenjaminBossan
Copy link
Member

BenjaminBossan commented Sep 2, 2024

Thanks for the additional information. I was afraid that something more could be going on as you wrote that the results are very different.

For me to further assist in this, it would be very helpful if you could share a bit more information. Ideally, you could upload the adapter to HF so that I can try it myself. Please also provide the code that you use to check if the outputs are as expected or not. If you cannot share the checkpoint, would it be possible to share the training code instead?

Another thing that could be going on is that there have been some changes to transformers recently that interfere with prompt tuning methods in PEFT. Therefore, it would be helpful if you could test if your issue resolves when using an older transformers version. Maybe you can figure out which one you used back then? If you downgrade transformers, you may also have to downgrade PEFT to a version that corresponds to that time. If you can determine that your checkpoint works with version X but not Y, this would greatly increase the chances of figuring out what's going wrong.

PS: Just checked, something like transformers v4.29 and v0.3.0 would correspond to the time you mentioned.

BenjaminBossan added a commit to BenjaminBossan/peft that referenced this issue Sep 2, 2024
There is a small numerical discrepancy between the outputs of a p-tuning
model before and after loading. Even though it is small, it can still
affect generations, so this PR eliminates it.

As an example, without the fix, this is the difference in logits for
opt-125m:

>       torch.testing.assert_close(output_loaded, output_peft)
E       AssertionError: Tensor-likes are not close!
E
E       Mismatched elements: 30 / 10557120 (0.0%)
E       Greatest absolute difference: 1.1086463928222656e-05 at index (0, 9, 9314) (up to 1e-05 allowed)
E       Greatest relative difference: 0.00021288332936819643 at index (0, 9, 9314) (up to 1.3e-06 allowed)

Details about how this comes about are explained here:

huggingface#2043 (comment)

The gist of it is that if we take a single sample, repeat it X times,
and then forward it through a model (which is the training path in
p-tuning), we would expect the same output as if we forwarded this
sample only once and repeated the output X times (the inference path for
p-tuning). However, for sufficiently large models, the two approaches
can have tiny differences.

With the fixed approach, there is no difference between training and
inference code paths when it comes to this. The new code should also be
slightly more compute efficient, but in practice will not make a
noticeable difference.
@martin-wey
Copy link
Author

@BenjaminBossan thanks again. Please ignore my last reply if you have seen it. I am still working on it to make sure I provide you 100% accurate information. I'll reply asap.

BenjaminBossan added a commit that referenced this issue Sep 3, 2024
…#2047)

There is a small numerical discrepancy between the outputs of a p-tuning
model before and after loading. Even though it is small, it can still
affect generations, so this PR eliminates it.

As an example, without the fix, this is the difference in logits for
opt-125m:

>       torch.testing.assert_close(output_loaded, output_peft)
E       AssertionError: Tensor-likes are not close!
E
E       Mismatched elements: 30 / 10557120 (0.0%)
E       Greatest absolute difference: 1.1086463928222656e-05 at index (0, 9, 9314) (up to 1e-05 allowed)
E       Greatest relative difference: 0.00021288332936819643 at index (0, 9, 9314) (up to 1.3e-06 allowed)

Details about how this comes about are explained here:

#2043 (comment)

The gist of it is that if we take a single sample, repeat it X times,
and then forward it through a model (which is the training path in
p-tuning), we would expect the same output as if we forwarded this
sample only once and repeated the output X times (the inference path for
p-tuning). However, for sufficiently large models, the two approaches
can have tiny differences.

With the fixed approach, there is no difference between training and
inference code paths when it comes to this. The new code should also be
slightly more compute efficient, but in practice will not make a
noticeable difference.
@BenjaminBossan
Copy link
Member

@martin-wey Do you have any updates?

@martin-wey
Copy link
Author

martin-wey commented Sep 13, 2024

@BenjaminBossan Yes, as you suggested I compared fine-tuning using prompt tuning with 1. Transformers v4.29 / Peft v0.3.0 and 2. Transformers v4.44.2 / Peft v0.12.0. I used deepseek-coder-6.7b-instruct.

Summary of my findings:

  • I found disturbing differences in terms of training and validation losses between both runs.
  • At inference, the model fine-tuned using the older versions of the libraries generates content as expected.
  • At inference, the model fine-tuned using the new versions of the libraries generates content that is identical to the model in zero-shot.

Here's the code compatible for both versions:

import argparse

import torch

from datasets import load_dataset
from peft import get_peft_model, PromptTuningConfig, TaskType, PromptTuningInit

from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    EarlyStoppingCallback,
    DataCollatorForLanguageModeling,
    TrainerCallback,
    set_seed
)

from collator import CustomDataCollatorForCompletionOnlyLM


def load_model_and_tokenizer(args):
    model = AutoModelForCausalLM.from_pretrained(
        args.model_name_or_path,
        torch_dtype=torch.bfloat16,
        trust_remote_code=True,
        device_map="auto"
    )
    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True)
    peft_config = PromptTuningConfig(
        task_type=TaskType.CAUSAL_LM,
        prompt_tuning_init=PromptTuningInit.RANDOM,
        num_virtual_tokens=20, 
        tokenizer_name_or_path=args.model_name_or_path,
    )

    model = get_peft_model(model, peft_config)
    model.print_trainable_parameters()

    if getattr(tokenizer, "pad_token_id") is None:
        tokenizer.pad_token_id = tokenizer.eos_token_id
        model.config.pad_token_id = model.config.eos_token_id

    return model, tokenizer


def main(args):
    dataset = load_dataset("neulab/docprompting-conala")
    model, tokenizer = load_model_and_tokenizer(args)

    def tokenize(example):
        prompt = f"{tokenizer.bos_token}\n"
        prompt += f"### Instruction:\n{example['nl']}\n"
        prompt += f"### Response:\n{example['cmd']}\n<|EOT|>"
        
        model_inputs = tokenizer(prompt, truncation=True, max_length=128, padding="max_length")

        return model_inputs

    tokenized_dataset = dataset.map(
        tokenize,
        batched=False,
        remove_columns=[cn for cn in dataset["train"].column_names if cn not in ["input_ids", "attention_mask"]],
    )

    training_args = TrainingArguments(
        output_dir=args.run_dir,
        evaluation_strategy="epoch",
        save_strategy="epoch",
        num_train_epochs=args.num_epochs,
        per_device_train_batch_size=args.batch_size,
        per_device_eval_batch_size=args.batch_size,
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        learning_rate=args.learning_rate,
        lr_scheduler_type=args.lr_scheduler_type,
        logging_strategy="steps",
        bf16=True,
        logging_steps=1,
        save_total_limit=10,
        load_best_model_at_end=True,
        report_to="wandb" if args.use_wandb else "none"
    )
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_dataset["train"],
        eval_dataset=tokenized_dataset["validation"],
        tokenizer=tokenizer,
        data_collator=CustomDataCollatorForCompletionOnlyLM("### Response", tokenizer=tokenizer),
        callbacks=[EarlyStoppingCallback(early_stopping_patience=args.patience)]
    )

    trainer.train()
    trainer.model.save_pretrained(f"{args.run_dir}/best_model_checkpoint")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_name_or_path", default="deepseek-ai/deepseek-coder-6.7b-instruct", type=str)
    parser.add_argument("--output_dir", default=".", type=str)

    parser.add_argument("--num_epochs", type=int, default=3)
    parser.add_argument("--batch_size", type=int, default=4)
    parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
    parser.add_argument("--patience", type=int, default=2)

    parser.add_argument("--learning_rate", type=float,  default=3e-3)
    parser.add_argument("--lr_scheduler_type", type=str, default="linear")

    parser.add_argument("--use_wandb", action="store_true")
    parser.add_argument("--seed", type=int, default=42)

    args = parser.parse_args()
    set_seed(args.seed)

    args.model_name = args.model_name_or_path.split('/')[-1]
    args.run_dir = f"{args.output_dir}/{args.model_name}_conala_prompt-tuning_new/"
    main(args)

Data collator (extension of the trl completion only data collator, that keeps the EOS token in the labels):

class CustomDataCollatorForCompletionOnlyLM(DataCollatorForCompletionOnlyLM):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
        batch = super().torch_call(examples)

        # ensure the last tokens is taken into account for loss computation
        # otherwise the model may never stop generating at inference
        batch["labels"][:, -1] = batch["input_ids"][:, -1]

        return batch

Learning curves are drastically different:
image

Eval curves:
image

@BenjaminBossan
Copy link
Member

Thanks for the reproducer. I made a few small changes, most notably using deepseek-ai/deepseek-coder-1.3b-instruct" so that I can run it just like that on my machine. What I found is that indeed, with the old versions, training loss looks normal but with the new versions, the loss is flat (albeit much lower). Isolating the issue, it looks like the transformers version is responsible:

image

(I added a small offset for legibility)

So next I ran git bisect and I could identify the exact transformers PR: huggingface/transformers#24653.

Next I wanted to know if the issue is related to PEFT or not. Therefore, I changed the code to run full fine-tuning, no PEFT at all. What I found is that starting from this commit (34d94094279d2c903d9d8a51a65edb265f22c849), the loss looks wrong:

{'loss': 1.1484, 'learning_rate': 0.00029699999999999996, 'epoch': 0.0}                                                                                                                                                                                                             
{'loss': 2.4375, 'learning_rate': 0.000294, 'epoch': 0.01}                                                                                                                                                                                                                          
{'loss': 1.5, 'learning_rate': 0.00029099999999999997, 'epoch': 0.01}                                                                                                                                                                                                               
{'loss': 1.5312, 'learning_rate': 0.00028799999999999995, 'epoch': 0.01}                                                                                                                                                                                                            
{'loss': 3.5625, 'learning_rate': 0.000285, 'epoch': 0.02}                                                                                                                                                                                                                          
{'loss': 1.8906, 'learning_rate': 0.00028199999999999997, 'epoch': 0.02}                                                                                                                                                                                                            
{'loss': 1.4141, 'learning_rate': 0.000279, 'epoch': 0.03}                                                                                                                                                                                                                          
{'loss': 1.6094, 'learning_rate': 0.000276, 'epoch': 0.03}                                                                                                                                                                                                                          
{'loss': 0.9297, 'learning_rate': 0.00027299999999999997, 'epoch': 0.03}                                                                                                                                                                                                            
{'loss': 1.125, 'learning_rate': 0.00027, 'epoch': 0.04}                                                                                                                                                                                                                            
{'loss': 1.5859, 'learning_rate': 0.000267, 'epoch': 0.04}                                                                                                                                                                                                                          
{'loss': 1.3828, 'learning_rate': 0.00026399999999999997, 'epoch': 0.04}                                                                                                                                                                                                            
{'loss': 1.4141, 'learning_rate': 0.000261, 'epoch': 0.05}                                                                                                                                                                                                                          
{'loss': 1.2812, 'learning_rate': 0.000258, 'epoch': 0.05}                                                                                                                                                                                                                          
{'loss': 1.8438, 'learning_rate': 0.00025499999999999996, 'epoch': 0.06}                                                                                                                                                                                                            
{'loss': 1.2188, 'learning_rate': 0.00025199999999999995, 'epoch': 0.06}                                                                                                                                                                                                            
{'loss': 1.3047, 'learning_rate': 0.000249, 'epoch': 0.06}                                                                                                                                                                                                                          
{'loss': 0.9844, 'learning_rate': 0.00024599999999999996, 'epoch': 0.07}                                                                                                                                                                                                            
{'loss': 1.1875, 'learning_rate': 0.000243, 'epoch': 0.07}                                                                                                                                                                                                                          
{'loss': 0.9453, 'learning_rate': 0.00023999999999999998, 'epoch': 0.07} 

When I go to the previous commit (9342c8fb824dfc9a4a374292fa9995ae1bc52da0), the loss looks good:

{'loss': 10.875, 'learning_rate': 0.00029699999999999996, 'epoch': 0.0}                                                                                                                                                                                                             
{'loss': 4.9688, 'learning_rate': 0.000294, 'epoch': 0.01}                                                                                                                                                                                                                          
{'loss': 3.9688, 'learning_rate': 0.00029099999999999997, 'epoch': 0.01}                                                                                                                                                                                                            
{'loss': 6.1562, 'learning_rate': 0.00028799999999999995, 'epoch': 0.01}                                                                                                                                                                                                            
{'loss': 3.5312, 'learning_rate': 0.000285, 'epoch': 0.02}                                                                                                                                                                                                                          
{'loss': 4.0312, 'learning_rate': 0.00028199999999999997, 'epoch': 0.02}                                                                                                                                                                                                            
{'loss': 3.6562, 'learning_rate': 0.000279, 'epoch': 0.03}                                                                                                                                                                                                                          
{'loss': 2.9375, 'learning_rate': 0.000276, 'epoch': 0.03}                                                                                                                                                                                                                          
{'loss': 1.75, 'learning_rate': 0.00027299999999999997, 'epoch': 0.03}                                                                                                                                                                                                              
{'loss': 1.5312, 'learning_rate': 0.00027, 'epoch': 0.04}                                                                                                                                                                                                                           
{'loss': 1.3281, 'learning_rate': 0.000267, 'epoch': 0.04}                                                                                                                                                                                                                          
{'loss': 1.4062, 'learning_rate': 0.00026399999999999997, 'epoch': 0.04}                                                                                                                                                                                                            
{'loss': 1.4844, 'learning_rate': 0.000261, 'epoch': 0.05}                                                                                                                                                                                                                          
{'loss': 1.1797, 'learning_rate': 0.000258, 'epoch': 0.05}                                                                                                                                                                                                                          
{'loss': 1.8281, 'learning_rate': 0.00025499999999999996, 'epoch': 0.06}                                                                                                                                                                                                            
{'loss': 1.0625, 'learning_rate': 0.00025199999999999995, 'epoch': 0.06}                                                                                                                                                                                                            
{'loss': 1.4688, 'learning_rate': 0.000249, 'epoch': 0.06}                                                                                                                                                                                                                          
{'loss': 0.9727, 'learning_rate': 0.00024599999999999996, 'epoch': 0.07}                                                                                                                                                                                                            
{'loss': 1.2344, 'learning_rate': 0.000243, 'epoch': 0.07}                                                                                                                                                                                                                          
{'loss': 0.6797, 'learning_rate': 0.00023999999999999998, 'epoch': 0.07}

I don't know enough about transformers to judge why that could be, so I'm pinging @gante for help.

Training script without PEFT
import argparse
import warnings
from typing import Any, Dict, List, Optional, Union

import numpy as np
import torch

from datasets import load_dataset

from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling,
    set_seed
)


# copied from TRL to avoid the need to install it
class DataCollatorForCompletionOnlyLM(DataCollatorForLanguageModeling):
    def __init__(
        self,
        response_template: Union[str, List[int]],
        instruction_template: Optional[Union[str, List[int]]] = None,
        *args,
        mlm: bool = False,
        ignore_index: int = -100,
        padding_free: bool = False,
        **kwargs,
    ):
        super().__init__(*args, mlm=mlm, **kwargs)

        self.instruction_template = instruction_template
        if isinstance(instruction_template, str):
            # The user provides a string, must tokenize
            self.instruction_token_ids = self.tokenizer.encode(self.instruction_template, add_special_tokens=False)
        else:
            # The user already provides the token ids
            self.instruction_token_ids = instruction_template

        self.response_template = response_template
        if isinstance(response_template, str):
            # The user provides a string, must tokenize
            self.response_token_ids = self.tokenizer.encode(self.response_template, add_special_tokens=False)
        else:
            # The user already provides the token ids
            self.response_token_ids = response_template

        if not self.mlm and self.instruction_template and self.tokenizer.pad_token_id == self.tokenizer.eos_token_id:
            warnings.warn(
                "The pad_token_id and eos_token_id values of this tokenizer are identical. "
                "If you are planning for multi-turn training, "
                "it can result in the model continuously generating questions and answers without eos token. "
                "To avoid this, set the pad_token_id to a different value."
            )

        self.ignore_index = ignore_index
        self.padding_free = padding_free

    def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
        batch = super().torch_call(examples)

        if self.instruction_template is None:
            for i in range(len(examples)):
                response_token_ids_start_idx = None

                for idx in np.where(batch["labels"][i] == self.response_token_ids[0])[0]:
                    # `response_token_ids` is `'### Response:\n'`, here we are just making sure that the token IDs match
                    if (
                        self.response_token_ids
                        == batch["labels"][i][idx : idx + len(self.response_token_ids)].tolist()
                    ):
                        response_token_ids_start_idx = idx

                if response_token_ids_start_idx is None:
                    warnings.warn(
                        f"Could not find response key `{self.response_template}` in the "
                        f'following instance: {self.tokenizer.decode(batch["input_ids"][i])} '
                        f"This instance will be ignored in loss calculation. "
                        f"Note, if this happens often, consider increasing the `max_seq_length`."
                    )
                    batch["labels"][i, :] = self.ignore_index
                else:
                    response_token_ids_end_idx = response_token_ids_start_idx + len(self.response_token_ids)

                    # Make pytorch loss function ignore all tokens up through the end of the response key
                    batch["labels"][i, :response_token_ids_end_idx] = self.ignore_index

        else:
            for i in range(len(examples)):
                response_token_ids_idxs = []
                human_token_ids_idxs = []

                for assistant_idx in np.where(batch["labels"][i] == self.response_token_ids[0])[0]:
                    # find the indexes of the start of a response.
                    if (
                        self.response_token_ids
                        == batch["labels"][i][assistant_idx : assistant_idx + len(self.response_token_ids)].tolist()
                    ):
                        response_token_ids_idxs.append(assistant_idx + len(self.response_token_ids))

                if len(response_token_ids_idxs) == 0:
                    warnings.warn(
                        f"Could not find response key `{self.response_template}` in the "
                        f'following instance: {self.tokenizer.decode(batch["input_ids"][i])} '
                        f"This instance will be ignored in loss calculation. "
                        f"Note, if this happens often, consider increasing the `max_seq_length`."
                    )
                    batch["labels"][i, :] = self.ignore_index

                human_token_ids = self.instruction_token_ids
                for human_idx in np.where(batch["labels"][i] == human_token_ids[0])[0]:
                    # find the indexes of the start of a human answer.
                    if human_token_ids == batch["labels"][i][human_idx : human_idx + len(human_token_ids)].tolist():
                        human_token_ids_idxs.append(human_idx)

                if len(human_token_ids_idxs) == 0:
                    warnings.warn(
                        f"Could not find instruction key `{self.instruction_template}` in the "
                        f'following instance: {self.tokenizer.decode(batch["input_ids"][i])} '
                        f"This instance will be ignored in loss calculation. "
                        f"Note, if this happens often, consider increasing the `max_seq_length`."
                    )
                    batch["labels"][i, :] = self.ignore_index

                if (
                    len(human_token_ids_idxs) > 0
                    and len(response_token_ids_idxs) > 0
                    and human_token_ids_idxs[0] > response_token_ids_idxs[0]
                ):
                    human_token_ids_idxs = [0] + human_token_ids_idxs

                for idx, (start, end) in enumerate(zip(human_token_ids_idxs, response_token_ids_idxs)):
                    # Make pytorch loss function ignore all non response tokens
                    if idx != 0:
                        batch["labels"][i, start:end] = self.ignore_index
                    else:
                        batch["labels"][i, :end] = self.ignore_index

                if len(response_token_ids_idxs) < len(human_token_ids_idxs):
                    batch["labels"][i, human_token_ids_idxs[-1] :] = self.ignore_index

        if self.padding_free:
            # remove padding, `attention_mask` and add `position_ids`
            attn_mask = batch.pop("attention_mask")
            batch["input_ids"] = batch["input_ids"][attn_mask.bool()].unsqueeze(0)
            batch["position_ids"] = attn_mask.cumsum(1)[attn_mask.bool()].unsqueeze(0) - 1
            batch["labels"] = batch["labels"][attn_mask.bool()].unsqueeze(0)
            batch["labels"][batch["position_ids"] == 0] = self.ignore_index

        return batch


class CustomDataCollatorForCompletionOnlyLM(DataCollatorForCompletionOnlyLM):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
        batch = super().torch_call(examples)

        # ensure the last tokens is taken into account for loss computation
        # otherwise the model may never stop generating at inference
        batch["labels"][:, -1] = batch["input_ids"][:, -1]

        return batch


def load_model_and_tokenizer(args):
    model = AutoModelForCausalLM.from_pretrained(
        args.model_name_or_path,
        torch_dtype=torch.bfloat16,
        trust_remote_code=True,
        #device_map="auto",
        device_map={"": 0},
    )
    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True)

    if getattr(tokenizer, "pad_token_id") is None:
        tokenizer.pad_token_id = tokenizer.eos_token_id
        model.config.pad_token_id = model.config.eos_token_id

    return model, tokenizer


def main(args):
    dataset = load_dataset("neulab/docprompting-conala")
    model, tokenizer = load_model_and_tokenizer(args)

    def tokenize(example):
        prompt = f"{tokenizer.bos_token}\n"
        prompt += f"### Instruction:\n{example['nl']}\n"
        prompt += f"### Response:\n{example['cmd']}\n<|EOT|>"

        model_inputs = tokenizer(prompt, truncation=True, max_length=128, padding="max_length")

        return model_inputs

    tokenized_dataset = dataset.map(
        tokenize,
        batched=False,
        remove_columns=[cn for cn in dataset["train"].column_names if cn not in ["input_ids", "attention_mask"]],
    )

    training_args = TrainingArguments(
        output_dir=args.run_dir,
        evaluation_strategy="epoch",
        save_strategy="epoch",
        num_train_epochs=args.num_epochs,
        per_device_train_batch_size=args.batch_size,
        per_device_eval_batch_size=args.batch_size,
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        learning_rate=args.learning_rate,
        lr_scheduler_type=args.lr_scheduler_type,
        logging_strategy="steps",
        bf16=True,
        logging_steps=1,
        save_total_limit=10,
        load_best_model_at_end=False,# BB
        report_to="wandb" if args.use_wandb else "none",
        max_steps=100,
    )
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_dataset["train"],
        eval_dataset=tokenized_dataset["validation"],
        tokenizer=tokenizer,
        data_collator=CustomDataCollatorForCompletionOnlyLM("### Response", tokenizer=tokenizer),
        # callbacks=[EarlyStoppingCallback(early_stopping_patience=args.patience)]
    )

    trainer.train()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_name_or_path", default="deepseek-ai/deepseek-coder-1.3b-instruct", type=str)
    parser.add_argument("--output_dir", default=".", type=str)

    parser.add_argument("--num_epochs", type=int, default=3)
    parser.add_argument("--batch_size", type=int, default=4)
    parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
    parser.add_argument("--patience", type=int, default=2)

    parser.add_argument("--learning_rate", type=float,  default=3e-4)
    parser.add_argument("--lr_scheduler_type", type=str, default="linear")

    parser.add_argument("--use_wandb", action="store_true")
    parser.add_argument("--seed", type=int, default=42)

    args = parser.parse_args()
    set_seed(args.seed)

    args.model_name = args.model_name_or_path.split('/')[-1]
    args.run_dir = f"{args.output_dir}/{args.model_name}_conala_prompt-tuning_new/"
    main(args)

@martin-wey
Copy link
Author

Thanks a lot for looking into it!

I'll use the latest version of PEFT with an older version of Transformers before it gets resolved :)

@gante
Copy link
Member

gante commented Sep 17, 2024

@BenjaminBossan thank you for the deep dive and git bisect!

Am I right in saying that training with and without PEFT has been broken on e.g. deepseek-ai/deepseek-coder-1.3b-instruct (i.e. llama architecture) since the PR you linked? At least using the training script you shared.

If so, is it still true even on commits like this one, which were supposed to stabilize training?

Assuming the answer is yes, then it means some RoPE-related change was the culprit 🤔 The highlighted PR only touches RoPE.

@BenjaminBossan
Copy link
Member

BenjaminBossan commented Sep 17, 2024

Thanks for the reply Joao. I checked out that commit and still find the strange loss pattern using normal fine-tuning:

{'loss': 1.213, 'grad_norm': 0.04748094454407692, 'learning_rate': 0.00297, 'epoch': 0.0}                                                                                                                                                                                           
{'loss': 1.1009, 'grad_norm': 0.058126840740442276, 'learning_rate': 0.00294, 'epoch': 0.01}                                                                                                                                                                                        
{'loss': 1.4116, 'grad_norm': 0.06582096219062805, 'learning_rate': 0.00291, 'epoch': 0.01}                                                                                                                                                                                         
{'loss': 1.2287, 'grad_norm': 0.05835267901420593, 'learning_rate': 0.0028799999999999997, 'epoch': 0.01}                                                                                                                                                                           
{'loss': 1.7492, 'grad_norm': 0.06075122952461243, 'learning_rate': 0.00285, 'epoch': 0.02}                                                                                                                                                                                         
{'loss': 1.3371, 'grad_norm': 0.05048394203186035, 'learning_rate': 0.00282, 'epoch': 0.02}                                                                                                                                                                                         
{'loss': 1.5431, 'grad_norm': 0.0814090147614479, 'learning_rate': 0.0027900000000000004, 'epoch': 0.03}                                                                                                                                                                            
{'loss': 1.3134, 'grad_norm': 0.05842101201415062, 'learning_rate': 0.0027600000000000003, 'epoch': 0.03}                                                                                                                                                                           
{'loss': 1.1006, 'grad_norm': 0.0562712661921978, 'learning_rate': 0.0027300000000000002, 'epoch': 0.03}                                                                                                                                                                            
{'loss': 1.249, 'grad_norm': 0.08580105006694794, 'learning_rate': 0.0027, 'epoch': 0.04}                                                                                                                                                                                           
{'loss': 1.2432, 'grad_norm': 0.0493149533867836, 'learning_rate': 0.00267, 'epoch': 0.04}                                                                                                                                                                                          
{'loss': 1.2075, 'grad_norm': 0.04863704741001129, 'learning_rate': 0.00264, 'epoch': 0.04}                                                                                                                                                                                         
{'loss': 0.8575, 'grad_norm': 0.06719737499952316, 'learning_rate': 0.00261, 'epoch': 0.05}

So it looks quite similar to the pattern observed when using the latest transformers.

I can't say for certain that this is "wrong" and the previous pattern is "right", maybe it's the other way round. It's just that the loss is quite low right from the start and does not improve. This, together with what @martin-wey reported, makes it more likely than not that something is indeed broken.

Is there a way to easily switch the type of positional embedding? I guess not but if yes, we can test that.

@gante
Copy link
Member

gante commented Sep 17, 2024

Is there a way to easily switch the type of positional embedding? I guess not but if yes, we can test that.

In the case of RoPE models no -- the model calls the position embeding's forward to obtain two position-dependent constants, which are mixed with the hidden states. In other words, it's more complex than the typical block, and has a unique structure deeply intertwined with the forward of the attention layers

@gante
Copy link
Member

gante commented Sep 17, 2024

@BenjaminBossan we missed a detail :D It makes sense that the PR that introduced RoPE scaling has a huge impact on the loss -- the test model uses rope scaling

Prior to that PR rope scaling didn't exist -> the fine tuning script changes the position embeddings (because scaling is not done) -> the loss starts at a much higher value

@BenjaminBossan
Copy link
Member

Thanks @gante so it is expected that after the addition of RoPE scaling, the loss would change considerably? In this case, it is actually much lower not higher.

@martin-wey My takeaway so far is that after the given PR, the model itself changed due to the addition of RoPE scaling, therefore it is no surprise that the trained PEFT adapter would not work anymore. However, we should still expect that prompt-tuning works even after said PR. I sat down again with the training script and what I found is that I had to increase the learning rate quite a lot to see training progress. You mentioned you tried up to 3e-2, I would recommend trying even higher, for instance I got a nice progress on eval loss with 0.1 (for deepseek-ai/deepseek-coder-1.3b-instruct).

@gante
Copy link
Member

gante commented Sep 18, 2024

Thanks @gante so it is expected that after the addition of RoPE scaling, the loss would change considerably? In this case, it is actually much lower not higher.

@BenjaminBossan Yes. Before that PR, for that checkpoint in particular, we couldn't create properly the position embeddings with which the model was fine-tuned. That explains the lower starting loss in our most recent versions -- they are creating the position embeddings correctly.

Conversely, if we use a checkpoint without rope_scaling in its model config, that PR should have no impact on the model.

[RoPE is a constant created at load-time, therefore it is not trained. However, it massively influences the attention layer if we don't create the constants correctly]

@martin-wey
Copy link
Author

martin-wey commented Sep 20, 2024

Thanks @gante so it is expected that after the addition of RoPE scaling, the loss would change considerably? In this case, it is actually much lower not higher.

@martin-wey My takeaway so far is that after the given PR, the model itself changed due to the addition of RoPE scaling, therefore it is no surprise that the trained PEFT adapter would not work anymore. However, we should still expect that prompt-tuning works even after said PR. I sat down again with the training script and what I found is that I had to increase the learning rate quite a lot to see training progress. You mentioned you tried up to 3e-2, I would recommend trying even higher, for instance I got a nice progress on eval loss with 0.1 (for deepseek-ai/deepseek-coder-1.3b-instruct).

Thanks for the suggestion. I also tried different learning rates, including 3e-1, 3e-2, and 3e-3. The evaluation loss is good and even decreases. However, I am still facing the issue at generation: the model tuned using prompt tuning or p-tuning generates identical content to the model in zero-shot when using the latest version of Transformers. I could solve this issue by downgrading Transformers to v4.29.

@BenjaminBossan
Copy link
Member

However, I am still facing the issue at generation: the model tuned using prompt tuning or p-tuning generates identical content to the model in zero-shot when using the latest version of Transformers.

Could you please provide the exact code you used for checking this?

@martin-wey
Copy link
Author

martin-wey commented Sep 23, 2024

However, I am still facing the issue at generation: the model tuned using prompt tuning or p-tuning generates identical content to the model in zero-shot when using the latest version of Transformers.

Could you please provide the exact code you used for checking this?

I use the fine-tuning script provided above.
To test inference:

import torch

from datasets import load_from_disk
from peft import AutoPeftModelForCausalLM
from transformers import AutoTokenizer

model = AutoPeftModelForCausalLM.from_pretrained("runs/Phi-3-mini-128k-instruct_conala_p-tuning/checkpoint-1335").to("cuda")
model.eval()
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-128k-instruct")

example = [
  {"role": "user", "content": "divide the values with same keys of two dictionary `d1` and `d2`"}
]
prompt = tokenizer.apply_chat_template(example, add_generation_prompt=True, return_dict=True)

with torch.no_grad():
    inputs = {k: v.to(model.device) for k, v in inputs.items()}
    outputs = model.generate(
        input_ids=inputs["input_ids"],
        attention_mask=inputs["attention_mask"],
        max_new_tokens=128
    )
    print(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True))

For all checkpoints (after 1/2/3/4/5 epochs), the model generates identical content to the base model, although the evaluation loss seems ok for both p-tuning and prompt tuning. For reference, the evaluation losses are somewhat similar to the model fine-tuned with LoRA. However, I do not get this issue at generation when fine-tuning with LoRA.
The problem remains across models and test examples. There might be something wrong in the code related to generation in peft/transformers. As a reminder, I did not get this issue at inference with an older version of transformers/peft (v4.29 and v0.3.0).

I am also getting the following warning: Position ids are not supported for parameter efficient tuning. Ignoring position ids. with prompt-based tuned models.

@BenjaminBossan
Copy link
Member

Thanks @martin-wey for the extra info. I tried to reproduce this using the latest PEFT and transformers version. I changed the base model to deepseek-ai/deepseek-coder-1.3b-instruct because of memory, and trained for 100 steps with a learning rate of 0.05.

I modified your inference code slightly, as it's not quite complete. This is what I get:

import gc

import torch

from peft import AutoPeftModelForCausalLM
from transformers import AutoTokenizer, AutoModelForCausalLM

device = 0
seed = 0
model_id = "deepseek-ai/deepseek-coder-1.3b-instruct"
example = [
  {"role": "user", "content": "divide the values with same keys of two dictionary `d1` and `d2`"}
]
prompt = tokenizer.apply_chat_template(example, add_generation_prompt=True, return_dict=True)
inputs = {k: torch.tensor(v).to(device).unsqueeze(0) for k, v in prompt.items()}

# base model
model = AutoModelForCausalLM.from_pretrained(model_id)
model.to(device);
torch.manual_seed(seed)
model.eval();

with torch.no_grad():
    outputs = model.generate(
        input_ids=inputs["input_ids"],
        attention_mask=inputs["attention_mask"],
        max_new_tokens=128
    )

print(tokenizer.batch_decod

del model
torch.cuda.empty_cache()
gc.collect()

# trained model
path = "./checkpoint_deepseek-ai--deepseek-coder-1.3b-instruct_peft-0.12.1.dev0_transformers-4.45.0.dev0"
model = AutoPeftModelForCausalLM.from_pretrained(path)
model.to(device);
torch.manual_seed(seed)
model.eval();

with torch.no_grad():
    outputs = model.generate(
        input_ids=inputs["input_ids"],
        attention_mask=inputs["attention_mask"],
        max_new_tokens=128
    )

print(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True)[0])

For the base model, what I get is (only response part):

### Response:
Sure, I can help with that. Here's a Python example:

\```python
d1 = {'a': 1, 'b': 2, 'c': 3}
d2 = {'a': 4, 'b': 5, 'c': 6}

result = {key: d1[key] / d2[key] for key in d1.keys() & d2.keys()}

print(result)
\```

This will output:

\```
{'a': 0.25, 'b': 0

(note that I escaped the backtick for formatting reasons)

And for the trained model:

{k: d1[k] / d2[k] for k in d1.keys() & d2.keys()}

So we can see that the model did learn something and that the response is much more succinct, similar to what we find in the training data. I tried a few seeds to ensure this was not random but got the same results.

@martin-wey
Copy link
Author

Thank you @BenjaminBossan! Upgrading to the latest versions of transformers and peft solves the issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants