-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
52 lines (37 loc) · 1.8 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import tensorflow as tf
import os
import utils
import model
flags = tf.app.flags
flags.DEFINE_string('data_dir', None, 'train and validation data directory')
flags.DEFINE_string('model_dir', None, 'where to store checkpoints')
flags.DEFINE_string('vocab', None, 'vocab pkl path')
flags.DEFINE_string('mode', None, 'train, validation, test')
flags.DEFINE_integer('num_epochs', 1, 'number of epochs')
flags.DEFINE_integer('batch_size', 32, 'batch size')
flags.DEFINE_integer('epochs_per_eval', 1, 'epochs between evaluation')
flags.DEFINE_float('learning_rate', 0.0001, 'learning rate')
flags.DEFINE_integer('summary_freq', 200, 'frequency to write summary on tensorboard')
flags.DEFINE_integer('save_freq', 4000, 'steps between saving two checkpoints')
flags.DEFINE_integer('decay_step', 100000, 'steps per decay')
flags.DEFINE_float('decay_rate', 0.1, 'decay rate')
flags.DEFINE_boolean('pretrained', None, 'if use pretrain model')
flags.DEFINE_boolean('checkpoint', None, 'checkpoint to restore, is None use latest')
flags.DEFINE_string('predict_image', None, 'image to be captioned')
FLAGS = flags.FLAGS
def main(unused_argv):
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
vocab = utils.load_pickle(FLAGS.vocab)
sess = tf.Session(config=config)
caption_model = model.RNNAgent(sess, vocab, FLAGS)
if FLAGS.mode == 'train':
coco_data_train = utils.load_pickle(os.path.join(FLAGS.data_dir, 'train_dict.pkl'))
coco_data_val = utils.load_pickle(os.path.join(FLAGS.data_dir, 'val_dict.pkl'))
print('Successfully loading data...')
caption_model.learn(coco_data_train, coco_data_val)
elif FLAGS.mode == 'inference':
assert FLAGS.predict_image is not None
caption_model.inference(FLAGS.predict_image)
if __name__ == '__main__':
tf.app.run()