-
Notifications
You must be signed in to change notification settings - Fork 212
/
utils.py
105 lines (87 loc) · 3.09 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import json
from datetime import datetime
from pathlib import Path
import random
import numpy as np
import torch
import tqdm
def cuda(x):
return x.cuda(async=True) if torch.cuda.is_available() else x
def write_event(log, step, **data):
data['step'] = step
data['dt'] = datetime.now().isoformat()
log.write(json.dumps(data, sort_keys=True))
log.write('\n')
log.flush()
def check_crop_size(image_height, image_width):
"""Checks if image size divisible by 32.
Args:
image_height:
image_width:
Returns:
True if both height and width divisible by 32 and False otherwise.
"""
return image_height % 32 == 0 and image_width % 32 == 0
def train(args, model, criterion, train_loader, valid_loader, validation, init_optimizer, n_epochs=None, fold=None,
num_classes=None):
lr = args.lr
n_epochs = n_epochs or args.n_epochs
optimizer = init_optimizer(lr)
root = Path(args.root)
model_path = root / 'model_{fold}.pt'.format(fold=fold)
if model_path.exists():
state = torch.load(str(model_path))
epoch = state['epoch']
step = state['step']
model.load_state_dict(state['model'])
print('Restored model, epoch {}, step {:,}'.format(epoch, step))
else:
epoch = 1
step = 0
save = lambda ep: torch.save({
'model': model.state_dict(),
'epoch': ep,
'step': step,
}, str(model_path))
report_each = 10
log = root.joinpath('train_{fold}.log'.format(fold=fold)).open('at', encoding='utf8')
valid_losses = []
for epoch in range(epoch, n_epochs + 1):
model.train()
random.seed()
tq = tqdm.tqdm(total=(len(train_loader) * args.batch_size))
tq.set_description('Epoch {}, lr {}'.format(epoch, lr))
losses = []
tl = train_loader
try:
mean_loss = 0
for i, (inputs, targets) in enumerate(tl):
inputs = cuda(inputs)
with torch.no_grad():
targets = cuda(targets)
outputs = model(inputs)
loss = criterion(outputs, targets)
optimizer.zero_grad()
batch_size = inputs.size(0)
loss.backward()
optimizer.step()
step += 1
tq.update(batch_size)
losses.append(loss.item())
mean_loss = np.mean(losses[-report_each:])
tq.set_postfix(loss='{:.5f}'.format(mean_loss))
if i and i % report_each == 0:
write_event(log, step, loss=mean_loss)
write_event(log, step, loss=mean_loss)
tq.close()
save(epoch + 1)
valid_metrics = validation(model, criterion, valid_loader, num_classes)
write_event(log, step, **valid_metrics)
valid_loss = valid_metrics['valid_loss']
valid_losses.append(valid_loss)
except KeyboardInterrupt:
tq.close()
print('Ctrl+C, saving snapshot')
save(epoch)
print('done.')
return