Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Support FSDP2 #3231

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 103 additions & 5 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1464,13 +1464,15 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
elif self.distributed_type == DistributedType.FSDP:
# We need to fix the optimizer *before* sharding the model
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP

from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy, CPUOffloadPolicy, FSDPModule, OffloadPolicy
from torch.distributed.fsdp.api import ShardingStrategy
# Check if the model is already a FSDP model due to `Manual Wrapping` and if so,
# don't wrap it again
# In case the model is already compiled using PyTorch 2.0 and the wrapped model in it
# is a FSDP model, don't wrap it again
is_type_fsdp = isinstance(model, FSDP) or (
is_compiled_module(model) and isinstance(model._orig_mod, FSDP)
# We check for FSDPModule instead of FSDP class for FSDP v2
is_type_fsdp = isinstance(model, FSDPModule) or (
is_compiled_module(model) and isinstance(model._orig_mod, FSDPModule)
)

if not is_type_fsdp:
Expand Down Expand Up @@ -1498,8 +1500,100 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
"ignored_modules": fsdp_plugin.ignored_modules,
"limit_all_gathers": fsdp_plugin.limit_all_gathers,
"device_id": self.device,
}
#######
# fsdp2_kwargs holds all the args supported by
# FSDP2 through fully_shard API
# Most of FSDP2 args can be deduced from the existing FSDP1 args
# Some of the existing FSDP1 args not supported or by default set to True
# information can be found here - https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md
#######
fsdp2_kwargs = {
"reshard_after_forward": True,
"mesh": None,
"mp_policy": MixedPrecisionPolicy(),
"offload_policy": OffloadPolicy()
# shard_placement_fn has been a feature quite recently
# "shard_placement_fn": None
}
model = FSDP(model, **kwargs)

#######
# Preparation of mesh and reshard_after_forward
# Both of these params may be exposed directly to user to be passed through FSDP config
# However, otherway could be to hide them and set them based on sharding strategy

# Deduction of the mesh and reshard_after_forward from sharding strategy analogy
# borrowed from https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md
# 1 process_group + FULL_SHARD ==> 1D mesh + reshard_after_forward=True
# 1 process_group + SHARD_GRAD_OP ==> 1D mesh + reshard_after_forward=False
# 2 process_groups/2D device_mesh + HYBRID_SHARD ==> 2D mesh + reshard_after_forward=True
# 2 process_groups/2D device_mesh + _HYBRID_SHARD_ZERO2 ==> 2D mesh + reshard_after_forward=False
#######

if kwargs["sharding_strategy"] == ShardingStrategy.FULL_SHARD:
# mesh
# no need to prepare mesh and go with default

# reshard_after_forward = True
fsdp2_kwargs["reshard_after_forward"]=True
elif kwargs["sharding_strategy"] == ShardingStrategy.SHARD_GRAD_OP:
# mesh
# no need to prepare mesh and go with default

# # reshard_after_forward = False
fsdp2_kwargs["reshard_after_forward"]=False
elif kwargs["sharding_strategy"] == ShardingStrategy.HYBRID_SHARD:
# mesh
# at this point, pytorch does not set 2 d mesh by default based on inter and intra node assumption
# https://github.com/pytorch/pytorch/issues/140102
# reshard_after_forward = True
fsdp2_kwargs["reshard_after_forward"]=True
elif kwargs["sharding_strategy"] == ShardingStrategy._HYBRID_SHARD_ZERO2:
# mesh
# at this point, pytorch does not set 2 d mesh by default based on inter and intro node assumption
# https://github.com/pytorch/pytorch/issues/140102
# reshard_after_forward = False
fsdp2_kwargs["reshard_after_forward"]=False

#######
# mixed precision policy can be mapped from FSDP1 to FSDP2 arg classes
# except for output_dtype new to FSDP2 and has to come from user
#######

if kwargs["mixed_precision"] is not None:
# MixedPrecisionPolicy is from the new _composable design
fsdp2_kwargs["mp_policy"] = MixedPrecisionPolicy(
param_dtype=kwargs["mixed_precision"].param_dtype,
reduce_dtype=kwargs["mixed_precision"].reduce_dtype,
cast_forward_inputs=kwargs["mixed_precision"].cast_forward_inputs
# output_dtype cannot be deduced from FSDP1 args and has to come from user
# buffer_dtype is not available, is it not required for FSDP2?
)

#######
# offload policy can be mapped from FSDP1 to FSDP2 arg classes
# pinning memory seems to be a new feature to FSDP2
# offloading params is the default behaviour
#######

if kwargs["cpu_offload"] is not None and kwargs["cpu_offload"].offload_params:
# CPUOffloadPolicy is from the new _composable design
fsdp2_kwargs["mp_policy"] = CPUOffloadPolicy(
# pin_memory= cannot be deduced from FSDP1 args and has to come from user
# offloads params is the default behaviour
)

#######
# auto_wrap_policy is not yet supported by FSDP2
# therefore manual wrapping has to be done like below
#######
for layer in model.model.layers:
Copy link

@kyleliang919 kyleliang919 Jan 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one doesn't seem to apply to general use case.
Feels like it should be something like below that checks and apply fully_shard from bottom up.

stack = [model]
ordered_modules = []
while stack:
    current_modules = stack.pop()
    for _, attr in current_module.__dict__.items():
        if isinstance(attr, torch.nn.Module):
            stack.append(attr)
    ordered_modules.append(current_module)

for each in ordered_modules[::-1]:
    fully_shard(each, **fsdp2_kwargs)

fully_shard(layer, **fsdp2_kwargs)
fully_shard(model, **fsdp2_kwargs)

#######
# does existing activation_checkpointing API work out of the box with FSDP2?
#######
if fsdp_plugin.activation_checkpointing:
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
CheckpointImpl,
Expand Down Expand Up @@ -2364,7 +2458,11 @@ def clip_grad_norm_(self, parameters, max_norm, norm_type=2):
parameters = [p for p in parameters]
for model in self._models:
if parameters == [p for p in model.parameters()]:
return model.clip_grad_norm_(max_norm, norm_type)
#######
# gradient clipping function is not part of the FSDP class object like in FSDP v1
# rather is removed
#######
return torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm, norm_type)
elif self.distributed_type == DistributedType.DEEPSPEED:
# `accelerator.backward(loss)` is doing that automatically. Therefore, its implementation is not needed
# We cannot return the gradient norm because DeepSpeed does it.
Expand Down
Loading