Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

float8 delayed scaling: private API to fix user overriding buffers #1292

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

vkuzo
Copy link
Contributor

@vkuzo vkuzo commented Nov 15, 2024

Summary:

Context: pytorch/torchtitan#654

If the user has delayed scaling and FSDP float8 all-gather on, there is a subtle bug that can happen if the user calls
model.to_empty(device="cuda"):

  1. to_empty recreates the buffers for tracking weight amax and scale
  2. (1) leaves the buffers pointed to by Float8Linear.weight._amax_buffer, etc orphaned, because they don't participate in to_empty

I couldn't think of an easy and clean way to auto-fix this since we can't expect torch.nn.Module to know that our logic has multiple references to the same buffer, so exposing a private API for now until we can think of something better.

With the current fix, the user can then call
_maybe_fixup_delayed_scaling_buffers manually to relink the buffers to the correct new versions.

Test Plan: CI

Reviewers:

Subscribers:

Tasks:

Tags:

Copy link

pytorch-bot bot commented Nov 15, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1292

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit faa1593 with merge base 56bf2e8 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Nov 15, 2024
Summary:

Context: pytorch/torchtitan#654

If the user has delayed scaling and FSDP float8 all-gather on, there is
a subtle bug that can happen if the user calls
`model.to_empty(device="cuda")`:
1. to_empty recreates the buffers for tracking weight amax and scale
2. (1) leaves the buffers pointed to by Float8Linear.weight._amax_buffer, etc orphaned, because they don't participate in `to_empty`

I couldn't think of an easy and clean way to auto-fix this since we can't expect
`torch.nn.Module` to know that our logic has multiple references to the same
buffer, so exposing a private API for now until we can think of something better.

With the current fix, the user can then call
`_maybe_fixup_delayed_scaling_buffers` manually to relink the buffers to
the correct new versions.

Test Plan: CI

Reviewers:

Subscribers:

Tasks:

Tags:
@vkuzo vkuzo added the topic: not user facing Use this tag if you don't want this PR to show up in release notes label Nov 15, 2024
@vkuzo vkuzo requested a review from weifengpy November 15, 2024 06:00
assert m_fp8[0].fp8_scale_weight is m_fp8[0].weight._scale_buffer

m_fp8.to_empty(device="cuda")
m_fp8[0]._maybe_fixup_delayed_scaling_buffers()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we would need to call this inside torchtitan’s training loop?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: not user facing Use this tag if you don't want this PR to show up in release notes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants