Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for ZeRO-2/3 and ZeRO-offload in fairscale #10354

Merged
merged 6 commits into from
Feb 25, 2021
Merged

Conversation

sgugger
Copy link
Collaborator

@sgugger sgugger commented Feb 23, 2021

What does this PR do?

This PR adds support for the new FullyShardedDataParallel introduced in fairscale. See this PR for more details.

The PR changes a tiny bit the behavior of the --sharded_ddp flag/training argument to support a list of options. You can still use the TrainingArguments class with sharded_dpp=True but if launching a script, --sharded_ddp has to be replaced with --sharded_ddp simple. The --sharded_ddp was marked as an experimental API so I think this breaking change is fine if properly documented.

Other values supported are: zero_dp_2, zero_dp_2 offload, zero_dp_3 and zero_dp_3 offload. To fully take advantage of the zero_dp_3/zero_dp_3 offload the model passed to the Trainer will need to have its internal layers wrapped inside the FullyShardedDataParallel, but this out of scope for this particular PR.

For all those new modes, the model simply needs to be wrapped inside FullyShardedDataParallel but the optimizer needs to be created after the model wrapping (to get the parameters shards).

Note that:

  • predict_with_generate does not work with this integration
  • cpu_offload does not work for now due to the bug mentioned in this issue. Once the issue is fixed, the option should work with the existing code.

One thing to think further is that this integration breaks the usual convention that self.model is the original model (FullyShardedDataParallel consumes the model to use less memory).

from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
from fairscale.optim import OSS
from fairscale.optim.grad_scaler import ShardedGradScaler

if version.parse(fairscale.__version__) >= version.parse("0.3"):
from fairscale.nn.data_parallel import FullyShardedDataParallel as FullyShardedDDP
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this may introduce a confusion here, should we stick to DP and not DDP to match the real name? i.e. FullyShardedDP and ShardedDP?

Perhaps change the original flag to reflect that as well? --sharded_dp?

Copy link
Contributor

@stas00 stas00 Feb 23, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, made a request to make those names renamed to match DDP here:
facebookresearch/fairscale#413 (comment)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Stas. I personally think the distinction between DDP and DP is not going to matter anymore. Even pytorch DDP itself is moving to remove the "device_ids" argument in the future so that there isn't a support for a single process DP (as opposed to distributed/multiprocess DP). Therefore, I think sticking with FSDP is fine within fairscale.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your follow up, @min-xu-ai

Use Sharded DDP training from `FairScale <https://github.com/facebookresearch/fairscale>`__ (in distributed
training only). This is an experimental feature.

Can take up to six values:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Can take up to six values:
Can be one of the following values:
  • clarifying that it's one of them
  • the total count is of no useful value to the user

- :obj:`"no"`: for no sharded DataParallelism (default behavior)
- :obj:`"simple"`: to use first instance of sharded DDP released by fairscale (:obj:`ShardedDDP`) similar
to ZeRO-2.
- :obj:`"zero_2"`: to use the second instance of sharded DPP released by fairscale (:obj:`FullyShardedDDP`)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we are smashing concepts a bit here. ZeRO is a big territory with many features. the 3 stages belong to ZeRO-DP part of ZeRO, so ideally this should be zero_dp_(1|2|3) or zero_dp(1|2|3).

This is just a suggestion though, if you strongly feel having just the number is clear enough, that's OK too.

Copy link
Contributor

@stas00 stas00 Feb 23, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, and that's why they call it DP and not DDP, because it's ZeRO-DP.

@stas00
Copy link
Contributor

stas00 commented Feb 23, 2021

Other values supported are: zero2, zero2_offload, zero3 and zero3_offload. To fully take advantage of the zero3/zero3_offload the model passed to the Trainer will need to have its internal layers wrapped inside the FullyShardedDataParallel, but this out of scope for this particular PR.

Do you feel it's better to hardcode these combinations and not have a more flexible approach of:

--sharded_ddp "zero2;offload;future_option"

or

--sharded_ddp "zero2 offload future_option"

which would enable adding new features in the future, without needing to create all possible combinations of options which would double every time a new option will be added.

This is the cmd API I'm tentatively using for the pipelines --pipeline "chunks=5 device_map=0:0-5,1:5-10 ...."

One thing to think further is that this integration breaks the usual convention that self.model is the original model (FullyShardedDataParallel consumes the model to use less memory).

Yes, we will need to rethink this - the trainer is getting more and more complex.

@sgugger
Copy link
Collaborator Author

sgugger commented Feb 23, 2021

Do you feel it's better to hardcode these combinations and not have a more flexible approach of:

--sharded_ddp "zero2;offload;future_option"

Happy to explore that design as it seems more flexible and less prone to future breaking changes. Will adapt the PR accordingly once we get the wrapper to work.

@stas00
Copy link
Contributor

stas00 commented Feb 23, 2021

Probably whitespace separation is more readable: --sharded_ddp "zero2 offload future_option"

Also we need to make sure that we distinguish between FullyShardedDataParallel and ShardedDataParallel since as the commentary was made, they aren't quite the same. Perhaps not_full for ShardedDataParallel? both should be corresponding to stage2 but they don't work in the same way.

Deepspeed has a stage param which goes from 0 to 3. where stage=0 doesn't enable ZeRO, and then each number matches the stage.

For the user's sake perhaps we could make things as similar as possible so it'd be more intuitive for them to switch between fairscale (and eventually pytorch) and deepspeed.

Also note that DeepSpeed exposes other params like the size of buckets, which actually are very important and need to be user-configurable. I won't be surprised that FSDP will also have those configurable down the road - i.e. more params.

@sgugger
Copy link
Collaborator Author

sgugger commented Feb 25, 2021

Reworked the API to take your suggestion of list of options into account @stas00. I don't think we have to worry about uniformizing with deepspeed or cleaning more at this stage as:

  • this API will evolve in the future (ShardedDataParallel might very well disappear if FullyShardedDataParallel is better, and this might change again on the road to be merged in PyTorch)
  • we don't know yet all the options we will have between deepspeed/fairscale/PyTorch
  • this is an experimental API and while we won't break it just for fun, we can make slight changes down the road.

@sgugger sgugger changed the title [WIP] Add support for ZeRO-2/3 and ZeRO-offload in fairscale Add support for ZeRO-2/3 and ZeRO-offload in fairscale Feb 25, 2021
Copy link
Contributor

@stas00 stas00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome work, @sgugger !!!

src/transformers/trainer.py Outdated Show resolved Hide resolved
sharded_ddp: ShardedDDPType = field(
default="no",
sharded_ddp: str = field(
default="",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perhaps list the choices here? and perhaps a very small example of combining 2 of them in the value, since it's not a usual pattern - a user might struggle here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with @stas00!

src/transformers/training_args.py Outdated Show resolved Hide resolved
Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fantastic! Left only nitpicks.

Comment on lines +285 to +287
3. To use the second version of Sharded data-parallelism, add ``--sharded_ddp zero_dp_2`` or ``--sharded_ddp zero_dp_3`
to the command line arguments, and make sure you have added the distributed launcher ``-m torch.distributed.launch
--nproc_per_node=NUMBER_OF_GPUS_YOU_HAVE`` if you haven't been using it already.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Love the API

docs/source/main_classes/trainer.rst Show resolved Hide resolved
sharded_ddp: ShardedDDPType = field(
default="no",
sharded_ddp: str = field(
default="",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with @stas00!

@sgugger sgugger merged commit 9d14be5 into master Feb 25, 2021
@sgugger sgugger deleted the zero_3_offload branch February 25, 2021 16:07
@stas00
Copy link
Contributor

stas00 commented Feb 26, 2021

Moving out the cl arg naming discussion from #10354 (review) to the open

So if's not DDP but DP, then we should probably change the cl arg to _dp as I suggested above so that it's consistently either DP or DDP all the way through.

Or perhaps we should just call it --sharded? the dp part is already inside the value anyway as in: --sharded zero_dp_3

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants