Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set FSDP transformer_layer_cls_to_wrap to model._no_split_modules ? #24568

Closed
apoorvkh opened this issue Jun 29, 2023 · 9 comments · Fixed by huggingface/accelerate#1753 or #24980
Closed

Comments

@apoorvkh
Copy link
Contributor

apoorvkh commented Jun 29, 2023

Feature request

Currently, when training with FSDP, the Trainer expects to receive an fsdp_config argument specifying fsdp_transformer_layer_cls_to_wrap.

elif self.args.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None:
transformer_cls_to_wrap = set()
for layer_class in self.args.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]:
transformer_cls = get_module_class_from_name(model, layer_class)
if transformer_cls is None:
raise Exception("Could not find the transformer layer class to wrap in the model.")
else:
transformer_cls_to_wrap.add(transformer_cls)
auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
# Transformer layer class to wrap
transformer_layer_cls=transformer_cls_to_wrap,
)

I am wondering if we can set this automatically, when the model has a _no_split_modules attribute, e.g.

_no_split_modules = ["OPTDecoderLayer"]

Motivation

It would be a convenient feature to set this automatically. This argument is model-specific, but it might be nice to define training arguments independently of a specific model type.

Your contribution

Happy to help make a PR. Would be great if you can confirm whether this would be desirable or if I am misunderstanding something. Thanks!

@sgugger
Copy link
Collaborator

sgugger commented Jun 29, 2023

cc @pacman100

@apoorvkh
Copy link
Contributor Author

Any thoughts about this? Maybe also cc @stas00?

@stas00
Copy link
Contributor

stas00 commented Jul 20, 2023

Unfortunately I don't have experience with FSDP to contribute to this discussion.

@sgugger
Copy link
Collaborator

sgugger commented Jul 20, 2023

@pacman100 Friendly ping

@pacman100
Copy link
Contributor

pacman100 commented Jul 21, 2023

Hello @apoorvkh, the code part you highlighted is enabled now only when using FSDP+XLA. For general FSDP, internally everything is handled by Accelerate. It happens here:

if len(self.fsdp) > 0 and not self.fsdp_config["xla"]:
os.environ["ACCELERATE_USE_FSDP"] = "true"
from accelerate.utils.constants import (
FSDP_AUTO_WRAP_POLICY,
FSDP_SHARDING_STRATEGY,
)
for fsdp_option in self.fsdp:
if fsdp_option.upper() in FSDP_SHARDING_STRATEGY:
# set environment variable for FSDP sharding strategy
os.environ["FSDP_SHARDING_STRATEGY"] = str(FSDP_SHARDING_STRATEGY.index(fsdp_option.upper()) + 1)
elif fsdp_option == FSDPOption.OFFLOAD:
os.environ["FSDP_OFFLOAD_PARAMS"] = "true"
elif fsdp_option == FSDPOption.AUTO_WRAP:
if self.fsdp_config["fsdp_min_num_params"] > 0:
os.environ["FSDP_MIN_NUM_PARAMS"] = str(self.fsdp_config["fsdp_min_num_params"])
os.environ["FSDP_AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[1]
elif self.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None:
os.environ["FSDP_TRANSFORMER_CLS_TO_WRAP"] = ",".join(
self.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]
)
os.environ["FSDP_AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[0]
prefetch_policy = self.fsdp_config.get("fsdp_backward_prefetch", "NO_PREFETCH")
os.environ["FSDP_BACKWARD_PREFETCH"] = prefetch_policy.upper()

fsdp_transformer_layer_cls_to_wrap support specifying multiple modules but most of the time it is enough to specify the _no_split_modules. So, we can have _no_split_modules as a default in case the user doesn't specify it when passing --fsdp full_shard auto_wrap.

@pacman100
Copy link
Contributor

PRs huggingface/accelerate#1753 and #24980 should add this capability wherein it will try model. _no_split_modules if fsdp_transformer_layer_cls_to_wrap isn't specified. Can you try it out?

@apoorvkh
Copy link
Contributor Author

Very cool, thanks a ton! I will try it out and let you know.

@apoorvkh
Copy link
Contributor Author

Just circling back, works on my end -- thanks again!

@xhluca
Copy link
Contributor

xhluca commented Nov 28, 2023

@pacman100 I want to better understand the mechanism of FSDP's wrapping.

Do you know why transformer_layer_cls_to_wrap can be automatically assigned to _no_split_module by default?

My understanding of that latter is from this post:

Actually using this device map later on won't work, because the layers composing this model have residual connections (where the input of the block is added to the output of the block) so all of a given layer should be on the same device. We can indicate this to Accelerate by passing a list of module names that shouldn't be split with the no_split_module_classes keyword argument:

I understand this means that the module should not be split during the forward pass. However, I am not sure I see the connection with transformer_layer_cls_to_wrap, which seems to be a way to indicate which class should be wrapped by FSDP (this is based on my limited understanding of FSDP).

Is there a connection between those two variables, or is it simply a way to quickly find the name of the transformer layers (since it is named with a convention of {model_name}DecoderLayer but it is not always consistent)?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
5 participants