Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New Version Usage Issue #24724

Closed
4 tasks
Excuses123 opened this issue Jul 10, 2023 · 28 comments · Fixed by huggingface/accelerate#1753 or #24980
Closed
4 tasks

New Version Usage Issue #24724

Excuses123 opened this issue Jul 10, 2023 · 28 comments · Fixed by huggingface/accelerate#1753 or #24980

Comments

@Excuses123
Copy link

Excuses123 commented Jul 10, 2023

System Info

  • transformers version: 4.29.0
  • Platform: Linux-3.10.0-1160.92.1.el7.x86_64-x86_64-with-glibc2.31
  • Python version: 3.10.9
  • Huggingface_hub version: 0.15.1
  • Safetensors version: 0.3.1
  • PyTorch version (GPU?): 2.0.1+cu117 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?:
  • Using distributed or parallel set-up in script?:

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

##Here is my code.

import os
import logging
from dataclasses import dataclass, field
from typing import Dict, Optional, Sequence

import torch
import transformers
from datasets import load_dataset, load_from_disk
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    Trainer,
    DataCollatorForSeq2Seq,
)

IGNORE_INDEX = -100

PROMPT_DICT = {
    "prompt_input": (
        "### 指令:\n{instruction}\n\n### 输入:\n{input}\n\n### 回答:"
    ),
    "prompt_no_input": (
        "### 指令:\n{instruction}\n\n### 回答:"
    ),
}


@dataclass
class TrainingArguments(transformers.TrainingArguments):
    model_name_or_path: Optional[str] = field(default=None, metadata={"help": "模型名称"})
    cache_dir: Optional[str] = field(default=None, metadata={"help": "模型地址"})
    data_path: str = field(default=None, metadata={"help": "数据地址"})
    mask_input: bool = field(default=True, metadata={"help": "是否遮掉指令,只计算回答的损失"})
    model_max_length: int = field(default=512, metadata={"help": "最大序列长度"})
    optim: str = field(default="adamw_torch", metadata={"help": "优化器"})


@dataclass
class DataCollatorForSupervisedDataset(object):
    """Collate examples for supervised fine-tuning."""

    tokenizer: transformers.PreTrainedTokenizer

    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        input_ids, labels = tuple([torch.tensor(instance[key]) for instance in instances]
                                  for key in ("input_ids", "labels"))
        input_ids = torch.nn.utils.rnn.pad_sequence(
            input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
        )
        labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
        return dict(
            input_ids=input_ids,
            labels=labels,
            attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
        )


def train():
    local_rank = int(os.environ["LOCAL_RANK"])

    parser = transformers.HfArgumentParser(TrainingArguments)
    training_args, = parser.parse_args_into_dataclasses()
    if local_rank == 0:
        print(training_args)

    tokenizer = AutoTokenizer.from_pretrained(
        training_args.model_name_or_path,
        cache_dir=training_args.cache_dir,
        model_max_length=training_args.model_max_length,
        padding_side="right"
    )

    model = AutoModelForCausalLM.from_pretrained(
        training_args.model_name_or_path,
        cache_dir=training_args.cache_dir,
        # torch_dtype=torch.float16
    )

    def generate_and_tokenize(sample):
        prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]

        source = prompt_input.format_map(sample) if sample.get("input", "") != "" \
            else prompt_no_input.format_map(sample)
        target = f"\n{sample['output']}{tokenizer.eos_token}"
        complete = source + target

        # </s> 1 2 3 : a b </s>
        complete_tokenized = tokenizer(complete,
                                       truncation=True,
                                       max_length=training_args.model_max_length)
        # </s> 1 2 3 :
        source_tokenized = tokenizer(source,
                                     truncation=True,
                                     max_length=training_args.model_max_length)

        if training_args.mask_input:
            source_len = len(source_tokenized['input_ids'])
            complete_tokenized['labels'] = [IGNORE_INDEX] * source_len + complete_tokenized['input_ids'][source_len:]
        else:
            complete_tokenized['labels'] = complete_tokenized['input_ids'].copy()

        return complete_tokenized

    tokenized_path = os.path.join(os.path.dirname(training_args.data_path),
                                  f"{training_args.model_name_or_path.split('/')[-1]}_tokenized")
    if not os.path.exists(tokenized_path):
        logging.warning("tokenized data not existed, tokenize data...")
        data = load_dataset("json", data_files=training_args.data_path)
        train_dataset = data['train'].shuffle().map(generate_and_tokenize,
                                                    batched=False,
                                                    remove_columns=["instruction", "input", "output"])
        if local_rank == 0:
            train_dataset.save_to_disk(tokenized_path)
    else:
        logging.warning("tokenized data existed, load data...")
        train_dataset = load_from_disk(tokenized_path)
    # data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
    data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer,
                                           label_pad_token_id=IGNORE_INDEX,
                                           pad_to_multiple_of=8)

    logging.warning("training...")
    trainer = Trainer(model=model,
                      tokenizer=tokenizer,
                      args=training_args,
                      train_dataset=train_dataset,
                      eval_dataset=None,
                      data_collator=data_collator)

    trainer.train()
    trainer.save_state()
    trainer.save_model(output_dir=training_args.output_dir)
    tokenizer.save_pretrained(save_directory=training_args.output_dir)

