-
Notifications
You must be signed in to change notification settings - Fork 0
/
kto.py
186 lines (157 loc) · 9.6 KB
/
kto.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Baseline Comparison:
### KTO ###
WANDB_PROJECT=gritkto accelerate launch --config_file=config_8gpusdsz2_m7.yml kto.py --model_name_or_path HuggingFaceH4/mistral-7b-sft-beta --output_dir /data/niklas/m7-1ep-kto-v3 --report_to "wandb" --per_device_train_batch_size 4 --gradient_accumulation_steps 1 --optim rmsprop --learning_rate 5e-07 --beta 0.1 --logging_steps 1 --bf16 --sanity_check False --num_train_epochs 1
### DPO ###
WANDB_PROJECT=gritkto accelerate launch --config_file=config_8gpusdsz2_m7.yml dpo.py --model_name_or_path HuggingFaceH4/mistral-7b-sft-beta --output_dir /data/niklas/m7-1ep-dpo-v3 --report_to "wandb" --per_device_train_batch_size 4 --gradient_accumulation_steps 1 --optim rmsprop --learning_rate 5e-07 --logging_steps 1 --bf16 --sanity_check False --num_train_epochs 1
### GRIT KTO ###
# M7
WANDB_PROJECT=gritkto accelerate launch --config_file=config_8gpusdsz2_m7.yml kto.py --model_name_or_path GritLM/GritLM-7B --output_dir /data/niklas/GritLM-7B-KTO --report_to "wandb" --per_device_train_batch_size 4 --gradient_accumulation_steps 1 --optim rmsprop --learning_rate 5e-07 --beta 0.1 --logging_steps 1 --bf16 --sanity_check True --num_train_epochs 1
# M8x7
# ZeRO2
WANDB_PROJECT=gritkto accelerate launch --config_file=config_8gpusdsz2_m7.yml kto.py --model_name_or_path GritLM/GritLM-8x7B --output_dir /data/niklas/GritLM-8x7B-KTO --report_to "wandb" --per_device_train_batch_size 1 --gradient_accumulation_steps 4 --optim rmsprop --learning_rate 5e-07 --beta 0.1 --logging_steps 1 --bf16 --sanity_check True --num_train_epochs 1
# ZeRO2 - 2 machines
WANDB_PROJECT=gritkto accelerate launch --config_file=config_16gpusdsz2_m7_re.yml kto.py --model_name_or_path GritLM/GritLM-8x7B --output_dir /data/niklas/GritLM-8x7B-KTO --report_to "wandb" --per_device_train_batch_size 1 --gradient_accumulation_steps 1 --optim rmsprop --learning_rate 5e-07 --beta 0.1 --logging_steps 1 --bf16 --sanity_check True --num_train_epochs 1
WANDB_PROJECT=gritkto accelerate launch --config_file=config_16gpusdsz2_m7_re.yml --machine_rank=1 kto.py --model_name_or_path GritLM/GritLM-8x7B --output_dir /data/niklas/GritLM-8x7B-KTO --report_to "wandb" --per_device_train_batch_size 1 --gradient_accumulation_steps 1 --optim rmsprop --learning_rate 5e-07 --beta 0.1 --logging_steps 1 --bf16 --sanity_check True --num_train_epochs 1
# ZeRO3
WANDB_PROJECT=gritkto accelerate launch --config_file=config_8gpusds_m7.yml kto.py --model_name_or_path GritLM/GritLM-8x7B --output_dir /data/niklas/GritLM-8x7B-KTO --report_to "wandb" --per_device_train_batch_size 1 --gradient_accumulation_steps 1 --optim rmsprop --learning_rate 5e-07 --beta 0.1 --logging_steps 1 --bf16 --sanity_check True --num_train_epochs 1
# DPO
# Z2
WANDB_PROJECT=gritkto accelerate launch --config_file=config_8gpusdsz2_m7.yml dpo.py --model_name_or_path GritLM/GritLM-8x7B --output_dir /data/niklas/GritLM-8x7B-DPO --report_to "wandb" --per_device_train_batch_size 1 --gradient_accumulation_steps 1 --optim rmsprop --learning_rate 5e-07 --logging_steps 1 --bf16 --sanity_check True --num_train_epochs 1
# Z3
WANDB_PROJECT=gritkto accelerate launch --config_file=config_8gpusds_m7.yml dpo.py --model_name_or_path GritLM/GritLM-8x7B --output_dir /data/niklas/GritLM-8x7B-DPO --report_to "wandb" --per_device_train_batch_size 1 --gradient_accumulation_steps 1 --optim rmsprop --learning_rate 5e-07 --logging_steps 1 --bf16 --sanity_check True --num_train_epochs 1
"""
from dataclasses import dataclass, field
from typing import Optional
from datasets import Dataset, load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser
from trl import KTOConfig, KTOTrainer, ModelConfig, get_peft_config
# Define and parse arguments.
@dataclass
class ScriptArguments:
"""
The arguments for the KTO training script.
"""
# debugging
sanity_check: Optional[bool] = field(default=True, metadata={"help": "only train on 1000 samples"})
def extract_anthropic_prompt(prompt_and_response):
"""Extract the anthropic prompt from a prompt and response pair."""
search_term = "\n\nAssistant:"
search_term_idx = prompt_and_response.rfind(search_term)
assert search_term_idx != -1, f"Prompt and response does not contain '{search_term}'"
return prompt_and_response[: search_term_idx + len(search_term)]
def get_hh(split: str, sanity_check: bool = False, silent: bool = False, cache_dir: str = None) -> Dataset:
"""Load the Anthropic Helpful-Harmless dataset from Hugging Face and convert it to the necessary format.
The dataset is converted to a dictionary with the following structure:
{
'prompt': List[str],
'completion': List[str],
'label': List[bool],
}
Prompts should be structured as follows:
\n\nHuman: <prompt>\n\nAssistant:
Multiple turns are allowed, but the prompt should always start with \n\nHuman: and end with \n\nAssistant:.
"""
dataset = load_dataset("Anthropic/hh-rlhf", split=split, cache_dir=cache_dir)
if sanity_check:
dataset = dataset.select(range(min(len(dataset), 1000)))
flat_data = {
"prompt": [],
"completion": [],
"label": [],
}
for sample in dataset:
prompt = extract_anthropic_prompt(sample["chosen"])
#flat_data["prompt"].append(prompt)
flat_data["prompt"].append(f"<|user|>\n{prompt}\n<|assistant|>\n")
flat_data["completion"].append(sample["chosen"][len(prompt) :])
flat_data["label"].append(True)
#flat_data["prompt"].append(prompt)
flat_data["prompt"].append(f"<|user|>\n{prompt}\n<|assistant|>\n")
flat_data["completion"].append(sample["rejected"][len(prompt) :])
flat_data["label"].append(False)
return dataset.from_dict(flat_data)
def get_ultrabin(split: str, sanity_check: bool = False, silent: bool = False, cache_dir: str = None) -> Dataset:
dataset = load_dataset("HuggingFaceH4/ultrafeedback_binarized", split=split, cache_dir=cache_dir)
if sanity_check:
dataset = dataset.select(range(min(len(dataset), 1000)))
flat_data = {
"prompt": [],
"completion": [],
"label": [],
}
for sample in dataset:
prompt = sample["prompt"]
if len(sample["chosen"][1]["content"].strip()) > 0:
flat_data["prompt"].append(f"<|user|>\n{prompt}\n<|assistant|>\n")
flat_data["completion"].append(sample["chosen"][1]["content"])
flat_data["label"].append(True)
if len(sample["rejected"][1]["content"].strip()) > 0:
flat_data["prompt"].append(f"<|user|>\n{prompt}\n<|assistant|>\n")
flat_data["completion"].append(sample["rejected"][1]["content"])
flat_data["label"].append(False)
return dataset.from_dict(flat_data)
if __name__ == "__main__":
parser = HfArgumentParser((ScriptArguments, KTOConfig, ModelConfig))
script_args, kto_args, model_args = parser.parse_args_into_dataclasses()
# 1. load a pretrained model
model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, trust_remote_code=True, attn_implementation="sdpa", torch_dtype="auto")
model_ref = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, trust_remote_code=True, attn_implementation="sdpa", torch_dtype="auto")
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# 2. Load the Anthropic Helpful-Harmless dataset
#train_dataset = get_hh("train", sanity_check=script_args.sanity_check)
train_dataset = get_ultrabin("train_prefs", sanity_check=script_args.sanity_check)
# 3. Load evaluation dataset
#eval_dataset = get_hh("test", sanity_check=script_args.sanity_check)
eval_dataset = get_ultrabin("test_prefs", sanity_check=script_args.sanity_check)
# 4. initialize the KTO trainer
trainer = KTOTrainer(
model,
model_ref,
args=kto_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
peft_config=get_peft_config(model_args),
)
"""
# https://github.com/huggingface/trl/issues/1147#issuecomment-1896206757
prepared_model = trainer._wrap_model(
trainer.model, training=True, dataloader=None
)
if hasattr(trainer.lr_scheduler, "step"):
prepared_model, trainer.optimizer = trainer.accelerator.prepare(
prepared_model, trainer.optimizer
)
else:
(
prepared_model,
trainer.optimizer,
trainer.lr_scheduler,
) = trainer.accelerator.prepare(
prepared_model, trainer.optimizer, trainer.lr_scheduler
)
trainer.model_wrapped = prepared_model
if trainer.is_fsdp_enabled:
trainer.model = prepared_model
if trainer.ref_model is not None:
trainer.ref_model = trainer.accelerator.prepare_model(trainer.ref_model)
trainer.accelerator.prepare_model = lambda model, *args, **kwargs: model # Monkey-patch prepare_model a no-op , since we have manually prepared the models
"""
# 5. train
trainer.train()