-
Notifications
You must be signed in to change notification settings - Fork 25
/
cli.py
151 lines (120 loc) · 6.54 KB
/
cli.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import argparse
import multiprocessing
import random
from swiftsage.agents import SwiftSage
from swiftsage.utils.commons import api_configs, setup_logging
from pkg_resources import resource_filename
logger = setup_logging()
def run_test(swiftsage, problem, max_iterations=5, reward_threshold=8):
logger.info(f"Testing problem: {problem}")
reasoning, solution, messages = swiftsage.solve(problem, max_iterations, reward_threshold)
logger.info(f"Final reasoning:\n{reasoning}")
logger.info(f"Final solution:\n{solution}")
logger.info("=" * 50)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("-p", "--problem", type=str)
parser.add_argument("--api_provider", default="Together", choices=["Together", "SambaNova"], type=str)
parser.add_argument("--swift_model_id", default="meta-llama/Meta-Llama-3-8B-Instruct-Turbo", type=str)
parser.add_argument("--feedback_model_id", default="meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", type=str)
parser.add_argument("--sage_model_id", default="meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo", type=str)
default_template_dir = resource_filename('swiftsage', 'prompt_templates')
parser.add_argument("--prompt_template_dir", default=default_template_dir, type=str)
parser.add_argument("--use_retrieval", action="store_true")
parser.add_argument("--start_with_sage", action="store_true")
parser.add_argument("--max_iterations", default=5, type=int)
parser.add_argument("--reward_threshold", default=8, type=int)
parser.add_argument("--swift_temperature", default=0.5, type=float, help="Temperature for the Swift model")
parser.add_argument("--swift_top_p", default=0.9, type=float, help="Top-p sampling for the Swift model")
parser.add_argument("--feedback_temperature", default=0.5, type=float, help="Temperature for the Feedback model")
parser.add_argument("--feedback_top_p", default=0.9, type=float, help="Top-p sampling for the Feedback model")
parser.add_argument("--sage_temperature", default=0.5, type=float, help="Temperature for the Sage model")
parser.add_argument("--sage_top_p", default=0.9, type=float, help="Top-p sampling for the Sage model")
args = parser.parse_args()
# if args.api_provider == "SambaNova":
# args.swift_model_id = args.swift_model_id.split("/")[-1][:-len("Turbo")]
# args.feedback_model_id = args.feedback_model_id.split("/")[-1][:-len("Turbo")]
# args.sage_model_id = args.sage_model_id.split("/")[-1][:-len("Turbo")]
return args
def main():
args = parse_args()
multiprocessing.set_start_method('spawn')
# TODO: for retrieval augmentation (not implemented yet now)
# dataset = ["Example problem 1: ...", "Example problem 2: ...", "Example problem 3: ..."]
# embeddings = np.random.rand(len(dataset), 768) # Placeholder, replace with actual embeddings
# Configuration for each LLM
# swift_config = {
# "model_id": "Meta-Llama-3.1-8B-Instruct",
# "api_config": api_configs['SambaNova']
# }
# feedback_config = {
# "model_id": "Meta-Llama-3.1-70B-Instruct",
# "api_config": api_configs['SambaNova']
# }
# sage_config = {
# "model_id": "Meta-Llama-3.1-405B-Instruct",
# "api_config": api_configs['SambaNova']
# }
swift_config = {
"model_id": args.swift_model_id,
"api_config": api_configs[args.api_provider],
"temperature": args.swift_temperature,
"top_p": args.swift_top_p,
"max_tokens": 2048,
}
feedback_config = {
"model_id": args.feedback_model_id,
"api_config": api_configs[args.api_provider],
"temperature": args.feedback_temperature,
"top_p": args.feedback_top_p,
"max_tokens": 2048,
}
sage_config = {
"model_id": args.sage_model_id,
"api_config": api_configs[args.api_provider],
"temperature": args.sage_temperature,
"top_p": args.sage_top_p,
"max_tokens": 2048,
}
# specify the path to the prompt templates
prompt_template_dir = args.prompt_template_dir
dataset = []
embeddings = [] # TODO: for retrieval augmentation (not implemented yet now)
s2 = SwiftSage(
dataset,
embeddings,
prompt_template_dir,
swift_config,
sage_config,
feedback_config,
use_retrieval=args.use_retrieval,
start_with_sage=args.start_with_sage,
)
test_problems = [
"Solve the equation: 2x + 5 = 13", # 0
"If h(x)=x-4 and g(h(x))=x^2-8x+10, find g(x)? show the formula for g(x)", # 1
"Solve the equation: 6y + 5 = 29", # 2
"Who lives longer, Lowell Sherman or Jonathan Kaplan?", # 3
"9.9 or 9.11 -- which is bigger?", # 4
"How can you solve the quadratic equation 3x^2 + 7.15x + 4 = 0 using the quadratic formula?", # 5
"Explain why sound waves cannot travel in a vacuum?", # 6
"How many grams of hydrogen (H) are present in 23.5 grams of water (H2O)?", # 7
"What is the distance between the points (2, 3) and (5, 8)?", # 8
"Why can the Hubble telescope capture clear images of distant stars and galaxies, but not a detailed image of Pluto?", # 9
"""A rectangular band formation is a formation with $m$ band members in each of $r$ rows, where $m$ and $r$ are integers. A particular band has less than 100 band members. The director arranges them in a rectangular formation and finds that he has two members left over. If he increases the number of members in each row by 1 and reduces the number of rows by 2, there are exactly enough places in the new formation for each band member. What is the largest number of members the band could have?""",
"""Tim wants to invest some money in a bank which compounds quarterly with an annual interest rate of $7\%$. To the nearest dollar, how much money should he invest if he wants a total of $\$60,\!000$ at the end of $5$ years?""",
"""In an SR latch built from NOR gates, which condition is not allowed
Options:
[ "S=0, R=2", "S=2, R=2", "S=1, R=1", "S=1, R=-1", "S=1, R=2", "S=0, R=0", "S=2, R=0", "S=1, R=0", "S=2, R=1", "S=0, R=1" ]
Which one is the correct answer?""",
# ... add other problems here ...
"""How many letter r are there in the word "strawberry"?"""
]
if not args.problem:
problem = random.choice(test_problems)
print(f"Problem: {problem}")
else:
problem = args.problem
run_test(s2, problem, args.max_iterations, args.reward_threshold)
if __name__ == '__main__':
main()