if __name__ == '__main__':
    train()

Expected behavior

Has anyone encountered this problem? I used the same instruction fine-tuning code. It runs successfully with transformers package version 4.29.0, but when I upgrade to version 4.30.2, it fails to run and throws an OOM (Out of Memory) error. Does anyone know the reason behind this?

Below is the GPU status during my successful run.
image

@Excuses123
Copy link
Author

Here's another question, in the new version of the Transformers package, the default loaded model by from_pretrained has become safeTensors. How can I change it to pytorch.bin? Is there any parameter I can specify?

@amyeroberts
Copy link
Collaborator

Hi @Excuses123, thanks for raising this issue.

Without knowing the model or dataset, we're unable to reproduce and won't be able to debug this issue. Is there a minimal reproducible snippet with a public dataset and model checkpoint where this issue (increase memory footprint) still occurs and you could share?

To force the model to not load safetensor weights you can pass use_safetensors=False in the from_pretrained call

@Excuses123
Copy link
Author

Excuses123 commented Jul 11, 2023

@amyeroberts Thank you for your response.

I am using the model:
bigscience/bloomz-1b1

The data can be found at: https://huggingface.co/datasets/BelleGroup/train_0.5M_CN/blob/main/Belle_open_source_0.5M.json

Below is the execution script:

torchrun --nproc_per_node=4 --master_port=12345 train.py \
    --model_name_or_path bigscience/bloomz-1b1 \
    --cache_dir /workspace/pretrain_model/bloomz \
    --output_dir /workspace/finetune_model/bloomz/bloomz_1b1_sft \
    --data_path /workspace/datasets/Belle_train_0.5M_CN/Belle_open_source_0.5M.json \
    --fp16 True \
    --num_train_epochs 1 \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 1 \
    --gradient_accumulation_steps 32 \
    --model_max_length 512 \
    --evaluation_strategy "no" \
    --save_strategy "steps" \
    --save_steps 2000 \
    --save_total_limit 1 \
    --learning_rate 2e-5 \
    --weight_decay 0. \
    --warmup_ratio 0.03 \
    --lr_scheduler_type "cosine" \
    --logging_steps 1 \
    --fsdp "full_shard auto_wrap" \
    --fsdp_transformer_layer_cls_to_wrap 'BloomBlock' \
    --report_to "tensorboard"

After testing, The maximum version that can currently run is 4.29.2, and all versions after that cannot run.

@Excuses123
Copy link
Author

I guess it might be caused by FSDP (Fully Sharded Data Parallelism), but I'm not sure.

@amyeroberts
Copy link
Collaborator

@Excuses123 Have you tried running without FDSP? Which version of accelerate are you running?

@Excuses123
Copy link
Author

Excuses123 commented Jul 12, 2023

