Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add util for ram efficient loading of model when using fsdp #25107

Merged
merged 16 commits into from
Aug 17, 2023

Conversation

pacman100
Copy link
Contributor

@pacman100 pacman100 commented Jul 26, 2023

What does this PR do?

  1. Fixes an issue explained in [FSDP] FSDP doesn't work (random accuracy performance) when using param_init_fn and sync_module_states=True pytorch/pytorch#105840 when using FSDP for training very large models. Should be merged after support for ram efficient loading of model with FSDP accelerate#1777

Currently, when using FSDP, the model is loaded for each of the N processes completely on CPU leading to huge CPU RAM usage. When training models like Flacon-40B with FSDP on a dgx node with 8 GPUs, it would lead to CPU RAM getting out of memory because each process is loading 160GB (40B x 4Bytes (FP32)) in CPU RAM for a total of 160*8=1280GB requirement which results in script getting killed due to out of CPU RAM.

To combat this, we load the model only on rank 0 and have it on meta device when rank!=0. Then use no-op param_init_fn along with sync_module_states=True for FSDP to properly init the weights on other ranks and broadcast the params from rank 0 to other ranks.

Usage:

No user-facing changes:

Post this PR:

accelerator.process_index=0 GPU Memory before entering the loading : 0
accelerator.process_index=0 GPU Memory consumed at the end of the loading (end-begin): 0
accelerator.process_index=0 GPU Peak Memory consumed during the loading (max-begin): 0
accelerator.process_index=0 GPU Total Peak Memory consumed during the loading (max): 0
accelerator.process_index=0 CPU Memory before entering the loading : 926
accelerator.process_index=0 CPU Memory consumed at the end of the loading (end-begin): 26415
accelerator.process_index=0 CPU Peak Memory consumed during the loading (max-begin): 31818
accelerator.process_index=0 CPU Total Peak Memory consumed during the loading (max): 32744
accelerator.process_index=0 model.lm_head.weight=Parameter containing:
tensor([[-0.0179,  0.0201, -0.0273,  ..., -0.0275, -0.0396, -0.0131],
        [-0.0510, -0.0079, -0.0383,  ..., -0.0481,  0.0581,  0.0282],
        [-0.0217, -0.0216, -0.0064,  ..., -0.0508,  0.0554, -0.0013],
        ...,
        [ 0.0425,  0.0452, -0.0131,  ...,  0.0019,  0.0476,  0.0342],
        [-0.0170, -0.0085,  0.0449,  ..., -0.0074,  0.0178,  0.0043],
        [-0.0439, -0.0859, -0.0820,  ...,  0.0130,  0.0669,  0.0884]],
       requires_grad=True)
