Skip to content

Commit

Permalink
add example and start docs
Browse files Browse the repository at this point in the history
  • Loading branch information
pacman100 committed Mar 12, 2024
1 parent 0c248e9 commit e54c97d
Show file tree
Hide file tree
Showing 8 changed files with 190 additions and 35 deletions.
9 changes: 7 additions & 2 deletions docs/source/accelerate/deepspeed.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@ Below is a table that summarizes the compatibility between PEFT's LoRA, [`bitsan
|---|---|
| Zero-1 | 🟢 |
| Zero-2 | 🟢 |
| Zero-3 | 🔴 |
| Zero-3 | 🟢 |

For using DeepSpeed Stage 3 + QLoRA, please share to the section []() below:

For confirming these observations, we ran the SFT (Supervised Fine-tuning) [offical example scripts](https://github.com/huggingface/trl/tree/main/examples) of the [Transformers Reinforcement Learning (TRL) library](https://github.com/huggingface/trl) using QLoRA + PEFT and the accelerate configs available [here](https://github.com/huggingface/trl/tree/main/examples/accelerate_configs). We ran these experiments on a 2x NVIDIA T4 GPU.

Note DeepSpeed-Zero3 and `bitsandbytes` are currently **not** compatible.

# Use PEFT and DeepSpeed with ZeRO3 for finetuning large models on multiple machines and multiple nodes
# Use PEFT and DeepSpeed with ZeRO3 for finetuning large models on multiple devices and multiple nodes

This section of guide will help you learn how to use our DeepSpeed [training script](https://github.com/huggingface/peft/blob/main/examples/sft/train.py) for performing SFT. You'll configure the script to do SFT (supervised fine-tuning) of Llama-70B model with LoRA and ZeRO-3 on 8xH100 80GB GPUs on a single machine. You can configure it to scale to multiple machines by changing the accelerate config.

Expand Down Expand Up @@ -171,6 +173,9 @@ In the above example, the memory consumed per GPU is 64 GB (80%) as seen in the
## More resources
You can also refer this blog post [Falcon 180B Finetuning using 🤗 PEFT and DeepSpeed](https://medium.com/@sourabmangrulkar/falcon-180b-finetuning-using-peft-and-deepspeed-b92643091d99) on how to finetune 180B Falcon model on 16 A100 GPUs on 2 machines.
# Use PEFT QLoRA and DeepSpeed with ZeRO3 for finetuning large models on a single GPU
# Use PEFT and DeepSpeed with ZeRO3 and CPU Offloading for finetuning large models on a single GPU
This section of guide will help you learn how to use our DeepSpeed [training script](https://github.com/huggingface/peft/blob/main/examples/conditional_generation/peft_lora_seq2seq_accelerate_ds_zero3_offload.py). You'll configure the script to train a large model for conditional generation with ZeRO-3 and CPU Offload.

Expand Down
4 changes: 2 additions & 2 deletions examples/sft/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ Note:
1. At present, `use_reentrant` needs to be `False` when using gradient checkpointing with Multi-GPU QLoRA else it will lead to errors. However, this leads to huge GPU memory consumption.

## Multi-GPU SFT with LoRA and DeepSpeed
When you have access to multiple GPUs, it would be better to use normal LoRA with DeepSpeed/FSDP. TO use LoRA with DeepSpeed, refer the docs at [PEFT with DeepSpeed](https://huggingface.co/docs/peft/accelerate/deepspeed).
When you have access to multiple GPUs, it would be better to use normal LoRA with DeepSpeed/FSDP. To use LoRA with DeepSpeed, refer the docs at [PEFT with DeepSpeed](https://huggingface.co/docs/peft/accelerate/deepspeed).


## Multi-GPU SFT with LoRA and FSDP
When you have access to multiple GPUs, it would be better to use normal LoRA with DeepSpeed/FSDP. TO use LoRA with DeepSpeed, refer the docs at [PEFT with FSDP](https://huggingface.co/docs/peft/accelerate/fsdp).
When you have access to multiple GPUs, it would be better to use normal LoRA with DeepSpeed/FSDP. To use LoRA with DeepSpeed, refer the docs at [PEFT with FSDP](https://huggingface.co/docs/peft/accelerate/fsdp).


22 changes: 22 additions & 0 deletions examples/sft/configs/deepspeed_config_z3_qlora.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
deepspeed_multinode_launcher: standard
offload_optimizer_device: cpu
offload_param_device: cpu
zero3_init_flag: true
zero3_save_16bit_model: true
zero_stage: 3
distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 2
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
25 changes: 25 additions & 0 deletions examples/sft/configs/fsdp_config_qlora.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_backward_prefetch: BACKWARD_PRE
fsdp_cpu_ram_efficient_loading: true
fsdp_forward_prefetch: false
fsdp_offload_params: true
fsdp_sharding_strategy: FULL_SHARD
fsdp_state_dict_type: SHARDED_STATE_DICT
fsdp_sync_module_states: true
fsdp_use_orig_params: false
machine_rank: 0
main_training_function: main
mixed_precision: 'no'
num_machines: 1
num_processes: 2
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
41 changes: 41 additions & 0 deletions examples/sft/run_peft_qlora_deepspeed_stage3.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
accelerate launch --config_file "configs/deepspeed_config_z3_qlora.yaml" train.py \
--seed 100 \
--model_name_or_path "meta-llama/Llama-2-70b-hf" \
--dataset_name "smangrul/ultrachat-10k-chatml" \
--chat_template_format "chatml" \
--add_special_tokens False \
--append_concat_token False \
--splits "train,test" \
--max_seq_len 2048 \
--num_train_epochs 1 \
--logging_steps 5 \
--log_level "info" \
--logging_strategy "steps" \
--evaluation_strategy "epoch" \
--save_strategy "epoch" \
--push_to_hub \
--hub_private_repo True \
--hub_strategy "every_save" \
--bf16 True \
--packing True \
--learning_rate 1e-4 \
--lr_scheduler_type "cosine" \
--weight_decay 1e-4 \
--warmup_ratio 0.0 \
--max_grad_norm 1.0 \
--output_dir "llama-sft-qlora-dsz3" \
--per_device_train_batch_size 2 \
--per_device_eval_batch_size 2 \
--gradient_accumulation_steps 2 \
--gradient_checkpointing True \
--use_reentrant True \
--dataset_text_field "content" \
--use_flash_attn True \
--use_peft_lora True \
--lora_r 8 \
--lora_alpha 16 \
--lora_dropout 0.1 \
--lora_target_modules "all-linear" \
--use_4bit_quantization True \
--use_nested_quant True \
--bnb_4bit_compute_dtype "bfloat16"
42 changes: 42 additions & 0 deletions examples/sft/run_peft_qlora_fsdp.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
accelerate launch --config_file "configs/fsdp_config_qlora.yaml" train.py \
--seed 100 \
--model_name_or_path "meta-llama/Llama-2-70b-hf" \
--dataset_name "smangrul/ultrachat-10k-chatml" \
--chat_template_format "chatml" \
--add_special_tokens False \
--append_concat_token False \
--splits "train,test" \
--max_seq_len 2048 \
--num_train_epochs 1 \
--logging_steps 5 \
--log_level "info" \
--logging_strategy "steps" \
--evaluation_strategy "epoch" \
--save_strategy "epoch" \
--push_to_hub \
--hub_private_repo True \
--hub_strategy "every_save" \
--bf16 True \
--packing True \
--learning_rate 1e-4 \
--lr_scheduler_type "cosine" \
--weight_decay 1e-4 \
--warmup_ratio 0.0 \
--max_grad_norm 1.0 \
--output_dir "llama-sft-qlora-fsdp" \
--per_device_train_batch_size 2 \
--per_device_eval_batch_size 2 \
--gradient_accumulation_steps 2 \
--gradient_checkpointing True \
--use_reentrant True \
--dataset_text_field "content" \
--use_flash_attn True \
--use_peft_lora True \
--lora_r 8 \
--lora_alpha 16 \
--lora_dropout 0.1 \
--lora_target_modules "all-linear" \
--use_4bit_quantization True \
--use_nested_quant True \
--bnb_4bit_compute_dtype "bfloat16" \
--bnb_4bit_quant_storage_dtype "bfloat16"
53 changes: 35 additions & 18 deletions examples/sft/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ class ModelArguments:
"""

model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
metadata={
"help": "Path to pretrained model or model identifier from huggingface.co/models"
}
)
chat_template_format: Optional[str] = field(
default="none",
Expand All @@ -29,7 +31,9 @@ class ModelArguments:
lora_r: Optional[int] = field(default=64)
lora_target_modules: Optional[str] = field(
default="q_proj,k_proj,v_proj,o_proj,down_proj,up_proj,gate_proj",
metadata={"help": "comma separated list of target modules to apply LoRA layers to"},
metadata={
"help": "comma separated list of target modules to apply LoRA layers to"
},
)
use_nested_quant: Optional[bool] = field(
default=False,
Expand All @@ -39,6 +43,10 @@ class ModelArguments:
default="float16",
metadata={"help": "Compute dtype for 4bit base models"},
)
bnb_4bit_quant_storage_dtype: Optional[str] = field(
default="float32",
metadata={"help": "Quantization storage dtype for 4bit base models"},
)
bnb_4bit_quant_type: Optional[str] = field(
default="nf4",
metadata={"help": "Quantization type fp4 or nf4"},
Expand Down Expand Up @@ -79,15 +87,21 @@ class DataTrainingArguments:
default=False,
metadata={"help": "Use packing dataset creating."},
)
dataset_text_field: str = field(default="text", metadata={"help": "Dataset field to use as input text."})
dataset_text_field: str = field(
default="text", metadata={"help": "Dataset field to use as input text."}
)
max_seq_length: Optional[int] = field(default=512)
append_concat_token: Optional[bool] = field(
default=False,
metadata={"help": "If True, appends `eos_token_id` at the end of each sample being packed."},
metadata={
"help": "If True, appends `eos_token_id` at the end of each sample being packed."
},
)
add_special_tokens: Optional[bool] = field(
default=False,
metadata={"help": "If True, tokenizers adds special tokens to each sample being packed."},
metadata={
"help": "If True, tokenizers adds special tokens to each sample being packed."
},
)
splits: Optional[str] = field(
default="train,test",
Expand All @@ -100,13 +114,19 @@ def main(model_args, data_args, training_args):
set_seed(training_args.seed)

# model
model, peft_config, tokenizer = create_and_prepare_model(model_args, data_args, training_args)
model, peft_config, tokenizer = create_and_prepare_model(
model_args, data_args, training_args
)

# gradient ckpt
model.config.use_cache = not training_args.gradient_checkpointing
training_args.gradient_checkpointing = training_args.gradient_checkpointing and not model_args.use_unsloth
training_args.gradient_checkpointing = (
training_args.gradient_checkpointing and not model_args.use_unsloth
)
if training_args.gradient_checkpointing:
training_args.gradient_checkpointing_kwargs = {"use_reentrant": model_args.use_reentrant}
training_args.gradient_checkpointing_kwargs = {
"use_reentrant": model_args.use_reentrant
}

# datasets
train_dataset, eval_dataset = create_datasets(
Expand All @@ -133,14 +153,7 @@ def main(model_args, data_args, training_args):
max_seq_length=data_args.max_seq_length,
)
trainer.accelerator.print(f"{trainer.model}")
if model_args.use_peft_lora:
# handle PEFT+FSDP case
trainer.model.print_trainable_parameters()
if getattr(trainer.accelerator.state, "fsdp_plugin", None):
from peft.utils.other import fsdp_auto_wrap_policy

fsdp_plugin = trainer.accelerator.state.fsdp_plugin
fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy(trainer.model)
trainer.model.print_trainable_parameters()

# train
checkpoint = None
Expand All @@ -155,11 +168,15 @@ def main(model_args, data_args, training_args):


if __name__ == "__main__":
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
parser = HfArgumentParser(
(ModelArguments, DataTrainingArguments, TrainingArguments)
)
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
model_args, data_args, training_args = parser.parse_json_file(
json_file=os.path.abspath(sys.argv[1])
)
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
main(model_args, data_args, training_args)
29 changes: 16 additions & 13 deletions examples/sft/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ def preprocess(samples):
elif "test" in split:
raw_datasets["test"] = dataset
else:
raise ValueError(f"Split type {split} not recognized as one of test or train.")
raise ValueError(
f"Split type {split} not recognized as one of test or train."
)

if apply_chat_template:
raw_datasets = raw_datasets.map(
Expand All @@ -75,7 +77,9 @@ def preprocess(samples):

train_data = raw_datasets["train"]
valid_data = raw_datasets["test"]
print(f"Size of the train set: {len(train_data)}. Size of the validation set: {len(valid_data)}")
print(
f"Size of the train set: {len(train_data)}. Size of the validation set: {len(valid_data)}"
)
print(f"A sample of train dataset: {train_data[0]}")

return train_data, valid_data
Expand All @@ -84,8 +88,8 @@ def preprocess(samples):
def create_and_prepare_model(args, data_args, training_args):
if args.use_unsloth:
from unsloth import FastLanguageModel
device_map = None
bnb_config = None
quant_storage_stype = None

if (
torch.distributed.is_available()
Expand All @@ -97,30 +101,27 @@ def create_and_prepare_model(args, data_args, training_args):

if args.use_4bit_quantization:
compute_dtype = getattr(torch, args.bnb_4bit_compute_dtype)
quant_storage_stype = getattr(torch, args.bnb_4bit_quant_storage_dtype)

bnb_config = BitsAndBytesConfig(
load_in_4bit=args.use_4bit_quantization,
bnb_4bit_quant_type=args.bnb_4bit_quant_type,
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=args.use_nested_quant,
bnb_4bit_quant_storage=quant_storage_stype,
)

if compute_dtype == torch.float16 and args.use_4bit_quantization:
major, _ = torch.cuda.get_device_capability()
if major >= 8:
print("=" * 80)
print("Your GPU supports bfloat16, you can accelerate training with the argument --bf16")
print(
"Your GPU supports bfloat16, you can accelerate training with the argument --bf16"
)
print("=" * 80)
elif args.use_8bit_quantization:
bnb_config = BitsAndBytesConfig(load_in_8bit=args.use_8bit_quantization)

if args.use_4bit_quantization or args.use_8bit_quantization:
device_map = (
int(os.environ.get("LOCAL_RANK", -1))
if torch.distributed.is_available() and torch.distributed.is_initialized()
else "auto"
) # {"": 0}

if args.use_unsloth:
# Load model
model, _ = FastLanguageModel.from_pretrained(
Expand All @@ -133,9 +134,9 @@ def create_and_prepare_model(args, data_args, training_args):
model = AutoModelForCausalLM.from_pretrained(
args.model_name_or_path,
quantization_config=bnb_config,
device_map=device_map,
trust_remote_code=True,
attn_implementation="flash_attention_2" if args.use_flash_attn else "eager",
torch_dtype=quant_storage_stype or torch.float32,
)

peft_config = None
Expand Down Expand Up @@ -174,7 +175,9 @@ def create_and_prepare_model(args, data_args, training_args):
# make embedding resizing configurable?
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=8)
else:
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(
args.model_name_or_path, trust_remote_code=True
)
tokenizer.pad_token = tokenizer.eos_token

if args.use_unsloth:
Expand Down

0 comments on commit e54c97d

Please sign in to comment.