forked from wbenbihi/hourglasstensorflow
-
Notifications
You must be signed in to change notification settings - Fork 1
/
train_launcher.py
48 lines (40 loc) · 1.99 KB
/
train_launcher.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
"""
TRAIN LAUNCHER
"""
import configparser
from hourglass_tiny import HourglassModel
from datagen import DataGenerator
def process_config(conf_file):
"""
"""
params = {}
config = configparser.ConfigParser()
config.read(conf_file)
for section in config.sections():
if section == 'DataSetHG':
for option in config.options(section):
params[option] = eval(config.get(section, option))
if section == 'Network':
for option in config.options(section):
params[option] = eval(config.get(section, option))
if section == 'Train':
for option in config.options(section):
params[option] = eval(config.get(section, option))
if section == 'Validation':
for option in config.options(section):
params[option] = eval(config.get(section, option))
if section == 'Saver':
for option in config.options(section):
params[option] = eval(config.get(section, option))
return params
if __name__ == '__main__':
print('--Parsing Config File')
params = process_config('config.cfg')
print('--Creating Dataset')
dataset = DataGenerator(params['joint_list'], params['img_directory'], params['training_txt_file'], remove_joints=params['remove_joints'])
dataset._create_train_table()
dataset._randomize()
dataset._create_sets()
model = HourglassModel(nFeat=params['nfeats'], nStack=params['nstacks'], nModules=params['nmodules'], nLow=params['nlow'], outputDim=params['num_joints'], batch_size=params['batch_size'], attention = params['mcam'],training=True, drop_rate= params['dropout_rate'], lear_rate=params['learning_rate'], decay=params['learning_rate_decay'], decay_step=params['decay_step'], dataset=dataset, name=params['name'], logdir_train=params['log_dir_train'], logdir_test=params['log_dir_test'], tiny= params['tiny'], w_loss=params['weighted_loss'] , joints= params['joint_list'],modif=False)
model.generate_model()
model.training_init(nEpochs=params['nepochs'], epochSize=params['epoch_size'], saveStep=params['saver_step'], dataset = None)