forked from PatWalters/TS
-
Notifications
You must be signed in to change notification settings - Fork 0
/
ts_main.py
executable file
·120 lines (103 loc) · 4.07 KB
/
ts_main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
#!/usr/bin/env python
import importlib
import json
import sys
from datetime import timedelta
from timeit import default_timer as timer
import pandas as pd
from thompson_sampling import ThompsonSampler
from ts_logger import get_logger
def read_input(json_filename: str) -> dict:
"""
Read input parameters from a json file
:param json_filename: input json file
:return: a dictionary with the input parameters
"""
input_data = None
with open(json_filename, 'r') as ifs:
input_data = json.load(ifs)
module = importlib.import_module("evaluators")
evaluator_class_name = input_data["evaluator_class_name"]
class_ = getattr(module, evaluator_class_name)
evaluator_arg = input_data["evaluator_arg"]
evaluator = class_(evaluator_arg)
input_data['evaluator_class'] = evaluator
return input_data
def parse_input_dict(input_data: dict) -> None:
"""
Parse the input dictionary and add the necessary information
:param input_data:
"""
module = importlib.import_module("evaluators")
evaluator_class_name = input_data["evaluator_class_name"]
class_ = getattr(module, evaluator_class_name)
evaluator_arg = input_data["evaluator_arg"]
evaluator = class_(evaluator_arg)
input_data['evaluator_class'] = evaluator
def run_ts(input_dict: dict, hide_progress: bool = False) -> None:
"""
Perform a Thompson sampling run
:param hide_progress: hide the progress bar
:param input_dict: dictionary with input parameters
"""
evaluator = input_dict["evaluator_class"]
reaction_smarts = input_dict["reaction_smarts"]
num_ts_iterations = input_dict["num_ts_iterations"]
reagent_file_list = input_dict["reagent_file_list"]
num_warmup_trials = input_dict["num_warmup_trials"]
result_filename = input_dict.get("results_filename")
ts_mode = input_dict["ts_mode"]
log_filename = input_dict.get("log_filename")
logger = get_logger(__name__, filename=log_filename)
ts = ThompsonSampler(mode=ts_mode)
ts.set_hide_progress(hide_progress)
ts.set_evaluator(evaluator)
ts.read_reagents(reagent_file_list=reagent_file_list, num_to_select=None)
ts.set_reaction(reaction_smarts)
# run the warm-up phase to generate an initial set of scores for each reagent
ts.warm_up(num_warmup_trials=num_warmup_trials)
# run the search with TS
out_list = ts.search(num_cycles=num_ts_iterations)
total_evaluations = evaluator.counter
percent_searched = total_evaluations / ts.get_num_prods() * 100
logger.info(f"{total_evaluations} evaluations | {percent_searched:.3f}% of total")
# write the results to disk
out_df = pd.DataFrame(out_list, columns=["score", "SMILES", "Name"])
if result_filename is not None:
out_df.to_csv(result_filename, index=False)
logger.info(f"Saved results to: {result_filename}")
if not hide_progress:
if ts_mode == "maximize":
print(out_df.sort_values("score", ascending=False).drop_duplicates(subset="SMILES").head(10))
else:
print(out_df.sort_values("score", ascending=True).drop_duplicates(subset="SMILES").head(10))
return out_df
def run_10_cycles():
""" A testing function for the paper
:return: None
"""
json_file_name = sys.argv[1]
input_dict = read_input(json_file_name)
for i in range(0, 10):
input_dict['results_filename'] = f"ts_result_{i:03d}.csv"
run_ts(input_dict, hide_progress=False)
def compare_iterations():
""" A testing function for the paper
:return:
"""
json_file_name = sys.argv[1]
input_dict = read_input(json_file_name)
for i in (2, 5, 10, 50, 100):
num_ts_iterations = i * 1000
input_dict["num_ts_iterations"] = num_ts_iterations
input_dict["results_filename"] = f"iteration_test_{i}K.csv"
run_ts(input_dict)
def main():
start = timer()
json_filename = sys.argv[1]
input_dict = read_input(json_filename)
run_ts(input_dict)
end = timer()
print("Elapsed time", timedelta(seconds=end - start))
if __name__ == "__main__":
main()