-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Missing modules in prompt-based PEFT when re-loading model #2043
Comments
Thanks for reporting this issue. First of all, please don't load a model using That said, even with these adjustments, as well as ensuring that the model is in eval mode, I could reproduce the error, i.e. there is a small discrepancy after loading (in my tests, mean abs diff of logits was ~0.02, depending on model). However, this discrepancy is not due to the missing module being loaded. When you check these lines: peft/src/peft/utils/save_and_load.py Lines 144 to 155 in 679bcd8
You can see that the prompt embedding is saved as part of the Still, I'm not sure where the difference comes from and will investigate further. Just wanted to share some insights I had so far. PS: Interestingly, for |
I may have pasted a wrong snippet. I'm having the same issue when first loading the base model then use
Got it, thanks!
Even so, the PEFT-tuned model's responses to input prompts are very similar to the base model. Besides, I'm also having troubles with Phi-3.5-mini-128k-instruct and CodeQwen1.5-7B-Chat. However, LoRA-based PEFTs (including QLoRA and DoRA) work just fine. Thanks for your quick reply. I will also keep trying with other LLMs. |
@BenjaminBossan here's some more details about the forward pass of the PEFT and base models. I hope it helps :) I checked whether the input goes through all the modules of the network using hooks. # ...
model = AutoPeftModelForCausalLM.from_pretrained(
"../runs/codellama/CodeLlama-7b-Instruct-hf_conala_p-tuning_3e-3/checkpoint-198/",
torch_dtype="bfloat16",
device_map="auto"
)
def forward_hook(module, input, output):
print(f"Module: {module.__class__.__name__}")
print(f"Input: {input}")
print(f"Output: {output}")
print("-" * 50)
hook_word_embeddings = model.word_embeddings.register_forward_hook(forward_hook)
hook_prompt_embeddings = model.prompt_encoder['default'].embedding.register_forward_hook(forward_hook)
for sample in test_set:
tokenized_sample = tokenizer.apply_chat_template(
sample["messages"],
return_tensors="pt",
).to(model.device)
output = model(tokenized_sample) Output: Module: Embedding
Input: (tensor([[ 1, 518, 25580, 29962, 3532, 14816, 29903, 6778, 13, 3492,
526, 263, 8444, 20255, 29889, 13, 29966, 829, 14816, 29903,
6778, 13, 13, 4563, 680, 278, 1819, 411, 1021, 6611,
310, 1023, 8600, 421, 29881, 29896, 29952, 322, 421, 29881,
29906, 29952, 518, 29914, 25580, 29962]], device='cuda:0'),)
Output: tensor([[[ 0.0069, 0.0031, -0.0013, ..., 0.0003, -0.0031, -0.0026],
[ 0.0170, -0.0242, 0.0211, ..., -0.0048, -0.0002, 0.0183],
[-0.0312, -0.0112, -0.0510, ..., 0.0070, 0.0281, -0.0294],
...,
[ 0.0142, -0.0182, -0.0065, ..., -0.0210, 0.0099, -0.0132],
[-0.0312, -0.0112, -0.0510, ..., 0.0070, 0.0281, -0.0294],
[-0.0237, -0.0297, 0.0014, ..., 0.0427, -0.0008, -0.0029]]],
device='cuda:0', dtype=torch.bfloat16) The input never goes through # ...
model = AutoModelForCausalLM.from_pretrained(
"codellama/CodeLlama-7b-Instruct-hf",
torch_dtype="bfloat16",
device_map="auto"
)
hook_embeddings = model.model.embed_tokens.register_forward_hook(forward_hook)
# ... Output (identical): Module: Embedding
Input: (tensor([[ 1, 518, 25580, 29962, 3532, 14816, 29903, 6778, 13, 3492,
526, 263, 8444, 20255, 29889, 13, 29966, 829, 14816, 29903,
6778, 13, 13, 4563, 680, 278, 1819, 411, 1021, 6611,
310, 1023, 8600, 421, 29881, 29896, 29952, 322, 421, 29881,
29906, 29952, 518, 29914, 25580, 29962]], device='cuda:0'),)
Output: tensor([[[ 0.0069, 0.0031, -0.0013, ..., 0.0003, -0.0031, -0.0026],
[ 0.0170, -0.0242, 0.0211, ..., -0.0048, -0.0002, 0.0183],
[-0.0312, -0.0112, -0.0510, ..., 0.0070, 0.0281, -0.0294],
...,
[ 0.0142, -0.0182, -0.0065, ..., -0.0210, 0.0099, -0.0132],
[-0.0312, -0.0112, -0.0510, ..., 0.0070, 0.0281, -0.0294],
[-0.0237, -0.0297, 0.0014, ..., 0.0427, -0.0008, -0.0029]]],
device='cuda:0', dtype=torch.bfloat16, grad_fn=<EmbeddingBackward0>) Basically it seems like the prompt_encoder is bypassed during the forward pass. That's why I'm getting the same output when generating code using both PEFT and base models. |
Thanks for digging deeper into this.
Yes, this is expected, as it relates to the optimization I mentioned above. The prompt embeddings to be prefixed are precomputed, therefore, the embedding's
I cannot observe this, for me the outputs of the loaded PEFT model are very close to the outputs of the original PEFT model, but different enough that generations start different at some point. Investigating this further sent me down a rabbit hole and I think I have figured out the issue. To cut it short, the issue appears to be that the individual outputs of sending the same input repeated 10 times through the MLP does not equal the output of sending it through the MLP once. Let me illustrate: import torch
torch.manual_seed(0);
device = 0
input_size = 128
hidden_size = 32
output_size = 64
layers = [
torch.nn.Linear(input_size, hidden_size),
torch.nn.ReLU(),
torch.nn.Linear(hidden_size, hidden_size),
torch.nn.ReLU(),
torch.nn.Linear(hidden_size, output_size),
]
mlp_head = torch.nn.Sequential(*layers)
mlp_head.to(device).eval();
x = torch.randn(20, input_size).to(device)
# x repeats 10x along the batch dimension
x = x.repeat(10, 1, 1)
# output with all 10 identical samples
out0 = mlp_head(x)
# output with only 1 of the samples
out1 = mlp_head(x[:1])
for i in range(10):
# this should be 0 but it is 2.0311928139449265e-08
print((out0[i:i+1] - out1).abs().mean().float().item()) When the involved sizes are small enough, the difference is actually 0, which may explain why So how does this translate to p-tuning? Let's check these lines which are executed during training time: Lines 658 to 663 in 679bcd8
Line 695 in 679bcd8
You see that we repeat the same input Now, when we load the model, we go through a slightly different code path: Line 693 in 679bcd8
Here, we just take the output of a single sample and repeat it To address this, I created the following patch: @@ -692,7 +692,13 @@ class PeftModel(PushToHubMixin, torch.nn.Module):
if peft_config.inference_mode:
prompts = prompt_encoder.embedding.weight.repeat(batch_size, 1, 1)
else:
- prompts = prompt_encoder(prompt_tokens)
+ prompt_tokens = (
+ self.prompt_tokens[self.active_adapter]
+ .unsqueeze(0)
+ .expand(1, -1)
+ .to(prompt_encoder.embedding.weight.device)
+ )
+ prompts = prompt_encoder(prompt_tokens).repeat(batch_size, 1, 1)
return prompts
def get_nb_trainable_parameters(self) -> tuple[int, int]: With this patch, the error vanishes for me. For completeness, here is the script I used to check it: import torch
from transformers import AutoModelForCausalLM
from peft import PromptEncoderConfig, get_peft_model, AutoPeftModelForCausalLM, PeftModel
inputs = torch.arange(10).view(-1, 1).to(0)
model_id = "meta-llama/Meta-Llama-3-8B"
#model_id = "bigscience/bloomz-560m"
#model_id = "Qwen/Qwen2-1.5B"
#model_id = "microsoft/Phi-3.5-mini-instruct"
#model_id = "facebook/opt-125m"
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map=0,
)
peft_config = PromptEncoderConfig(task_type="CAUSAL_LM", num_virtual_tokens=20, encoder_hidden_size=128)
model = get_peft_model(model, peft_config)
model.eval();
torch.manual_seed(0)
with torch.inference_mode():
output_peft = model(inputs).logits
gen_peft = model.generate(inputs, min_new_tokens=10, max_new_tokens=10)
model.save_pretrained("/tmp/peft/2043")
del model
torch.cuda.empty_cache()
model = AutoPeftModelForCausalLM.from_pretrained("/tmp/peft/2043", device_map=0, torch_dtype=torch.bfloat16)
# using `model = AutoModelForCausalLM.from_pretrained(...); model = PeftModel.from_pretrained(...)` also works
torch.manual_seed(0)
with torch.inference_mode():
output_loaded = model(inputs).logits
gen_loaded = model.generate(inputs, min_new_tokens=10, max_new_tokens=10)
torch.testing.assert_close(output_loaded, output_peft)
torch.testing.assert_close(gen_loaded, gen_peft) If you can confirm that this patch solves your original issue, I will create a PR to fix this in PEFT. |
Thanks a lot for the detailed explanation. From my end, I am still experiencing both models generating exactly the same content for a given prompt, even with this fix. It makes no sense to me as the training phase using p-tuning or prompt tuning works properly, with validation loss decreasing. I use the following dataset: https://huggingface.co/datasets/neulab/docprompting-conala. The model learns to generate 1-2 lines of code for a given instruction. I implemented a TrainerCallback after each epoch to generate code for a given test sample to rule out potential issues when reloading the model from a checkpoint. class GenerateAfterEpochCallback(TrainerCallback):
def __init__(self, test_example, tokenizer):
self.test_example = test_example
self.tokenizer = tokenizer
def on_epoch_end(self, args, state, control, **kwargs):
model = kwargs['model']
tokenizer = self.tokenizer
model.eval()
inputs = tokenizer.apply_chat_template(
self.test_example["messages"][:-1] # remove assistant's solution,
return_dict=True,
return_tensors="pt",
add_generation_prompt=True
).to(model.device)
with torch.no_grad():
output_ids = model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_length=128,
)
generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
print(f"\nGenerated text after epoch {state.epoch}:\n{generated_text}\n")
model.train()
return control Regardless of the PEFT hyperparameters configuration, the model generates the same content all the time, which is also identical to the base model. I tried a few things that all result in identical outputs (see below):
Output: Interestingly, when generating without This is weird. I used prompt-based tunings with PEFT in may/june 2023 for another project (with a similar setup), and I never had such issue. |
Thanks for the additional information. I was afraid that something more could be going on as you wrote that the results are very different. For me to further assist in this, it would be very helpful if you could share a bit more information. Ideally, you could upload the adapter to HF so that I can try it myself. Please also provide the code that you use to check if the outputs are as expected or not. If you cannot share the checkpoint, would it be possible to share the training code instead? Another thing that could be going on is that there have been some changes to transformers recently that interfere with prompt tuning methods in PEFT. Therefore, it would be helpful if you could test if your issue resolves when using an older transformers version. Maybe you can figure out which one you used back then? If you downgrade transformers, you may also have to downgrade PEFT to a version that corresponds to that time. If you can determine that your checkpoint works with version X but not Y, this would greatly increase the chances of figuring out what's going wrong. PS: Just checked, something like transformers v4.29 and v0.3.0 would correspond to the time you mentioned. |
There is a small numerical discrepancy between the outputs of a p-tuning model before and after loading. Even though it is small, it can still affect generations, so this PR eliminates it. As an example, without the fix, this is the difference in logits for opt-125m: > torch.testing.assert_close(output_loaded, output_peft) E AssertionError: Tensor-likes are not close! E E Mismatched elements: 30 / 10557120 (0.0%) E Greatest absolute difference: 1.1086463928222656e-05 at index (0, 9, 9314) (up to 1e-05 allowed) E Greatest relative difference: 0.00021288332936819643 at index (0, 9, 9314) (up to 1.3e-06 allowed) Details about how this comes about are explained here: huggingface#2043 (comment) The gist of it is that if we take a single sample, repeat it X times, and then forward it through a model (which is the training path in p-tuning), we would expect the same output as if we forwarded this sample only once and repeated the output X times (the inference path for p-tuning). However, for sufficiently large models, the two approaches can have tiny differences. With the fixed approach, there is no difference between training and inference code paths when it comes to this. The new code should also be slightly more compute efficient, but in practice will not make a noticeable difference.
@BenjaminBossan thanks again. Please ignore my last reply if you have seen it. I am still working on it to make sure I provide you 100% accurate information. I'll reply asap. |
…#2047) There is a small numerical discrepancy between the outputs of a p-tuning model before and after loading. Even though it is small, it can still affect generations, so this PR eliminates it. As an example, without the fix, this is the difference in logits for opt-125m: > torch.testing.assert_close(output_loaded, output_peft) E AssertionError: Tensor-likes are not close! E E Mismatched elements: 30 / 10557120 (0.0%) E Greatest absolute difference: 1.1086463928222656e-05 at index (0, 9, 9314) (up to 1e-05 allowed) E Greatest relative difference: 0.00021288332936819643 at index (0, 9, 9314) (up to 1.3e-06 allowed) Details about how this comes about are explained here: #2043 (comment) The gist of it is that if we take a single sample, repeat it X times, and then forward it through a model (which is the training path in p-tuning), we would expect the same output as if we forwarded this sample only once and repeated the output X times (the inference path for p-tuning). However, for sufficiently large models, the two approaches can have tiny differences. With the fixed approach, there is no difference between training and inference code paths when it comes to this. The new code should also be slightly more compute efficient, but in practice will not make a noticeable difference.
@martin-wey Do you have any updates? |
@BenjaminBossan Yes, as you suggested I compared fine-tuning using prompt tuning with 1. Transformers v4.29 / Peft v0.3.0 and 2. Transformers v4.44.2 / Peft v0.12.0. I used deepseek-coder-6.7b-instruct. Summary of my findings:
Here's the code compatible for both versions: import argparse
import torch
from datasets import load_dataset
from peft import get_peft_model, PromptTuningConfig, TaskType, PromptTuningInit
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TrainingArguments,
Trainer,
EarlyStoppingCallback,
DataCollatorForLanguageModeling,
TrainerCallback,
set_seed
)
from collator import CustomDataCollatorForCompletionOnlyLM
def load_model_and_tokenizer(args):
model = AutoModelForCausalLM.from_pretrained(
args.model_name_or_path,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True)
peft_config = PromptTuningConfig(
task_type=TaskType.CAUSAL_LM,
prompt_tuning_init=PromptTuningInit.RANDOM,
num_virtual_tokens=20,
tokenizer_name_or_path=args.model_name_or_path,
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
if getattr(tokenizer, "pad_token_id") is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
model.config.pad_token_id = model.config.eos_token_id
return model, tokenizer
def main(args):
dataset = load_dataset("neulab/docprompting-conala")
model, tokenizer = load_model_and_tokenizer(args)
def tokenize(example):
prompt = f"{tokenizer.bos_token}\n"
prompt += f"### Instruction:\n{example['nl']}\n"
prompt += f"### Response:\n{example['cmd']}\n<|EOT|>"
model_inputs = tokenizer(prompt, truncation=True, max_length=128, padding="max_length")
return model_inputs
tokenized_dataset = dataset.map(
tokenize,
batched=False,
remove_columns=[cn for cn in dataset["train"].column_names if cn not in ["input_ids", "attention_mask"]],
)
training_args = TrainingArguments(
output_dir=args.run_dir,
evaluation_strategy="epoch",
save_strategy="epoch",
num_train_epochs=args.num_epochs,
per_device_train_batch_size=args.batch_size,
per_device_eval_batch_size=args.batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
learning_rate=args.learning_rate,
lr_scheduler_type=args.lr_scheduler_type,
logging_strategy="steps",
bf16=True,
logging_steps=1,
save_total_limit=10,
load_best_model_at_end=True,
report_to="wandb" if args.use_wandb else "none"
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset["train"],
eval_dataset=tokenized_dataset["validation"],
tokenizer=tokenizer,
data_collator=CustomDataCollatorForCompletionOnlyLM("### Response", tokenizer=tokenizer),
callbacks=[EarlyStoppingCallback(early_stopping_patience=args.patience)]
)
trainer.train()
trainer.model.save_pretrained(f"{args.run_dir}/best_model_checkpoint")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_name_or_path", default="deepseek-ai/deepseek-coder-6.7b-instruct", type=str)
parser.add_argument("--output_dir", default=".", type=str)
parser.add_argument("--num_epochs", type=int, default=3)
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
parser.add_argument("--patience", type=int, default=2)
parser.add_argument("--learning_rate", type=float, default=3e-3)
parser.add_argument("--lr_scheduler_type", type=str, default="linear")
parser.add_argument("--use_wandb", action="store_true")
parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args()
set_seed(args.seed)
args.model_name = args.model_name_or_path.split('/')[-1]
args.run_dir = f"{args.output_dir}/{args.model_name}_conala_prompt-tuning_new/"
main(args) Data collator (extension of the trl completion only data collator, that keeps the EOS token in the labels): class CustomDataCollatorForCompletionOnlyLM(DataCollatorForCompletionOnlyLM):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
batch = super().torch_call(examples)
# ensure the last tokens is taken into account for loss computation
# otherwise the model may never stop generating at inference
batch["labels"][:, -1] = batch["input_ids"][:, -1]
return batch |
Thanks for the reproducer. I made a few small changes, most notably using (I added a small offset for legibility) So next I ran git bisect and I could identify the exact transformers PR: huggingface/transformers#24653. Next I wanted to know if the issue is related to PEFT or not. Therefore, I changed the code to run full fine-tuning, no PEFT at all. What I found is that starting from this commit ( {'loss': 1.1484, 'learning_rate': 0.00029699999999999996, 'epoch': 0.0}
{'loss': 2.4375, 'learning_rate': 0.000294, 'epoch': 0.01}
{'loss': 1.5, 'learning_rate': 0.00029099999999999997, 'epoch': 0.01}
{'loss': 1.5312, 'learning_rate': 0.00028799999999999995, 'epoch': 0.01}
{'loss': 3.5625, 'learning_rate': 0.000285, 'epoch': 0.02}
{'loss': 1.8906, 'learning_rate': 0.00028199999999999997, 'epoch': 0.02}
{'loss': 1.4141, 'learning_rate': 0.000279, 'epoch': 0.03}
{'loss': 1.6094, 'learning_rate': 0.000276, 'epoch': 0.03}
{'loss': 0.9297, 'learning_rate': 0.00027299999999999997, 'epoch': 0.03}
{'loss': 1.125, 'learning_rate': 0.00027, 'epoch': 0.04}
{'loss': 1.5859, 'learning_rate': 0.000267, 'epoch': 0.04}
{'loss': 1.3828, 'learning_rate': 0.00026399999999999997, 'epoch': 0.04}
{'loss': 1.4141, 'learning_rate': 0.000261, 'epoch': 0.05}
{'loss': 1.2812, 'learning_rate': 0.000258, 'epoch': 0.05}
{'loss': 1.8438, 'learning_rate': 0.00025499999999999996, 'epoch': 0.06}
{'loss': 1.2188, 'learning_rate': 0.00025199999999999995, 'epoch': 0.06}
{'loss': 1.3047, 'learning_rate': 0.000249, 'epoch': 0.06}
{'loss': 0.9844, 'learning_rate': 0.00024599999999999996, 'epoch': 0.07}
{'loss': 1.1875, 'learning_rate': 0.000243, 'epoch': 0.07}
{'loss': 0.9453, 'learning_rate': 0.00023999999999999998, 'epoch': 0.07} When I go to the previous commit ( {'loss': 10.875, 'learning_rate': 0.00029699999999999996, 'epoch': 0.0}
{'loss': 4.9688, 'learning_rate': 0.000294, 'epoch': 0.01}
{'loss': 3.9688, 'learning_rate': 0.00029099999999999997, 'epoch': 0.01}
{'loss': 6.1562, 'learning_rate': 0.00028799999999999995, 'epoch': 0.01}
{'loss': 3.5312, 'learning_rate': 0.000285, 'epoch': 0.02}
{'loss': 4.0312, 'learning_rate': 0.00028199999999999997, 'epoch': 0.02}
{'loss': 3.6562, 'learning_rate': 0.000279, 'epoch': 0.03}
{'loss': 2.9375, 'learning_rate': 0.000276, 'epoch': 0.03}
{'loss': 1.75, 'learning_rate': 0.00027299999999999997, 'epoch': 0.03}
{'loss': 1.5312, 'learning_rate': 0.00027, 'epoch': 0.04}
{'loss': 1.3281, 'learning_rate': 0.000267, 'epoch': 0.04}
{'loss': 1.4062, 'learning_rate': 0.00026399999999999997, 'epoch': 0.04}
{'loss': 1.4844, 'learning_rate': 0.000261, 'epoch': 0.05}
{'loss': 1.1797, 'learning_rate': 0.000258, 'epoch': 0.05}
{'loss': 1.8281, 'learning_rate': 0.00025499999999999996, 'epoch': 0.06}
{'loss': 1.0625, 'learning_rate': 0.00025199999999999995, 'epoch': 0.06}
{'loss': 1.4688, 'learning_rate': 0.000249, 'epoch': 0.06}
{'loss': 0.9727, 'learning_rate': 0.00024599999999999996, 'epoch': 0.07}
{'loss': 1.2344, 'learning_rate': 0.000243, 'epoch': 0.07}
{'loss': 0.6797, 'learning_rate': 0.00023999999999999998, 'epoch': 0.07} I don't know enough about transformers to judge why that could be, so I'm pinging @gante for help. Training script without PEFTimport argparse
import warnings
from typing import Any, Dict, List, Optional, Union
import numpy as np
import torch
from datasets import load_dataset
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TrainingArguments,
Trainer,
DataCollatorForLanguageModeling,
set_seed
)
# copied from TRL to avoid the need to install it
class DataCollatorForCompletionOnlyLM(DataCollatorForLanguageModeling):
def __init__(
self,
response_template: Union[str, List[int]],
instruction_template: Optional[Union[str, List[int]]] = None,
*args,
mlm: bool = False,
ignore_index: int = -100,
padding_free: bool = False,
**kwargs,
):
super().__init__(*args, mlm=mlm, **kwargs)
self.instruction_template = instruction_template
if isinstance(instruction_template, str):
# The user provides a string, must tokenize
self.instruction_token_ids = self.tokenizer.encode(self.instruction_template, add_special_tokens=False)
else:
# The user already provides the token ids
self.instruction_token_ids = instruction_template
self.response_template = response_template
if isinstance(response_template, str):
# The user provides a string, must tokenize
self.response_token_ids = self.tokenizer.encode(self.response_template, add_special_tokens=False)
else:
# The user already provides the token ids
self.response_token_ids = response_template
if not self.mlm and self.instruction_template and self.tokenizer.pad_token_id == self.tokenizer.eos_token_id:
warnings.warn(
"The pad_token_id and eos_token_id values of this tokenizer are identical. "
"If you are planning for multi-turn training, "
"it can result in the model continuously generating questions and answers without eos token. "
"To avoid this, set the pad_token_id to a different value."
)
self.ignore_index = ignore_index
self.padding_free = padding_free
def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
batch = super().torch_call(examples)
if self.instruction_template is None:
for i in range(len(examples)):
response_token_ids_start_idx = None
for idx in np.where(batch["labels"][i] == self.response_token_ids[0])[0]:
# `response_token_ids` is `'### Response:\n'`, here we are just making sure that the token IDs match
if (
self.response_token_ids
== batch["labels"][i][idx : idx + len(self.response_token_ids)].tolist()
):
response_token_ids_start_idx = idx
if response_token_ids_start_idx is None:
warnings.warn(
f"Could not find response key `{self.response_template}` in the "
f'following instance: {self.tokenizer.decode(batch["input_ids"][i])} '
f"This instance will be ignored in loss calculation. "
f"Note, if this happens often, consider increasing the `max_seq_length`."
)
batch["labels"][i, :] = self.ignore_index
else:
response_token_ids_end_idx = response_token_ids_start_idx + len(self.response_token_ids)
# Make pytorch loss function ignore all tokens up through the end of the response key
batch["labels"][i, :response_token_ids_end_idx] = self.ignore_index
else:
for i in range(len(examples)):
response_token_ids_idxs = []
human_token_ids_idxs = []
for assistant_idx in np.where(batch["labels"][i] == self.response_token_ids[0])[0]:
# find the indexes of the start of a response.
if (
self.response_token_ids
== batch["labels"][i][assistant_idx : assistant_idx + len(self.response_token_ids)].tolist()
):
response_token_ids_idxs.append(assistant_idx + len(self.response_token_ids))
if len(response_token_ids_idxs) == 0:
warnings.warn(
f"Could not find response key `{self.response_template}` in the "
f'following instance: {self.tokenizer.decode(batch["input_ids"][i])} '
f"This instance will be ignored in loss calculation. "
f"Note, if this happens often, consider increasing the `max_seq_length`."
)
batch["labels"][i, :] = self.ignore_index
human_token_ids = self.instruction_token_ids
for human_idx in np.where(batch["labels"][i] == human_token_ids[0])[0]:
# find the indexes of the start of a human answer.
if human_token_ids == batch["labels"][i][human_idx : human_idx + len(human_token_ids)].tolist():
human_token_ids_idxs.append(human_idx)
if len(human_token_ids_idxs) == 0:
warnings.warn(
f"Could not find instruction key `{self.instruction_template}` in the "
f'following instance: {self.tokenizer.decode(batch["input_ids"][i])} '
f"This instance will be ignored in loss calculation. "
f"Note, if this happens often, consider increasing the `max_seq_length`."
)
batch["labels"][i, :] = self.ignore_index
if (
len(human_token_ids_idxs) > 0
and len(response_token_ids_idxs) > 0
and human_token_ids_idxs[0] > response_token_ids_idxs[0]
):
human_token_ids_idxs = [0] + human_token_ids_idxs
for idx, (start, end) in enumerate(zip(human_token_ids_idxs, response_token_ids_idxs)):
# Make pytorch loss function ignore all non response tokens
if idx != 0:
batch["labels"][i, start:end] = self.ignore_index
else:
batch["labels"][i, :end] = self.ignore_index
if len(response_token_ids_idxs) < len(human_token_ids_idxs):
batch["labels"][i, human_token_ids_idxs[-1] :] = self.ignore_index
if self.padding_free:
# remove padding, `attention_mask` and add `position_ids`
attn_mask = batch.pop("attention_mask")
batch["input_ids"] = batch["input_ids"][attn_mask.bool()].unsqueeze(0)
batch["position_ids"] = attn_mask.cumsum(1)[attn_mask.bool()].unsqueeze(0) - 1
batch["labels"] = batch["labels"][attn_mask.bool()].unsqueeze(0)
batch["labels"][batch["position_ids"] == 0] = self.ignore_index
return batch
class CustomDataCollatorForCompletionOnlyLM(DataCollatorForCompletionOnlyLM):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
batch = super().torch_call(examples)
# ensure the last tokens is taken into account for loss computation
# otherwise the model may never stop generating at inference
batch["labels"][:, -1] = batch["input_ids"][:, -1]
return batch
def load_model_and_tokenizer(args):
model = AutoModelForCausalLM.from_pretrained(
args.model_name_or_path,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
#device_map="auto",
device_map={"": 0},
)
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True)
if getattr(tokenizer, "pad_token_id") is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
model.config.pad_token_id = model.config.eos_token_id
return model, tokenizer
def main(args):
dataset = load_dataset("neulab/docprompting-conala")
model, tokenizer = load_model_and_tokenizer(args)
def tokenize(example):
prompt = f"{tokenizer.bos_token}\n"
prompt += f"### Instruction:\n{example['nl']}\n"
prompt += f"### Response:\n{example['cmd']}\n<|EOT|>"
model_inputs = tokenizer(prompt, truncation=True, max_length=128, padding="max_length")
return model_inputs
tokenized_dataset = dataset.map(
tokenize,
batched=False,
remove_columns=[cn for cn in dataset["train"].column_names if cn not in ["input_ids", "attention_mask"]],
)
training_args = TrainingArguments(
output_dir=args.run_dir,
evaluation_strategy="epoch",
save_strategy="epoch",
num_train_epochs=args.num_epochs,
per_device_train_batch_size=args.batch_size,
per_device_eval_batch_size=args.batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
learning_rate=args.learning_rate,
lr_scheduler_type=args.lr_scheduler_type,
logging_strategy="steps",
bf16=True,
logging_steps=1,
save_total_limit=10,
load_best_model_at_end=False,# BB
report_to="wandb" if args.use_wandb else "none",
max_steps=100,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset["train"],
eval_dataset=tokenized_dataset["validation"],
tokenizer=tokenizer,
data_collator=CustomDataCollatorForCompletionOnlyLM("### Response", tokenizer=tokenizer),
# callbacks=[EarlyStoppingCallback(early_stopping_patience=args.patience)]
)
trainer.train()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_name_or_path", default="deepseek-ai/deepseek-coder-1.3b-instruct", type=str)
parser.add_argument("--output_dir", default=".", type=str)
parser.add_argument("--num_epochs", type=int, default=3)
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
parser.add_argument("--patience", type=int, default=2)
parser.add_argument("--learning_rate", type=float, default=3e-4)
parser.add_argument("--lr_scheduler_type", type=str, default="linear")
parser.add_argument("--use_wandb", action="store_true")
parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args()
set_seed(args.seed)
args.model_name = args.model_name_or_path.split('/')[-1]
args.run_dir = f"{args.output_dir}/{args.model_name}_conala_prompt-tuning_new/"
main(args) |
Thanks a lot for looking into it! I'll use the latest version of PEFT with an older version of Transformers before it gets resolved :) |
@BenjaminBossan thank you for the deep dive and git bisect! Am I right in saying that training with and without PEFT has been broken on e.g. If so, is it still true even on commits like this one, which were supposed to stabilize training? Assuming the answer is yes, then it means some RoPE-related change was the culprit 🤔 The highlighted PR only touches RoPE. |
Thanks for the reply Joao. I checked out that commit and still find the strange loss pattern using normal fine-tuning:
So it looks quite similar to the pattern observed when using the latest transformers. I can't say for certain that this is "wrong" and the previous pattern is "right", maybe it's the other way round. It's just that the loss is quite low right from the start and does not improve. This, together with what @martin-wey reported, makes it more likely than not that something is indeed broken. Is there a way to easily switch the type of positional embedding? I guess not but if yes, we can test that. |
In the case of RoPE models no -- the model calls the position embeding's |
@BenjaminBossan we missed a detail :D It makes sense that the PR that introduced RoPE scaling has a huge impact on the loss -- the test model uses rope scaling Prior to that PR rope scaling didn't exist -> the fine tuning script changes the position embeddings (because scaling is not done) -> the loss starts at a much higher value |
Thanks @gante so it is expected that after the addition of RoPE scaling, the loss would change considerably? In this case, it is actually much lower not higher. @martin-wey My takeaway so far is that after the given PR, the model itself changed due to the addition of RoPE scaling, therefore it is no surprise that the trained PEFT adapter would not work anymore. However, we should still expect that prompt-tuning works even after said PR. I sat down again with the training script and what I found is that I had to increase the learning rate quite a lot to see training progress. You mentioned you tried up to 3e-2, I would recommend trying even higher, for instance I got a nice progress on eval loss with 0.1 (for |
@BenjaminBossan Yes. Before that PR, for that checkpoint in particular, we couldn't create properly the position embeddings with which the model was fine-tuned. That explains the lower starting loss in our most recent versions -- they are creating the position embeddings correctly. Conversely, if we use a checkpoint without [RoPE is a constant created at load-time, therefore it is not trained. However, it massively influences the attention layer if we don't create the constants correctly] |
Thanks for the suggestion. I also tried different learning rates, including 3e-1, 3e-2, and 3e-3. The evaluation loss is good and even decreases. However, I am still facing the issue at generation: the model tuned using prompt tuning or p-tuning generates identical content to the model in zero-shot when using the latest version of Transformers. I could solve this issue by downgrading Transformers to v4.29. |
Could you please provide the exact code you used for checking this? |
I use the fine-tuning script provided above. import torch
from datasets import load_from_disk
from peft import AutoPeftModelForCausalLM
from transformers import AutoTokenizer
model = AutoPeftModelForCausalLM.from_pretrained("runs/Phi-3-mini-128k-instruct_conala_p-tuning/checkpoint-1335").to("cuda")
model.eval()
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-128k-instruct")
example = [
{"role": "user", "content": "divide the values with same keys of two dictionary `d1` and `d2`"}
]
prompt = tokenizer.apply_chat_template(example, add_generation_prompt=True, return_dict=True)
with torch.no_grad():
inputs = {k: v.to(model.device) for k, v in inputs.items()}
outputs = model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_new_tokens=128
)
print(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True)) For all checkpoints (after 1/2/3/4/5 epochs), the model generates identical content to the base model, although the evaluation loss seems ok for both p-tuning and prompt tuning. For reference, the evaluation losses are somewhat similar to the model fine-tuned with LoRA. However, I do not get this issue at generation when fine-tuning with LoRA. I am also getting the following warning: |
Thanks @martin-wey for the extra info. I tried to reproduce this using the latest PEFT and transformers version. I changed the base model to I modified your inference code slightly, as it's not quite complete. This is what I get: import gc
import torch
from peft import AutoPeftModelForCausalLM
from transformers import AutoTokenizer, AutoModelForCausalLM
device = 0
seed = 0
model_id = "deepseek-ai/deepseek-coder-1.3b-instruct"
example = [
{"role": "user", "content": "divide the values with same keys of two dictionary `d1` and `d2`"}
]
prompt = tokenizer.apply_chat_template(example, add_generation_prompt=True, return_dict=True)
inputs = {k: torch.tensor(v).to(device).unsqueeze(0) for k, v in prompt.items()}
# base model
model = AutoModelForCausalLM.from_pretrained(model_id)
model.to(device);
torch.manual_seed(seed)
model.eval();
with torch.no_grad():
outputs = model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_new_tokens=128
)
print(tokenizer.batch_decod
del model
torch.cuda.empty_cache()
gc.collect()
# trained model
path = "./checkpoint_deepseek-ai--deepseek-coder-1.3b-instruct_peft-0.12.1.dev0_transformers-4.45.0.dev0"
model = AutoPeftModelForCausalLM.from_pretrained(path)
model.to(device);
torch.manual_seed(seed)
model.eval();
with torch.no_grad():
outputs = model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_new_tokens=128
)
print(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True)[0]) For the base model, what I get is (only response part):
(note that I escaped the backtick for formatting reasons) And for the trained model:
So we can see that the model did learn something and that the response is much more succinct, similar to what we find in the training data. I tried a few seeds to ensure this was not random but got the same results. |
Thank you @BenjaminBossan! Upgrading to the latest versions of transformers and peft solves the issue. |
System Info
python 3.10.10, transformers 4.44.2, peft 0.12.0
Who can help?
@BenjaminBossan @sayak
Information
Tasks
examples
folderReproduction
The following is a simplified script to reproduce the bug. I have experienced the same issue using transformers.Trainer. The fine-tuning using PEFT and p-tuning/prompt tuning works perfectly. However, when reloading the model from a saved PEFT checkpoint for generation, some modules are missing. In turn, the model does not generate expected content.
Basic script to save the adapter:
Output:
Load the PEFT model from checkpoint (method 1):
Load the PEFT model from checkpoint (method 2):
Output: The MLP part of the prompt encoder is missing
Perhaps something's wrong here: https://github.com/huggingface/peft/blob/v0.12.0/src/peft/tuners/p_tuning/model.py#L82
Expected behavior
I believe the entire prompt encoder should be reloaded from the PEFT checkpoint.
When I reload the model using the aforementioned methods, it generates content that is close to the base model, meaning the prompt encoder is not loaded properly.
The text was updated successfully, but these errors were encountered: