Skip to content

Commit

Permalink
Make encoder compatible with other Versions
Browse files Browse the repository at this point in the history
  • Loading branch information
MarvinTeichmann committed Mar 26, 2017
1 parent 65a3e12 commit 5237590
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
3 changes: 3 additions & 0 deletions encoder/fcn8_vgg.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,7 @@ def inference(hypes, images, train=True):

logits['fcn_logits'] = vgg_fcn.upscore32

logits['deep_feat'] = vgg_fcn.pool5
logits['early_feat'] = vgg_fcn.conv4_3

return logits
13 changes: 8 additions & 5 deletions encoder/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ def inference(hypes, images, train=True,
logits['feed2'] = scale4
logits['feed4'] = scale3

logits['early_feat'] = scale3
logits['deep_feat'] = scale5

if train:
restore = tf.global_variables()
hypes['init_function'] = _initalize_variables
Expand Down Expand Up @@ -246,12 +249,12 @@ def _bn(x, is_training, hypes):
# These ops will only be preformed when training.
mean, variance = tf.nn.moments(x, axis)

update_moving_mean = moving_averages.assign_moving_average(moving_mean,
mean,
BN_DECAY)
update_moving_variance = moving_averages.assign_moving_average(
moving_variance, variance, BN_DECAY)
if hypes['use_moving_average_bn']:
update_moving_mean = moving_averages.assign_moving_average(moving_mean,
mean,
BN_DECAY)
update_moving_variance = moving_averages.assign_moving_average(
moving_variance, variance, BN_DECAY)
tf.add_to_collection(UPDATE_OPS_COLLECTION, update_moving_mean)
tf.add_to_collection(UPDATE_OPS_COLLECTION, update_moving_variance)

Expand Down

0 comments on commit 5237590

Please sign in to comment.