Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX: Small numerical discrepancy for p-tuning after loading the model #2047

Conversation

BenjaminBossan
Copy link
Member

There is a small numerical discrepancy between the outputs of a p-tuning model before and after loading. Even though it is small, it can still affect generations, so this PR eliminates it.

As an example, without the fix, this is the difference in logits for opt-125m:

>       torch.testing.assert_close(output_loaded, output_peft)
E       AssertionError: Tensor-likes are not close!
E
E       Mismatched elements: 30 / 10557120 (0.0%)
E       Greatest absolute difference: 1.1086463928222656e-05 at index (0, 9, 9314) (up to 1e-05 allowed)
E       Greatest relative difference: 0.00021288332936819643 at index (0, 9, 9314) (up to 1.3e-06 allowed)

Details about how this comes about are explained here:

#2043 (comment)

The gist of it is that if we take a single sample, repeat it X times, and then forward it through a model (which is the training path in p-tuning), we would expect the same output as if we forwarded this sample only once and repeated the output X times (the inference path for p-tuning). However, for sufficiently large models, the two approaches can have tiny differences.

With the fixed approach, there is no difference between training and inference code paths when it comes to this. The new code should also be slightly more compute efficient, but in practice will not make a noticeable difference.

There is a small numerical discrepancy between the outputs of a p-tuning
model before and after loading. Even though it is small, it can still
affect generations, so this PR eliminates it.

As an example, without the fix, this is the difference in logits for
opt-125m:

>       torch.testing.assert_close(output_loaded, output_peft)
E       AssertionError: Tensor-likes are not close!
E
E       Mismatched elements: 30 / 10557120 (0.0%)
E       Greatest absolute difference: 1.1086463928222656e-05 at index (0, 9, 9314) (up to 1e-05 allowed)
E       Greatest relative difference: 0.00021288332936819643 at index (0, 9, 9314) (up to 1.3e-06 allowed)

Details about how this comes about are explained here:

huggingface#2043 (comment)

The gist of it is that if we take a single sample, repeat it X times,
and then forward it through a model (which is the training path in
p-tuning), we would expect the same output as if we forwarded this
sample only once and repeated the output X times (the inference path for
p-tuning). However, for sufficiently large models, the two approaches
can have tiny differences.

With the fixed approach, there is no difference between training and
inference code paths when it comes to this. The new code should also be
slightly more compute efficient, but in practice will not make a
noticeable difference.
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

torch.cuda.empty_cache()
gc.collect()

torch.manual_seed(0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even if we're loading from a pre-trained model, we need to set a seed?

model_id = "facebook/opt-125m"
inputs = torch.arange(10).view(-1, 1).to(self.device)

torch.manual_seed(0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even if we're loading from a pre-trained model, we need to set a seed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's more of a routine I have to ensure reproducibility in tests. Since we call get_peft_model, which initializes a random model, it makes sense to set a seed here. The seed below is indeed not required, as there is only model loading, so I removed it.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a single comment but nothing blocking.

@BenjaminBossan BenjaminBossan merged commit 95b3964 into huggingface:main Sep 3, 2024
14 checks passed
@BenjaminBossan BenjaminBossan deleted the fix-small-numerical-discrepancy-p-tuning branch September 3, 2024 14:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants