-
Notifications
You must be signed in to change notification settings - Fork 56
/
prompt.py
154 lines (143 loc) · 7.64 KB
/
prompt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import json
from config import args
def create_demo_text():
if args.demo_path == 'demos/svamp.json' or args.demo_path == 'demos/svamp_6.json' or args.demo_path == 'demos/gsm8k_6.json' \
or args.demo_path == 'demos/svamp_8_6.json' or args.demo_path == 'demos/svamp_4.json' or args.demo_path == 'demos/svamp_2.json' \
or args.demo_path == 'demos/auto_svamp_prompt/svamp_2prompt.json' or args.demo_path == 'demos/auto_svamp_prompt/gsm8k_2prompt.json':
x, z, y = [], [], []
with open(args.demo_path, encoding="utf-8") as f:
json_data = json.load(f)
json_data = json_data["demo"]
for line in json_data:
x.append(line["question"])
z.append(line["rationale"])
y.append(line["pred_ans"])
index_list = list(range(len(x)))
demo_text = ""
for i in index_list:
demo_text += x[i] + " " + z[i] + " " + \
args.direct_answer_trigger_for_fewshot + " " + y[i] + ".\n\n"
elif args.demo_path == 'demos/aqua.json':
x, y, z, k = [], [], [], []
with open(args.demo_path, encoding="utf-8") as f:
json_data = json.load(f)
json_data = json_data["demo"]
for line in json_data:
x.append(line['question'])
y.append(line['answer_choice'])
z.append(line['rationale'])
k.append(line['pred_ans'])
index_list = list(range(len(x)))
demo_text = ""
for i in index_list:
demo_text += x[i] + ' ' + y[i] + '\n' + \
z[i] + args.direct_answer_trigger_for_fewshot + ' ' + k[i] + ".\n\n"
elif args.demo_path == 'demos/commonsenseqa.json':
x, y, z, k = [], [], [], []
with open(args.demo_path, encoding="utf-8") as f:
json_data = json.load(f)
json_data = json_data["demo"]
for line in json_data:
x.append(line['question'])
y.append(line['answer_choice'])
z.append(line['rationale'])
k.append(line['pred_ans'])
index_list = list(range(len(x)))
demo_text = ""
for i in index_list:
demo_text += x[i] + ' ' + y[i] + '\n' + \
z[i] + ' So the answer is' + ' ' + k[i] + ".\n\n"
elif args.demo_path == 'demos/strategyqa.json':
x, z, y = [], [], []
with open(args.demo_path, encoding="utf-8") as f:
json_data = json.load(f)
json_data = json_data["demo"]
for line in json_data:
x.append(line["question"])
z.append(line["rationale"])
y.append(line["pred_ans"])
index_list = list(range(len(x)))
demo_text = ""
for i in index_list:
demo_text += x[i] + " " + z[i] + ' So the answer is' + " " + y[i] + ".\n\n"
elif args.demo_path == 'demos/coin_flip.json':
x, z, y = [], [], []
with open(args.demo_path, encoding="utf-8") as f:
json_data = json.load(f)
json_data = json_data["demo"]
for line in json_data:
x.append(line["question"])
z.append(line["rationale"])
y.append(line["pred_ans"])
index_list = list(range(len(x)))
demo_text = ""
for i in index_list:
demo_text += x[i] + " " + z[i] + ' So the answer is' + " " + y[i] + ".\n\n"
elif args.demo_path == 'demos/last_letters.json':
x, z, y = [], [], []
with open(args.demo_path, encoding="utf-8") as f:
json_data = json.load(f)
json_data = json_data["demo"]
for line in json_data:
x.append(line["question"])
z.append(line["rationale"])
y.append(line["pred_ans"])
index_list = list(range(len(x)))
demo_text = ""
for i in index_list:
demo_text += x[i] + " " + z[i] + ' The answer is' + " " + y[i] + ".\n\n"
else:
pass
return demo_text
Few_Shot_Demo_Folder = 'few_shot_demos/'
prompt_101 = "Let's think step by step."
prompt_201 = "Let's first understand the problem and devise a plan to solve the problem. " \
"Then, let's carry out the plan to solve the problem step by step."
prompt_301 = "Let's first understand the problem, extract relevant variables and their corresponding numerals, " \
"and devise a plan. Then, let's carry out the plan, calculate intermediate variables (pay attention to " \
"correct numeral calculation and commonsense), solve the problem step by step, and show the answer."
prompt_302 = "Let's first understand the problem, extract relevant variables and their corresponding numerals, " \
"and devise a complete plan. Then, let's carry out the plan, calculate intermediate variables " \
"(pay attention to correct numerical calculation and commonsense), " \
"solve the problem step by step, and show the answer."
prompt_303 = "Let's devise a plan and solve the problem step by step."
prompt_304 = "Let's first understand the problem and devise a complete plan. " \
"Then, let's carry out the plan and reason problem step by step. Every step answer the subquestion, " \
"\"does the person flip and what is the coin's current state?\". According to the coin's last state, " \
"give the final answer (pay attention to every flip and the coin’s turning state)."
prompt_305 = "Let's first understand the problem, extract relevant variables and their corresponding numerals, " \
"and make a complete plan. Then, let's carry out the plan, calculate intermediate variables (pay " \
"attention to correct numerical calculation and commonsense), " \
"solve the problem step by step, and show the answer."
prompt_306 = "Let's first prepare relevant information and make a plan. Then, let's answer the question step by step " \
"(pay attention to commonsense and logical coherence)."
prompt_307 = "Let's first understand the problem, extract relevant variables and their corresponding numerals, " \
"and make and devise a complete plan. Then, let's carry out the plan, calculate intermediate variables " \
"(pay attention to correct numerical calculation and commonsense), " \
"solve the problem step by step, and show the answer."
def get_prompt():
if args.learning_type == 'zero_shot':
try:
demos = None
return demos, eval('prompt_{}'.format(args.prompt_id))
except NameError as e:
raise NameError('can\'t find prompt_id: {}'.format(args.prompt_id))
elif args.learning_type == 'few_shot':
demo_file = Few_Shot_Demo_Folder + f'{args.domain}_prompt_{args.prompt_id}.json'
if args.dataset.lower() in ['aqua']:
demo_file = Few_Shot_Demo_Folder + f'{args.domain}_prompt_{args.prompt_id}_choices.json'
try:
f = open(demo_file, 'r', encoding='utf-8')
demos_list = json.load(f)
demos_list = demos_list['demo']
demos = '\n'.join(demos_list)
return demos, eval('prompt_{}'.format(args.prompt_id))
except NameError as e:
raise NameError('can\'t find prompt_id: {}'.format(args.prompt_id))
except FileNotFoundError as e:
raise FileNotFoundError('can\'t find the demo file: {}'.format(demo_file))
else:
raise ValueError('not support learning_type: {}'.format(args.learning_type))
def construct_input(prompt, text):
inputs = 'Q:' + text + "\nA: " + prompt
return inputs