Skip to content

Commit

Permalink
support QLoRA
Browse files Browse the repository at this point in the history
  • Loading branch information
hiyouga committed Jun 3, 2023
1 parent 1bd13d7 commit 3b9eee8
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 12 deletions.
11 changes: 9 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,30 @@

## Changelog

[23/06/03] Now we support quantized training and inference (aka QLoRA). Try `--quantization_bit 4/8` argument to work with quantized model. (experimental feature)

[23/05/31] Now we support training the BLOOM & BLOOMZ models in this repo. Try `--model_name_or_path bigscience/bloomz-7b1-mt` argument to use the BLOOMZ model.

## Supported Models

- [LLaMA](https://github.com/facebookresearch/llama) (7B, 13B, 33B, 65B)
- [BLOOM](https://huggingface.co/bigscience/bloom) & [BLOOMZ](https://huggingface.co/bigscience/bloomz) (560M, 1.1B, 1.7B, 3B, 7.1B, 176B)
- [LLaMA](https://github.com/facebookresearch/llama) (7B/13B/33B/65B)
- [BLOOM](https://huggingface.co/bigscience/bloom) & [BLOOMZ](https://huggingface.co/bigscience/bloomz) (560M/1.1B/1.7B/3B/7.1B/176B)

## Supported Training Approaches

- [(Continually) pre-training](https://s3-us-west-2.amazonaws.com/openai-assets/research-covers/language-unsupervised/language_understanding_paper.pdf)
- Full-parameter training
- Partial-parameter training
- [LoRA](https://arxiv.org/abs/2106.09685)
- [QLoRA](https://arxiv.org/abs/2305.14314)
- [Supervised fine-tuning](https://arxiv.org/abs/2109.01652)
- Full-parameter training
- Partial-parameter training
- [LoRA](https://arxiv.org/abs/2106.09685)
- [QLoRA](https://arxiv.org/abs/2305.14314)
- [RLHF](https://arxiv.org/abs/2203.02155)
- [LoRA](https://arxiv.org/abs/2106.09685)
- [QLoRA](https://arxiv.org/abs/2305.14314)

## Provided Datasets

Expand Down Expand Up @@ -209,6 +214,8 @@ CUDA_VISIBLE_DEVICES=0 python src/train_sft.py \
--predict_with_generate
```

We recommend using `--per_device_eval_batch_size=1` and `--max_target_length 128` in INT8 evaluation.

### CLI Demo

```bash
Expand Down
44 changes: 34 additions & 10 deletions src/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
AutoModelForCausalLM,
AutoTokenizer,
HfArgumentParser,
Seq2SeqTrainingArguments
Seq2SeqTrainingArguments,
BitsAndBytesConfig
)
from transformers.utils import check_min_version
from transformers.utils.versions import require_version
Expand Down Expand Up @@ -167,12 +168,27 @@ def load_pretrained(

# Quantization configurations (using bitsandbytes library).
if model_args.quantization_bit is not None:
assert model_args.quantization_bit == 8, "We only accept 8-bit quantization."
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
#require_version("transformers>=4.30.0.dev0", "To fix: pip install git+https://github.com/huggingface/transformers.git")
#require_version("peft>=0.4.0.dev0", "To fix: pip install git+https://github.com/huggingface/peft.git")
#require_version("accelerate>=0.20.0.dev0", "To fix: pip install git+https://github.com/huggingface/accelerate.git")
config_kwargs["load_in_8bit"] = True
if model_args.quantization_bit == 8:
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
config_kwargs["load_in_8bit"] = True
config_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_threshold=6.0
)
elif model_args.quantization_bit == 4:
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
require_version("transformers>=4.30.0.dev0", "To fix: pip install git+https://github.com/huggingface/transformers.git")
require_version("peft>=0.4.0.dev0", "To fix: pip install git+https://github.com/huggingface/peft.git")
require_version("accelerate>=0.20.0.dev0", "To fix: pip install git+https://github.com/huggingface/accelerate.git")
config_kwargs["load_in_4bit"] = True
config_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=finetuning_args.compute_dtype,
bnb_4bit_use_double_quant=model_args.double_quantization,
bnb_4bit_quant_type=model_args.quantization_type
)
else:
raise NotImplementedError
is_mergeable = False
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))

Expand All @@ -183,7 +199,7 @@ def load_pretrained(
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
config=config,
torch_dtype=torch.float16, # the model weights are float16 type
torch_dtype=torch.bfloat16 if finetuning_args.compute_dtype == torch.bfloat16 else torch.float16,
low_cpu_mem_usage=True,
**config_kwargs
)
Expand Down Expand Up @@ -237,13 +253,13 @@ def prepare_args(

# Check arguments (do not check finetuning_args since it may be loaded from checkpoints)
if stage != "sft" and training_args.predict_with_generate:
raise ValueError("`predict_with_generate` cannot be set as True in PT, RM and PPO stages.")
raise ValueError("`predict_with_generate` cannot be set as True at PT, RM and PPO stages.")

if training_args.do_train and training_args.predict_with_generate:
raise ValueError("`predict_with_generate` cannot be set as True while training.")

if training_args.do_predict and (not training_args.predict_with_generate):
raise ValueError("Please enable `predict_with_generate` for saving model predictions.")
raise ValueError("Please enable `predict_with_generate` to save model predictions.")

if model_args.quantization_bit is not None and (not training_args.do_train):
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
Expand All @@ -257,6 +273,14 @@ def prepare_args(

training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning

if model_args.quantization_bit is not None:
if training_args.fp16:
finetuning_args.compute_dtype = torch.float16
elif training_args.bf16:
finetuning_args.compute_dtype = torch.bfloat16
else:
finetuning_args.compute_dtype = torch.float32

# Log on each process the small summary:
logger.info(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}\n"
Expand Down
5 changes: 5 additions & 0 deletions src/utils/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import json
import torch
from typing import List, Literal, Optional
from dataclasses import asdict, dataclass, field

Expand Down Expand Up @@ -207,6 +208,10 @@ class FinetuningArguments:
LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"up_proj\", \"down_proj\"], \
BLOOM choices: [\"query_key_value\", \"dense\", \"dense_\"]"}
)
compute_dtype: Optional[torch.dtype] = field(
default=None,
metadata={"help": "Used in quantization configs. Do not specify this argument manually."}
)

def __post_init__(self):
if isinstance(self.lora_target, str):
Expand Down

0 comments on commit 3b9eee8

Please sign in to comment.