From e415ad1b0ffd88b72ffb5eea49689a3ab7b751f6 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Mon, 2 Sep 2024 16:19:09 +0200 Subject: [PATCH 1/2] FIX: Small numerical discrepancy for p-tuning There is a small numerical discrepancy between the outputs of a p-tuning model before and after loading. Even though it is small, it can still affect generations, so this PR eliminates it. As an example, without the fix, this is the difference in logits for opt-125m: > torch.testing.assert_close(output_loaded, output_peft) E AssertionError: Tensor-likes are not close! E E Mismatched elements: 30 / 10557120 (0.0%) E Greatest absolute difference: 1.1086463928222656e-05 at index (0, 9, 9314) (up to 1e-05 allowed) E Greatest relative difference: 0.00021288332936819643 at index (0, 9, 9314) (up to 1.3e-06 allowed) Details about how this comes about are explained here: https://github.com/huggingface/peft/issues/2043#issuecomment-2321522577 The gist of it is that if we take a single sample, repeat it X times, and then forward it through a model (which is the training path in p-tuning), we would expect the same output as if we forwarded this sample only once and repeated the output X times (the inference path for p-tuning). However, for sufficiently large models, the two approaches can have tiny differences. With the fixed approach, there is no difference between training and inference code paths when it comes to this. The new code should also be slightly more compute efficient, but in practice will not make a noticeable difference. --- src/peft/peft_model.py | 6 +++++- tests/test_gpu_examples.py | 41 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 1 deletion(-) diff --git a/src/peft/peft_model.py b/src/peft/peft_model.py index b7d270b4aa..a26339e432 100644 --- a/src/peft/peft_model.py +++ b/src/peft/peft_model.py @@ -690,9 +690,13 @@ def get_prompt(self, batch_size: int, task_ids: Optional[torch.Tensor] = None) - prompts = prompt_encoder(prompt_tokens, task_ids) else: if peft_config.inference_mode: - prompts = prompt_encoder.embedding.weight.repeat(batch_size, 1, 1) + prompts = prompt_encoder.embedding.weight else: + # Take only one prompt token sample and expand the output instead of expanding the input, see: + # https://github.com/huggingface/peft/issues/2043#issuecomment-2321522577 + prompt_tokens = prompt_tokens[:1] prompts = prompt_encoder(prompt_tokens) + prompts = prompts.repeat(batch_size, 1, 1) return prompts def get_nb_trainable_parameters(self) -> tuple[int, int]: diff --git a/tests/test_gpu_examples.py b/tests/test_gpu_examples.py index 5a45bbb91b..eaad624a60 100644 --- a/tests/test_gpu_examples.py +++ b/tests/test_gpu_examples.py @@ -52,6 +52,7 @@ LoftQConfig, LoraConfig, PeftModel, + PromptEncoderConfig, TaskType, get_peft_model, prepare_model_for_kbit_training, @@ -3147,3 +3148,43 @@ def test_boft_half_conv(self): conv = boft.layer.Conv2d(conv, "conv", boft_n_butterfly_factor=2).to(dtype=torch.bfloat16) x = torch.randn(1, 160, 160, device="cuda", dtype=torch.bfloat16) conv(x) # does not raise + + +@require_torch_gpu +class TestPTuningReproducibility: + device = infer_device() + + def test_p_tuning_exactly_reproducible_after_loading(self, tmp_path): + # See: https://github.com/huggingface/peft/issues/2043#issuecomment-2321522577 + # Ensure that after loading a p-tuning checkpoint, results are exactly reproducible (before the patch, they were + # only _almost_ identical). + + # The model must be sufficiently large for the effect to be measurable, which is why this test requires is not + # run on CPU. + model_id = "facebook/opt-125m" + inputs = torch.arange(10).view(-1, 1).to(self.device) + + torch.manual_seed(0) + model = AutoModelForCausalLM.from_pretrained(model_id).to(self.device) + peft_config = PromptEncoderConfig(task_type="CAUSAL_LM", num_virtual_tokens=20, encoder_hidden_size=128) + model = get_peft_model(model, peft_config).eval() + + with torch.inference_mode(): + output_peft = model(inputs).logits + gen_peft = model.generate(inputs, min_new_tokens=10, max_new_tokens=10) + + model.save_pretrained(tmp_path) + del model + torch.cuda.empty_cache() + gc.collect() + + torch.manual_seed(0) + model = AutoModelForCausalLM.from_pretrained(model_id).to(self.device) + model = PeftModel.from_pretrained(model, tmp_path) + + with torch.inference_mode(): + output_loaded = model(inputs).logits + gen_loaded = model.generate(inputs, min_new_tokens=10, max_new_tokens=10) + + torch.testing.assert_close(output_loaded, output_peft) + torch.testing.assert_close(gen_loaded, gen_peft) From 511213ec7ab3c46dc653cb356d25035484b68873 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Tue, 3 Sep 2024 11:41:06 +0200 Subject: [PATCH 2/2] Remove unnecessary seeding --- tests/test_gpu_examples.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_gpu_examples.py b/tests/test_gpu_examples.py index eaad624a60..cbaee5f5ec 100644 --- a/tests/test_gpu_examples.py +++ b/tests/test_gpu_examples.py @@ -3178,7 +3178,6 @@ def test_p_tuning_exactly_reproducible_after_loading(self, tmp_path): torch.cuda.empty_cache() gc.collect() - torch.manual_seed(0) model = AutoModelForCausalLM.from_pretrained(model_id).to(self.device) model = PeftModel.from_pretrained(model, tmp_path)