-
Notifications
You must be signed in to change notification settings - Fork 125
/
train.py
178 lines (149 loc) · 7.2 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import tensorflow as tf
from tqdm import tqdm
import numpy as np
from utils import load_obj
class Train:
"""Trainer class for the CNN.
It's also responsible for loading/saving the model checkpoints from/to experiments/experiment_name/checkpoint_dir"""
def __init__(self, sess, model, data, summarizer):
self.sess = sess
self.model = model
self.args = self.model.args
self.saver = tf.train.Saver(max_to_keep=self.args.max_to_keep,
keep_checkpoint_every_n_hours=10,
save_relative_paths=True)
# Summarizer references
self.data = data
self.summarizer = summarizer
# Initializing the model
self.init = None
self.__init_model()
# Loading the model checkpoint if exists
self.__load_imagenet_weights()
self.__load_model()
############################################################################################################
# Model related methods
def __init_model(self):
print("Initializing the model...")
self.init = tf.group(tf.global_variables_initializer())
self.sess.run(self.init)
print("Model initialized\n\n")
def save_model(self):
"""
Save Model Checkpoint
:return:
"""
print("Saving a checkpoint")
self.saver.save(self.sess, self.args.checkpoint_dir, self.model.global_step_tensor)
print("Checkpoint Saved\n\n")
def __load_model(self):
latest_checkpoint = tf.train.latest_checkpoint(self.args.checkpoint_dir)
if latest_checkpoint:
print("Loading model checkpoint {} ...\n".format(latest_checkpoint))
self.saver.restore(self.sess, latest_checkpoint)
print("Checkpoint loaded\n\n")
else:
print("First time to train!\n\n")
def __load_imagenet_weights(self):
variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
try:
print("Loading ImageNet pretrained weights...")
dict = load_obj(self.args.pretrained_path)
run_list = []
for variable in variables:
for key, value in dict.items():
# Adding ':' means that we are interested in the variable itself and not the variable parameters
# that are used in adaptive optimizers
if key + ":" in variable.name:
run_list.append(tf.assign(variable, value))
self.sess.run(run_list)
print("Weights loaded\n\n")
except KeyboardInterrupt:
print("No pretrained ImageNet weights exist. Skipping...\n\n")
############################################################################################################
# Train and Test methods
def train(self):
for cur_epoch in range(self.model.global_epoch_tensor.eval(self.sess) + 1, self.args.num_epochs + 1, 1):
# Initialize tqdm
num_iterations = self.args.train_data_size // self.args.batch_size
tqdm_batch = tqdm(self.data.generate_batch(type='train'), total=num_iterations,
desc="Epoch-" + str(cur_epoch) + "-")
# Initialize the current iterations
cur_iteration = 0
# Initialize classification accuracy and loss lists
loss_list = []
acc_list = []
# Loop by the number of iterations
for X_batch, y_batch in tqdm_batch:
# Get the current iteration for summarizing it
cur_step = self.model.global_step_tensor.eval(self.sess)
# Feed this variables to the network
feed_dict = {self.model.X: X_batch,
self.model.y: y_batch,
self.model.is_training: True
}
# Run the feed_forward
_, loss, acc, summaries_merged = self.sess.run(
[self.model.train_op, self.model.loss, self.model.accuracy, self.model.summaries_merged],
feed_dict=feed_dict)
# Append loss and accuracy
loss_list += [loss]
acc_list += [acc]
# Update the Global step
self.model.global_step_assign_op.eval(session=self.sess,
feed_dict={self.model.global_step_input: cur_step + 1})
self.summarizer.add_summary(cur_step, summaries_merged=summaries_merged)
if cur_iteration >= num_iterations - 1:
avg_loss = np.mean(loss_list)
avg_acc = np.mean(acc_list)
# summarize
summaries_dict = dict()
summaries_dict['loss'] = avg_loss
summaries_dict['acc'] = avg_acc
# summarize
self.summarizer.add_summary(cur_step, summaries_dict=summaries_dict)
# Update the Current Epoch tensor
self.model.global_epoch_assign_op.eval(session=self.sess,
feed_dict={self.model.global_epoch_input: cur_epoch + 1})
# Print in console
tqdm_batch.close()
print("Epoch-" + str(cur_epoch) + " | " + "loss: " + str(avg_loss) + " -" + " acc: " + str(
avg_acc)[
:7])
# Break the loop to finalize this epoch
break
# Update the current iteration
cur_iteration += 1
# Save the current checkpoint
if cur_epoch % self.args.save_model_every == 0 and cur_epoch != 0:
self.save_model()
# Test the model on validation or test data
if cur_epoch % self.args.test_every == 0:
self.test('val')
def test(self, test_type='val'):
num_iterations = self.args.test_data_size // self.args.batch_size
tqdm_batch = tqdm(self.data.generate_batch(type=test_type), total=num_iterations,
desc='Testing')
# Initialize classification accuracy and loss lists
loss_list = []
acc_list = []
cur_iteration = 0
for X_batch, y_batch in tqdm_batch:
# Feed this variables to the network
feed_dict = {self.model.X: X_batch,
self.model.y: y_batch,
self.model.is_training: False
}
# Run the feed_forward
loss, acc = self.sess.run(
[self.model.loss, self.model.accuracy],
feed_dict=feed_dict)
# Append loss and accuracy
loss_list += [loss]
acc_list += [acc]
if cur_iteration >= num_iterations - 1:
avg_loss = np.mean(loss_list)
avg_acc = np.mean(acc_list)
print('Test results | test_loss: ' + str(avg_loss) + ' - test_acc: ' + str(avg_acc)[:7])
break
cur_iteration += 1