accelerator.process_index=1 GPU Memory before entering the loading : 0
accelerator.process_index=1 GPU Memory consumed at the end of the loading (end-begin): 0
accelerator.process_index=1 GPU Peak Memory consumed during the loading (max-begin): 0
accelerator.process_index=1 GPU Total Peak Memory consumed during the loading (max): 0
accelerator.process_index=1 CPU Memory before entering the loading : 933
accelerator.process_index=1 CPU Memory consumed at the end of the loading (end-begin): 10
accelerator.process_index=1 CPU Peak Memory consumed during the loading (max-begin): 573
accelerator.process_index=1 CPU Total Peak Memory consumed during the loading (max): 1506
accelerator.process_index=1 model.lm_head.weight=Parameter containing:
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], requires_grad=True)
accelerator.process_index=0 GPU Memory before entering the prepare : 0
accelerator.process_index=0 GPU Memory consumed at the end of the prepare (end-begin): 13202
accelerator.process_index=0 GPU Peak Memory consumed during the prepare (max-begin): 15458
accelerator.process_index=0 GPU Total Peak Memory consumed during the prepare (max): 15458
accelerator.process_index=0 CPU Memory before entering the prepare : 27345
accelerator.process_index=0 CPU Memory consumed at the end of the prepare (end-begin): -26394
accelerator.process_index=0 CPU Peak Memory consumed during the prepare (max-begin): 0
accelerator.process_index=0 CPU Total Peak Memory consumed during the prepare (max): 27345
FullyShardedDataParallel(
  (_fsdp_wrapped_module): RWForCausalLM(
    (transformer): RWModel(
      (word_embeddings): Embedding(65024, 4544)
      (h): ModuleList(
        (0-31): 32 x FullyShardedDataParallel(
          (_fsdp_wrapped_module): DecoderLayer(
            (input_layernorm): LayerNorm((4544,), eps=1e-05, elementwise_affine=True)
            (self_attention): Attention(
              (maybe_rotary): RotaryEmbedding()
              (query_key_value): Linear(in_features=4544, out_features=4672, bias=False)
              (dense): Linear(in_features=4544, out_features=4544, bias=False)
              (attention_dropout): Dropout(p=0.0, inplace=False)
            )
            (mlp): MLP(
              (dense_h_to_4h): Linear(in_features=4544, out_features=18176, bias=False)
              (act): GELU(approximate='none')
              (dense_4h_to_h): Linear(in_features=18176, out_features=4544, bias=False)
            )
          )
        )
      )
      (ln_f): LayerNorm((4544,), eps=1e-05, elementwise_affine=True)
    )
    (lm_head): Linear(in_features=4544, out_features=65024, bias=False)
  )
)
accelerator.process_index=1 GPU Memory before entering the prepare : 0
accelerator.process_index=1 GPU Memory consumed at the end of the prepare (end-begin): 13202
accelerator.process_index=1 GPU Peak Memory consumed during the prepare (max-begin): 15458
accelerator.process_index=1 GPU Total Peak Memory consumed during the prepare (max): 15458
accelerator.process_index=1 CPU Memory before entering the prepare : 945
accelerator.process_index=1 CPU Memory consumed at the end of the prepare (end-begin): 4
accelerator.process_index=1 CPU Peak Memory consumed during the prepare (max-begin): 4
accelerator.process_index=1 CPU Total Peak Memory consumed during the prepare (max): 949
accelerator.process_index=1 model.lm_head.weight=Parameter containing:
tensor([[-0.0179,  0.0201, -0.0273,  ..., -0.0275, -0.0396, -0.0131],
        [-0.0510, -0.0079, -0.0383,  ..., -0.0481,  0.0581,  0.0282],
        [-0.0217, -0.0216, -0.0064,  ..., -0.0508,  0.0554, -0.0013],
        ...,
        [ 0.0425,  0.0452, -0.0131,  ...,  0.0019,  0.0476,  0.0342],
        [-0.0170, -0.0085,  0.0449,  ..., -0.0074,  0.0178,  0.0043],
        [-0.0439, -0.0859, -0.0820,  ...,  0.0130,  0.0669,  0.0884]],
       device='cuda:1', requires_grad=True)
accelerator.process_index=0 model.lm_head.weight=Parameter containing:
tensor([[-0.0179,  0.0201, -0.0273,  ..., -0.0275, -0.0396, -0.0131],
        [-0.0510, -0.0079, -0.0383,  ..., -0.0481,  0.0581,  0.0282],
        [-0.0217, -0.0216, -0.0064,  ..., -0.0508,  0.0554, -0.0013],
        ...,
        [ 0.0425,  0.0452, -0.0131,  ...,  0.0019,  0.0476,  0.0342],
        [-0.0170, -0.0085,  0.0449,  ..., -0.0074,  0.0178,  0.0043],
        [-0.0439, -0.0859, -0.0820,  ...,  0.0130,  0.0669,  0.0884]],
       device='cuda:0', requires_grad=True)

So you can see that during loading Rank 1 doesn't take any more CPU RAM. And the performance between both setups matches.

To Do:

  • Add docs in the FSDP section

@pacman100 pacman100 requested review from sgugger and removed request for sgugger July 26, 2023 10:50
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jul 26, 2023

