Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Upcasting for FSDP in Mixed Precision. Add Concept Guide for FSPD and DeepSpeed. #2674

Merged
merged 15 commits into from
Apr 29, 2024

Conversation

fabianlim
Copy link
Contributor

@fabianlim fabianlim commented Apr 16, 2024

What does this PR do?

This PR address the issues identified in #2624, whereby differences were found between how DeepSpeed (DS) and FSDP handle the sharded parameters during mixed precision. To address this we have:

  • added logic in accelerate.prepare to upcast the FSDP sharded paramaters, if it has been detected that:
    1. mixed precision has been activated, and,
    2. the sharded parameters are in low precision.
  • added concept guide to clarify to how DS and FSDP handles mixed precision. Also to inform that 🤗 Accelerate will upcast FSDP sharded weights.
  • added warning to inform user, that if the weights were upcasted, then this will affect the precision in which the checkpoint will be saved.

In addition to the above, we also:

  • int the above mentioned concept guide, clarified the equivalences between DS and FSDP configs, enabling users to better transition between DS/FSDP.
  • updated the outdated accelerate launch cli commands.

Checklist:

  • test with low mem
  • test with SHARD_GRAD_OP
  • test with CPU offload
  • consider the impacts of fp8 training suggested by @stas00. Update: this has to be done later

To test and reproduce:

  • get the FSDP/ DS accelerate configurations, and the reproduction script (call it learning_rate_repro.py) from here.
  • run the script with accelerate_fspd.yaml config and without --bf16 flag; this gets the bf16 case in the below plot. The script will load the model in bfloat16 (since load_model_dtype was default) and turn off mixed precision.
    accelerate launch \
        --num_processes 4 \
        --config_file accelerate_fsdp.yaml \
        learning_rate_repro.py  \
          --num_train_epochs 10 \
          --output_dir './results' \
          --per_device_train_batch_size 10 \
          --lr_scheduler_type "linear" \
          --learning_rate 1e-6 \
          --logging_steps 1 \
          --save_strategy 'no' 
  • run the same adding --bf16 (this gets bf16-with-mp in the plot below). This runs bfloat16 model in FSDP with mixed precision.
  • run the same benchmark with a DS accelerate_ds.yaml config (here --bf16 is irrelevant, DS will handle both cases similarly).

In the plot below, we can see that with the casting logic put in:

  • if we detect that the model is in low precision bfloat16 but mixed precision is turned on (--bf16), then we perform the upcasting. Hence we see that the loss curves of DS and FSDP will now be the same with the added casting logic.
  • If the logic was absent, then we would have observed that FSDP would not converge, regardless if --bf16 had been supplied or not

Sample of New Warnings

We have added some logic to reduce repetitive warnings:

