Skip to content

Commit

Permalink
Allow to train resnet without moving average bn.
Browse files Browse the repository at this point in the history
  • Loading branch information
MarvinTeichmann committed Mar 16, 2017
1 parent 08dffdb commit 253d4d1
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 27 deletions.
2 changes: 1 addition & 1 deletion encoder/fcn8_vgg.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,6 @@ def inference(hypes, images, train=True):

logits['feed4'] = vgg_fcn.pool3

logits['fcn_logits'] = vgg_fcn.upscore32
# logits['fcn_logits'] = vgg_fcn.upscore32

return logits
61 changes: 36 additions & 25 deletions encoder/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,25 +70,29 @@ def inference(hypes, images, train=True,

with tf.variable_scope('scale1'):
x = _conv(x, 64, ksize=7, stride=2)
x = _bn(x, is_training)
x = _bn(x, is_training, hypes)
x = _relu(x)
scale1 = x

with tf.variable_scope('scale2'):
x = _max_pool(x, ksize=3, stride=2)
x = stack(x, num_blocks[0], 64, bottleneck, is_training, stride=1)
x = stack(x, num_blocks[0], 64, bottleneck, is_training, stride=1,
hypes=hypes)
scale2 = x

with tf.variable_scope('scale3'):
x = stack(x, num_blocks[1], 128, bottleneck, is_training, stride=2)
x = stack(x, num_blocks[1], 128, bottleneck, is_training, stride=2,
hypes=hypes)
scale3 = x

with tf.variable_scope('scale4'):
x = stack(x, num_blocks[2], 256, bottleneck, is_training, stride=2)
x = stack(x, num_blocks[2], 256, bottleneck, is_training, stride=2,
hypes=hypes)
scale4 = x

with tf.variable_scope('scale5'):
x = stack(x, num_blocks[3], 512, bottleneck, is_training, stride=2)
x = stack(x, num_blocks[3], 512, bottleneck, is_training, stride=2,
hypes=hypes)
scale5 = x

logits['images'] = images
Expand Down Expand Up @@ -153,19 +157,21 @@ def _imagenet_preprocess(rgb):
return bgr


def stack(x, num_blocks, filters_internal, bottleneck, is_training, stride):
def stack(x, num_blocks, filters_internal, bottleneck, is_training, stride,
hypes):
for n in range(num_blocks):
s = stride if n == 0 else 1
with tf.variable_scope('block%d' % (n + 1)):
x = block(x,
filters_internal,
bottleneck=bottleneck,
is_training=is_training,
stride=s)
stride=s,
hypes=hypes)
return x


def block(x, filters_internal, is_training, stride, bottleneck):
def block(x, filters_internal, is_training, stride, bottleneck, hypes):
filters_in = x.get_shape()[-1]

# Note: filters_out isn't how many filters are outputed.
Expand All @@ -183,31 +189,31 @@ def block(x, filters_internal, is_training, stride, bottleneck):
if bottleneck:
with tf.variable_scope('a'):
x = _conv(x, filters_internal, ksize=1, stride=stride)
x = _bn(x, is_training)
x = _bn(x, is_training, hypes)
x = _relu(x)

with tf.variable_scope('b'):
x = _conv(x, filters_internal, ksize=3, stride=1)
x = _bn(x, is_training)
x = _bn(x, is_training, hypes)
x = _relu(x)

with tf.variable_scope('c'):
x = _conv(x, filters_out, ksize=1, stride=1)
x = _bn(x, is_training)
x = _bn(x, is_training, hypes)
else:
with tf.variable_scope('A'):
x = _conv(x, filters_internal, ksize=3, stride=stride)
x = _bn(x, is_training)
x = _bn(x, is_training, hypes)
x = _relu(x)

with tf.variable_scope('B'):
x = _conv(x, filters_out, ksize=3, stride=1)
x = _bn(x, is_training)
x = _bn(x, is_training, hypes)

with tf.variable_scope('shortcut'):
if filters_out != filters_in or stride != 1:
shortcut = _conv(shortcut, filters_out, ksize=1, stride=stride)
shortcut = _bn(shortcut, is_training)
shortcut = _bn(shortcut, is_training, hypes)

return _relu(x + shortcut)

Expand All @@ -216,7 +222,7 @@ def _relu(x):
return tf.nn.relu(x)


def _bn(x, is_training):
def _bn(x, is_training, hypes):
x_shape = x.get_shape()
params_shape = x_shape[-1:]
axis = list(range(len(x_shape) - 1))
Expand All @@ -239,16 +245,21 @@ def _bn(x, is_training):

# These ops will only be preformed when training.
mean, variance = tf.nn.moments(x, axis)
update_moving_mean = moving_averages.assign_moving_average(moving_mean,
mean, BN_DECAY)
update_moving_variance = moving_averages.assign_moving_average(
moving_variance, variance, BN_DECAY)
tf.add_to_collection(UPDATE_OPS_COLLECTION, update_moving_mean)
tf.add_to_collection(UPDATE_OPS_COLLECTION, update_moving_variance)

mean, variance = control_flow_ops.cond(
is_training, lambda: (mean, variance),
lambda: (moving_mean, moving_variance))

if hypes['use_moving_average_bn']:
update_moving_mean = moving_averages.assign_moving_average(moving_mean,
mean,
BN_DECAY)
update_moving_variance = moving_averages.assign_moving_average(
moving_variance, variance, BN_DECAY)
tf.add_to_collection(UPDATE_OPS_COLLECTION, update_moving_mean)
tf.add_to_collection(UPDATE_OPS_COLLECTION, update_moving_variance)

mean, variance = control_flow_ops.cond(
is_training, lambda: (mean, variance),
lambda: (moving_mean, moving_variance))
else:
mean, variance = mean, variance

x = tf.nn.batch_normalization(x, mean, variance, beta, gamma, BN_EPSILON)
# x.set_shape(inputs.get_shape()) ??
Expand Down
2 changes: 1 addition & 1 deletion hypes/KittiRes.json
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
"reseize_image": true,
"image_height" : 384,
"image_width" : 1248,

"augment_level": 1
},

Expand All @@ -71,5 +70,6 @@
"clip_norm" : 1.0,
"wd": 5e-4,
"load_pretrained": true,
"use_moving_average_bn": false,
"scale_down": 0.5
}

0 comments on commit 253d4d1

Please sign in to comment.