The documentation is not available anymore as the PR was closed or merged.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As said internally, I would prefer for this to be done automatically in from_pretrained when FSDP is detected with the options that make sense. For instance we have several tests for DeepSpeed ZeRO3 in from_pretrained, one loading the state dict only on the main process.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks better, thansk!

src/transformers/trainer.py Outdated Show resolved Hide resolved
src/transformers/training_args.py Outdated Show resolved Hide resolved
src/transformers/trainer_pt_utils.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, just have the unresolved comment on why the prefetch this is removed in the trainer file.

@lwmlyy
Copy link

lwmlyy commented Aug 15, 2023

@pacman100 Hi, I've tried to run Llama2 with the two PR but it seems something went wrong. Plz check, thx!

While copying the parameter named "model.layers.29.self_attn.v_proj.weight", whose dimensions in the model are torch.Size([4096, 4096]) and whose dimensions in the checkpoint are torch.Size([4096, 4096]), an exception occurred : ('Cannot copy out of meta tensor; no data!\nException raised from copy_impl at ../aten/src/ATen/native/Copy.cpp:188 (most recent call first):\nframe #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x57 (0x7f649c6c04d7 in /opt/conda/lib/python3.10/site-packages/torch/lib/libc10.so)\nframe #1: + 0x11c32e4 (0x7f64ea8552e4 in /opt/conda/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so)\nframe #2: at::native::copy_(at::Tensor&, at::Tensor const&, bool) + 0x62 (0x7f64eb3deb32 in /opt/conda/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so)\nframe #3: at::ops::copy::redispatch(c10::DispatchKeySet, at::Tensor&, at::Tensor const&, bool) + 0x7b (0x7f64ebff07db in /opt/conda/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so)\nframe #4: + 0x5443145 (0x7f64eead5145 in /opt/conda/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so)\nframe #5: at::ops::copy::redispatch(c10::DispatchKeySet, at::Tensor&, at::Tensor const&, bool) + 0x7b (0x7f64ebff07db in /opt/conda/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so)\nframe #6: + 0x54454f4 (0x7f64eead74f4 in /opt/conda/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so)\nframe #7: at::ops::copy::call(at::Tensor&, at::Tensor const&, bool) + 0x15f (0x7f64ec04dadf in /opt/conda/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so)\nframe #8: + 0x4cdbc9 (0x7f65030ddbc9 in /opt/conda/lib/python3.10/site-packages/torch/lib/libtorch_python.so)\nframe #9: + 0x1453a3 (0x55e385cfb3a3 in /opt/conda/bin/python)\nframe #10: _PyEval_EvalFrameDefault + 0x6f3 (0x55e385ce9b13 in /opt/conda/bin/python)\nframe #11: + 0x1515df (0x55e385d075df in /opt/conda/bin/python)\nframe #12: _PyEval_EvalFrameDefault + 0x2b8f (0x55e385cebfaf in /opt/conda/bin/python)\nframe #13: _PyFunction_Vectorcall + 0x6f (0x55e385cfa03f in /opt/conda/bin/python)\nframe #14: _PyEval_EvalFrameDefault + 0x304 (0x55e385ce9724 in /opt/conda/bin/python)\nframe #15: _PyFunction_Vectorcall + 0x6f (0x55e385cfa03f in /opt/conda/bin/python)\nframe #16: _PyEval_EvalFrameDefault + 0x304 (0x55e385ce9724 in /opt/conda/bin/python)\nframe #17: _PyFunction_Vectorcall + 0x6f (0x55e385cfa03f in /opt/conda/bin/python)\nframe #18: _PyEval_EvalFrameDefault + 0x304 (0x55e385ce9724 in /opt/conda/bin/python)\nframe #19: _PyFunction_Vectorcall + 0x6f (0x55e385cfa03f in /opt/conda/bin/python)\nframe #20: _PyEval_EvalFrameDefault + 0x304 (0x55e385ce9724 in /opt/conda/bin/python)\nframe #21: _PyFunction_Vectorcall + 0x6f (0x55e385cfa03f in /opt/conda/bin/python)\nframe #22: _PyEval_EvalFrameDefault + 0x304 (0x55e385ce9724 in /opt/conda/bin/python)\nframe #23: _PyFunction_Vectorcall + 0x6f (0x55e385cfa03f in /opt/conda/bin/python)\nframe #24: _PyEval_EvalFrameDefault + 0x12ff (0x55e385cea71f in /opt/conda/bin/python)\nframe #25: _PyFunction_Vectorcall + 0x6f (0x55e385cfa03f in /opt/conda/bin/python)\nframe #26: _PyEval_EvalFrameDefault + 0x304 (0x55e385ce9724 in /opt/conda/bin/python)\nframe #27: + 0x150d7c (0x55e385d06d7c in /opt/conda/bin/python)\nframe #28: _PyEval_EvalFrameDefault + 0x12ff (0x55e385cea71f in /opt/conda/bin/python)\nframe #29: + 0x150d7c (0x55e385d06d7c in /opt/conda/bin/python)\nframe #30: _PyEval_EvalFrameDefault + 0x12ff (0x55e385cea71f in /opt/conda/bin/python)\nframe #31: _PyFunction_Vectorcall + 0x6f (0x55e385cfa03f in /opt/conda/bin/python)\nframe #32: PyObject_Call + 0xb8 (0x55e385d07a08 in /opt/conda/bin/python)\nframe #33: _PyEval_EvalFrameDefault + 0x2b8f (0x55e385cebfaf in /opt/conda/bin/python)\nframe #34: _PyFunction_Vectorcall + 0x6f (0x55e385cfa03f in /opt/conda/bin/python)\nframe #35: _PyEval_EvalFrameDefault + 0x12ff (0x55e385cea71f in /opt/conda/bin/python)\nframe #36: _PyFunction_Vectorcall + 0x6f (0x55e385cfa03f in /opt/conda/bin/python)\nframe #37: _PyEval_EvalFrameDefault + 0x304 (0x55e385ce9724 in /opt/conda/bin/python)\nframe #38: _PyFunction_Vectorcall + 0x6f (0x55e385cfa03f in /opt/conda/bin/python)\nframe #39: _PyEval_EvalFrameDefault + 0x4a35 (0x55e385cede55 in /opt/conda/bin/python)\nframe #40: + 0x1e64d2 (0x55e385d9c4d2 in /opt/conda/bin/python)\nframe #41: PyEval_EvalCode + 0x87 (0x55e385d9c417 in /opt/conda/bin/python)\nframe #42: + 0x219ed9 (0x55e385dcfed9 in /opt/conda/bin/python)\nframe #43: + 0x2147e4 (0x55e385dca7e4 in /opt/conda/bin/python)\nframe #44: + 0x98214 (0x55e385c4e214 in /opt/conda/bin/python)\nframe #45: _PyRun_SimpleFileObject + 0x1af (0x55e385dc4b1f in /opt/conda/bin/python)\nframe #46: _PyRun_AnyFileObject + 0x43 (0x55e385dc46e3 in /opt/conda/bin/python)\nframe #47: Py_RunMain + 0x39f (0x55e385dc189f in /opt/conda/bin/python)\nframe #48: Py_BytesMain + 0x39 (0x55e385d8f709 in /opt/conda/bin/python)\nframe #49: __libc_start_main + 0xf3 (0x7f6534704083 in /usr/lib/x86_64-linux-gnu/libc.so.6)\nframe #50: + 0x1d9611 (0x55e385d8f611 in /opt/conda/bin/python)\n',).

