diff --git a/src/example.sh b/src/example.sh index 05eb0fa..ca3d57a 100644 --- a/src/example.sh +++ b/src/example.sh @@ -28,7 +28,7 @@ python train.py \ ### Inference -# Default dataset and checkpoint directories +# Default dataset and checkpoint directories (MSCOCO, COMIC-256) python infer.py # Custom dataset and checkpoint directories @@ -37,3 +37,16 @@ python infer.py \ --dataset_dir '/home/jiahuei/Documents/3_Datasets/MSCOCO_captions' \ --gpu '1' +# InstaPIC +python infer.py \ + --infer_checkpoints_dir 'insta/word_add_softmax_h8_tie_lstm_run_01' \ + --annotations_file 'insta_testval_raw.json' + +# Custom InstaPIC directory +python infer.py \ + --infer_checkpoints_dir 'insta/word_add_softmax_h8_tie_lstm_run_01' \ + --dataset_dir '/home/jiahuei/Documents/3_Datasets/InstaPIC' \ + --annotations_file 'insta_testval_raw.json' + + + diff --git a/src/infer.py b/src/infer.py index c8e5831..b1ed983 100644 --- a/src/infer.py +++ b/src/infer.py @@ -35,10 +35,7 @@ def create_parser(): '--infer_checkpoints', type=str, default='all', help='The checkpoint numbers to be evaluated. Comma-separated.') parser.add_argument( - '--annotations_file', type=str, - default=pjoin( - os.path.dirname(CURR_DIR), 'common', 'coco_caption', - 'annotations', 'captions_val2014.json'), + '--annotations_file', type=str, default='captions_val2014.json', help='The annotations / reference file for calculating scores.') parser.add_argument( '--dataset_dir', type=str, @@ -86,6 +83,10 @@ def create_parser(): default_exp_dir = pjoin(os.path.dirname(CURR_DIR), 'experiments') args.infer_checkpoints_dir = pjoin(default_exp_dir, args.infer_checkpoints_dir) + args.annotations_file = pjoin( + os.path.dirname(CURR_DIR), + 'common', 'coco_caption', 'annotations', args.annotations_file) + if args.infer_checkpoints == 'all': files = sorted(os.listdir(args.infer_checkpoints_dir), key=nat_key) files = [f for f in files if ckpt_prefix in f]