Skip to content

Commit

Permalink
fsdp fixes and enhancements (#24980)
Browse files Browse the repository at this point in the history
* fix fsdp prepare to remove the warnings and fix excess memory usage

* Update training_args.py

* parity for FSDP+XLA

* Update trainer.py
  • Loading branch information
pacman100 authored Jul 21, 2023
1 parent ec3dfe5 commit f4eb459
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 5 deletions.
4 changes: 2 additions & 2 deletions docs/source/en/main_classes/trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ as the model saving with FSDP activated is only available with recent fixes.
- Remaining FSDP config is passed via `--fsdp_config <path_to_fsdp_config.json>`. It is either a location of
FSDP json config file (e.g., `fsdp_config.json`) or an already loaded json file as `dict`.
- If auto wrapping is enabled, you can either use transformer based auto wrap policy or size based auto wrap policy.
- For transformer based auto wrap policy, please specify `fsdp_transformer_layer_cls_to_wrap` in the config file.
- For transformer based auto wrap policy, it is recommended to specify `fsdp_transformer_layer_cls_to_wrap` in the config file. If not specified, the default value is `model._no_split_modules` when available.
This specifies the list of transformer layer class name (case-sensitive) to wrap ,e.g, [`BertLayer`], [`GPTJBlock`], [`T5Block`] ....
This is important because submodules that share weights (e.g., embedding layer) should not end up in different FSDP wrapped units.
Using this policy, wrapping happens for each block containing Multi-Head Attention followed by couple of MLP layers.
Expand Down Expand Up @@ -482,7 +482,7 @@ Pass `--fsdp "full shard"` along with following changes to be made in `--fsdp_co
This setting can only be used when the xla flag is set to true, and an auto wrapping policy is specified through
`fsdp_min_num_params` or `fsdp_transformer_layer_cls_to_wrap`.
- You can either use transformer based auto wrap policy or size based auto wrap policy.
- For transformer based auto wrap policy, please specify `fsdp_transformer_layer_cls_to_wrap` in the config file.
- For transformer based auto wrap policy, it is recommended to specify `fsdp_transformer_layer_cls_to_wrap` in the config file. If not specified, the default value is `model._no_split_modules` when available.
This specifies the list of transformer layer class name (case-sensitive) to wrap ,e.g, [`BertLayer`], [`GPTJBlock`], [`T5Block`] ....
This is important because submodules that share weights (e.g., embedding layer) should not end up in different FSDP wrapped units.
Using this policy, wrapping happens for each block containing Multi-Head Attention followed by couple of MLP layers.
Expand Down
13 changes: 11 additions & 2 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1377,18 +1377,24 @@ def _wrap_model(self, model, training=True, dataloader=None):
raise ImportError("Missing XLA FSDP related module; please make sure to use torch-xla >= 2.0.")
auto_wrap_policy = None
auto_wrapper_callable = None
default_transformer_cls_names_to_wrap = getattr(model, "_no_split_modules", None)
fsdp_transformer_layer_cls_to_wrap = self.args.fsdp_config.get(
"fsdp_transformer_layer_cls_to_wrap", default_transformer_cls_names_to_wrap
)

if self.args.fsdp_config["fsdp_min_num_params"] > 0:
auto_wrap_policy = functools.partial(
size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"]
)
elif self.args.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None:
elif fsdp_transformer_layer_cls_to_wrap is not None:
transformer_cls_to_wrap = set()
for layer_class in self.args.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]:
for layer_class in fsdp_transformer_layer_cls_to_wrap:
transformer_cls = get_module_class_from_name(model, layer_class)
if transformer_cls is None:
raise Exception("Could not find the transformer layer class to wrap in the model.")
else:
transformer_cls_to_wrap.add(transformer_cls)

auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
# Transformer layer class to wrap
Expand Down Expand Up @@ -1600,6 +1606,7 @@ def _inner_training_loop(
and self.sharded_ddp != ShardedDDPOption.SIMPLE
or is_sagemaker_mp_enabled()
or self.fsdp is not None
or self.is_fsdp_enabled
)

# We need to reset the scheduler, as its parameters may be different on subsequent calls
Expand Down Expand Up @@ -1631,6 +1638,8 @@ def _inner_training_loop(
use_accelerator_prepare = True if model is self.model else False

if delay_optimizer_creation:
if use_accelerator_prepare:
self.model = self.accelerator.prepare(self.model)
self.create_optimizer_and_scheduler(num_training_steps=max_steps)

# prepare using `accelerator` prepare
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1567,14 +1567,14 @@ def __post_init__(self):
elif fsdp_option == FSDPOption.OFFLOAD:
os.environ["FSDP_OFFLOAD_PARAMS"] = "true"
elif fsdp_option == FSDPOption.AUTO_WRAP:
os.environ["FSDP_AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[0]
if self.fsdp_config["fsdp_min_num_params"] > 0:
os.environ["FSDP_MIN_NUM_PARAMS"] = str(self.fsdp_config["fsdp_min_num_params"])
os.environ["FSDP_AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[1]
elif self.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None:
os.environ["FSDP_TRANSFORMER_CLS_TO_WRAP"] = ",".join(
self.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]
)
os.environ["FSDP_AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[0]
prefetch_policy = self.fsdp_config.get("fsdp_backward_prefetch", "NO_PREFETCH")
os.environ["FSDP_BACKWARD_PREFETCH"] = prefetch_policy.upper()

Expand Down

0 comments on commit f4eb459

Please sign in to comment.