Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Support FSDP2 #3231

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open

[RFC] Support FSDP2 #3231

wants to merge 2 commits into from

Conversation

kmehant
Copy link

@kmehant kmehant commented Nov 8, 2024

What does this PR do?

Prototype implementation for porting from FSDP V1 to FSDP V2. There are couple of open questions in this PR that would need comments and discussion.

  1. Do we want to maintain FSDP V1 as is and add a experimental parallel to FSDP V2?
  2. When we want to maintain 2 versions, should we maintain separate FSDP plugins and distributed types for each versions?
  3. For HF/transformers users, using fsdp_config, how we want to allow them to choose between these versions?
  4. How we want prepare 2D mesh for HSDP, should that be an input from user?

Preliminary run of this PR and results

The current version of the PR has been tested for basic functionality (full shard) and compared with previous FSDP V1 implementation.

Key Value
Model Maykeye/TinyLLama-v0
Mesh size 2 GPUs
sharding full shard

Memory

Screenshot 2024-11-09 at 12 50 10 AM

Loss Parity

Screenshot 2024-11-09 at 12 59 56 AM

Throughput

TODO

Fixes #2873

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@muellerzr

Signed-off-by: Mehant Kammakomati <[email protected]>
@raghukiran1224
Copy link

@ByronHsu FYI - thoughts?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@SumanthRH
Copy link
Contributor

@kmehant thanks for starting this PR! I was looking at FSDP2 support in accelerate and landed here!

Do we want to maintain FSDP V1 as is and add a experimental parallel to FSDP V2?

Yes please! It looks like FSDP2 will be in a public API in the next torch release (2.6): pytorch/pytorch@d815efc , so maybe things are somewhat stable ? But many of the older config parameters (like auto_wrap_policy) are simply not there in V2 so I'd prefer if accelerate users get time to migrate.

When we want to maintain 2 versions, should we maintain separate FSDP plugins and distributed types for each versions?

Hmm if the new API had supported most of the V1 configurations, I would think having only a feature flag would be enough -i.e something like ACCELERATE_FSDP2_ENABLED. But looking at the API differences: https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md#fsdp1--fsdp2-api-differences
I think just having a feature flag + a different plugin FullyShardedDataParallelPluginV2 is better.

For HF/transformers users, using fsdp_config, how we want to allow them to choose between these versions?

It looks like accelerate doesn't do much validation for the config parameters in fsdp_config until plugin initialization. So in the accelerate config, specifying enable_fsdp2 should be fine , and users would be expected to list only v2 parameters in fsdp_config - this is validated when FullyShardedDataParallelPluginV2 is initialized.

@SumanthRH
Copy link
Contributor

cc @muellerzr curious to know if this already in the pipeline internally from HF!

# auto_wrap_policy is not yet supported by FSDP2
# therefore manual wrapping has to be done like below
#######
for layer in model.model.layers:
Copy link

@kyleliang919 kyleliang919 Jan 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one doesn't seem to apply to general use case.
Feels like it should be something like below that checks and apply fully_shard from bottom up.

stack = [model]
ordered_modules = []
while stack:
    current_modules = stack.pop()
    for _, attr in current_module.__dict__.items():
        if isinstance(attr, torch.nn.Module):
            stack.append(attr)
    ordered_modules.append(current_module)

for each in ordered_modules[::-1]:
    fully_shard(each, **fsdp2_kwargs)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Plan to support FSDP2?
5 participants