Skip to content

Commit

Permalink
update infer.py
Browse files Browse the repository at this point in the history
  • Loading branch information
jiahuei committed Jun 5, 2019
1 parent ef1a479 commit 41254c3
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 5 deletions.
15 changes: 14 additions & 1 deletion src/example.sh
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ python train.py \


### Inference
# Default dataset and checkpoint directories
# Default dataset and checkpoint directories (MSCOCO, COMIC-256)
python infer.py

# Custom dataset and checkpoint directories
Expand All @@ -37,3 +37,16 @@ python infer.py \
--dataset_dir '/home/jiahuei/Documents/3_Datasets/MSCOCO_captions' \
--gpu '1'

# InstaPIC
python infer.py \
--infer_checkpoints_dir 'insta/word_add_softmax_h8_tie_lstm_run_01' \
--annotations_file 'insta_testval_raw.json'

# Custom InstaPIC directory
python infer.py \
--infer_checkpoints_dir 'insta/word_add_softmax_h8_tie_lstm_run_01' \
--dataset_dir '/home/jiahuei/Documents/3_Datasets/InstaPIC' \
--annotations_file 'insta_testval_raw.json'



9 changes: 5 additions & 4 deletions src/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,7 @@ def create_parser():
'--infer_checkpoints', type=str, default='all',
help='The checkpoint numbers to be evaluated. Comma-separated.')
parser.add_argument(
'--annotations_file', type=str,
default=pjoin(
os.path.dirname(CURR_DIR), 'common', 'coco_caption',
'annotations', 'captions_val2014.json'),
'--annotations_file', type=str, default='captions_val2014.json',
help='The annotations / reference file for calculating scores.')
parser.add_argument(
'--dataset_dir', type=str,
Expand Down Expand Up @@ -86,6 +83,10 @@ def create_parser():
default_exp_dir = pjoin(os.path.dirname(CURR_DIR), 'experiments')
args.infer_checkpoints_dir = pjoin(default_exp_dir, args.infer_checkpoints_dir)

args.annotations_file = pjoin(
os.path.dirname(CURR_DIR),
'common', 'coco_caption', 'annotations', args.annotations_file)

if args.infer_checkpoints == 'all':
files = sorted(os.listdir(args.infer_checkpoints_dir), key=nat_key)
files = [f for f in files if ckpt_prefix in f]
Expand Down

0 comments on commit 41254c3

Please sign in to comment.