@amyeroberts I have tried it, and without FSDP, both the new and old versions of transformers throw an OOM error. My accelerate version is 0.20.3.

@amyeroberts
Copy link
Collaborator

both the new and old versions of transformers throw an OOM error.

@Excuses123 Is this including versions <= 4.29.2 ?

@Excuses123
Copy link
Author

@amyeroberts I have tried version 4.29.0 and it works

@amyeroberts
Copy link
Collaborator

amyeroberts commented Jul 13, 2023

@Excuses123 OK, thanks for confirming.

Could you:

  • Format the code example so that all of the code is in markdown code blocks: ``` code goes here ```
  • Try on the most recent version of transformers, installing from source?
  • Share the versions of datasets being used?

@Excuses123
Copy link
Author

@amyeroberts I have fixed the code formatting, and the version of my datasets is 2.11.0. My machine is currently running a task, and as soon as it is finished, I will try the latest version.

@larrylawl
Copy link

Facing the same issue. Code ran smoothly with transformers==4.28.1 but OOM with transformers==4.30.2

@amyeroberts
Copy link
Collaborator

@Excuses123 @larrylawl OK, thanks for the information and updates.

I'm going to cc @pacman100 and @younesbelkada who know more about training in fp16 and torchrun

@Ying1123
Copy link

Ying1123 commented Jul 21, 2023

I can confirm this. It is a bug introduced recently. It can be reproduced by the Vicuna training example.
The script works well for 4.28.1 but hits OOM with 4.31.0.

With 4.31.0, the warning is

FSDP Warning: When using FSDP, it is efficient and recommended to call prepare for the model before creating the optimizer
FSDP Warning: When using FSDP, several parameter groups will be conflated into a single one due to nested module wrapping and parameter flattening.

To fix it, I followed the guide and changed these lines (

if delay_optimizer_creation:
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
# prepare using `accelerator` prepare
if use_accelerator_prepare:
self.model.train()
if hasattr(self.lr_scheduler, "step"):
if self.use_apex:
model = self.accelerator.prepare(self.model)
else:
model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
else:
# to handle cases wherein we pass "DummyScheduler" such as when it is specified in DeepSpeed config.
model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(
self.model, self.optimizer, self.lr_scheduler
)
) to

        model = self.accelerator.prepare(model)
        if delay_optimizer_creation:
            self.create_optimizer_and_scheduler(num_training_steps=max_steps)
        self.optimizer = self.accelerator.prepare(self.optimizer)

Then the warnings and OOM disappeared.

@pacman100 @younesbelkada I think my fix is a hack that only works for my case. Could you do a more complete fix in the main branch?

@pacman100
Copy link
Contributor

Hello @Ying1123, Thank you for the detailed info, very helpful. Could you please try out the above PRs for accelerate and transformers and see if it fixes the OOM?

@Ying1123
Copy link

Hello @Ying1123, Thank you for the detailed info, very helpful. Could you please try out the above PRs for accelerate and transformers and see if it fixes the OOM?

Thanks @pacman100, cherry-pick the PRs for transformers v4.31.0 and accelerate v0.21.0 works for me.

@merrymercy
Copy link
Contributor

@pacman100 Hi, I am still getting out-of-memory issues with the latest main.
With transformer==4.28.1, the vicuna-7b example can run on 4xA100 (40GB) without any issues.

After accelerate is used for FSDP (from v4.30 - the current main), the example hits OOM.
Before your fix, the example hits OOM immediately. After your fix, the example hits OOM after a few batches.

From these observations, I can confirm that the recent refactoring makes the memory usage higher than the older version but I do not know how to debug because I am not familiar with Accelerate.
Could you do more testing and help us fix it? This blocks us from updating transformers to the latest version.

@pacman100
Copy link
Contributor

Hello @merrymercy, can you post the vram usage with the 4.28 version?

@Xuekai-Zhu
Copy link

Xuekai-Zhu commented Jul 26, 2023

