-
Notifications
You must be signed in to change notification settings - Fork 12
/
main.py
91 lines (76 loc) · 2.83 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from genericpath import isfile
import json
import os
if __name__ == '__main__':
os.sys.path.append('./src')
from src.model.model import MMGNet
from src.utils.config import Config
from utils import util
import torch
import argparse
def main():
config = load_config()
os.environ["CUDA_LAUNCH_BLOCKING"] = "0"
util.set_random_seed(config.SEED)
if config.VERBOSE:
print(config)
model = MMGNet(config)
save_path = os.path.join(config.PATH,'config', model.model_name, model.exp)
os.makedirs(save_path, exist_ok=True)
save_path = os.path.join(save_path, 'config.json')
config.DEVICE = 'cuda'
if not os.path.exists(save_path):
with open(save_path, 'w') as f:
json.dump(config, f)
# init device
if torch.cuda.is_available() and len(config.GPU) > 0:
config.DEVICE = torch.device("cuda")
else:
config.DEVICE = torch.device("cpu")
# just for test
if config.MODE == 'eval':
print('start validation...')
model.load(best=True)
model.validation()
exit()
try:
model.load()
except:
print('unable to load previous model.')
print('\nstart training...\n')
model.train()
# we test the best model in the end
model.config.EVAL = True
print('start validation...')
model.load()
model.validation()
def load_config():
r"""loads model config
"""
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--config', type=str, default='config_example.json', help='configuration file name. Relative path under given path (default: config.yml)')
parser.add_argument('--loadbest', type=int, default=0,choices=[0,1], help='1: load best model or 0: load checkpoints. Only works in non training mode.')
parser.add_argument('--mode', type=str, choices=['train','trace','eval'], help='mode. can be [train,trace,eval]',required=True)
parser.add_argument('--exp', type=str)
args = parser.parse_args()
config_path = os.path.abspath(args.config)
if not os.path.exists(config_path):
raise RuntimeError('Targer config file does not exist. {}' & config_path)
# load config file
config = Config(config_path)
if 'NAME' not in config:
config_name = os.path.basename(args.config)
if len(config_name) > len('config_'):
name = config_name[len('config_'):]
name = os.path.splitext(name)[0]
translation_table = dict.fromkeys(map(ord, '!@#$'), None)
name = name.translate(translation_table)
config['NAME'] = name
config.LOADBEST = args.loadbest
config.MODE = args.mode
config.exp = args.exp
return config
if __name__ == '__main__':
main()