-
Notifications
You must be signed in to change notification settings - Fork 34
/
train.py
74 lines (59 loc) · 2.26 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import torch
import visdom
from tqdm import trange
from torch.autograd import Variable
class Train(object):
def __init__(self, model, data_loader, optimizer, criterion, lr, wd, batch_size, vis):
super(Train, self).__init__()
self.model = model
self.data_loader = data_loader
self.optimizer = optimizer
self.criterion = criterion
self.lr = lr
self.wd = wd
self.bs = batch_size
self.vis = None
if vis:
self.vis = visdom.Visdom()
self.loss_window = self.vis.line(X=torch.zeros((1,)).cpu(),
Y=torch.zeros((1)).cpu(),
opts=dict(xlabel='minibatches',
ylabel='Loss',
title='Training Loss',
legend=['Loss']))
self.iterations = 0
def forward(self):
self.model.train()
# TODO adjust learning rate
total_loss = 0
pbar = trange(len(self.data_loader.dataset), desc='Training ')
for batch_idx, (x, yt) in enumerate(self.data_loader):
x = x.cuda(async=True)
yt = yt.cuda(async=True)
input_var = Variable(x)
target_var = Variable(yt)
# compute output
y = self.model(input_var)
loss = self.criterion(y, target_var)
# measure accuracy and record loss
total_loss += loss.item()
# compute gradient and do SGD step
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
if batch_idx % 10 == 0:
# Update tqdm bar
if (batch_idx*self.bs + 10*len(x)) <= len(self.data_loader.dataset):
pbar.update(10 * len(x))
else:
pbar.update(len(self.data_loader.dataset) - int(batch_idx*self.bs))
# Display plot using visdom
if self.vis:
self.vis.line(
X=torch.ones((1)).cpu() * self.iterations,
Y=loss.data.cpu(),
win=self.loss_window,
update='append')
self.iterations += 1
pbar.close()
return total_loss*self.bs/len(self.data_loader.dataset)