-
Notifications
You must be signed in to change notification settings - Fork 2
/
gen_hf2.py
136 lines (116 loc) · 3.93 KB
/
gen_hf2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
#!/usr/bin/env python
from dataclasses import dataclass, field
from typing import List, Optional
import numpy as np
import torch
from datasets import load_dataset
from transformers import (
AutoTokenizer,
HfArgumentParser,
)
from vllm import LLM, SamplingParams
import json
@dataclass
class ScriptArguments:
"""
The arguments for the DPO training script.
"""
model_name_or_path: Optional[str] = field(
default="your model",
metadata={"help": "the location of the SFT model name or path"},
)
dataset_name_or_path: Optional[str] = field(
default="RLHFlow/test_generation_2k",
metadata={"help": "the location of the dataset name or path"},
)
local_index: Optional[int] = field(
default=999,
metadata={"help": "the local index of the agent"},
)
output_dir: Optional[str] = field(
default="",
metadata={"help": "the location of the output file"},
)
my_world_size: Optional[int] = field(
default=4,
metadata={"help": "the total number of the agents"},
)
K: Optional[int] = field(
default=8,
metadata={"help": "the number of generations per prompt"},
)
max_input_length: Optional[int] = field(
default=10000,
metadata={"help": "the maximum length of the input tokens"},
)
max_new_tokens: Optional[int] = field(
default=2048,
metadata={"help": "the maximum length of the new tokens"},
)
seed: Optional[int] = field(
default=42,
metadata={"help": "the random seed"},
)
temperature: Optional[float] = field(
default=0.7,
metadata={"help": "the temperature"},
)
use_beam_search: Optional[bool] = field(
default=False,
metadata={"help": "the beam search"},
)
dataset_key: Optional[str] = field(
default="context_messages",
metadata={"help": "the key of the dataset"},
)
eos_ids: List[int] = field(default_factory=lambda: [], metadata={"help": "the ids of the end of sentence tokens"})
parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]
model_path = script_args.model_name_or_path
print("model_path", model_path)
seed = script_args.seed
# set seed
torch.manual_seed(seed)
np.random.seed(seed)
llm = LLM(
model=model_path,
tokenizer=model_path,
dtype="bfloat16",
max_model_len=script_args.max_input_length,
load_format="auto",
seed=42,
)
tokenizer = AutoTokenizer.from_pretrained(model_path)
sampling_params = SamplingParams(
temperature=script_args.temperature,
top_p=1.0,
max_tokens=script_args.max_new_tokens,
n=script_args.K,
stop_token_ids=[tokenizer.eos_token_id] + script_args.eos_ids,
#stop=["<|user|>"],
)
ds = load_dataset(script_args.dataset_name_or_path, split="train")
ds = ds.map(
lambda x: {
"prompt": tokenizer.apply_chat_template(x[script_args.dataset_key], tokenize=False, add_generation_prompt=True)
}
)
data_size = len(ds["prompt"])
one_num_share = int(data_size / script_args.my_world_size)
ds = ds.select(np.arange(script_args.local_index * one_num_share, (script_args.local_index + 1) * one_num_share))
print([script_args.local_index * one_num_share, (script_args.local_index + 1) * one_num_share])
print(ds, script_args.dataset_name_or_path)
print(ds[0])
prompts = ds["prompt"]
outputs = llm.generate(prompts, sampling_params=sampling_params, use_tqdm=True)
completions = []
used_prompts = []
gathered_data = []
for i, output in enumerate(outputs):
tmp_data = {"prompt": ds[i][script_args.dataset_key], "responses": [out.text for out in output.outputs]}
gathered_data.append(tmp_data)
print("I collect ", len(gathered_data), "samples")
with open(script_args.output_dir + str(script_args.local_index) + ".json", "w", encoding="utf8") as f:
for i in range(len(gathered_data)):
json.dump(gathered_data[i], f, ensure_ascii=False)
f.write('\n')