-
Notifications
You must be signed in to change notification settings - Fork 2
/
main.py
61 lines (48 loc) · 2.12 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import time
import torch
import torch.optim as optim
from model import Model
from train import train_epoch
from options import get_options
from util import *
def run(opts):
# Set the random seed
torch.manual_seed(opts.seed)
random.seed(opts.seed)
# Set the device
opts.device = torch.device(f'cuda:{opts.gpu_id}' if opts.use_cuda else 'cpu')
# Load and prepare data
train_graphs = load_graphs(dirname=opts.train_dsp_dataset_dir)
valid_graphs = load_graphs(dirname=opts.valid_dsp_dataset_dir)
resources = load_resources(opts.communicate_costs, dirname=opts.res_dataset_dir)
train_data = build_samples(train_graphs, resources, opts)
valid_data = build_samples(valid_graphs, resources, opts)
# train_data, valid_data = data_split(total_data, opts.train_ratio, shuffle=True)
build_feature(train_data, is_train=True)
build_feature(valid_data, is_train=False)
train_data = data_augment(train_data, opts.train_batch_size)
# Initialize model
model = Model(opts.op_dim,
opts.slot_dim,
opts.edge_dim,
opts.embed_dim,
opts.dsp_conv_iter,
opts.res_conv_iter,
opts.dsp_gcn_aggr,
opts.res_gcn_aggr,
opts.gcn_act,
opts.rnn_type,
opts.tanh_clip).to(opts.device)
optimizer = optim.Adam(model.parameters(), lr=opts.lr)
lr_scheduler = optim.lr_scheduler.LambdaLR(optimizer, lambda epoch: opts.lr_decay ** epoch)
if opts.save_model:
model_dir = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
os.mkdir(os.path.join(opts.model_dir, model_dir))
best_avg_reward = -1
for epoch in range(1, opts.epochs + 1):
valid_avg_reward = train_epoch(train_data, valid_data, model, optimizer, lr_scheduler, epoch, opts)
if opts.save_model and epoch > opts.save_model_epoch_threshold and valid_avg_reward > best_avg_reward:
best_avg_reward = valid_avg_reward
torch.save(model, f'model/{model_dir}/best_model.pt')
if __name__ == '__main__':
run(get_options())