forked from microsoft/StemGNN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
72 lines (67 loc) · 3.31 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import os
import torch
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
from datetime import datetime
from models.handler import train, test
import argparse
import pandas as pd
parser = argparse.ArgumentParser()
parser.add_argument('--train', type=bool, default=True)
parser.add_argument('--evaluate', type=bool, default=True)
# parser.add_argument('--dataset', type=str, default='PeMS07')
# parser.add_argument('--dataset', type=str, default='ECG_data')
parser.add_argument('--dataset', type=str, default='mar2020')
parser.add_argument('--window_size', type=int, default=10)
# parser.add_argument('--window_size', type=int, default=12)
parser.add_argument('--horizon', type=int, default=3)
parser.add_argument('--train_length', type=float, default=7)
parser.add_argument('--valid_length', type=float, default=2)
parser.add_argument('--test_length', type=float, default=1)
parser.add_argument('--epoch', type=int, default=50)
parser.add_argument('--lr', type=float, default=1e-4)
parser.add_argument('--multi_layer', type=int, default=5)
parser.add_argument('--device', type=str, default='cpu')
parser.add_argument('--validate_freq', type=int, default=1)
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--norm_method', type=str, default='z_score')
parser.add_argument('--optimizer', type=str, default='RMSProp')
parser.add_argument('--early_stop', type=bool, default=False)
parser.add_argument('--exponential_decay_step', type=int, default=5)
parser.add_argument('--decay_rate', type=float, default=0.5)
parser.add_argument('--dropout_rate', type=float, default=0.5)
parser.add_argument('--leakyrelu_rate', type=int, default=0.2)
args = parser.parse_args()
print(f'Training configs: {args}')
data_file = os.path.join('dataset', args.dataset + '.csv')
# data_file = os.path.join('ce290_data', args.dataset + '.csv')
result_train_file = os.path.join('output', args.dataset, 'train')
result_test_file = os.path.join('output', args.dataset, 'test')
if not os.path.exists(result_train_file):
os.makedirs(result_train_file)
if not os.path.exists(result_test_file):
os.makedirs(result_test_file)
data = pd.read_csv(data_file).values
# split data
train_ratio = args.train_length / (args.train_length + args.valid_length + args.test_length)
valid_ratio = args.valid_length / (args.train_length + args.valid_length + args.test_length)
test_ratio = 1 - train_ratio - valid_ratio
train_data = data[:int(train_ratio * len(data))]
valid_data = data[int(train_ratio * len(data)):int((train_ratio + valid_ratio) * len(data))]
test_data = data[int((train_ratio + valid_ratio) * len(data)):]
torch.manual_seed(0)
if __name__ == '__main__':
if args.train:
try:
before_train = datetime.now().timestamp()
_, normalize_statistic = train(train_data, valid_data, args, result_train_file)
after_train = datetime.now().timestamp()
print(f'Training took {(after_train - before_train) / 60} minutes')
except KeyboardInterrupt:
print('-' * 99)
print('Exiting from training early')
if args.evaluate:
before_evaluation = datetime.now().timestamp()
test(test_data, args, result_train_file, result_test_file)
after_evaluation = datetime.now().timestamp()
print(f'Evaluation took {(after_evaluation - before_evaluation) / 60} minutes')
print('done')