-
Notifications
You must be signed in to change notification settings - Fork 73
/
mistral_7B_rm.py
284 lines (242 loc) · 9.89 KB
/
mistral_7B_rm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
########################
# This script is modified from the TRL package https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama/scripts/reward_modeling.py
# This script is designed for the reward modeling with Mistral model which should be handled carefully because it does not have an official pad token
# If you have any question, feel free to send me an email via [email protected]
########################
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union
# import evaluate
import numpy as np
import torch
import torch.nn as nn
from datasets import load_dataset
# from peft import LoraConfig, TaskType, get_peft_model
from transformers import (
AutoModelForSequenceClassification,
AutoTokenizer,
HfArgumentParser,
Trainer,
TrainingArguments,
)
from transformers.utils import PaddingStrategy
# Define and parse arguments.
@dataclass
class ScriptArguments:
"""
These arguments vary depending on how many GPUs you have, what their capacity and features are, and what size model you want to train.
"""
local_rank: Optional[int] = field(
default=-1, metadata={"help": "Used for multi-gpu"})
deepspeed: Optional[str] = field(
default=None,
metadata={
"help": "Path to deepspeed config if using deepspeed. You may need this if the model that you want to train doesn't fit on a single GPU."
},
)
per_device_train_batch_size: Optional[int] = field(default=1)
per_device_eval_batch_size: Optional[int] = field(default=1)
gradient_accumulation_steps: Optional[int] = field(default=64)
learning_rate: Optional[float] = field(default=5e-6)
weight_decay: Optional[float] = field(default=0.001)
model_name: Optional[str] = field(
default="mistralai/Mistral-7B-Instruct-v0.2",
metadata={
"help": "The model that you want to train from the Hugging Face hub. E.g. gpt2, gpt2-xl, bert, etc."
},
)
bf16: Optional[bool] = field(
default=True,
metadata={
"help": "This essentially cuts the training time in half if you want to sacrifice a little precision and have a supported GPU."
},
)
num_train_epochs: Optional[int] = field(
default=1,
metadata={"help": "The number of training epochs for the reward model."},
)
train_set_path: Optional[str] = field(
default="hendrydong/preference_700K",
metadata={"help": "The dir of the subset of the training data to use"},
)
eval_set_path: Optional[str] = field(
default="hendrydong/preference_700K",
metadata={"help": "The dir of the subset of the eval data to use"},
)
output_path: Optional[str] = field(
default="./bt_models/mistral_rm",
metadata={"help": "The dir for output model"},
)
gradient_checkpointing: Optional[bool] = field(
default=True,
metadata={"help": "Enables gradient checkpointing."},
)
optim: Optional[str] = field(
# default="adamw_hf",
default="paged_adamw_32bit",
# default="adamw_torch_fused",
metadata={"help": "The optimizer to use."},
)
lr_scheduler_type: Optional[str] = field(
default="cosine",
metadata={"help": "The lr scheduler"},
)
max_length: Optional[int] = field(default=4096)
save_every_steps: Optional[int] = field(
default=999999,
metadata={"help": "Save the model every x steps"},
)
eval_every_steps: Optional[int] = field(
default=999999,
metadata={"help": "Eval the model every x steps"},
)
parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]
# Load the value-head model and tokenizer.
tokenizer_name = script_args.model_name
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast = False)
# Adjusted according to the base model
# Need to do this for the models that don't have an official pad token.
#tokenizer.pad_token = tokenizer.eos_token
#tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
print(tokenizer.padding_side)
tokenizer.truncation_side = "left"
tokenizer.model_max_length = script_args.max_length
# tokenizer.padding_side = "right"
###
# Get the dataset
train_path = script_args.train_set_path
eval_path = script_args.eval_set_path
output_name = script_args.output_path
def build_dataset(tokenizer, train_path, eval_path):
def tokenize(sample):
sample['positive'] = tokenizer.apply_chat_template(
sample['chosen'], tokenize=False, add_generation_prompt=False).replace(tokenizer.bos_token, "")
sample['negative'] = tokenizer.apply_chat_template(
sample['rejected'], tokenize=False, add_generation_prompt=False).replace(tokenizer.bos_token, "")
tokenized_pos = tokenizer(sample['positive'], truncation=True)
tokenized_neg = tokenizer(sample['negative'], truncation=True)
sample["input_ids_j"] = tokenized_pos["input_ids"]
sample["attention_mask_j"] = tokenized_pos["attention_mask"]
sample["input_ids_k"] = tokenized_neg["input_ids"]
sample["attention_mask_k"] = tokenized_neg["attention_mask"]
return sample
ds = load_dataset(train_path, split="train").shuffle(seed=42)
#ds = ds.select(range(2000))
ds = ds.map(tokenize, num_proc=8)
eval_dataset = None
train_dataset = ds
#eval_dataset = load_dataset(eval_path, split="train").shuffle(seed=42).select(range(500))
eval_dataset = ds.select(range(500))
return train_dataset, eval_dataset
train_dataset, eval_dataset = build_dataset(tokenizer, train_path, eval_path)
print("Training set: ", len(train_dataset), " Eval set: ", len(eval_dataset))
# Define the trainer
# Define the trainer
training_args = TrainingArguments(
output_dir=output_name,
learning_rate=script_args.learning_rate,
per_device_train_batch_size=script_args.per_device_train_batch_size,
per_device_eval_batch_size=script_args.per_device_eval_batch_size,
num_train_epochs=script_args.num_train_epochs,
weight_decay=script_args.weight_decay,
evaluation_strategy="steps",
eval_steps=script_args.eval_every_steps,
save_strategy="steps",
save_steps=script_args.save_every_steps,
gradient_accumulation_steps=script_args.gradient_accumulation_steps,
gradient_checkpointing=script_args.gradient_checkpointing,
deepspeed=script_args.deepspeed,
local_rank=script_args.local_rank,
remove_unused_columns=False,
label_names=[],
bf16=script_args.bf16,
logging_strategy="steps",
logging_steps=10,
optim=script_args.optim,
lr_scheduler_type=script_args.lr_scheduler_type,
warmup_ratio=0.03,
report_to='wandb'
)
model = AutoModelForSequenceClassification.from_pretrained(
script_args.model_name, num_labels=1, torch_dtype=torch.bfloat16, use_flash_attention_2=True,
)
model.config.use_cache = not script_args.gradient_checkpointing
model.config.pad_token_id = tokenizer.pad_token_id
model.resize_token_embeddings(len(tokenizer))
num_proc = 24 # Can adjust to be higher if you have more processors.
original_columns = train_dataset.column_names
# We need to define a special data collator that batches the data in our j vs k format.
@dataclass
class RewardDataCollatorWithPadding:
tokenizer: AutoTokenizer
padding: Union[bool, str, PaddingStrategy] = True
max_length: Optional[int] = None
pad_to_multiple_of: Optional[int] = None
return_tensors: str = "pt"
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
merged_features = []
for feature in features:
merged_features.append(
{
"input_ids": feature["input_ids_j"],
"attention_mask": feature["attention_mask_j"],
}
)
merged_features.append(
{
"input_ids": feature["input_ids_k"],
"attention_mask": feature["attention_mask_k"],
}
)
batch = self.tokenizer.pad(
merged_features,
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors=self.return_tensors,
)
batch = {
"input_ids": batch["input_ids"],
"attention_mask": batch["attention_mask"],
"return_loss": True,
}
return batch
# Define the trainer
def compute_metrics(eval_pred):
result = {}
pos_predictions_scores = eval_pred.predictions[0]
neg_predictions_scores = eval_pred.predictions[1]
# We assume that the first sample is preferred by default in groundtruth
result['accuracy'] = np.sum(
pos_predictions_scores > neg_predictions_scores) / len(pos_predictions_scores)
return result
class RewardTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False):
rewards = model(
input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"]
)[0]
bsz = rewards.size(0)
jidx = torch.arange(0, bsz, 2)
kidx = jidx + 1
rewards_j = rewards[jidx]
rewards_k = rewards[kidx]
loss = -nn.functional.logsigmoid(rewards_j - rewards_k).mean()
if return_outputs:
return loss, {"rewards_j": rewards_j, "rewards_k": rewards_k}
return loss
# Train the model, woohoo.
trainer = RewardTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
compute_metrics=compute_metrics,
data_collator=RewardDataCollatorWithPadding(
tokenizer=tokenizer, max_length=script_args.max_length),
)
trainer.train()
print("Saving last checkpoint of the model")
#model.save_pretrained(output_name + "/last_checkpoint")
trainer.save_model(output_name + "/last_checkpoint")
tokenizer.save_pretrained(output_name + "/last_checkpoint")