Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FSDP + bf16 mixed precision on Llama2-7b shows weird behavior. #2127

Closed
tmabraham opened this issue Nov 6, 2023 · 15 comments
Closed

FSDP + bf16 mixed precision on Llama2-7b shows weird behavior. #2127

tmabraham opened this issue Nov 6, 2023 · 15 comments

Comments

@tmabraham
Copy link
Contributor

tmabraham commented Nov 6, 2023

FSDP with mixed precision shows weird behavior. For example, if in the model definition we have:

    model = LlamaForCausalLM.from_pretrained(
        MODEL_NAME,
        torch_dtype=torch.bfloat16,
        use_cache=False,
    )

this performs worse than

    model = LlamaForCausalLM.from_pretrained(
        MODEL_NAME,
        use_cache=False,
    )

image

Note that this is on 0.1% of the dataset I am using. Another weird issue is just increasing from 0.1% to 1% causes it to fail to train (loss stays effectively constant), even the one with decreasing loss.

Code for reproduction: https://github.com/tmabraham/hf_fsdp/

@muellerzr
Copy link
Collaborator

@tmabraham can you provide a full repr with:

  • Commands you ran
  • Versions of transformers and versions of accelerate
  • (Other information that should've been prompted when opening a bug report)

@tmabraham
Copy link
Contributor Author

I did sbatch start_multinode.sh which runs the following command for multi-node training:

accelerate launch \
--num_processes=$(( 8 * $SLURM_JOB_NUM_NODES )) \
--num_machines $SLURM_JOB_NUM_NODES \
--machine_rank $THEID \
--main_process_ip $SLURM_LAUNCH_NODE_IPADDR \
--main_process_port $MASTER_PORT \
--mixed_precision=bf16  \
--config_file accelerate_config.yaml \
train.py \
--batch_size 8 \
--gradient_accumulate_every 1 \
--wandb_entity "tmabraham" \
--wandb_project "pubmed-llama-2" \
--wandb_name "pubmed-llama-2-7b-full-epoch-accelerate-test" \
--output_dir "/fsx/home-tmabraham/ckpts/pubmed-llama-2-7b/pubmed-llama-2-7b-full-epoch-accelerate-test" \
--dataset_name "tmabraham/pubmed-enrico-tokenized"

This is the output of accelerate env:

- `Accelerate` version: 0.24.0.dev0
- Platform: Linux-5.15.0-1037-aws-x86_64-with-glibc2.31
- Python version: 3.10.10
- Numpy version: 1.25.0
- PyTorch version (GPU?): 2.1.0+cu121 (True)
- PyTorch XPU available: False
- PyTorch NPU available: False
- System RAM: 1121.82 GB
- GPU type: NVIDIA A100-SXM4-80GB
- `Accelerate` default config:
        - compute_environment: LOCAL_MACHINE
        - distributed_type: MULTI_GPU
        - mixed_precision: bf16
        - use_cpu: False
        - debug: False
        - num_processes: 128
        - machine_rank: 0
        - num_machines: 16
        - gpu_ids: all
        - main_process_ip:
        - main_process_port: 12802
        - rdzv_backend: static
        - same_network: True
        - main_training_function: main
        - downcast_bf16: no
        - tpu_use_cluster: False
        - tpu_use_sudo: False
        - tpu_env: []

transformers was version 4.35.0.dev0

Let me know if you need any more info...

@pacman100
Copy link
Contributor

pacman100 commented Nov 8, 2023

Hello,

FSDP with mixed precision shows weird behavior. For example, if in the model definition we have:

    model = LlamaForCausalLM.from_pretrained(
        MODEL_NAME,
        torch_dtype=torch.bfloat16,
        use_cache=False,
    )

this performs worse than

    model = LlamaForCausalLM.from_pretrained(
        MODEL_NAME,
        torch_dtype=torch.bfloat16,
        use_cache=False,
    )

Above, both the code snippets are same 😅.

I observed the same behaviour couple days ago. Don't pass torch_dtype=torch.bfloat16 as it loads the entire model in BF16 which then hinders mixed-precision training wherein certain layers such as layer norms, softmax and the output logits are required to be in FP32 for stable training.

As you have passed bf16 for mixed-precision config param, it should be enough. When loading the pretrained model, just do:

model = LlamaForCausalLM.from_pretrained(MODEL_NAME, use_cache=False)

@pacman100
Copy link
Contributor

pacman100 commented Nov 8, 2023

@tmabraham, there are following bugs in the code you shared:

  1. shuffle=False for training dataloader which leads to all sorts of unstable training
  2. Don't pass torch_dtype=torch.bfloat16 as it loads the entire model in BF16 which then hinders mixed-precision training wherein certain layers such as layer norms, softmax and the output logits are required to be in FP32 for stable training.
  3. You need to reduce loss across GPUs every step and also account for gradient accumulation steps while logging it to wandb:
step_loss = accelerator.reduce(loss.detach().clone(), reduction="mean").item()
accelerator.log({"loss": step_loss  / GRADIENT_ACCUMULATE_EVERY}, step=completed_steps)

other changes:

  1. I am using my FA2 monkeypatch code.
  2. No warmup is working best for this dataset and model.

Code: https://github.com/pacman100/hf_fsdp
Command:

accelerate launch  --config_file accelerate_config.yaml  train.py \
> --batch_size 8 \
> --gradient_accumulate_every 2 \
> --wandb_name "pubmed-llama-2-7b-full-epoch-accelerate-test" \
> --output_dir "/fsx/sourab/delete_me/pubmed-llama-2-7b-full-epoch-accelerate-test" \
> --dataset_name "smangrul/pubmed_smol"

Dataset: 1% of 7M sample => smangrul/pubmed_smol

output logs:

15%|██████████████▌                                                                                 | 83/547 [26:43<2:26:53, 19.00s/it]
Screenshot 2023-11-09 at 2 14 02 AM

@jph00
Copy link

jph00 commented Nov 9, 2023

Code: https://github.com/pacman100/hf_fsdp

Unless I'm missing something, you're getting much worse loss curve than @tmabraham here -- you've divided the loss by 2; @tmabraham had loss down beneath 0.3 after making the same adjustment.

@pacman100
Copy link
Contributor

pacman100 commented Nov 15, 2023

Unless I'm missing something, you're getting much worse loss curve than @tmabraham here -- you've divided the loss by 2; @tmabraham had loss down beneath 0.3 after making the same adjustment.

Hello @jph00, are you using the exact same code, command and setup? I am using 8 GPUs with 2 gradient accumulation steps.

@pacman100
Copy link
Contributor

For reference, I get the exact same loss curve as above using the llama-recipes (https://github.com/facebookresearch/llama-recipes/tree/main/src/llama_recipes) as seen below:

Commands:

git clone https://github.com/pacman100/llama-recipes
cd llama-recipes
torchrun --nnodes 1 --nproc_per_node 8  examples/finetuning.py --batch_size_training 8

Wandb plot:
Screenshot 2023-11-15 at 12 25 48 PM

@pacman100
Copy link
Contributor

Note that this is on 0.1% of the dataset I am using. Another weird issue is just increasing from 0.1% to 1% causes it to fail to train (loss stays effectively constant), even the one with decreasing loss.

Both these issues are clarified. Feel free to close the issue.

@pacman100
Copy link
Contributor

I also see the same loss curve using Transformers FA2 integration instead of my FA2 monkey-patch (cc @younesbelkada).

Changes in the hf_fsdp/train.py are as follows:

...
- replace_llama_attn_with_flash_attn()
+ #replace_llama_attn_with_flash_attn()

# Create fresh LlamaForCausalLM model
    model = LlamaForCausalLM.from_pretrained(
        MODEL_NAME,
        use_cache=False,
+        use_flash_attention_2=True,
    )

Screenshot 2023-11-15 at 1 18 17 PM

@jph00
Copy link

jph00 commented Nov 15, 2023

Both these issues are clarified. Feel free to close the issue.

Sorry I think one of us is missing something here @pacman100 -- you've shown you're getting a loss curve which is not training correctly. As @tmabraham showed, the loss should be reducing 3x from the initial batch to the 500th. Your loss flattens out after 50-100 steps and isn't showing much reduction at all. Your model doesn't seem to be training correctly AFAICT.

@pacman100
Copy link
Contributor

As @tmabraham showed, the loss should be reducing 3x from the initial batch to the 500th.

That was on 0.1% data that too repeated multiple times, so the loss going down as it sees the same sample again is expected. I am not repeating any data and taking 1% subset. Can they please run it again with the exact steps that I have done and get back with results? I have tried accelerate integration as well as llama recipe by Meta and get the same results. Also, note that the perplexity would already be 2-3 as per my runs, hard to imagine it going down further given that the largest llama model had the training perplexity of e^1.5 (4.481).

It would help to get the exact code, command, distributed setup and the library versions.

@jph00
Copy link

jph00 commented Nov 15, 2023

Can they please run it again with the exact steps that I have done and get back with results?

Got it - yup I'll work with @tmabraham to make that happen.

@pacman100
Copy link
Contributor

To further add, in the offline information Zach provided regarding this issue, there was a plot that showed loss going down from 1.45 to 1.3 in 8000 steps. So, there is no proper information on what is being actually run.

@tmabraham
Copy link
Contributor Author

Hello @pacman100, thank you for looking into this! I am trying the changes you have made and will report back soon.

Note that I observed issues even without FlashAttention 2.0, as you can see here where I am using BetterTransformer which IIUC implements Flash Attention 1.0. did try FA2 through Transformers, and a separate FA2 Llama2 implementation from Together, all resulting in same behavior. I will try again though with these other changes.

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants