-
Notifications
You must be signed in to change notification settings - Fork 12
/
train.py
39 lines (27 loc) · 1.06 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import argparse, os
from importlib import import_module
import torch
from src.util.config_parse import ConfigParser
from src.trainer import get_trainer_class
def main():
# parsing configuration
args = argparse.ArgumentParser()
args.add_argument('-s', '--session_name', default=None, type=str)
args.add_argument('-c', '--config', default=None, type=str)
args.add_argument('-r', '--resume', action='store_true')
args.add_argument('-g', '--gpu', default=None, type=str)
args.add_argument( '--thread', default=4, type=int)
args = args.parse_args()
assert args.config is not None, 'config file path is needed'
if args.session_name is None:
args.session_name = args.config # set session name to config file name
cfg = ConfigParser(args)
# device setting
if cfg['gpu'] is not None:
os.environ['CUDA_VISIBLE_DEVICES'] = cfg['gpu']
# intialize trainer
trainer = get_trainer_class(cfg['trainer'])(cfg)
# train
trainer.train()
if __name__ == '__main__':
main()