Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX: Small numerical discrepancy for p-tuning after loading the model #2047

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/peft/peft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,9 +690,13 @@ def get_prompt(self, batch_size: int, task_ids: Optional[torch.Tensor] = None) -
prompts = prompt_encoder(prompt_tokens, task_ids)
else:
if peft_config.inference_mode:
prompts = prompt_encoder.embedding.weight.repeat(batch_size, 1, 1)
prompts = prompt_encoder.embedding.weight
else:
# Take only one prompt token sample and expand the output instead of expanding the input, see:
# https://github.com/huggingface/peft/issues/2043#issuecomment-2321522577
prompt_tokens = prompt_tokens[:1]
prompts = prompt_encoder(prompt_tokens)
prompts = prompts.repeat(batch_size, 1, 1)
return prompts

def get_nb_trainable_parameters(self) -> tuple[int, int]:
Expand Down
40 changes: 40 additions & 0 deletions tests/test_gpu_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
LoftQConfig,
LoraConfig,
PeftModel,
PromptEncoderConfig,
TaskType,
get_peft_model,
prepare_model_for_kbit_training,
Expand Down Expand Up @@ -3147,3 +3148,42 @@ def test_boft_half_conv(self):
conv = boft.layer.Conv2d(conv, "conv", boft_n_butterfly_factor=2).to(dtype=torch.bfloat16)
x = torch.randn(1, 160, 160, device="cuda", dtype=torch.bfloat16)
conv(x) # does not raise


@require_torch_gpu
class TestPTuningReproducibility:
device = infer_device()

def test_p_tuning_exactly_reproducible_after_loading(self, tmp_path):
# See: https://github.com/huggingface/peft/issues/2043#issuecomment-2321522577
# Ensure that after loading a p-tuning checkpoint, results are exactly reproducible (before the patch, they were
# only _almost_ identical).

# The model must be sufficiently large for the effect to be measurable, which is why this test requires is not
# run on CPU.
model_id = "facebook/opt-125m"
inputs = torch.arange(10).view(-1, 1).to(self.device)

torch.manual_seed(0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even if we're loading from a pre-trained model, we need to set a seed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's more of a routine I have to ensure reproducibility in tests. Since we call get_peft_model, which initializes a random model, it makes sense to set a seed here. The seed below is indeed not required, as there is only model loading, so I removed it.

model = AutoModelForCausalLM.from_pretrained(model_id).to(self.device)
peft_config = PromptEncoderConfig(task_type="CAUSAL_LM", num_virtual_tokens=20, encoder_hidden_size=128)
model = get_peft_model(model, peft_config).eval()

with torch.inference_mode():
output_peft = model(inputs).logits
gen_peft = model.generate(inputs, min_new_tokens=10, max_new_tokens=10)

model.save_pretrained(tmp_path)
del model
torch.cuda.empty_cache()
gc.collect()

model = AutoModelForCausalLM.from_pretrained(model_id).to(self.device)
model = PeftModel.from_pretrained(model, tmp_path)

with torch.inference_mode():
output_loaded = model(inputs).logits
gen_loaded = model.generate(inputs, min_new_tokens=10, max_new_tokens=10)

torch.testing.assert_close(output_loaded, output_peft)
torch.testing.assert_close(gen_loaded, gen_peft)
Loading