Skip to content

Commit

Permalink
Implement Layer selection
Browse files Browse the repository at this point in the history
  • Loading branch information
MarvinTeichmann committed Mar 16, 2017
1 parent ce038be commit 08dffdb
Showing 1 changed file with 16 additions and 8 deletions.
24 changes: 16 additions & 8 deletions encoder/resnet.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
'''
"""
The MIT License (MIT)
Original Work: Copyright (c) 2016 Ryan Dahl
Expand All @@ -7,7 +7,7 @@
Modified Work: Copyright (c) 2017 Marvin Teichmann
For details see 'licenses/RESNET_LICENSE.txt'
'''
"""
import tensorflow as tf
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.training import moving_averages
Expand All @@ -33,12 +33,13 @@
IMAGENET_MEAN_BGR = [103.062623801, 115.902882574, 123.151630838, ]


network_file = os.path.join("tensorflow_resnet_convert_1.1",
"ResNet-L101.ckpt")

network_url = "Not yet uploaded."


def checkpoint_fn(layers):
return 'ResNet-L%d.ckpt' % layers


def inference(hypes, images, train=True,
num_classes=1000,
num_blocks=[3, 4, 6, 3], # defaults to 50-layer network
Expand All @@ -55,6 +56,8 @@ def inference(hypes, images, train=True,
num_blocks = [3, 4, 23, 3]
elif layers == 152:
num_blocks = [3, 8, 36, 3]
else:
assert()

if preprocess:
x = _imagenet_preprocess(images)
Expand Down Expand Up @@ -113,13 +116,18 @@ def _initalize_variables(hypes):

saver = tf.train.Saver(var_list=restore)

filename = network_file
layers = hypes['arch']['layers']

assert layers in [50, 101, 152]

filename = checkpoint_fn(layers)

if 'TV_DIR_DATA' in os.environ:
filename = os.path.join(os.environ['TV_DIR_DATA'], 'weights',
filename)
"tensorflow_resnet", filename)
else:
filename = os.path.join('DATA', 'weights', filename)
filename = os.path.join('DATA', 'weights', "tensorflow_resnet",
filename)

if not os.path.exists(filename):
logging.error("File not found: {}".format(filename))
Expand Down

0 comments on commit 08dffdb

Please sign in to comment.