-
Notifications
You must be signed in to change notification settings - Fork 0
/
resnet.py
359 lines (292 loc) · 14.1 KB
/
resnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
"""
Build the original ResNet with cifar10 data
By JasonLJX
"""
import os
import re
import sys
import tarfile
from six.moves import urllib
import tensorflow as tf
import cifar10_input
FLAGS = tf.app.flags.FLAGS
# Basic model parameters.
tf.app.flags.DEFINE_integer('batch_size', 128,
"""Number of images to process in a batch.""")
tf.app.flags.DEFINE_string('data_dir', './cifar10_data',
"""Path to the CIFAR-10 data directory.""")
tf.app.flags.DEFINE_boolean('use_fp16', False,
"""Train the model using fp16.""")
# Global constants describing the CIFAR-10 data set.
IMAGE_SIZE = cifar10_input.IMAGE_SIZE
NUM_CLASSES = cifar10_input.NUM_CLASSES
NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = cifar10_input.NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN
NUM_EXAMPLES_PER_EPOCH_FOR_EVAL = cifar10_input.NUM_EXAMPLES_PER_EPOCH_FOR_EVAL
# Constants describing the training process.
MOVING_AVERAGE_DECAY = 0.9999 # The decay to use for the moving average.
NUM_EPOCHS_PER_DECAY = 350.0 # Epochs after which learning rate decays.
LEARNING_RATE_DECAY_FACTOR = 0.01 # Learning rate decay factor.
INITIAL_LEARNING_RATE = 0.05 # Initial learning rate.
BN_EPSILON = 0.001 # Batch normalization rate
# If a model is trained with multiple GPUs, prefix all Op names with tower_name
# to differentiate the operations. Note that this prefix is removed from the
# names of the summaries when visualizing a model.
TOWER_NAME = 'tower'
DATA_URL = 'http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz'
# We have two inputs: inputs and distorted_inputs
# Distorted_inputs gives us the cutted and rotated picture
def distorted_inputs():
"""Construct distorted input for CIFAR training using the Reader ops. Used
for training files
Returns:
images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
labels: Labels. 1D tensor of [batch_size] size.
Raises:
ValueError: If no data_dir
"""
if not FLAGS.data_dir:
raise ValueError('Please supply a data_dir')
data_dir = os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin')
images, labels = cifar10_input.distorted_inputs(data_dir=data_dir,
batch_size=FLAGS.batch_size)
return images, labels
def inputs(eval_data):
"""Construct input for CIFAR evaluation using the Reader ops. This is usesd
for creating testing files
Args:
eval_data: bool, indicating if one should use the train or eval data set.
Returns:
images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
labels: Labels. 1D tensor of [batch_size] size.
Raises:
ValueError: If no data_dir
"""
if not FLAGS.data_dir:
raise ValueError('Please supply a data_dir')
data_dir = os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin')
images, labels = cifar10_input.inputs(eval_data=eval_data,
data_dir=data_dir,
batch_size=FLAGS.batch_size)
return images, labels
def _activation_summary(x):
"""Helper to create summaries for activations.
Creates a summary that provides a histogram of activations.
Creates a summary that measures the sparsity of activations.
Args:
x: Tensor
Returns:
nothing
"""
# Remove 'tower_[0-9]/' from the name in case this is a multi-GPU training
# session. This helps the clarity of presentation on tensorboard.
tensor_name = re.sub('%s_[0-9]*/' % TOWER_NAME, '', x.op.name)
tf.summary.histogram(tensor_name + '/activations', x)
tf.summary.scalar(tensor_name + '/sparsity',
tf.nn.zero_fraction(x))
def create_variables(name, shape, initializer=tf.contrib.layers.xavier_initializer()):
'''
:param name: A string. The name of the new variable
:param shape: A list of dimensions
:param initializer: User Xavier as default.
:return: The created variable
'''
new_variables = tf.get_variable(name, shape=shape, initializer=initializer)
return new_variables
def batch_normalization_layer(input_layer, dimension):
'''
Helper function to do batch normalziation
:param input_layer: 4D tensor
:param dimension: input_layer.get_shape().as_list()[-1]. The depth of the 4D tensor
:return: the 4D tensor after being normalized
'''
mean, variance = tf.nn.moments(input_layer, axes=[0, 1, 2])
beta = tf.get_variable('beta', dimension, tf.float32,
initializer=tf.constant_initializer(0.0, tf.float32))
gamma = tf.get_variable('gamma', dimension, tf.float32,
initializer=tf.constant_initializer(1.0, tf.float32))
bn_layer = tf.nn.batch_normalization(input_layer, mean, variance, beta, gamma, BN_EPSILON)
return bn_layer
def conv_bn_relu_layer(input_layer, filter_shape, stride):
'''
A helper function to conv, batch normalize and relu the input tensor sequentially
:param input_layer: 4D tensor
:param filter_shape: list. [filter_height, filter_width, filter_depth, filter_number]
:param stride: stride size for conv
:return: 4D tensor. Y = Relu(batch_normalize(conv(X)))
'''
out_channel = filter_shape[-1]
filter = create_variables(name='conv', shape=filter_shape)
conv_layer = tf.nn.conv2d(input_layer, filter, strides=[1, stride, stride, 1], padding='SAME')
bn_layer = batch_normalization_layer(conv_layer, out_channel)
output = tf.nn.relu(bn_layer)
return output
def bn_relu_conv_layer(input_layer, filter_shape, stride):
'''
A helper function to batch normalize, relu and conv the input layer sequentially
:param input_layer: 4D tensor
:param filter_shape: list. [filter_height, filter_width, filter_depth, filter_number]
:param stride: stride size for conv
:return: 4D tensor. Y = conv(Relu(batch_normalize(X)))
'''
in_channel = input_layer.get_shape().as_list()[-1]
bn_layer = batch_normalization_layer(input_layer, in_channel)
relu_layer = tf.nn.relu(bn_layer)
filter = create_variables(name='conv', shape=filter_shape)
conv_layer = tf.nn.conv2d(relu_layer, filter, strides=[1, stride, stride, 1], padding='SAME')
return conv_layer
def output_layer(input_layer, num_labels):
'''
:param input_layer: 2D tensor
:param num_labels: int. How many output labels in total? (10 for cifar10 and 100 for cifar100)
:return: output layer Y = WX + B
'''
input_dim = input_layer.get_shape().as_list()[-1]
fc_w = create_variables(name='fc_weights', shape=[input_dim, num_labels],
initializer=tf.uniform_unit_scaling_initializer(factor=1.0))
fc_b = create_variables(name='fc_bias', shape=[num_labels], initializer=tf.zeros_initializer())
fc_h = tf.matmul(input_layer, fc_w) + fc_b
return fc_h
def residual_block(input_layer, output_channel, first_block=False):
'''
Defines a residual block in ResNet
:param input_layer: 4D tensor
:param output_channel: int. return_tensor.get_shape().as_list()[-1] = output_channel
:param first_block: if this is the first residual block of the whole network
:return: 4D tensor.
'''
input_channel = input_layer.get_shape().as_list()[-1]
# When it's time to "shrink" the image size, we use stride = 2
if input_channel * 2 == output_channel:
increase_dim = True
stride = 2
elif input_channel == output_channel:
increase_dim = False
stride = 1
else:
raise ValueError('Output and input channel does not match in residual blocks!!!')
# The first conv layer of the first residual block does not need to be normalized and relu-ed.
with tf.variable_scope('conv1_in_block'):
if first_block:
filter = create_variables(name='conv', shape=[3, 3, input_channel, output_channel])
conv1 = tf.nn.conv2d(input_layer, filter=filter, strides=[1, 1, 1, 1], padding='SAME')
else:
conv1 = bn_relu_conv_layer(input_layer, [3, 3, input_channel, output_channel], stride)
with tf.variable_scope('conv2_in_block'):
conv2 = bn_relu_conv_layer(conv1, [3, 3, output_channel, output_channel], 1)
# When the channels of input layer and conv2 does not match, we add zero pads to increase the
# depth of input layers
if increase_dim is True:
pooled_input = tf.nn.avg_pool(input_layer, ksize=[1, 2, 2, 1],
strides=[1, 2, 2, 1], padding='VALID')
padded_input = tf.pad(pooled_input, [[0, 0], [0, 0], [0, 0], [input_channel // 2,
input_channel // 2]])
else:
padded_input = input_layer
output = conv2 + padded_input
return output
def inference(input_tensor_batch, n, reuse=False):
'''
The main function that defines the ResNet. total layers = 1 + 2n + 2n + 2n +1 = 6n + 2
:param input_tensor_batch: 4D tensor
:param n: num_residual_blocks
:param reuse: To build train graph, reuse=False. To build validation graph and share weights
with train graph, resue=True
:return: last layer in the network. Not softmax-ed
'''
layers = []
with tf.variable_scope('conv0', reuse=reuse):
conv0 = conv_bn_relu_layer(input_tensor_batch, [3, 3, 3, 16], 1)
_activation_summary(conv0)
layers.append(conv0)
for i in range(n):
with tf.variable_scope('conv1_%d' %i, reuse=reuse):
if i == 0:
conv1 = residual_block(layers[-1], 16, first_block=True)
else:
conv1 = residual_block(layers[-1], 16)
_activation_summary(conv1)
layers.append(conv1)
for i in range(n):
with tf.variable_scope('conv2_%d' %i, reuse=reuse):
conv2 = residual_block(layers[-1], 32)
_activation_summary(conv2)
layers.append(conv2)
for i in range(n):
with tf.variable_scope('conv3_%d' %i, reuse=reuse):
conv3 = residual_block(layers[-1], 64)
layers.append(conv3)
assert conv3.get_shape().as_list()[1:] == [8, 8, 64]
with tf.variable_scope('fc', reuse=reuse):
in_channel = layers[-1].get_shape().as_list()[-1]
bn_layer = batch_normalization_layer(layers[-1], in_channel)
relu_layer = tf.nn.relu(bn_layer)
global_pool = tf.reduce_mean(relu_layer, [1, 2])
assert global_pool.get_shape().as_list()[-1:] == [64]
output = output_layer(global_pool, 10)
layers.append(output)
return output
def loss(logits, labels):
"""
Add L2loss to all the trainable variables
Add summary for "loss" and "loss/avg"
:param logits: logits from inference()
:param labels: labels from distorted_inputs or inputs() 1-D tensor of shape[batch_size]
:return: loss tensor of type float
"""
# Calculate the average cross entropy loss across the batch.
labels = tf.cast(labels, tf.int64)
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=labels, logits=logits, name='cross_entropy_per_example')
cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy')
tf.add_to_collection('losses', cross_entropy_mean)
# The total loss is defined as the cross entropy loss plus all of the weight
# decay terms (L2 loss).
return tf.add_n(tf.get_collection('losses'), name='total_loss')
def training(total_loss, global_step):
"""Train CIFAR-10 model.
Create an optimizer and apply to all trainable variables. Add moving
average for all trainable variables.
Args:
total_loss: Total loss from loss().
global_step: Integer Variable counting the number of training steps
processed.
Returns:
train_op: op for training.
"""
# Variables that affect learning rate.
num_batches_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN / FLAGS.batch_size
decay_steps = int(num_batches_per_epoch * NUM_EPOCHS_PER_DECAY)
# Decay the learning rate exponentially based on the number of steps.
learning_rate = tf.train.exponential_decay(INITIAL_LEARNING_RATE,
global_step,
decay_steps,
LEARNING_RATE_DECAY_FACTOR,
staircase=True)
tf.summary.scalar('learning_rate', learning_rate)
# Add a scalar summary for the snapshot loss.
tf.summary.scalar('loss', total_loss)
# Create the gradient descent optimizer with the given learning rate.
optimizer = tf.train.GradientDescentOptimizer(learning_rate)
# Use the optimizer to apply the gradients that minimize the loss
# (and also increment the global step counter) as a single training step.
train_op = optimizer.minimize(total_loss, global_step=global_step)
return train_op
def maybe_download_and_extract():
"""Download and extract the tarball from Alex's website."""
dest_directory = FLAGS.data_dir
if not os.path.exists(dest_directory):
os.makedirs(dest_directory)
filename = DATA_URL.split('/')[-1]
filepath = os.path.join(dest_directory, filename)
if not os.path.exists(filepath):
def _progress(count, block_size, total_size):
sys.stdout.write('\r>> Downloading %s %.1f%%' % (filename,
float(count * block_size) / float(total_size) * 100.0))
sys.stdout.flush()
filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress)
print()
statinfo = os.stat(filepath)
print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
extracted_dir_path = os.path.join(dest_directory, 'cifar-10-batches-bin')
if not os.path.exists(extracted_dir_path):
tarfile.open(filepath, 'r:gz').extractall(dest_directory)