-
Notifications
You must be signed in to change notification settings - Fork 27.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for ZeRO-2/3 and ZeRO-offload in fairscale #10354
Conversation
from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP | ||
from fairscale.optim import OSS | ||
from fairscale.optim.grad_scaler import ShardedGradScaler | ||
|
||
if version.parse(fairscale.__version__) >= version.parse("0.3"): | ||
from fairscale.nn.data_parallel import FullyShardedDataParallel as FullyShardedDDP |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this may introduce a confusion here, should we stick to DP and not DDP to match the real name? i.e. FullyShardedDP and ShardedDP?
Perhaps change the original flag to reflect that as well? --sharded_dp
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, made a request to make those names renamed to match DDP here:
facebookresearch/fairscale#413 (comment)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Stas. I personally think the distinction between DDP and DP is not going to matter anymore. Even pytorch DDP itself is moving to remove the "device_ids" argument in the future so that there isn't a support for a single process DP (as opposed to distributed/multiprocess DP). Therefore, I think sticking with FSDP is fine within fairscale.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for your follow up, @min-xu-ai
src/transformers/training_args.py
Outdated
Use Sharded DDP training from `FairScale <https://github.com/facebookresearch/fairscale>`__ (in distributed | ||
training only). This is an experimental feature. | ||
|
||
Can take up to six values: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can take up to six values: | |
Can be one of the following values: |
- clarifying that it's one of them
- the total count is of no useful value to the user
src/transformers/training_args.py
Outdated
- :obj:`"no"`: for no sharded DataParallelism (default behavior) | ||
- :obj:`"simple"`: to use first instance of sharded DDP released by fairscale (:obj:`ShardedDDP`) similar | ||
to ZeRO-2. | ||
- :obj:`"zero_2"`: to use the second instance of sharded DPP released by fairscale (:obj:`FullyShardedDDP`) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we are smashing concepts a bit here. ZeRO is a big territory with many features. the 3 stages belong to ZeRO-DP part of ZeRO, so ideally this should be zero_dp_(1|2|3)
or zero_dp(1|2|3)
.
This is just a suggestion though, if you strongly feel having just the number is clear enough, that's OK too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, and that's why they call it DP and not DDP, because it's ZeRO-DP.
Do you feel it's better to hardcode these combinations and not have a more flexible approach of:
or
which would enable adding new features in the future, without needing to create all possible combinations of options which would double every time a new option will be added. This is the cmd API I'm tentatively using for the pipelines
Yes, we will need to rethink this - the trainer is getting more and more complex. |
Happy to explore that design as it seems more flexible and less prone to future breaking changes. Will adapt the PR accordingly once we get the wrapper to work. |
Probably whitespace separation is more readable: Also we need to make sure that we distinguish between Deepspeed has a For the user's sake perhaps we could make things as similar as possible so it'd be more intuitive for them to switch between fairscale (and eventually pytorch) and deepspeed. Also note that DeepSpeed exposes other params like the size of buckets, which actually are very important and need to be user-configurable. I won't be surprised that FSDP will also have those configurable down the road - i.e. more params. |
Reworked the API to take your suggestion of list of options into account @stas00. I don't think we have to worry about uniformizing with deepspeed or cleaning more at this stage as:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome work, @sgugger !!!
sharded_ddp: ShardedDDPType = field( | ||
default="no", | ||
sharded_ddp: str = field( | ||
default="", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
perhaps list the choices here? and perhaps a very small example of combining 2 of them in the value, since it's not a usual pattern - a user might struggle here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with @stas00!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fantastic! Left only nitpicks.
3. To use the second version of Sharded data-parallelism, add ``--sharded_ddp zero_dp_2`` or ``--sharded_ddp zero_dp_3` | ||
to the command line arguments, and make sure you have added the distributed launcher ``-m torch.distributed.launch | ||
--nproc_per_node=NUMBER_OF_GPUS_YOU_HAVE`` if you haven't been using it already. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Love the API
sharded_ddp: ShardedDDPType = field( | ||
default="no", | ||
sharded_ddp: str = field( | ||
default="", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with @stas00!
Co-authored-by: Stas Bekman <[email protected]>
Moving out the cl arg naming discussion from #10354 (review) to the open So if's not DDP but DP, then we should probably change the cl arg to Or perhaps we should just call it |
What does this PR do?
This PR adds support for the new
FullyShardedDataParallel
introduced in fairscale. See this PR for more details.The PR changes a tiny bit the behavior of the
--sharded_ddp
flag/training argument to support a list of options. You can still use the TrainingArguments class withsharded_dpp=True
but if launching a script,--sharded_ddp
has to be replaced with--sharded_ddp simple
. The--sharded_ddp
was marked as an experimental API so I think this breaking change is fine if properly documented.Other values supported are:
zero_dp_2
,zero_dp_2 offload
,zero_dp_3
andzero_dp_3 offload
. To fully take advantage of thezero_dp_3
/zero_dp_3 offload
the model passed to theTrainer
will need to have its internal layers wrapped inside theFullyShardedDataParallel
, but this out of scope for this particular PR.For all those new modes, the model simply needs to be wrapped inside
FullyShardedDataParallel
but the optimizer needs to be created after the model wrapping (to get the parameters shards).Note that:
predict_with_generate
does not work with this integrationcpu_offload
does not work for now due to the bug mentioned in this issue. Once the issue is fixed, the option should work with the existing code.One thing to think further is that this integration breaks the usual convention that
self.model
is the original model (FullyShardedDataParallel
consumes the model to use less memory).