-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathexport_graph.py
61 lines (50 loc) · 2.55 KB
/
export_graph.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
""" Freeze variables and convert 2 generator networks to 2 GraphDef files.
This makes file size smaller and can be used for inference in production.
An example of command-line usage is:
python export_graph.py --checkpoint_dir checkpoints/20170424-1152 \
--XtoY_model apple2orange.pb \
--YtoX_model orange2apple.pb \
--image_size 256
"""
import tensorflow as tf
import os
from tensorflow.python.tools.freeze_graph import freeze_graph
from model import CycleGAN
import utils
FLAGS = tf.flags.FLAGS
tf.flags.DEFINE_string('checkpoint_dir', '', 'checkpoints directory path')
tf.flags.DEFINE_string('XtoY_model', 'apple2orange.pb', 'XtoY model name, default: apple2orange.pb')
tf.flags.DEFINE_string('YtoX_model', 'orange2apple.pb', 'YtoX model name, default: orange2apple.pb')
tf.flags.DEFINE_integer('image_size', '128', 'image size, default: 256')
tf.flags.DEFINE_integer('ngf', 64,
'number of gen filters in first conv layer, default: 64')
tf.flags.DEFINE_string('norm', 'instance',
'[instance, batch] use instance norm or batch norm, default: instance')
def export_graph(model_name, checkpoint_name, XtoY=True):
graph = tf.Graph()
with graph.as_default():
cycle_gan = CycleGAN(ngf=FLAGS.ngf, norm=FLAGS.norm, image_size=FLAGS.image_size, is_training=False)
input_image = tf.placeholder(tf.float32, shape=[FLAGS.image_size, FLAGS.image_size, 3], name='input_image')
cycle_gan.model()
if XtoY:
output_image = cycle_gan.G.encode_image(tf.expand_dims(input_image, 0))
# else:
# output_image = cycle_gan.F.sample(tf.expand_dims(input_image, 0))
output_image = tf.identity(cycle_gan.G.decode_image(output_image), name='output_image')
restore_saver = tf.train.Saver()
export_saver = tf.train.Saver()
with tf.Session(graph=graph) as sess:
sess.run(tf.global_variables_initializer())
latest_ckpt = tf.train.latest_checkpoint(checkpoint_name)
restore_saver.restore(sess, latest_ckpt)
output_graph_def = tf.graph_util.convert_variables_to_constants(
sess, graph.as_graph_def(), [output_image.op.name])
tf.train.write_graph(output_graph_def, 'pretrained', model_name, as_text=False)
tf.train.write_graph(sess.graph_def, 'pretrained', model_name+'.pbtxt')
def main(unused_argv):
print('Export XtoY model...')
export_graph(FLAGS.XtoY_model, FLAGS.checkpoint_dir, XtoY=True)
print('Export YtoX model...')
export_graph(FLAGS.YtoX_model, FLAGS.checkpoint_dir, XtoY=False)
if __name__ == '__main__':
tf.app.run()