-
Notifications
You must be signed in to change notification settings - Fork 1
/
run.py
65 lines (55 loc) · 2.6 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
# -*- coding: utf-8 -*-
import argparse
import os
from parser.cmds import Train
from parser.config import Config
import torch
import random
import numpy as np
import logging
import warnings
warnings.filterwarnings("ignore")
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Create the Biaffine Parser model.'
)
subparsers = parser.add_subparsers(title='Commands', dest='mode')
subcommands = {
'train': Train()
}
for name, subcommand in subcommands.items():
subparser = subcommand.add_subparser(name, subparsers)
subparser.add_argument('--conf', '-c', default='config.ini',
help='path to config file')
subparser.add_argument('--file', '-f', default='exp/ptb',
help='path to saved files')
subparser.add_argument('--preprocess', '-p', action='store_true',
help='whether to preprocess the data first')
subparser.add_argument('--device', '-d', default='-1',
help='ID of GPU to use')
subparser.add_argument('--seed', '-s', default=1, type=int,
help='seed for generating random numbers')
subparser.add_argument('--threads', '-t', default=16, type=int,
help='max num of threads')
subparser.add_argument('--tree', action='store_true',
help='whether to ensure well-formedness')
subparser.add_argument('--feat', default='tag',
choices=['tag', 'char', 'bert'],
help='choices of additional features')
args = parser.parse_args()
logging.basicConfig(filename=args.output, filemode='w', format='%(asctime)s %(levelname)-8s %(message)s', level=logging.INFO, datefmt='%Y-%m-%d %H:%M:%S')
logging.info(f"Set the max num of threads to {args.threads}")
logging.info(f"Set the seed for generating random numbers to {args.seed}")
logging.info(f"Set the device with ID {args.device} visible")
torch.set_num_threads(args.threads)
torch.manual_seed(args.seed)
os.environ['CUDA_VISIBLE_DEVICES'] = args.device
args.fields = os.path.join(args.file, 'fields')
args.model = os.path.join(args.file, 'model')
args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
logging.info(f"Override the default configs with parsed arguments")
args = Config(args.conf).update(vars(args))
logging.info(args)
logging.info(f"Run the subcommand in mode {args.mode}")
cmd = subcommands[args.mode]
cmd(args)