diff --git a/encoder/resnet.py b/encoder/resnet.py index f78de5e..7a06e02 100644 --- a/encoder/resnet.py +++ b/encoder/resnet.py @@ -1,4 +1,4 @@ -''' +""" The MIT License (MIT) Original Work: Copyright (c) 2016 Ryan Dahl @@ -7,7 +7,7 @@ Modified Work: Copyright (c) 2017 Marvin Teichmann For details see 'licenses/RESNET_LICENSE.txt' -''' +""" import tensorflow as tf from tensorflow.python.ops import control_flow_ops from tensorflow.python.training import moving_averages @@ -33,12 +33,13 @@ IMAGENET_MEAN_BGR = [103.062623801, 115.902882574, 123.151630838, ] -network_file = os.path.join("tensorflow_resnet_convert_1.1", - "ResNet-L101.ckpt") - network_url = "Not yet uploaded." +def checkpoint_fn(layers): + return 'ResNet-L%d.ckpt' % layers + + def inference(hypes, images, train=True, num_classes=1000, num_blocks=[3, 4, 6, 3], # defaults to 50-layer network @@ -55,6 +56,8 @@ def inference(hypes, images, train=True, num_blocks = [3, 4, 23, 3] elif layers == 152: num_blocks = [3, 8, 36, 3] + else: + assert() if preprocess: x = _imagenet_preprocess(images) @@ -113,13 +116,18 @@ def _initalize_variables(hypes): saver = tf.train.Saver(var_list=restore) - filename = network_file + layers = hypes['arch']['layers'] + + assert layers in [50, 101, 152] + + filename = checkpoint_fn(layers) if 'TV_DIR_DATA' in os.environ: filename = os.path.join(os.environ['TV_DIR_DATA'], 'weights', - filename) + "tensorflow_resnet", filename) else: - filename = os.path.join('DATA', 'weights', filename) + filename = os.path.join('DATA', 'weights', "tensorflow_resnet", + filename) if not os.path.exists(filename): logging.error("File not found: {}".format(filename))