diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 8d56d82d77d972..1a7c0a9a453632 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -4751,6 +4751,7 @@ def create_accelerator_and_postprocess(self): args = { "deepspeed_plugin": self.args.deepspeed_plugin, + "fsdp_plugin": self.args.fsdp_plugin, "gradient_accumulation_plugin": gradient_accumulation_plugin, } if is_accelerate_available("0.28.0"): diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index a270754a26abe5..f098ffa3d30519 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -1906,6 +1906,7 @@ def __post_init__(self): warnings.warn("`--xla_fsdp_grad_ckpt` is useful only when `--xla` is set to true.") # accelerate integration for FSDP + self.fsdp_plugin = None if len(self.fsdp) > 0 and is_accelerate_available("0.28.0"): os.environ["ACCELERATE_USE_FSDP"] = "true" from accelerate.utils.constants import ( @@ -1948,6 +1949,13 @@ def __post_init__(self): os.environ[f"{prefix}USE_ORIG_PARAMS"] = str(self.fsdp_config.get("use_orig_params", "true")).lower() + fsdp_plugin_kwargs = {} + if self.fsdp_config.get("device_mesh", None) and is_accelerate_available("0.34.0"): + fsdp_plugin_kwargs["device_mesh"] = self.fsdp_config["device_mesh"] + + from accelerate.utils import FullyShardedDataParallelPlugin + self.fsdp_plugin = FullyShardedDataParallelPlugin(**fsdp_plugin_kwargs) + if self.tpu_metrics_debug: warnings.warn( "using `--tpu_metrics_debug` is deprecated and will be removed in version 5 of 🤗 Transformers. Use"