-
Notifications
You must be signed in to change notification settings - Fork 54
/
train_hpo.py
183 lines (166 loc) · 7.86 KB
/
train_hpo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import argparse
from cords.utils.config_utils import load_config_data
from ray.tune.suggest.hyperopt import HyperOptSearch
from ray.tune.suggest.bayesopt import BayesOptSearch
from ray.tune.suggest.skopt import SkOptSearch
from ray.tune.suggest.dragonfly import DragonflySearch
from ray.tune.suggest.ax import AxSearch
from ray.tune.suggest.bohb import TuneBOHB
from ray.tune.suggest.nevergrad import NevergradSearch
from ray.tune.suggest.optuna import OptunaSearch
from ray.tune.suggest.zoopt import ZOOptSearch
from ray.tune.suggest.sigopt import SigOptSearch
from ray.tune.suggest.hebo import HEBOSearch
from ray.tune.schedulers import AsyncHyperBandScheduler
from ray.tune.schedulers import HyperBandScheduler
from ray.tune.schedulers.hb_bohb import HyperBandForBOHB
from ray import tune
from train_sl import TrainClassifier
class HyperParamTuning:
def __init__(self, config_file_data, train_config_data):
self.cfg = config_file_data
self.train_class = TrainClassifier(train_config_data)
self.train_class.cfg.train_args.print_every = 1
self.search_algo = self.get_search_algo(self.cfg.search_algo, self.cfg.space, self.cfg.metric, self.cfg.mode)
self.scheduler = self.get_scheduler(self.cfg.scheduler, self.cfg.metric, self.cfg.mode)
# save subset method, to be used in log dir name
self.subset_method = self.train_class.cfg.dss_args.type
def param_tune(self, config):
#update parameters in config dict
new_config = self.update_parameters(self.train_class.cfg, config)
self.train_class.cfg = new_config
# turn on reporting to ray every time
self.train_class.cfg.report_tune = True
self.train_class.train()
def start_eval(self):
if self.search_algo is None:
analysis = tune.run(
self.param_tune,
num_samples=self.cfg.num_evals,
config=self.cfg.space,
search_alg=self.search_algo,
scheduler=self.scheduler,
resources_per_trial=self.cfg.resources,
local_dir=self.cfg.log_dir+self.subset_method+'/',
log_to_file=True,
name=self.cfg.name,
resume=self.cfg.resume)
else:
analysis = tune.run(
self.param_tune,
num_samples=self.cfg.num_evals,
search_alg=self.search_algo,
scheduler=self.scheduler,
resources_per_trial=self.cfg.resources,
local_dir=self.cfg.log_dir+self.subset_method+'/',
log_to_file=True,
name=self.cfg.name,
resume=self.cfg.resume)
best_config = analysis.get_best_config(metric=self.cfg.metric, mode=self.cfg.mode)
print("Best Config: ", best_config)
if self.cfg['final_train']:
self.final_train(best_config)
def get_search_algo(self, method, space, metric, mode):
# HyperOptSearch
if method == "hyperopt" or method == "TPE":
search = HyperOptSearch(space, metric = metric, mode = mode)
# BayesOptSearch
elif method == "bayesopt" or method == "BO":
search = BayesOptSearch(space, metric = metric, mode = mode)
# SkoptSearch
elif method == "skopt" or method == "SKBO":
search = SkOptSearch(space, metric = metric, mode = mode)
# DragonflySearch
elif method == "dragonfly" or method == "SBO":
search = DragonflySearch(space, metric = metric, mode = mode)
# AxSearch
elif method == "ax" or method == "BBO":
search = AxSearch(space, metric = metric, mode = mode)
# TuneBOHB
elif method == "tunebohb" or method == "BOHB":
search = TuneBOHB(space, metric = metric, mode = mode)
# NevergradSearch
elif method == "nevergrad" or method == "GFO":
search = NevergradSearch(space, metric = metric, mode = mode)
# OptunaSearch
elif method == "optuna" or method == "OSA":
search = OptunaSearch(space, metric = metric, mode = mode)
# ZOOptSearch
elif method == "zoopt" or method == "ZOO":
search = ZOOptSearch(space, metric = metric, mode = mode)
# SigOptSearch
elif method == "sigopt":
search = SigOptSearch(space, metric = metric, mode = mode)
# HEBOSearch
elif method == "hebo" or method == "HEBO":
search = HEBOSearch(space, metric = metric, mode = mode)
else:
search = None
return search
def get_scheduler(self, method, metric, mode):
if method == "ASHA" or method == "asha":
scheduler = AsyncHyperBandScheduler(metric = metric, mode = mode,
max_t = self.train_class.cfg.train_args.num_epochs)
elif method == "hyperband" or method == "HB":
scheduler = HyperBandScheduler(metric = metric, mode = mode,
max_t = self.train_class.cfg.train_args.num_epochs)
elif method == "BOHB":
scheduler = HyperBandForBOHB(metric = metric, mode = mode)
else:
scheduler = None
return scheduler
def final_train(self, best_params):
# change strategy to Full (i.e use whole dataset)
# update (optimized) parameters
new_config = self.update_parameters(self.train_class.cfg, best_params)
if self.cfg.final_train_type in ['Full', 'full']:
new_config.dss_args.type = 'Full'
elif self.cfg.final_train_type in ['GradMatchPB', 'gmpb']:
new_config.dss_args.type = 'GradMatchPB'
new_config.dss_args.fraction = 0.3
new_config.dss_args.select_every = 5
new_config.dss_args.lam = 0
new_config.dss_args.selection_type = 'PerBatch'
new_config.dss_args.v1 = True
new_config.dss_args.valid = False
new_config.dss_args.eps = 1e-100
new_config.dss_args.linear_layer = True
new_config.dss_args.kappa = 0.5
else:
print('Unknow final_train_type in Hyperparameter tuning class. Exiting...')
exit(1)
self.train_class.cfg = new_config
self.train_class.train()
def update_parameters(self, config, new_config):
# a generic function to update parameters
if 'learning_rate' in new_config:
config.optimizer.lr = new_config['learning_rate']
if 'learning_rate1' in new_config:
config.optimizer.lr1 = new_config['learning_rate1']
if 'learning_rate2' in new_config:
config.optimizer.lr2 = new_config['learning_rate2']
if 'learning_rate3' in new_config:
config.optimizer.lr3 = new_config['learning_rate3']
if 'optimizer' in new_config:
config.optimizer.type = new_config['optimizer']
if 'nesterov' in new_config:
config.optimizer.nesterov = new_config['nesterov']
if 'scheduler' in new_config:
config.scheduler.type = new_config['scheduler']
if 'gamma' in new_config:
config.scheduler.gamma = new_config['gamma']
if 'epochs' in new_config:
config.train_args.num_epochs = new_config['epochs']
if 'trn_batch_size' in new_config:
config.dataloader.batch_size = new_config['trn_batch_size']
if 'hidden_size' in new_config:
config.model.hidden_size = new_config['hidden_size']
if 'num_layers' in new_config:
config.model.num_layers = new_config['num_layers']
return config
if __name__ == "__main__":
argparser = argparse.ArgumentParser()
argparser.add_argument("--config_file", default="configs/config_hyper_param_tuning.py")
args = argparser.parse_args()
hyperparam_tuning = HyperParamTuning(load_config_data(args.config_file))
hyperparam_tuning.start_eval()