/data/flim/accelerate/src/accelerate/accelerator.py:1492: UserWarning: Upcasting low precision parameters in MistralForCausalLM, mixed precision is turned on: model.embed_tokens.weight,model.norm.weight,lm_head.weight.
  warnings.warn(
/data/flim/accelerate/src/accelerate/accelerator.py:1492: UserWarning: Upcasting low precision parameters in MistralDecoderLayer, mixed precision is turned on: self_attn.q_proj.weight,self_attn.k_proj.weight,self_attn.v_proj.weight,self_attn.o_proj.weight,mlp.gate_proj.weight,mlp.up_proj.weight,mlp.down_proj.weight,input_layernorm.weight,post_attention_layernorm.weight.
  warnings.warn(

Performance of low precision models under FSDP and DS and mixed precision, with the proposed casting fix

This was plotted when comparing FSDP (Full Shard) and DeepSpeed (Zero3)
image

This was plotted when comparing FSDP in various modes, namely GradOp and CPUOffload (while full-sharding)
image

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Tagging @stas00, @muellerzr and @pacman100 first. We can add more reviewers later.

Copy link
Contributor

@pacman100 pacman100 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @fabianlim for the concept guides 📄 for switching between FSDP and DS and adding logic to upcast the loaded model to FP32 when using AMP given that the loaded model is having lower precision for proper convergence📉! Left few comments and suggestions.

docs/source/concept_guides/fsdp_and_deepspeed.md Outdated Show resolved Hide resolved
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is very good so far! A recommendation:

Let's please also show doing the same with FullyShardedDataParallelPlugin and DeepSpeedPlugin directly, rather than just with the config file, as many users don't know this can be done

docs/source/concept_guides/fsdp_and_deepspeed.md Outdated Show resolved Hide resolved
docs/source/concept_guides/fsdp_and_deepspeed.md Outdated Show resolved Hide resolved
docs/source/concept_guides/fsdp_and_deepspeed.md Outdated Show resolved Hide resolved
docs/source/concept_guides/fsdp_and_deepspeed.md Outdated Show resolved Hide resolved
docs/source/concept_guides/fsdp_and_deepspeed.md Outdated Show resolved Hide resolved
docs/source/concept_guides/fsdp_and_deepspeed.md Outdated Show resolved Hide resolved
src/accelerate/accelerator.py Outdated Show resolved Hide resolved
@fabianlim
Copy link
Contributor Author

Let's please also show doing the same with FullyShardedDataParallelPlugin and DeepSpeedPlugin directly, rather than just with the config file, as many users don't know this can be done

@muellerzr thanks for your comments, I have addressed most of except for the one above which requires a bit more effort. Will work on in. In the meantime if you have any suggestions to my responses pls let me know.

@shagunsodhani
Copy link

added logic in accelerate.prepare to upcast the FSDP sharded paramaters, if it has been detected that:

  • mixed precision has been activated, and,
  • the sharded parameters are in low precision.

Is there a way to disable that behavior ? i.e. no upcast the parameters even when both the conditions are met ? And to confirm, the parameters would not be upcasted when using just bf16 training (i,e, no mixed precision training)

@fabianlim
Copy link
Contributor Author

fabianlim commented Apr 20, 2024

@shagunsodhani before this PR, if model is in low precision, and regardless there is mixed precision or not, we will not upcast

This PR is poposing that when model is in low precision

  • mixed_precision == False -> no upcast
  • mixed_precision == True -> upcast (to gain parity with DS)

The argument is there is no need to support the do not upcast if the model in low precision and mixed precision == true case. Firstly it contradicts with the docs. Secondly, if the model is already in low precision, then you should perform forward, reduce, all in that low precision, which is equivalent to turning mixed precision off.

@shagunsodhani
Copy link

@shagunsodhani before this PR, if model is in low precision, and regardless there is mixed precision or not, we will not upcast

This PR is poposing that when model is in low precision

  • mixed_precision == False -> no upcast
  • mixed_precision == True -> upcast (to gain parity with DS)

Thank you for the context. Let me try rephrasing the question:

If I want to train a model using bf16 and FSDP, this is what I do:

  1. Create an Accelerator object with mixed_precision=bf16 (with relevant FSDP flags)
  2. Creat model in fp32
  3. Pass the model to acclerate.prepare
  4. Train...

Following #2624 , it seems deepspeed upcasts the parameters in the optimizer and this PR is applying that behavior to FSDP as well.

So if a user wants to opt-out of this behaviour, they could do the following:

  1. Create an Accelerator object with mixed_precision=no (with relevant FSDP flags)
  2. Creat model in bf16
  3. Pass the model to acclerate.prepare
  4. Train...

Is that correct ?

@fabianlim
Copy link
Contributor Author

fabianlim commented Apr 20, 2024

If I want to train a model using bf16 and FSDP, this is what I do:

Create an Accelerator object with mixed_precision=bf16 (with relevant FSDP flags)
Creat model in fp32
Pass the model to acclerate.prepare
Train...

@shagunsodhani it depends what you mean by opt out. If you do the above before and after this PR, you would actually have

  • flat params in 32
  • optimizer params in 32
  • fwd_bwd, reduce, etc in bf16

This PR does not affect the scenario above, you would have exactly the same behavior as above

This PR only changes the following scenario, which is neither the two that you mentioned.

  • create model in bf16, mixed precision = bf16

Therefore this PR does not also affect the second scenario you presented

So if a user wants to opt-out of this behaviour, they could do the following:
Create an Accelerator object with mixed_precision=no (with relevant FSDP flags)
Creat model in bf16
Pass the model to acclerate.prepare
Train...

in which case you would have the below:

  • flat params in bf16
  • optimizer in bf16
  • fwd_bkwd, reduce in bf16

@shagunsodhani
Copy link

Great - thank you for explaining this :)

@fabianlim
Copy link
Contributor Author

fabianlim commented Apr 23, 2024

@muellerzr I addressed your last comment on the PLugins

  • put two small notes, once for FSDP and DS, pointing to an equivalence between FullyShardedDataParallelPlugin and DeepSpeedPlugin` with the config flags. 681c697
  • tested cpu_offload and grad_op, updated the results in the main description. See the two new curves just above "Before Submitting"
  • put one more warning that the automatic upcast may increase the precision of the saved checkpoints.

I think its almost there, we have covered most bases, just left the FP8 item that @stas00 suggested. I see the FSDP with fp8 PR is merged https://github.com/huggingface/accelerate/pull/2655/files

So it seems from this table, we only upcast weights to 32 only when all the following conditions are met?:

  • mixed-precision == 'fp8'
  • TE enabled
  • MS-AMP disabled

@muellerzr
Copy link
Collaborator

Accelerate will still always wrap those outputs w/ the ConvertOutputsToFp32, the chart itself was taken from MSAMP

Copy link
Collaborator

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks phenomenal! Great work 🤗

I left a few wording nits, but looks great to me! cc @pacman100 as well for a last review

docs/source/concept_guides/fsdp_and_deepspeed.md Outdated Show resolved Hide resolved
docs/source/concept_guides/fsdp_and_deepspeed.md Outdated Show resolved Hide resolved
docs/source/concept_guides/fsdp_and_deepspeed.md Outdated Show resolved Hide resolved
docs/source/concept_guides/fsdp_and_deepspeed.md Outdated Show resolved Hide resolved
Copy link
Contributor

@pacman100 pacman100 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, @fabianlim, for fixing the mixed precision training with FSDP when the model is not loaded in full-precision and for improving the documentation by outlining the parallels between FSDP and DeepSpeed. Super helpful!

Left a few nits

src/accelerate/accelerator.py Outdated Show resolved Hide resolved
docs/source/concept_guides/fsdp_and_deepspeed.md Outdated Show resolved Hide resolved
@fabianlim
Copy link
Contributor Author

fabianlim commented Apr 29, 2024

@muellerzr im trygint to test fp8 with FSDP, to understand the implications of the upcasting logic on FP8. Im testing on accelerate/main, after the merges https://github.com/huggingface/accelerate/pull/2655/files but im having some trouble

I am:

  • using transformer_engine
  • setting ACCELERATE_MIXED_PRECISION = fp8

But this is what I notice inside AcceleratorState

I have tried to do some minor tweaks, like uncommenting set_mixed_precision and setting MixedPrecision directly in FullyShardedDataParallelPlugin, but i will get KeyError in the prepare_model function when it tries to FSDP wrap the model with the replaced te.Linear layers (this will be because new_named_params will have keys like 'model.layers.21._fsdp_wrapped_module.mlp.down_proj.weight').

I constructed Accelerator args as follows:

# inslide transformers.Trainer.create_accelerator_and_postprocess 
from accelerate.utils.dataclasses import FP8RecipeKwargs, FullyShardedDataParallelPlugin
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision
fsdp_plugin = FullyShardedDataParallelPlugin(
    mixed_precision_policy=MixedPrecision(param_dtype=torch.bfloat16)
)

# these passed to Accelerator, it seems that `FP8RecipeKwargs` are no longer created in `Accelerator.__init__`
args = {
    "deepspeed_plugin": self.args.deepspeed_plugin,
    "gradient_accumulation_plugin": gradient_accumulation_plugin,
    'fsdp_plugin': fsdp_plugin, 
    'kwargs_handlers': [
        FP8RecipeKwargs(backend='te')
    ]
}

So it is probably so that my understanding of FP8 is not correct, and if you are away of a simple demo to run FP8 (transformer eingine) with FSDP can you point me to it?

Im on nightly transformers, accelerate and torch 2.3

Co-authored-by: Sourab Mangrulkar <[email protected]>
@muellerzr
Copy link
Collaborator

@fabianlim the answer is we just currently haven't enabled FP8 mixed precision for FSDP, it's something we're looking into as TransformersEngine only recently finished getting support for that going

@muellerzr
Copy link
Collaborator

So for right now, happy to merge this in when you're ready just let me know 🤗

@fabianlim
Copy link
Contributor Author

@muellerzr got it.. ok lets merge this first then address the fp8 later!

@muellerzr
Copy link
Collaborator

Thanks a bunch for all of this fantastic work @fabianlim, once CI is green we'll merge and this will get in before the next release 🚀 (End of this week)

@muellerzr muellerzr merged commit 9557598 into huggingface:main Apr 29, 2024
23 checks passed
@stas00
Copy link
Contributor

stas00 commented Apr 29, 2024

Awesome work, @fabianlim!

Some belated feedback:

  • I think this could be a misleading entry:

snapshot_731

while both help to overcome the big model loading issue - they aren't the same functionality - since zero.Init requires special code to init tensors (need to manually gather those or else they will be silently skipped).

  • not sure what pipeline means here, should it say - weights prefetch instead?

snapshot_732

snapshot_734

perhaps it'd be better to say - "one could use ..."

It's trivial to get this unsharded checkpoint during the training with zero_to_fp32.py script, w/o slowing down the training. At least this is the approach that we use.

  • this is not entirely true

snapshot_735

For example, stage3_param_persistence_threshold will affect the prefetch - if the tensors aren't sharded there will be no prefetch. Then you have stage3_max_reuse_distance and stage3_max_live_parameters which too will have an impact.

perhaps the doc should say "when needed"? so that not to erroneously imply that all tensors are always pre-fetched?


## On Differences in Data Precision Handling

To discuss the how data precision is handled in both FSDP and Deepspeed, it is instructive to first give a flow of how model parameters are handled in these frameworks. Before the model / optimizer parameters are distributed across GPUs, parameter preparation is involved to first "flatten" them to one-dimensional [`torch.Tensor`](https://pytorch.org/docs/stable/tensors.html#torch-tensor)'s. The implementation of FSDP / DeepSpeed varies in the respect of the `dtype` in which these "flattened" parameters are stored, and there are ramifications with regards to how [`torch.Optimizer`](https://pytorch.org/docs/stable/optim.html#module-torch.optim)'s allocate their `dtypes`'s. The table below outlines the processes for both frameworks; the "Local" column indicates the process occurring at a per-gpu level, therefore any memory overheads by upcasting should be understood to be amortized by the number of gpus used.
Copy link
Contributor

@stas00 stas00 Apr 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • I don't know what "give a flow" means - did you mean "give an overview"?
  • there should be no 's but s in #torch-tensor)'s. above
  • same with [torch.Optimizer](https://pytorch.org/docs/stable/optim.html#module-torch.optim)’s
  • same with dtype's


Process | Local | Framework | Details
--|--|--|--
Loading, i.e., [`AutoModel.from_pretrained(..., torch_dtype)`] |
Copy link
Contributor

@stas00 stas00 Apr 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably would be more clear as:

`AutoModel.from_pretrained(..., torch_dtype=torch_dtype)

it's ambiguous when it's passed as a positional argument, since I don't think it is even supported.

@fabianlim
Copy link
Contributor Author

@stas00 thank you for your great suggestions! Let me see how I can quickly address them. Is it ok if I open another PR @muellerzr, to address @stas00 comments?

@muellerzr
Copy link
Collaborator

Yep feel free!

@fabianlim
Copy link
Contributor Author

fabianlim commented Apr 30, 2024

@stas00 thank you again fro your above comments I have attempted to address them in #2725. In particular:

  • resolved your comments on apostrophes and positional argument
  • renamed pipeline -> weight prefetching
  • be more careful on the recommendation of --zero3_save_16bit_model: True.
  • be more careful with the statement regarding the prefetch logic of DeepSpeed.

For Discussion (copying @muellerzr, @pacman100 ):

  • regarding the equivalence of fsdp_cpu_ram_efficient_loading and zero3_init_flag. Is my understanding correct, it sounds like what you saying that zero.Init will do some weight gathering for other ranks. Am I understanding correctly?
  • If so, this will be similar to fsdp_cpu_ram_efficient_loading which also loads the weights in a single rank, and requires the other ranks to gather them.
  • to this end, accelerate launch will raise if fsdp_cpu_ram_efficient_loading is set but fsdp_sync_module_states is left unset.
  • However, if a user choses not to pass options directly in acclerate launch, but reply on the YAML config, then a user can set cpu_ram_efficient_loading: True and leave sync_module_states. This will cause problems. This is actually not true, but if the user sets the ACCELERATE_ envvars directly (e.g., in scenarios where the user is using torchrun) then this will create problems.
  • Thus in 19cfab4, I propose to remove the raise in accelerate.launch, and just set it them automatically in FullyShardedDataParallelPlugin.__post_init__.
  • Or, we could keep the raise, and have it work in tandem with the proposed automatic setting in FullyShardedDataParallelPlugin.

Not Done:

  • put links in FSDP and DeepSpeed docs.

@varadhbhatnagar
Copy link

Hi @fabianlim

Can you summarise, what happens to

flat params 
optimizer 
fwd_bkwd, reduce

when model is created in bf16 and mixed_precision='bf16', before and after this PR? (I found this format of explaination to be very nice #2674 (comment) )

Comment on lines +1500 to +1501
param.data = param.data.to(torch.float32) # upcasting
module._handle._orig_param_dtype = torch.float32 # update
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fabianlim, it looks like this introduced a breakage in FSDP integration - one of our tests now fails because of the last 2 lines, but works fine prior to this PR being merged.

    output = self.model(**batch)
  File "/env/lib/conda/mgor-core-dev/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/env/lib/conda/mgor-core-dev/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/env/lib/conda/mgor-core-dev/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 843, in forward
    args, kwargs = _pre_forward(
  File "/env/lib/conda/mgor-core-dev/lib/python3.10/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 380, in _pre_forward
    unshard_fn(state, handle)
  File "/env/lib/conda/mgor-core-dev/lib/python3.10/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 415, in _pre_forward_unshard
    _unshard(state, handle, state._unshard_stream, state._pre_unshard_stream)
  File "/env/lib/conda/mgor-core-dev/lib/python3.10/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 288, in _unshard
    ran_pre_unshard = handle.pre_unshard()
  File "/env/lib/conda/mgor-core-dev/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 1252, in pre_unshard
    ret = self._writeback_orig_params()
  File "/env/lib/conda/mgor-core-dev/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/env/lib/conda/mgor-core-dev/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 2249, in _writeback_orig_params
    self._writeback_tensor(
  File "/env/lib/conda/mgor-core-dev/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 2348, in _writeback_tensor
    raise RuntimeError(
RuntimeError: Cannot writeback when the parameter shape changes
Expects torch.Size([48000]) but got torch.Size([3000, 16])

the total size (numel) is correct, but the shape is wrong

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also cc: @muellerzr

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the training step was just a trivial tiny llama2 model, like: https://huggingface.co/stas/tiny-random-llama-2

and the config was:

distributed_type: FSDP
fsdp_config:
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_forward_prefetch: false  # can flip to true once requiring torch>=2.2.0
  fsdp_backward_prefetch_policy: BACKWARD_PRE
  fsdp_offload_params: false
  fsdp_sharding_strategy: 1
  fsdp_state_dict_type: SHARDED_STATE_DICT
  fsdp_cpu_ram_efficient_loading: true
  fsdp_sync_module_states: true
  fsdp_use_orig_params: true

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@stas00 can you share the accelerate config ? This looks like something commonly seen if NOSHaRD is used.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

full config:

compute_environment: LOCAL_MACHINE

distributed_type: FSDP
fsdp_config:
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_forward_prefetch: false  # can flip to true once requiring torch>=2.2.0
  fsdp_backward_prefetch_policy: BACKWARD_PRE
  fsdp_offload_params: false
  fsdp_sharding_strategy: 1
  fsdp_state_dict_type: SHARDED_STATE_DICT
  fsdp_cpu_ram_efficient_loading: true
  fsdp_sync_module_states: true
  fsdp_use_orig_params: true

mixed_precision: bf16
machine_rank: 0
main_process_ip: null
main_process_port: null
main_training_function: main
num_machines: 1
num_processes: {num_processes}
use_cpu: false

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok I think it’s because of use-orig-params being set to true. This was a setting that I didn’t test

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason to do the upcasting so late in the game? Won't it be much simpler to do it before the model gets wrapped by FSDP?

Accelerate knows by this time everything it needs to know to make the right thing, no?

Copy link
Contributor Author

@fabianlim fabianlim Jun 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason is because we only wanted to cast the sharded params and not the whole model. Casting the whole model is not a good idea. DeepSpeed does the same thing and only casts the sharded params

Copy link
Contributor

@stas00 stas00 Jun 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Didn't we say that this was a convenience workaround and that the model should have been loaded in fp32 in the first place for fsdp to converge well?

Frankly, when do you not shard most of the model parameters? and even if selective typically it's the big params that are sharded so 2x of that would be close to all of model's params doubled in size.

Copy link
Contributor Author

@fabianlim fabianlim Jun 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@stas00
Essentially what I am saying is that upcasting ahead is will in general incur a 2X overhead in CPU memory in most general usage patterns that overhead will be there regardless if we use low_cpu_mem mode or not
image

BTW im having a bit of trouble reproducing your error on torch 2.1.2 and 2.2.0.

  • I took the learning_rate_repro.py script in the main issue
  • I ran the following command, where dataset.json is prepared following this
  • accelerate_stas.yaml is the config that was shared above.
  • using the same stas/tiny-random-llama-2
accelerate launch \
    --num_processes 4 \
    --main_process_port $MASTER_PORT \
    --config_file accelerate_stas.yaml \
    learning_rate_repro.py  \
      --model_name stas/tiny-random-llama-2 \
      --max_seq_len 4096 \
      --num_train_epochs 1 \
      --output_dir './results' \
      --per_device_train_batch_size 8 \
      --gradient_accumulation_steps 1 \
      --include_tokens_per_second \
      --gradient_checkpointing \
      --lr_scheduler_type "linear" \
      --learning_rate 1e-6 \
      --logging_steps 1 \
	  --packing True \
       --dataset_for_packing dataset.json \
      --dataset_text_field output \
      --bf16 \
      --max_steps 100 \
      --save_strategy 'no'

and the training runs.

{'loss': 8.0075, 'grad_norm': 0.9171388149261475, 'learning_rate': 9.9e-06, 'epoch': 1.0}
{'loss': 7.9684, 'grad_norm': 1.1230570077896118, 'learning_rate': 9.800000000000001e-06, 'epoch': 2.0}
{'loss': 7.8988, 'grad_norm': 1.232378602027893, 'learning_rate': 9.7e-06, 'epoch': 3.0}
{'loss': 7.8119, 'grad_norm': 1.1301443576812744, 'learning_rate': 9.600000000000001e-06, 'epoch': 4.0}
{'loss': 7.7319, 'grad_norm': 1.0580438375473022, 'learning_rate': 9.5e-06, 'epoch': 5.0}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants