-
Notifications
You must be signed in to change notification settings - Fork 11
/
Train.py
41 lines (35 loc) · 1.28 KB
/
Train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import parses.parses_train_gf as parses_gf
import parses.parses_train_rd as parses_rd
import parses.parses_train_rm as parses_rm
import parses.parses_train_et as parses_et
from train.trainer import name2trainer
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
'--component',
default='GF',
type=str,
help='GF/RD/RM/ET for indicating which kind of trainset are generated')
args = parser.parse_args()
if args.component == 'GF':
cfg,_ = parses_gf.get_config()
generator = name2trainer['trainer_gf'](cfg)
generator.run()
elif args.component == 'RD':
cfg,_ = parses_rd.get_config()
generator = name2trainer['trainer_rd'](cfg)
generator.run()
elif args.component == 'RM':
cfg,_ = parses_rm.get_config()
if (not cfg.batch_size == 1) or (not cfg.batch_size_val == 1):
print('The batch size for matcher training/validation should be 1. We will use batch_size = 1 and batch_size_val = 1 in the following.')
cfg.batch_size = 1
cfg.batch_size_val = 1
generator = name2trainer['trainer_rm'](cfg)
generator.run()
elif args.component == 'ET':
cfg,_ = parses_et.get_config()
generator = name2trainer['trainer_et'](cfg)
generator.run()
else:
print('wrong sign, choose one from GF/RD/RM/ET')