@pacman100
Copy link
Contributor Author

Hello @lwmlyy, I'm able to run 70B Llama on 32 A100 80GB GPUs with it without any issues. Can you share the config, minimal example and launch command?

@lwmlyy
Copy link

lwmlyy commented Aug 15, 2023

@pacman100 I run the code in meta-llama/llama-recipes#77 with the following command:
torchrun --nnodes 1 --nproc_per_node 4 llama_finetuning.py --enable_fsdp --pure_bf16 --model_name ../Llama-2-7b-hf --batch_size_training 1 --micro_batch_size 1 --dist_checkpoint_root_folder ../Llama-2-7b-hf --dist_checkpoint_folder fine-tuned

Could you also share the script for running Llama-70b as you mentioned?

@pacman100
Copy link
Contributor Author

Hello @lwmlyy, follow this: meta-llama/llama-recipes#77 (comment)

@lwmlyy
Copy link

lwmlyy commented Aug 15, 2023

@pacman100 As you mentioned, if the model is loaded with accelerate, no code change is needed. I wonder why the error shows up. Could you give some advice?

@pacman100
Copy link
Contributor Author

pacman100 commented Aug 15, 2023

@pacman100 I run the code in meta-llama/llama-recipes#77 with the following command:
torchrun --nnodes 1 --nproc_per_node 4 llama_finetuning.py --enable_fsdp --pure_bf16 --model_name ../Llama-2-7b-hf --batch_size_training 1 --micro_batch_size 1 --dist_checkpoint_root_folder ../Llama-2-7b-hf --dist_checkpoint_folder fine-tuned

