-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
129 lines (110 loc) · 4.38 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import os
import argparse
import torch
import pickle
import torch.utils.data as utils
import torch.optim as optim
import time
import numpy as np
from graph import Graph
from model import GNet
from pool import FeaturePooling
from metrics import loss_function
from data import CustomDatasetFolder
# Args
parser = argparse.ArgumentParser(description='Pixel2Mesh training script')
parser.add_argument('--data', type=str, default=None, metavar='D',
help="folder where data is located.")
parser.add_argument('--epochs', type=int, default=100, metavar='E',
help='number of epochs to train (default: 10)')
parser.add_argument('--lr', type=float, default=3e-5, metavar='LR',
help='learning rate (default: 3e-5)')
parser.add_argument('--log_step', type=int, default=100, metavar='LS',
help='how many batches to wait before logging training status (default: 100)')
parser.add_argument('--saving_step', type=int, default=1000, metavar='S',
help='how many batches to wait before saving model (default: 1000)')
parser.add_argument('--experiment', type=str, default='./model/', metavar='E',
help='folder where model and optimizer are saved.')
parser.add_argument('--load_model', type=str, default=None, metavar='M',
help='model file to load to continue training.')
parser.add_argument('--load_optimizer', type=str, default=None, metavar='O',
help='model file to load to continue training.')
args = parser.parse_args()
# Cuda
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Model
nIms = 5
if args.load_model is not None: # Continue training
state_dict = torch.load(args.load_model, map_location=device)
model_gcn = GNet(nIms)
model_gcn.load_state_dict(state_dict)
else:
model_gcn = GNet(nIms)
# Optimizer
if args.load_optimizer is not None:
state_dict_opt = torch.load(args.load_optimizer, map_location=device)
optimizer = optim.Adam(model_gcn.parameters(), lr=args.lr)
optimizer.load_state_dict(state_dict_opt)
else:
optimizer = optim.Adam(model_gcn.parameters(), lr=args.lr)
model_gcn.train()
# Graph
graph = Graph("./ellipsoid/init_info.pickle")
# Data Loader
folder = CustomDatasetFolder(args.data, extensions = ["dat"])
train_loader = torch.utils.data.DataLoader(folder, batch_size=1, shuffle=True)
# Param
nb_epochs = args.epochs
log_step = args.log_step
saving_step = args.saving_step
curr_loss = 0
# To GPU
if use_cuda:
print('Using GPU', flush=True)
model_gcn.cuda()
for state in optimizer.state.values():
for k, v in state.items():
if torch.is_tensor(v):
state[k] = v.cuda()
else:
print('Using CPU', flush=True)
print("nb trainable param", model_gcn.get_nb_trainable_params(), flush=True)
# Train
for epoch in range(1, nb_epochs+1):
for n, data in enumerate(train_loader):
ims, gt_points_list, gt_normals_list = data
ims = np.transpose(ims, (1, 0, 2, 3, 4))
gt_points_list = np.transpose(gt_points_list, (1, 0, 2, 3))
gt_normals_list = np.transpose(gt_normals_list, (1, 0, 2, 3))
if use_cuda:
ims = ims.cuda()
gt_points_list = gt_points_list.cuda()
gt_normals_list = gt_normals_list.cuda()
# Forward
graph.reset()
optimizer.zero_grad()
pools = []
for i in range(5):
pools.append(FeaturePooling(ims[i]))
pred_points = model_gcn(graph, pools)
# Loss
loss = loss_function(pred_points, gt_points_list[0].squeeze(),
gt_normals_list[0].squeeze(), graph)
# Backward
loss.backward()
optimizer.step()
curr_loss += loss
# Log
if (n+1)%log_step == 0:
print("Epoch", epoch, flush=True)
print("Batch", n+1, flush=True)
print(" Loss:", curr_loss.data.item()/log_step, flush=True)
curr_loss = 0
# Save
if (n+1)%saving_step == 0:
model_file = args.experiment + "model_" + str(n+1) + ".pth"
optimizer_file = args.experiment + "optimizer_" + str(n+1) + ".pth"
torch.save(model_gcn.state_dict(), model_file)
torch.save(optimizer.state_dict(), optimizer_file)
print("Saved model to " + model_file, flush=True)