-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathmain.py
71 lines (54 loc) · 2.69 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import os
import sys
# os.environ["OMP_NUM_THREADS"] = "1" # export OMP_NUM_THREADS=4
# os.environ["OPENBLAS_NUM_THREADS"] = "1" # export OPENBLAS_NUM_THREADS=4
# os.environ["MKL_NUM_THREADS"] = "1" # export MKL_NUM_THREADS=6
# os.environ["VECLIB_MAXIMUM_THREADS"] = "1" # export VECLIB_MAXIMUM_THREADS=4
# os.environ["NUMEXPR_NUM_THREADS"] = "1" # export NUMEXPR_NUM_THREADS=6
import warnings
warnings.filterwarnings('ignore')
import json
import yaml
import argparse
from src.train import *
from src.dataset import *
from src.utils import *
from config import *
def main(cfg):
seed_init(seed=cfg.seed)
if args.action == 'train':
print('--- Train Phase ---')
train_dataset = TrainDataset(cfg, 'train')
train_loader = DataLoader(train_dataset, batch_size=cfg.train.batch_size, shuffle=True, num_workers=cfg.num_worker)
val_dataset = TrainDataset(cfg, 'valid')
val_loader = DataLoader(val_dataset, batch_size=cfg.train.batch_size, shuffle=False, num_workers=cfg.num_worker)
data_loader = {'train':train_loader, 'valid':val_loader}
trainer = Trainer(data_loader, cfg)
trainer.train()
print('--- Test Phase ---')
seed_init(seed=cfg.seed)
tester = Tester(cfg)
tester.test(set_type='test')
if cfg.logging:
neptune.stop()
else:
print('--- Test Phase ---')
tester = Tester(cfg)
tester.test(set_type='test')
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('action', type=str, default='train', help='Action') # train / test
parser.add_argument('--config', default='./config/base.yaml', help='config yaml file')
parser.add_argument('--num_worker', type=int, default=0, help='Num workers')
parser.add_argument('--seed', type=int, default=1234, help='seed number')
parser.add_argument('--device', type=str, default='cuda:0', help='Cuda device')
parser.add_argument('--logging', type=bool, default=False, help='Logging option')
parser.add_argument('--resume', type=bool, default=False, help='Resume option')
parser.add_argument('--checkpoint', type=str, default='./checkpoints', help='Results save path')
parser.add_argument('--model_name', type=str, default='model-best.pth', help='Best model name')
parser.add_argument('--n_uttr', type=int, default=1, help='Number of target utterances') # default:1 for a fair comparison
args = parser.parse_args()
cfg = Config(args.config)
cfg = set_experiment(args, cfg) # merge arg and cfg, make directories
print(cfg)
main(cfg)