Skip to content

Commit

Permalink
support fsdp plugin created in training args
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Aug 22, 2024
1 parent 975b988 commit 746154d
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4751,6 +4751,7 @@ def create_accelerator_and_postprocess(self):

args = {
"deepspeed_plugin": self.args.deepspeed_plugin,
"fsdp_plugin": self.args.fsdp_plugin,
"gradient_accumulation_plugin": gradient_accumulation_plugin,
}
if is_accelerate_available("0.28.0"):
Expand Down
8 changes: 8 additions & 0 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1906,6 +1906,7 @@ def __post_init__(self):
warnings.warn("`--xla_fsdp_grad_ckpt` is useful only when `--xla` is set to true.")

# accelerate integration for FSDP
self.fsdp_plugin = None
if len(self.fsdp) > 0 and is_accelerate_available("0.28.0"):
os.environ["ACCELERATE_USE_FSDP"] = "true"
from accelerate.utils.constants import (
Expand Down Expand Up @@ -1948,6 +1949,13 @@ def __post_init__(self):

os.environ[f"{prefix}USE_ORIG_PARAMS"] = str(self.fsdp_config.get("use_orig_params", "true")).lower()

fsdp_plugin_kwargs = {}
if self.fsdp_config.get("device_mesh", None) and is_accelerate_available("0.34.0"):
fsdp_plugin_kwargs["device_mesh"] = self.fsdp_config["device_mesh"]

from accelerate.utils import FullyShardedDataParallelPlugin
self.fsdp_plugin = FullyShardedDataParallelPlugin(**fsdp_plugin_kwargs)

if self.tpu_metrics_debug:
warnings.warn(
"using `--tpu_metrics_debug` is deprecated and will be removed in version 5 of 🤗 Transformers. Use"
Expand Down

0 comments on commit 746154d

Please sign in to comment.