-
Notifications
You must be signed in to change notification settings - Fork 156
/
run_fsdp_qlora.py
185 lines (157 loc) · 6.15 KB
/
run_fsdp_qlora.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
import logging
from dataclasses import dataclass, field
import os
import random
import torch
from datasets import load_dataset
from transformers import AutoTokenizer, TrainingArguments
from trl.commands.cli_utils import TrlParser
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
set_seed,
)
from trl import setup_chat_format
from peft import LoraConfig
from trl import (
SFTTrainer)
# Comment in if you want to use the Llama 3 instruct template but make sure to add modules_to_save
# LLAMA_3_CHAT_TEMPLATE="{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}"
# Anthropic/Vicuna like template without the need for special tokens
LLAMA_3_CHAT_TEMPLATE = (
"{% for message in messages %}"
"{% if message['role'] == 'system' %}"
"{{ message['content'] }}"
"{% elif message['role'] == 'user' %}"
"{{ '\n\nHuman: ' + message['content'] + eos_token }}"
"{% elif message['role'] == 'assistant' %}"
"{{ '\n\nAssistant: ' + message['content'] + eos_token }}"
"{% endif %}"
"{% endfor %}"
"{% if add_generation_prompt %}"
"{{ '\n\nAssistant: ' }}"
"{% endif %}"
)
# ACCELERATE_USE_FSDP=1 FSDP_CPU_RAM_EFFICIENT_LOADING=1 torchrun --nproc_per_node=4 ./scripts/run_fsdp_qlora.py --config llama_3_70b_fsdp_qlora.yaml
@dataclass
class ScriptArguments:
dataset_path: str = field(
default=None,
metadata={
"help": "Path to the dataset"
},
)
model_id: str = field(
default=None, metadata={"help": "Model ID to use for SFT training"}
)
max_seq_length: int = field(
default=512, metadata={"help": "The maximum sequence length for SFT Trainer"}
)
def training_function(script_args, training_args):
################
# Dataset
################
train_dataset = load_dataset(
"json",
data_files=os.path.join(script_args.dataset_path, "train_dataset.json"),
split="train",
)
test_dataset = load_dataset(
"json",
data_files=os.path.join(script_args.dataset_path, "test_dataset.json"),
split="train",
)
################
# Model & Tokenizer
################
# Tokenizer
tokenizer = AutoTokenizer.from_pretrained(script_args.model_id, use_fast=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.chat_template = LLAMA_3_CHAT_TEMPLATE
# template dataset
def template_dataset(examples):
return{"text": tokenizer.apply_chat_template(examples["messages"], tokenize=False)}
train_dataset = train_dataset.map(template_dataset, remove_columns=["messages"])
test_dataset = test_dataset.map(template_dataset, remove_columns=["messages"])
# print random sample
with training_args.main_process_first(
desc="Log a few random samples from the processed training set"
):
for index in random.sample(range(len(train_dataset)), 2):
print(train_dataset[index]["text"])
# Model
torch_dtype = torch.bfloat16
quant_storage_dtype = torch.bfloat16
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch_dtype,
bnb_4bit_quant_storage=quant_storage_dtype,
)
model = AutoModelForCausalLM.from_pretrained(
script_args.model_id,
quantization_config=quantization_config,
attn_implementation="sdpa", # use sdpa, alternatively use "flash_attention_2"
torch_dtype=quant_storage_dtype,
use_cache=False if training_args.gradient_checkpointing else True, # this is needed for gradient checkpointing
)
if training_args.gradient_checkpointing:
model.gradient_checkpointing_enable()
################
# PEFT
################
# LoRA config based on QLoRA paper & Sebastian Raschka experiment
peft_config = LoraConfig(
lora_alpha=8,
lora_dropout=0.05,
r=16,
bias="none",
target_modules="all-linear",
task_type="CAUSAL_LM",
# modules_to_save = ["lm_head", "embed_tokens"] # add if you want to use the Llama 3 instruct template
)
################
# Training
################
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
dataset_text_field="text",
eval_dataset=test_dataset,
peft_config=peft_config,
max_seq_length=script_args.max_seq_length,
tokenizer=tokenizer,
packing=True,
dataset_kwargs={
"add_special_tokens": False, # We template with special tokens
"append_concat_token": False, # No need to add additional separator token
},
)
if trainer.accelerator.is_main_process:
trainer.model.print_trainable_parameters()
##########################
# Train model
##########################
checkpoint = None
if training_args.resume_from_checkpoint is not None:
checkpoint = training_args.resume_from_checkpoint
trainer.train(resume_from_checkpoint=checkpoint)
##########################
# SAVE MODEL FOR SAGEMAKER
##########################
if trainer.is_fsdp_enabled:
trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
trainer.save_model()
if __name__ == "__main__":
parser = TrlParser((ScriptArguments, TrainingArguments))
script_args, training_args = parser.parse_args_and_config()
# set use reentrant to False
if training_args.gradient_checkpointing:
training_args.gradient_checkpointing_kwargs = {"use_reentrant": True}
# set seed
set_seed(training_args.seed)
# launch training
training_function(script_args, training_args)