Hi @pacman100 @Ying1123 , I meet the same issus: OOM ; And I revised my tranfomers to 4.31.0 or 4.30.0 and accelerate=0.21.0, all these are not worked !
On 2 x A6000 48G, fine-tuning LLaMA 7B
With transformer=4.31.0, accelerate=0.22.0.dev0 (latest main), the warning is:

FutureWarning: using `--fsdp_transformer_layer_cls_to_wrap` is deprecated. Use fsdp_config instead
FSDP Warning: When using FSDP, it is efficient and recommended to call prepare for the model before creating the optimizer.
FSDP Warning: When using FSDP, several parameter groups will be conflated into a single one due to nested module wrapping and parameter flattening.

And my fsdp are:

    --fsdp "full_shard auto_wrap" \
    --fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \

@Xuekai-Zhu
Copy link

Xuekai-Zhu commented Jul 26, 2023

@pacman100 @Ying1123 And I found another way to add the fsdp_config.json can disappear the all follow warning :

FutureWarning: using `--fsdp_transformer_layer_cls_to_wrap` is deprecated. Use fsdp_config instead

And hacking method can disappear:

FSDP Warning: When using FSDP, it is efficient and recommended to call prepare for the model before creating the optimizer.
FSDP Warning: When using FSDP, several parameter groups will be conflated into a single one due to nested module wrapping and parameter flattening.

But all these still hit on OOM !
My fsdp_config.json is:

{
    "fsdp_auto_wrap_policy": "FULL_SHARD",
    "fsdp_transformer_layer_cls_to_wrap": "LlamaDecoderLayer"
}

I think there is better way to fix this.

@pacman100
Copy link
Contributor

I see same memory usage across versions for the following example:

cd transformers

export TASK_NAME=mrpc

torchrun --nnodes 1 --nproc-per-node 2 ./examples/pytorch/text-classification/run_glue.py   --model_name_or_path bert-base-cased   --task_name $TASK_NAME   --do_train   --do_eval   --max_seq_length 128   --per_device_train_batch_size 16   --learning_rate 5e-5   --num_train_epochs 3   --output_dir /tmp/$TASK_NAME/ --overwrite_output_dir --fsdp "full_shard auto_wrap" --fsdp_transformer_layer_cls_to_wrap BertLayer --bf16

version 4.28.1 - 5.4GB vram
latest main branch - 4.8GB vram

Please provide a minimal example that I can directly run without having to spend time in getting it to work.

@Xuekai-Zhu
Copy link

You mean the
transformers=the latest main branch;
accelerate=0.21.0 ?

@pacman100
Copy link
Contributor

Both Accelerate and Transformers main branch

@alanxmay
Copy link

alanxmay commented Aug 3, 2023

With both Accelerate and Transformers main branch works for me

@Zhuqln
Copy link

Zhuqln commented Aug 18, 2023

@Xuekai-Zhu did you fix the problem? i met the same oom as 2xA6000 with both main branch

@JACKHAHA363
Copy link

JACKHAHA363 commented Aug 31, 2023

I confirm using @Ying1123 's hacking does not work for me. I have 4 A100 card, with transformers==4.31.0, accelerator==0.21.0.

@Zhuqln
Copy link

Zhuqln commented Aug 31, 2023

due to this method. downgrade to transformer==4.28.1 worked for me

@pacman100 Hi, I am still getting out-of-memory issues with the latest main. With transformer==4.28.1, the vicuna-7b example can run on 4xA100 (40GB) without any issues.

After accelerate is used for FSDP (from v4.30 - the current main), the example hits OOM. Before your fix, the example hits OOM immediately. After your fix, the example hits OOM after a few batches.

From these observations, I can confirm that the recent refactoring makes the memory usage higher than the older version but I do not know how to debug because I am not familiar with Accelerate. Could you do more testing and help us fix it? This blocks us from updating transformers to the latest version.

@sdanyaani
Copy link

I tried all the solution still getting OOM on A100 80GB

@ArthurZucker
Copy link
Collaborator

If you still have an issue I suggest you to create a new issue, share a reproducer, a traceback and ping @pacman100, otherwise there is no way we can help you 😓

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet