Skip to content

Commit

Permalink
automatically set sync_module_states if low_cpu_mem is set
Browse files Browse the repository at this point in the history
  • Loading branch information
fabianlim committed Apr 30, 2024
1 parent f8aba3c commit 19cfab4
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 3 deletions.
3 changes: 2 additions & 1 deletion docs/source/concept_guides/fsdp_and_deepspeed.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ While FSDP require an explicit `--fsdp_cpu_ram_efficient_loading true` to activa

<Tip>

For FSDP, whenever setting `--fsdp_cpu_ram_efficient_loading true`, please also set `--fsdp_sync_module_states true`, otherwise the model will not load properly.
For FSDP, whenever setting `--fsdp_cpu_ram_efficient_loading true`, 🤗 `accelerate` will automatically set `sync_module_states` to true.
For RAM efficient loading the weights will be loaded only in a singe rank, and thus requires `sync_module_states` to broadcast weights to other ranks.

</Tip>

Expand Down
7 changes: 7 additions & 0 deletions src/accelerate/utils/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -1143,6 +1143,13 @@ def __post_init__(self):
self.forward_prefetch = str_to_bool(os.environ.get(prefix + "FORWARD_PREFETCH", "False")) == 1
self.activation_checkpointing = str_to_bool(os.environ.get(prefix + "ACTIVATION_CHECKPOINTING", "False")) == 1

if str_to_bool(os.environ.get("FSDP_CPU_RAM_EFFICIENT_LOADING", "False")) == 1 and not self.sync_module_states:
warnings.warn(
"sync_module_states cannot be False since efficient cpu ram loading enabled. "
"Setting sync_module_states to True."
)
self.sync_module_states = True

if self.sync_module_states:
if is_npu_available():
device = torch.npu.current_device()
Expand Down
2 changes: 0 additions & 2 deletions src/accelerate/utils/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,6 @@ def prepare_multi_gpu_env(args: argparse.Namespace) -> Dict[str, str]:

if args.use_fsdp:
current_env["ACCELERATE_USE_FSDP"] = "true"
if args.fsdp_cpu_ram_efficient_loading and not args.fsdp_sync_module_states:
raise ValueError("When using `--fsdp_cpu_ram_efficient_loading` set `--fsdp_sync_module_states` to `True`")

current_env["FSDP_SHARDING_STRATEGY"] = str(args.fsdp_sharding_strategy)
current_env["FSDP_OFFLOAD_PARAMS"] = str(args.fsdp_offload_params).lower()
Expand Down

0 comments on commit 19cfab4

Please sign in to comment.