Could you also share the script for running Llama-70b as you mentioned?

Hello, you aren't launching via accelerate launch (you are using torchrun) and as such the env variable ACCELERATE_USE_FSDP isn't enabled.

@lwmlyy
Copy link

lwmlyy commented Aug 15, 2023

@pacman100 I run the code in facebookresearch/llama-recipes#77 with the following command:
torchrun --nnodes 1 --nproc_per_node 4 llama_finetuning.py --enable_fsdp --pure_bf16 --model_name ../Llama-2-7b-hf --batch_size_training 1 --micro_batch_size 1 --dist_checkpoint_root_folder ../Llama-2-7b-hf --dist_checkpoint_folder fine-tuned
Could you also share the script for running Llama-70b as you mentioned?

Hello, you aren't launching via accelerate launch (you are using torchrun) and as such the env variable ACCELERATE_USE_FSDP isn't enabled.

@pacman100 Hi,I meet the same error when using the following command:
accelerate launch llama_finetuning.py --enable_fsdp --model_name ../Llama-2-7b-hf --batch_size_training 1 --micro_batch_size 1 --dist_checkpoint_root_folder ../Llama-2-7b-hf --dist_checkpoint_folder fine-tuned

It works fine with the command:
ACCELERATE_USE_FSDP=True accelerate launch llama_finetuning.py --enable_fsdp --model_name ../Llama-2-7b-hf --batch_size_training 1 --micro_batch_size 1 --dist_checkpoint_root_folder ../Llama-2-7b-hf --dist_checkpoint_folder fine-tuned

But the loss is nan.

@pacman100
Copy link
Contributor Author

Hello, you aren't using the accelerate integration of FSDP and you are mixing llama recipe implementation which doesn't use Accelerate. Please refer to the Accelerate docs on the proper way to use FSDP with Accelerate. Also, please raise a separate issue.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for iterating!

@pacman100 pacman100 merged commit c4c0cef into main Aug 17, 2023
3 checks passed
@pacman100 pacman100 deleted the smangrul/fsdp_cpu_ram_efficient_model_loading_util branch August 17, 2023 16:23
@nabarunbaruaAIML
Copy link

Hi @pacman100 ,

I am trying to Train Llama 70B Model in FSDP, I was going through your repo https://github.com/pacman100/ram_efficient_fsdp/blob/main/train.py, code is failing when trying to import this function load_pretrained_model_only_on_rank0 getting error "ImportError: cannot import name 'load_pretrained_model_only_on_rank0' from 'transformers' (/usr/local/lib/python3.10/dist-packages/transformers/init.py)". Tried to check this function in the Transformer Repo but couldn't find one in the main branch.

