-
Notifications
You must be signed in to change notification settings - Fork 964
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Upcasting for FSDP in Mixed Precision. Add Concept Guide for FSPD and DeepSpeed. #2674
Conversation
93c9e76
to
2bf8b9c
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @fabianlim for the concept guides 📄 for switching between FSDP and DS and adding logic to upcast the loaded model to FP32 when using AMP given that the loaded model is having lower precision for proper convergence📉! Left few comments and suggestions.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is very good so far! A recommendation:
Let's please also show doing the same with FullyShardedDataParallelPlugin
and DeepSpeedPlugin
directly, rather than just with the config file, as many users don't know this can be done
Co-authored-by: Zach Mueller <[email protected]>
ffc2299
to
0b8e97c
Compare
@muellerzr thanks for your comments, I have addressed most of except for the one above which requires a bit more effort. Will work on in. In the meantime if you have any suggestions to my responses pls let me know. |
Is there a way to disable that behavior ? i.e. no upcast the parameters even when both the conditions are met ? And to confirm, the parameters would not be upcasted when using just bf16 training (i,e, no mixed precision training) |
@shagunsodhani before this PR, if model is in low precision, and regardless there is mixed precision or not, we will not upcast This PR is poposing that when model is in low precision
The argument is there is no need to support the do not upcast if the model in low precision and mixed precision == true case. Firstly it contradicts with the docs. Secondly, if the model is already in low precision, then you should perform forward, reduce, all in that low precision, which is equivalent to turning mixed precision off. |
Thank you for the context. Let me try rephrasing the question: If I want to train a model using bf16 and FSDP, this is what I do:
Following #2624 , it seems deepspeed upcasts the parameters in the optimizer and this PR is applying that behavior to FSDP as well. So if a user wants to opt-out of this behaviour, they could do the following:
Is that correct ? |
@shagunsodhani it depends what you mean by opt out. If you do the above before and after this PR, you would actually have
This PR does not affect the scenario above, you would have exactly the same behavior as above This PR only changes the following scenario, which is neither the two that you mentioned.
Therefore this PR does not also affect the second scenario you presented
in which case you would have the below:
|
Great - thank you for explaining this :) |
2e88c7b
to
681c697
Compare
@muellerzr I addressed your last comment on the PLugins
I think its almost there, we have covered most bases, just left the FP8 item that @stas00 suggested. I see the FSDP with fp8 PR is merged https://github.com/huggingface/accelerate/pull/2655/files So it seems from this table, we only upcast weights to 32 only when all the following conditions are met?:
|
Accelerate will still always wrap those outputs w/ the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks phenomenal! Great work 🤗
I left a few wording nits, but looks great to me! cc @pacman100 as well for a last review
Co-authored-by: Zach Mueller <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you, @fabianlim, for fixing the mixed precision training with FSDP when the model is not loaded in full-precision and for improving the documentation by outlining the parallels between FSDP and DeepSpeed. Super helpful!
Left a few nits
@muellerzr im trygint to test fp8 with FSDP, to understand the implications of the upcasting logic on FP8. Im testing on I am:
But this is what I notice inside
I have tried to do some minor tweaks, like uncommenting I constructed # inslide transformers.Trainer.create_accelerator_and_postprocess
from accelerate.utils.dataclasses import FP8RecipeKwargs, FullyShardedDataParallelPlugin
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision
fsdp_plugin = FullyShardedDataParallelPlugin(
mixed_precision_policy=MixedPrecision(param_dtype=torch.bfloat16)
)
# these passed to Accelerator, it seems that `FP8RecipeKwargs` are no longer created in `Accelerator.__init__`
args = {
"deepspeed_plugin": self.args.deepspeed_plugin,
"gradient_accumulation_plugin": gradient_accumulation_plugin,
'fsdp_plugin': fsdp_plugin,
'kwargs_handlers': [
FP8RecipeKwargs(backend='te')
]
} So it is probably so that my understanding of FP8 is not correct, and if you are away of a simple demo to run FP8 (transformer eingine) with FSDP can you point me to it? Im on nightly |
Co-authored-by: Sourab Mangrulkar <[email protected]>
@fabianlim the answer is we just currently haven't enabled FP8 mixed precision for FSDP, it's something we're looking into as |
So for right now, happy to merge this in when you're ready just let me know 🤗 |
@muellerzr got it.. ok lets merge this first then address the fp8 later! |
Thanks a bunch for all of this fantastic work @fabianlim, once CI is green we'll merge and this will get in before the next release 🚀 (End of this week) |
Awesome work, @fabianlim! Some belated feedback:
while both help to overcome the big model loading issue - they aren't the same functionality - since
perhaps it'd be better to say - "one could use ..." It's trivial to get this unsharded checkpoint during the training with zero_to_fp32.py script, w/o slowing down the training. At least this is the approach that we use.
For example, perhaps the doc should say "when needed"? so that not to erroneously imply that all tensors are always pre-fetched? |
|
||
## On Differences in Data Precision Handling | ||
|
||
To discuss the how data precision is handled in both FSDP and Deepspeed, it is instructive to first give a flow of how model parameters are handled in these frameworks. Before the model / optimizer parameters are distributed across GPUs, parameter preparation is involved to first "flatten" them to one-dimensional [`torch.Tensor`](https://pytorch.org/docs/stable/tensors.html#torch-tensor)'s. The implementation of FSDP / DeepSpeed varies in the respect of the `dtype` in which these "flattened" parameters are stored, and there are ramifications with regards to how [`torch.Optimizer`](https://pytorch.org/docs/stable/optim.html#module-torch.optim)'s allocate their `dtypes`'s. The table below outlines the processes for both frameworks; the "Local" column indicates the process occurring at a per-gpu level, therefore any memory overheads by upcasting should be understood to be amortized by the number of gpus used. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- I don't know what "give a flow" means - did you mean "give an overview"?
- there should be no
's
buts
in#torch-tensor)'s.
above - same with
[torch.Optimizer](https://pytorch.org/docs/stable/optim.html#module-torch.optim)’s
- same with
dtype's
|
||
Process | Local | Framework | Details | ||
--|--|--|-- | ||
Loading, i.e., [`AutoModel.from_pretrained(..., torch_dtype)`] | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
probably would be more clear as:
`AutoModel.from_pretrained(..., torch_dtype=torch_dtype)
it's ambiguous when it's passed as a positional argument, since I don't think it is even supported.
@stas00 thank you for your great suggestions! Let me see how I can quickly address them. Is it ok if I open another PR @muellerzr, to address @stas00 comments? |
Yep feel free! |
@stas00 thank you again fro your above comments I have attempted to address them in #2725. In particular:
For Discussion (copying @muellerzr, @pacman100 ):
|
Hi @fabianlim Can you summarise, what happens to
when model is created in bf16 and mixed_precision='bf16', before and after this PR? (I found this format of explaination to be very nice #2674 (comment) ) |
param.data = param.data.to(torch.float32) # upcasting | ||
module._handle._orig_param_dtype = torch.float32 # update |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fabianlim, it looks like this introduced a breakage in FSDP integration - one of our tests now fails because of the last 2 lines, but works fine prior to this PR being merged.
output = self.model(**batch)
File "/env/lib/conda/mgor-core-dev/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/env/lib/conda/mgor-core-dev/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/env/lib/conda/mgor-core-dev/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 843, in forward
args, kwargs = _pre_forward(
File "/env/lib/conda/mgor-core-dev/lib/python3.10/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 380, in _pre_forward
unshard_fn(state, handle)
File "/env/lib/conda/mgor-core-dev/lib/python3.10/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 415, in _pre_forward_unshard
_unshard(state, handle, state._unshard_stream, state._pre_unshard_stream)
File "/env/lib/conda/mgor-core-dev/lib/python3.10/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 288, in _unshard
ran_pre_unshard = handle.pre_unshard()
File "/env/lib/conda/mgor-core-dev/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 1252, in pre_unshard
ret = self._writeback_orig_params()
File "/env/lib/conda/mgor-core-dev/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/env/lib/conda/mgor-core-dev/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 2249, in _writeback_orig_params
self._writeback_tensor(
File "/env/lib/conda/mgor-core-dev/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 2348, in _writeback_tensor
raise RuntimeError(
RuntimeError: Cannot writeback when the parameter shape changes
Expects torch.Size([48000]) but got torch.Size([3000, 16])
the total size (numel) is correct, but the shape is wrong
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also cc: @muellerzr
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the training step was just a trivial tiny llama2 model, like: https://huggingface.co/stas/tiny-random-llama-2
and the config was:
distributed_type: FSDP
fsdp_config:
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_forward_prefetch: false # can flip to true once requiring torch>=2.2.0
fsdp_backward_prefetch_policy: BACKWARD_PRE
fsdp_offload_params: false
fsdp_sharding_strategy: 1
fsdp_state_dict_type: SHARDED_STATE_DICT
fsdp_cpu_ram_efficient_loading: true
fsdp_sync_module_states: true
fsdp_use_orig_params: true
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@stas00 can you share the accelerate config ? This looks like something commonly seen if NOSHaRD is used.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
full config:
compute_environment: LOCAL_MACHINE
distributed_type: FSDP
fsdp_config:
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_forward_prefetch: false # can flip to true once requiring torch>=2.2.0
fsdp_backward_prefetch_policy: BACKWARD_PRE
fsdp_offload_params: false
fsdp_sharding_strategy: 1
fsdp_state_dict_type: SHARDED_STATE_DICT
fsdp_cpu_ram_efficient_loading: true
fsdp_sync_module_states: true
fsdp_use_orig_params: true
mixed_precision: bf16
machine_rank: 0
main_process_ip: null
main_process_port: null
main_training_function: main
num_machines: 1
num_processes: {num_processes}
use_cpu: false
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok I think it’s because of use-orig-params being set to true. This was a setting that I didn’t test
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason to do the upcasting so late in the game? Won't it be much simpler to do it before the model gets wrapped by FSDP?
Accelerate knows by this time everything it needs to know to make the right thing, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reason is because we only wanted to cast the sharded params and not the whole model. Casting the whole model is not a good idea. DeepSpeed does the same thing and only casts the sharded params
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Didn't we say that this was a convenience workaround and that the model should have been loaded in fp32 in the first place for fsdp to converge well?
Frankly, when do you not shard most of the model parameters? and even if selective typically it's the big params that are sharded so 2x of that would be close to all of model's params doubled in size.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@stas00
Essentially what I am saying is that upcasting ahead is will in general incur a 2X overhead in CPU memory in most general usage patterns that overhead will be there regardless if we use low_cpu_mem mode or not
BTW im having a bit of trouble reproducing your error on torch 2.1.2 and 2.2.0.
- I took the
learning_rate_repro.py
script in the main issue - I ran the following command, where
dataset.json
is prepared following this accelerate_stas.yaml
is the config that was shared above.- using the same
stas/tiny-random-llama-2
accelerate launch \
--num_processes 4 \
--main_process_port $MASTER_PORT \
--config_file accelerate_stas.yaml \
learning_rate_repro.py \
--model_name stas/tiny-random-llama-2 \
--max_seq_len 4096 \
--num_train_epochs 1 \
--output_dir './results' \
--per_device_train_batch_size 8 \
--gradient_accumulation_steps 1 \
--include_tokens_per_second \
--gradient_checkpointing \
--lr_scheduler_type "linear" \
--learning_rate 1e-6 \
--logging_steps 1 \
--packing True \
--dataset_for_packing dataset.json \
--dataset_text_field output \
--bf16 \
--max_steps 100 \
--save_strategy 'no'
and the training runs.
{'loss': 8.0075, 'grad_norm': 0.9171388149261475, 'learning_rate': 9.9e-06, 'epoch': 1.0}
{'loss': 7.9684, 'grad_norm': 1.1230570077896118, 'learning_rate': 9.800000000000001e-06, 'epoch': 2.0}
{'loss': 7.8988, 'grad_norm': 1.232378602027893, 'learning_rate': 9.7e-06, 'epoch': 3.0}
{'loss': 7.8119, 'grad_norm': 1.1301443576812744, 'learning_rate': 9.600000000000001e-06, 'epoch': 4.0}
{'loss': 7.7319, 'grad_norm': 1.0580438375473022, 'learning_rate': 9.5e-06, 'epoch': 5.0}
What does this PR do?
This PR address the issues identified in #2624, whereby differences were found between how DeepSpeed (DS) and FSDP handle the sharded parameters during mixed precision. To address this we have:
accelerate.prepare
to upcast the FSDP sharded paramaters, if it has been detected that:In addition to the above, we also:
Checklist:
SHARD_GRAD_OP
consider the impacts of fp8 training suggested by @stas00.Update: this has to be done laterTo test and reproduce:
learning_rate_repro.py
) from here.accelerate_fspd.yaml
config and without--bf16
flag; this gets thebf16
case in the below plot. The script will load the model inbfloat16
(sinceload_model_dtype
was default) and turn off mixed precision.--bf16
(this getsbf16-with-mp
in the plot below). This runsbfloat16
model in FSDP with mixed precision.accelerate_ds.yaml
config (here--bf16
is irrelevant, DS will handle both cases similarly).In the plot below, we can see that with the casting logic put in:
bfloat16
but mixed precision is turned on (--bf16
), then we perform the upcasting. Hence we see that the loss curves of DS and FSDP will now be the same with the added casting logic.--bf16
had been supplied or notSample of New Warnings
We have added some logic to reduce repetitive warnings:
Performance of low precision models under FSDP and DS and mixed precision, with the proposed casting fix
This was plotted when comparing FSDP (Full Shard) and DeepSpeed (Zero3)
This was plotted when comparing FSDP in various modes, namely GradOp and CPUOffload (while full-sharding)
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Tagging @stas00, @muellerzr and @pacman100 first. We can add more reviewers later.