-
-
Notifications
You must be signed in to change notification settings - Fork 8
/
train_stage1.py
113 lines (88 loc) · 4.43 KB
/
train_stage1.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
######################################
# Kaihua Tang
######################################
import torch
import torch.nn as nn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
import torch.nn.functional as F
import utils.general_utils as utils
from data.dataloader import get_loader
from utils.checkpoint_utils import Checkpoint
from utils.training_utils import *
from utils.test_loader import test_loader
class train_stage1():
def __init__(self, args, config, logger, eval=False):
# ============================================================================
# create model
logger.info('=====> Model construction from: ' + str(config['networks']['type']))
model_type = config['networks']['type']
model_file = config['networks'][model_type]['def_file']
model_args = config['networks'][model_type]['params']
logger.info('=====> Classifier construction from: ' + str(config['classifiers']['type']))
classifier_type = config['classifiers']['type']
classifier_file = config['classifiers'][classifier_type]['def_file']
classifier_args = config['classifiers'][classifier_type]['params']
model = utils.source_import(model_file).create_model(**model_args)
classifier = utils.source_import(classifier_file).create_model(**classifier_args)
model = nn.DataParallel(model).cuda()
classifier = nn.DataParallel(classifier).cuda()
# other initialization
self.config = config
self.logger = logger
self.model = model
self.classifier = classifier
self.optimizer = create_optimizer(model, classifier, logger, config)
self.scheduler = create_scheduler(self.optimizer, logger, config)
self.eval = eval
self.training_opt = config['training_opt']
self.checkpoint = Checkpoint(config)
# get dataloader
self.logger.info('=====> Get train dataloader')
self.train_loader = get_loader(config, 'train', config['dataset']['testset'], logger)
# get loss
self.loss_fc = create_loss(logger, config, self.train_loader)
# set eval
if self.eval:
test_func = test_loader(config)
self.testing = test_func(config, logger, model, classifier, val=True)
def run(self):
# Start Training
self.logger.info('=====> Start Stage 1 Training')
# run epoch
for epoch in range(self.training_opt['num_epochs']):
self.logger.info('------------ Start Epoch {} -----------'.format(epoch))
# preprocess for each epoch
total_batch = len(self.train_loader)
for step, (inputs, labels, _, _) in enumerate(self.train_loader):
self.optimizer.zero_grad()
# additional inputs
inputs, labels = inputs.cuda(), labels.cuda()
add_inputs = {}
features = self.model(inputs)
predictions = self.classifier(features, add_inputs)
# calculate loss
loss = self.loss_fc(predictions, labels)
iter_info_print = {self.training_opt['loss'] : loss.sum().item(),}
loss.backward()
self.optimizer.step()
# calculate accuracy
accuracy = (predictions.max(1)[1] == labels).sum().float() / predictions.shape[0]
# log information
iter_info_print.update({'Accuracy' : accuracy.item(), 'Loss' : loss.sum().item(), 'Poke LR' : float(self.optimizer.param_groups[0]['lr'])})
self.logger.info_iter(epoch, step, total_batch, iter_info_print, self.config['logger_opt']['print_iter'])
first_batch = (epoch == 0) and (step == 0)
if first_batch or self.config['logger_opt']['print_grad'] and step % 1000 == 0:
utils.print_grad(self.classifier.named_parameters())
utils.print_grad(self.model.named_parameters())
# evaluation on validation set
if self.eval:
val_acc = self.testing.run_val(epoch)
else:
val_acc = 0.0
# checkpoint
self.checkpoint.save(self.model, self.classifier, epoch, self.logger, acc=val_acc)
# update scheduler
self.scheduler.step()
# save best model path
self.checkpoint.save_best_model_path(self.logger)