Can you please help me, how I can execute your code.

Regards
Nabarun Barua

blbadger pushed a commit to blbadger/transformers that referenced this pull request Nov 8, 2023
…ace#25107)

* add util for ram efficient loading of model when using fsdp

* make fix-copies

* fixes 😅

* docs

* making it further easier to use

* rename the function

* refactor to handle fsdp ram efficiency in `from_pretrained`

* fixes

* fixes

* fixes

* update

* fixes

* revert `load_pretrained_model_only_on_rank0`

* resolve `load_from_checkpoint`
@pkaercher
Copy link

pkaercher commented Feb 6, 2024

I'm currently using transformer v.4.37.2 and accelerate v.0.26.1 and am training on one machine with 2 GPU processors. I'm seeing the Mistral 7B model being loaded onto CPU RAM x2 (once for each processor). I don't understand why since this fix was released with earlier versions transformer v.4.32.0 and accelerate v.0.22.0 and should load the model onto CPU only once, independent of the number of processors. Any insight anyone has is super appreciated!

These are the settings in my fsdp config file:
compute_environment: LOCAL_MACHINE debug: false distributed_type: FSDP downcast_bf16: 'no' fsdp_config: fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP fsdp_backward_prefetch: BACKWARD_PRE fsdp_cpu_ram_efficient_loading: true fsdp_forward_prefetch: false fsdp_offload_params: false fsdp_sharding_strategy: 1 fsdp_state_dict_type: SHARDED_STATE_DICT fsdp_sync_module_states: true fsdp_use_orig_params: true machine_rank: 0 main_training_function: main mixed_precision: bf16 num_machines: 1 num_processes: 2 rdzv_backend: static same_network: true tpu_env: [] tpu_use_cluster: false tpu_use_sudo: false use_cpu: false

@pacman100
Copy link
Contributor Author

Hello @pkaercher,

To use this feature, you would need to use Accelerate config for FSDP along with Accelerate launcher. For more details on how to use this, please refer https://huggingface.co/docs/transformers/trainer#accelerate-and-trainer

@pkaercher
Copy link

pkaercher commented Feb 8, 2024

