diff --git a/encoder/fcn8_vgg.py b/encoder/fcn8_vgg.py index d067c39..f57584c 100644 --- a/encoder/fcn8_vgg.py +++ b/encoder/fcn8_vgg.py @@ -55,4 +55,7 @@ def inference(hypes, images, train=True): logits['fcn_logits'] = vgg_fcn.upscore32 + logits['deep_feat'] = vgg_fcn.pool5 + logits['early_feat'] = vgg_fcn.conv4_3 + return logits diff --git a/encoder/resnet.py b/encoder/resnet.py index abf0917..f447f9a 100644 --- a/encoder/resnet.py +++ b/encoder/resnet.py @@ -101,6 +101,9 @@ def inference(hypes, images, train=True, logits['feed2'] = scale4 logits['feed4'] = scale3 + logits['early_feat'] = scale3 + logits['deep_feat'] = scale5 + if train: restore = tf.global_variables() hypes['init_function'] = _initalize_variables @@ -246,12 +249,12 @@ def _bn(x, is_training, hypes): # These ops will only be preformed when training. mean, variance = tf.nn.moments(x, axis) + update_moving_mean = moving_averages.assign_moving_average(moving_mean, + mean, + BN_DECAY) + update_moving_variance = moving_averages.assign_moving_average( + moving_variance, variance, BN_DECAY) if hypes['use_moving_average_bn']: - update_moving_mean = moving_averages.assign_moving_average(moving_mean, - mean, - BN_DECAY) - update_moving_variance = moving_averages.assign_moving_average( - moving_variance, variance, BN_DECAY) tf.add_to_collection(UPDATE_OPS_COLLECTION, update_moving_mean) tf.add_to_collection(UPDATE_OPS_COLLECTION, update_moving_variance)