-
Notifications
You must be signed in to change notification settings - Fork 125
/
layers.py
403 lines (349 loc) · 19.6 KB
/
layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
import tensorflow as tf
import numpy as np
############################################################################################################
# Convolution layer Methods
def __conv2d_p(name, x, w=None, num_filters=16, kernel_size=(3, 3), padding='SAME', stride=(1, 1),
initializer=tf.contrib.layers.xavier_initializer(), l2_strength=0.0, bias=0.0):
"""
Convolution 2D Wrapper
:param name: (string) The name scope provided by the upper tf.name_scope('name') as scope.
:param x: (tf.tensor) The input to the layer (N, H, W, C).
:param w: (tf.tensor) pretrained weights (if None, it means no pretrained weights)
:param num_filters: (integer) No. of filters (This is the output depth)
:param kernel_size: (integer tuple) The size of the convolving kernel.
:param padding: (string) The amount of padding required.
:param stride: (integer tuple) The stride required.
:param initializer: (tf.contrib initializer) The initialization scheme, He et al. normal or Xavier normal are recommended.
:param l2_strength:(weight decay) (float) L2 regularization parameter.
:param bias: (float) Amount of bias. (if not float, it means pretrained bias)
:return out: The output of the layer. (N, H', W', num_filters)
"""
with tf.variable_scope(name):
stride = [1, stride[0], stride[1], 1]
kernel_shape = [kernel_size[0], kernel_size[1], x.shape[-1], num_filters]
with tf.name_scope('layer_weights'):
if w == None:
w = __variable_with_weight_decay(kernel_shape, initializer, l2_strength)
__variable_summaries(w)
with tf.name_scope('layer_biases'):
if isinstance(bias, float):
bias = tf.get_variable('biases', [num_filters], initializer=tf.constant_initializer(bias))
__variable_summaries(bias)
with tf.name_scope('layer_conv2d'):
conv = tf.nn.conv2d(x, w, stride, padding)
out = tf.nn.bias_add(conv, bias)
return out
def conv2d(name, x, w=None, num_filters=16, kernel_size=(3, 3), padding='SAME', stride=(1, 1),
initializer=tf.contrib.layers.xavier_initializer(), l2_strength=0.0, bias=0.0,
activation=None, batchnorm_enabled=False, max_pool_enabled=False, dropout_keep_prob=-1,
is_training=True):
"""
This block is responsible for a convolution 2D layer followed by optional (non-linearity, dropout, max-pooling).
Note that: "is_training" should be passed by a correct value based on being in either training or testing.
:param name: (string) The name scope provided by the upper tf.name_scope('name') as scope.
:param x: (tf.tensor) The input to the layer (N, H, W, C).
:param num_filters: (integer) No. of filters (This is the output depth)
:param kernel_size: (integer tuple) The size of the convolving kernel.
:param padding: (string) The amount of padding required.
:param stride: (integer tuple) The stride required.
:param initializer: (tf.contrib initializer) The initialization scheme, He et al. normal or Xavier normal are recommended.
:param l2_strength:(weight decay) (float) L2 regularization parameter.
:param bias: (float) Amount of bias.
:param activation: (tf.graph operator) The activation function applied after the convolution operation. If None, linear is applied.
:param batchnorm_enabled: (boolean) for enabling batch normalization.
:param max_pool_enabled: (boolean) for enabling max-pooling 2x2 to decrease width and height by a factor of 2.
:param dropout_keep_prob: (float) for the probability of keeping neurons. If equals -1, it means no dropout
:param is_training: (boolean) to diff. between training and testing (important for batch normalization and dropout)
:return: The output tensor of the layer (N, H', W', C').
"""
with tf.variable_scope(name) as scope:
conv_o_b = __conv2d_p('conv', x=x, w=w, num_filters=num_filters, kernel_size=kernel_size, stride=stride,
padding=padding,
initializer=initializer, l2_strength=l2_strength, bias=bias)
if batchnorm_enabled:
conv_o_bn = tf.layers.batch_normalization(conv_o_b, training=is_training, epsilon=1e-5)
if not activation:
conv_a = conv_o_bn
else:
conv_a = activation(conv_o_bn)
else:
if not activation:
conv_a = conv_o_b
else:
conv_a = activation(conv_o_b)
def dropout_with_keep():
return tf.nn.dropout(conv_a, dropout_keep_prob)
def dropout_no_keep():
return tf.nn.dropout(conv_a, 1.0)
if dropout_keep_prob != -1:
conv_o_dr = tf.cond(is_training, dropout_with_keep, dropout_no_keep)
else:
conv_o_dr = conv_a
conv_o = conv_o_dr
if max_pool_enabled:
conv_o = max_pool_2d(conv_o_dr)
return conv_o
def grouped_conv2d(name, x, w=None, num_filters=16, kernel_size=(3, 3), padding='SAME', stride=(1, 1),
initializer=tf.contrib.layers.xavier_initializer(), num_groups=1, l2_strength=0.0, bias=0.0,
activation=None, batchnorm_enabled=False, dropout_keep_prob=-1,
is_training=True):
with tf.variable_scope(name) as scope:
sz = x.get_shape()[3].value // num_groups
conv_side_layers = [
conv2d(name + "_" + str(i), x[:, :, :, i * sz:i * sz + sz], w, num_filters // num_groups, kernel_size,
padding,
stride,
initializer,
l2_strength, bias, activation=None,
batchnorm_enabled=False, max_pool_enabled=False, dropout_keep_prob=dropout_keep_prob,
is_training=is_training) for i in
range(num_groups)]
conv_g = tf.concat(conv_side_layers, axis=-1)
if batchnorm_enabled:
conv_o_bn = tf.layers.batch_normalization(conv_g, training=is_training, epsilon=1e-5)
if not activation:
conv_a = conv_o_bn
else:
conv_a = activation(conv_o_bn)
else:
if not activation:
conv_a = conv_g
else:
conv_a = activation(conv_g)
return conv_a
def __depthwise_conv2d_p(name, x, w=None, kernel_size=(3, 3), padding='SAME', stride=(1, 1),
initializer=tf.contrib.layers.xavier_initializer(), l2_strength=0.0, bias=0.0):
with tf.variable_scope(name):
stride = [1, stride[0], stride[1], 1]
kernel_shape = [kernel_size[0], kernel_size[1], x.shape[-1], 1]
with tf.name_scope('layer_weights'):
if w is None:
w = __variable_with_weight_decay(kernel_shape, initializer, l2_strength)
__variable_summaries(w)
with tf.name_scope('layer_biases'):
if isinstance(bias, float):
bias = tf.get_variable('biases', [x.shape[-1]], initializer=tf.constant_initializer(bias))
__variable_summaries(bias)
with tf.name_scope('layer_conv2d'):
conv = tf.nn.depthwise_conv2d(x, w, stride, padding)
out = tf.nn.bias_add(conv, bias)
return out
def depthwise_conv2d(name, x, w=None, kernel_size=(3, 3), padding='SAME', stride=(1, 1),
initializer=tf.contrib.layers.xavier_initializer(), l2_strength=0.0, bias=0.0, activation=None,
batchnorm_enabled=False, is_training=True):
with tf.variable_scope(name) as scope:
conv_o_b = __depthwise_conv2d_p(name='conv', x=x, w=w, kernel_size=kernel_size, padding=padding,
stride=stride, initializer=initializer, l2_strength=l2_strength, bias=bias)
if batchnorm_enabled:
conv_o_bn = tf.layers.batch_normalization(conv_o_b, training=is_training, epsilon=1e-5)
if not activation:
conv_a = conv_o_bn
else:
conv_a = activation(conv_o_bn)
else:
if not activation:
conv_a = conv_o_b
else:
conv_a = activation(conv_o_b)
return conv_a
############################################################################################################
# ShuffleNet unit methods
def shufflenet_unit(name, x, w=None, num_groups=1, group_conv_bottleneck=True, num_filters=16, stride=(1, 1),
l2_strength=0.0, bias=0.0, batchnorm_enabled=True, is_training=True, fusion='add'):
# Paper parameters. If you want to change them feel free to pass them as method parameters.
activation = tf.nn.relu
with tf.variable_scope(name) as scope:
residual = x
bottleneck_filters = (num_filters // 4) if fusion == 'add' else (num_filters - residual.get_shape()[
3].value) // 4
if group_conv_bottleneck:
bottleneck = grouped_conv2d('Gbottleneck', x=x, w=None, num_filters=bottleneck_filters, kernel_size=(1, 1),
padding='VALID',
num_groups=num_groups, l2_strength=l2_strength, bias=bias,
activation=activation,
batchnorm_enabled=batchnorm_enabled, is_training=is_training)
shuffled = channel_shuffle('channel_shuffle', bottleneck, num_groups)
else:
bottleneck = conv2d('bottleneck', x=x, w=None, num_filters=bottleneck_filters, kernel_size=(1, 1),
padding='VALID', l2_strength=l2_strength, bias=bias, activation=activation,
batchnorm_enabled=batchnorm_enabled, is_training=is_training)
shuffled = bottleneck
padded = tf.pad(shuffled, [[0, 0], [1, 1], [1, 1], [0, 0]], "CONSTANT")
depthwise = depthwise_conv2d('depthwise', x=padded, w=None, stride=stride, l2_strength=l2_strength,
padding='VALID', bias=bias,
activation=None, batchnorm_enabled=batchnorm_enabled, is_training=is_training)
if stride == (2, 2):
residual_pooled = avg_pool_2d(residual, size=(3, 3), stride=stride, padding='SAME')
else:
residual_pooled = residual
if fusion == 'concat':
group_conv1x1 = grouped_conv2d('Gconv1x1', x=depthwise, w=None,
num_filters=num_filters - residual.get_shape()[3].value,
kernel_size=(1, 1),
padding='VALID',
num_groups=num_groups, l2_strength=l2_strength, bias=bias,
activation=None,
batchnorm_enabled=batchnorm_enabled, is_training=is_training)
return activation(tf.concat([residual_pooled, group_conv1x1], axis=-1))
elif fusion == 'add':
group_conv1x1 = grouped_conv2d('Gconv1x1', x=depthwise, w=None,
num_filters=num_filters,
kernel_size=(1, 1),
padding='VALID',
num_groups=num_groups, l2_strength=l2_strength, bias=bias,
activation=None,
batchnorm_enabled=batchnorm_enabled, is_training=is_training)
residual_match = residual_pooled
# This is used if the number of filters of the residual block is different from that
# of the group convolution.
if num_filters != residual_pooled.get_shape()[3].value:
residual_match = conv2d('residual_match', x=residual_pooled, w=None, num_filters=num_filters,
kernel_size=(1, 1),
padding='VALID', l2_strength=l2_strength, bias=bias, activation=None,
batchnorm_enabled=batchnorm_enabled, is_training=is_training)
return activation(group_conv1x1 + residual_match)
else:
raise ValueError("Specify whether the fusion is \'concat\' or \'add\'")
def channel_shuffle(name, x, num_groups):
with tf.variable_scope(name) as scope:
n, h, w, c = x.shape.as_list()
x_reshaped = tf.reshape(x, [-1, h, w, num_groups, c // num_groups])
x_transposed = tf.transpose(x_reshaped, [0, 1, 2, 4, 3])
output = tf.reshape(x_transposed, [-1, h, w, c])
return output
############################################################################################################
# Fully Connected layer Methods
def __dense_p(name, x, w=None, output_dim=128, initializer=tf.contrib.layers.xavier_initializer(), l2_strength=0.0,
bias=0.0):
"""
Fully connected layer
:param name: (string) The name scope provided by the upper tf.name_scope('name') as scope.
:param x: (tf.tensor) The input to the layer (N, D).
:param output_dim: (integer) It specifies H, the output second dimension of the fully connected layer [ie:(N, H)]
:param initializer: (tf.contrib initializer) The initialization scheme, He et al. normal or Xavier normal are recommended.
:param l2_strength:(weight decay) (float) L2 regularization parameter.
:param bias: (float) Amount of bias. (if not float, it means pretrained bias)
:return out: The output of the layer. (N, H)
"""
n_in = x.get_shape()[-1].value
with tf.variable_scope(name):
if w == None:
w = __variable_with_weight_decay([n_in, output_dim], initializer, l2_strength)
__variable_summaries(w)
if isinstance(bias, float):
bias = tf.get_variable("layer_biases", [output_dim], tf.float32, tf.constant_initializer(bias))
__variable_summaries(bias)
output = tf.nn.bias_add(tf.matmul(x, w), bias)
return output
def dense(name, x, w=None, output_dim=128, initializer=tf.contrib.layers.xavier_initializer(), l2_strength=0.0,
bias=0.0,
activation=None, batchnorm_enabled=False, dropout_keep_prob=-1,
is_training=True
):
"""
This block is responsible for a fully connected followed by optional (non-linearity, dropout, max-pooling).
Note that: "is_training" should be passed by a correct value based on being in either training or testing.
:param name: (string) The name scope provided by the upper tf.name_scope('name') as scope.
:param x: (tf.tensor) The input to the layer (N, D).
:param output_dim: (integer) It specifies H, the output second dimension of the fully connected layer [ie:(N, H)]
:param initializer: (tf.contrib initializer) The initialization scheme, He et al. normal or Xavier normal are recommended.
:param l2_strength:(weight decay) (float) L2 regularization parameter.
:param bias: (float) Amount of bias.
:param activation: (tf.graph operator) The activation function applied after the convolution operation. If None, linear is applied.
:param batchnorm_enabled: (boolean) for enabling batch normalization.
:param dropout_keep_prob: (float) for the probability of keeping neurons. If equals -1, it means no dropout
:param is_training: (boolean) to diff. between training and testing (important for batch normalization and dropout)
:return out: The output of the layer. (N, H)
"""
with tf.variable_scope(name) as scope:
dense_o_b = __dense_p(name='dense', x=x, w=w, output_dim=output_dim, initializer=initializer,
l2_strength=l2_strength,
bias=bias)
if batchnorm_enabled:
dense_o_bn = tf.layers.batch_normalization(dense_o_b, training=is_training, epsilon=1e-5)
if not activation:
dense_a = dense_o_bn
else:
dense_a = activation(dense_o_bn)
else:
if not activation:
dense_a = dense_o_b
else:
dense_a = activation(dense_o_b)
def dropout_with_keep():
return tf.nn.dropout(dense_a, dropout_keep_prob)
def dropout_no_keep():
return tf.nn.dropout(dense_a, 1.0)
if dropout_keep_prob != -1:
dense_o_dr = tf.cond(is_training, dropout_with_keep, dropout_no_keep)
else:
dense_o_dr = dense_a
dense_o = dense_o_dr
return dense_o
def flatten(x):
"""
Flatten a (N,H,W,C) input into (N,D) output. Used for fully connected layers after conolution layers
:param x: (tf.tensor) representing input
:return: flattened output
"""
all_dims_exc_first = np.prod([v.value for v in x.get_shape()[1:]])
o = tf.reshape(x, [-1, all_dims_exc_first])
return o
############################################################################################################
# Pooling Methods
def max_pool_2d(x, size=(2, 2), stride=(2, 2), name='pooling'):
"""
Max pooling 2D Wrapper
:param x: (tf.tensor) The input to the layer (N,H,W,C).
:param size: (tuple) This specifies the size of the filter as well as the stride.
:param name: (string) Scope name.
:return: The output is the same input but halfed in both width and height (N,H/2,W/2,C).
"""
size_x, size_y = size
stride_x, stride_y = stride
return tf.nn.max_pool(x, ksize=[1, size_x, size_y, 1], strides=[1, stride_x, stride_y, 1], padding='VALID',
name=name)
def avg_pool_2d(x, size=(2, 2), stride=(2, 2), name='avg_pooling', padding='VALID'):
"""
Average pooling 2D Wrapper
:param x: (tf.tensor) The input to the layer (N,H,W,C).
:param size: (tuple) This specifies the size of the filter as well as the stride.
:param name: (string) Scope name.
:return: The output is the same input but halfed in both width and height (N,H/2,W/2,C).
"""
size_x, size_y = size
stride_x, stride_y = stride
return tf.nn.avg_pool(x, ksize=[1, size_x, size_y, 1], strides=[1, stride_x, stride_y, 1], padding=padding,
name=name)
############################################################################################################
# Utilities for layers
def __variable_with_weight_decay(kernel_shape, initializer, wd):
"""
Create a variable with L2 Regularization (Weight Decay)
:param kernel_shape: the size of the convolving weight kernel.
:param initializer: The initialization scheme, He et al. normal or Xavier normal are recommended.
:param wd:(weight decay) L2 regularization parameter.
:return: The weights of the kernel initialized. The L2 loss is added to the loss collection.
"""
w = tf.get_variable('weights', kernel_shape, tf.float32, initializer=initializer)
collection_name = tf.GraphKeys.REGULARIZATION_LOSSES
if wd and (not tf.get_variable_scope().reuse):
weight_decay = tf.multiply(tf.nn.l2_loss(w), wd, name='w_loss')
tf.add_to_collection(collection_name, weight_decay)
return w
# Summaries for variables
def __variable_summaries(var):
"""
Attach a lot of summaries to a Tensor (for TensorBoard visualization).
:param var: variable to be summarized
:return: None
"""
with tf.name_scope('summaries'):
mean = tf.reduce_mean(var)
tf.summary.scalar('mean', mean)
with tf.name_scope('stddev'):
stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean)))
tf.summary.scalar('stddev', stddev)
tf.summary.scalar('max', tf.reduce_max(var))
tf.summary.scalar('min', tf.reduce_min(var))
tf.summary.histogram('histogram', var)