Hi @pacman100,
Thanks for your reply. I am using Accelerate config along with FSDP (see my config file in my post above that I created with accelerate config --config_file "fsdp_config.yaml". I am running my script with the command accelerate launch --config_file fsdp_config.yaml domain_adapt.py. Attached is my domain_adapt.py script. When I run, I see the CPU RAM go up to 65 GB for the 7B Mistral model, which is twice as much space as it should take up given 7Billion x 4Bytes = 28GB. Twice that (loading the model once for each of the 2 GPU processors I'm using) gives 56 GB, which, plus the space taken up by my environment packages and my dataset would be roughly 65 GB makes me think that accelerate is loading the Mistral model into my CPU RAM x2, which it shouldn't according to this fix.

from datasets import load_dataset
from peft import LoraConfig, get_peft_model 
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import DataCollatorForLanguageModeling
from transformers import Trainer, TrainingArguments
from transformers import DataCollatorForLanguageModeling

# Define training parameters

MODEL_CHECKPOINT = 'mistralai/Mistral-7B-Instruct-v0.2'
DATASET_PATH = '/appdata/data'
# SAVE_TO_PATH = '/appdata/embedding_models'
MASK_WHOLE_WORDS = False


ARGS = {
        'lora_alpha': 16,
        'lora_dropout': 0.1,
        'lora_r': 64,
        'output_dir': '/appdata/results',
        'per_device_train_batch_size': 1,
        'per_device_eval_batch_size': 1,
        'gradient_accumulation_steps': 16, 
        'optim': "paged_adamw_32bit",
        'evaluation_strategy': 'steps', # "epoch", default is 'no'
        'save_steps': 50, # defaults to 500
        'logging_steps': 50, # defaults to 500
        'num_train_epochs': 4,  # default is 3
        'learning_rate': 1e-4,
        'max_grad_norm': 0.3, # default is 1
        'max_steps': 500,  # training will only run to this number of steps; overrides num_train_epochs
        'warmup_ratio': 0.03,
        'lr_scheduler_type': "cosine",  # default is "linear"
        }


# Define functions

def run_domain_adaptation(model_checkpoint, dataset_path, args):  # training_dataset_name, mask_whole_words
    # Import model and tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
    model = AutoModelForCausalLM.from_pretrained(model_checkpoint)

    # Import and tokenize the data
    data = load_dataset(dataset_path , split='train')
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "right"
    tokenizer.mask_token = '<MASK>'


    def tokenize_function(examples, tokenizer=tokenizer):
        """examples is a dataset object"""
        result = tokenizer(examples["text"])
        return result


    tokenized_datasets = data.map(tokenize_function, batched=True, remove_columns=["text"])
    collator = DataCollatorForLanguageModeling(mlm=True, mlm_probability=0.15, tokenizer=tokenizer)

    peft_config = LoraConfig(
                            lora_alpha=args['lora_alpha'],
                            lora_dropout=args['lora_dropout'],
                            r=args['lora_r'],
                            bias="none",
                            task_type="CAUSAL_LM",
                            target_modules=[
                                "Wqkv",
                                "out_proj",
                                "up_proj",
                                "down_proj",
                            ])

    training_arguments = TrainingArguments(
                                        # output_dir=args['model_output_dir'],
                                        output_dir=args['output_dir'],
                                        per_device_train_batch_size=args['per_device_train_batch_size'],
                                        per_device_eval_batch_size=args['per_device_eval_batch_size'],
                                        gradient_accumulation_steps=args['gradient_accumulation_steps'],
                                        logging_steps=args['logging_steps'],
                                        save_strategy= 'epoch',
                                        # evaluation_strategy=args['evaluation_strategy'],
                                        num_train_epochs=args['num_train_epochs'],
                                        learning_rate=args['learning_rate'],
                                        bf16=True,
                                        # fsdp='full_shard',
                                        max_grad_norm=args['max_grad_norm'],
                                        warmup_ratio=args['warmup_ratio'],
                                        group_by_length=True,
                                        report_to='none',
                                        log_level='debug',
                                        )

    # Train
    model = get_peft_model(model, peft_config)
    trainer = Trainer(
                      model=model,
                      tokenizer=tokenizer,
                      data_collator=collator,
                      train_dataset=tokenized_datasets,
                    #   eval_dataset=tokenized_datasets['validation'],
                      args=training_arguments,
                     )

    trainer.train()

    # Save the model and tokenizer
    model_name = f"{model_checkpoint.split('/')[1]}"  # _{training_dataset_name}"
    trainer.save_model(fr"../embedding_models/{model_name}")
    tokenizer.save_pretrained(fr"../embedding_models/{model_name}")
    # trainer.save_model(SAVE_TO_PATH)

if __name__ == '__main__':
    run_domain_adaptation(model_checkpoint=MODEL_CHECKPOINT,
                          dataset_path=DATASET_PATH,
                          # training_dataset_name=TRAINING_DATASET_NAME,
                          args=ARGS)```

@pacman100
Copy link
Contributor Author

Hello @pkaercher,

Thank you for the above code, this helps. You are calling from_pretrained method before initializing the distributed process group. As such, from_pretrained has no info whether a distributed training run is in place and as such doesn't know which process is rank 0 or remaining ranks. For this to work when using Trainer, please create an instance of TrainingArguments before calling from_pretrained because TrainingArguments instance initializes the distributed process group.

Updating the docs here huggingface/accelerate#2430 with this information.

@pkaercher
Copy link

Thank you @pacman100 ! I did as you suggested and saw the max CPU RAM usage during loading of the model drop from 65.2 GB to 47.2 GB, so it looks like it's working now.

@fabianlim
Copy link
Contributor

fabianlim commented May 28, 2024

@pacman100 for quantized models the meta device creation seems to be skipped even for non-zero ranks, is my understanding correct?

new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
model_to_load,
state_dict,
loaded_keys,
start_prefix,
expected_keys,
device_map=device_map,
offload_folder=offload_folder,
offload_index=offload_index,
state_dict_folder=state_dict_folder,
state_dict_index=state_dict_index,
dtype=dtype,
hf_quantizer=hf_quantizer,
is_safetensors=is_safetensors,
keep_in_fp32_modules=keep_in_fp32_modules,
unexpected_keys=unexpected_keys,

  1. This means that while quantized models are typically smaller than full precision, there is no real low_cpu_mem feature for quantized models, and quantized models are always loaded in duplicity across ranks?
  2. But even for the non-zero ranks, they seem to be moved to "cpu", so wouldnt they occupy the same amount of cpu memory, as if they had been loaded from the checkpoints directly into "cpu"?
if is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized:
    for key, param in model_to_load.state_dict().items():
        if param.device == torch.device("meta"):
            set_module_tensor_to_device(
                model_to_load, key, "cpu", torch.empty(*param.size(), dtype=dtype)
            )

@Neo9061
Copy link

Neo9061 commented Jun 25, 2024

@pacman100 for quantized models the meta device creation seems to be skipped even for non-zero ranks, is my understanding correct?

new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
model_to_load,
state_dict,
loaded_keys,
start_prefix,
expected_keys,
device_map=device_map,
offload_folder=offload_folder,
offload_index=offload_index,
state_dict_folder=state_dict_folder,
state_dict_index=state_dict_index,
dtype=dtype,
hf_quantizer=hf_quantizer,
is_safetensors=is_safetensors,
keep_in_fp32_modules=keep_in_fp32_modules,
unexpected_keys=unexpected_keys,

  1. This means that while quantized models are typically smaller than full precision, there is no real low_cpu_mem feature for quantized models, and quantized models are always loaded in duplicity across ranks?
  2. But even for the non-zero ranks, they seem to be moved to "cpu", so wouldnt they occupy the same amount of cpu memory, as if they had been loaded from the checkpoints directly into "cpu"?
if is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized:
    for key, param in model_to_load.state_dict().items():
        if param.device == torch.device("meta"):
            set_module_tensor_to_device(
                model_to_load, key, "cpu", torch.empty(*param.size(), dtype=dtype)
            )

Hi @ArthurZucker @pacman100 @sgugger please help answer the question above?

It seems to be a problem when loading a large size model like 300B with 4 bit quantization. I have a similar issue documented in #31577

In particular, this line at

# For quantized modules with FSDP/DeepSpeed Stage 3, we need to quantize the parameter on the GPU
, will quantization algorithm quantizes model weights across all the GPUs or single GPU with offloading to CPU? it seems to quantizes to single GPU and thus it has OOM for super large size of model.

On the other end, by using @philschmid 's blogpost and code, I am able to load and train two models:

  1. llama-3 70B on 4 instances of g5.24xlarge, and
  2. mixtral 8x22 on 4 instances of g5.8xlarge (failed at model merging stage).

Neither g5.12xlarge (4 GPUs with 96 GB in total, and 192 GB CPU) nor g5.16xlarge (1 GPU with 24 GB and 256 GB CPU) has enough GPU memory on single GPU to load the 4-bit quantized model, thus I suspect you are doing offloading to CPU memory rather than using single GPU - rank 0 - to load the entire quantized models. But then it does not explain why Grok-1 model - HF format is failed at loading stage with 4 bit quantization.

@amyeroberts
Copy link
Collaborator

cc @SunMarc

@Neo9061
Copy link

Neo9061 commented Jun 27, 2024

Hi @SunMarc @amyeroberts any chance you can take a look my questions and share insights? many thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants