diff --git a/misc/rewards.py b/misc/rewards.py index 4f46f433..c82ee729 100644 --- a/misc/rewards.py +++ b/misc/rewards.py @@ -13,9 +13,13 @@ sys.path.append("cider") from pyciderevalcap.ciderD.ciderD import CiderD -CiderD_scorer = CiderD(df='coco-train-idxs') +CiderD_scorer = None #CiderD_scorer = CiderD(df='corpus') +def init_cider_scorer(cached_tokens): + global CiderD_scorer + CiderD_scorer = CiderD(df=cached_tokens) + def array_to_str(arr): out = '' for i in range(len(arr)): diff --git a/opts.py b/opts.py index 3db4148a..326da6ab 100644 --- a/opts.py +++ b/opts.py @@ -18,6 +18,8 @@ def parse_opt(): Note: this file contains absolute paths, be careful when moving files around; 'model.ckpt-*' : file(s) with model definition (created by tf) """) + parser.add_argument('--cached_tokens', type=str, default='coco-train-idxs', + help='Cached token file for calculating cider score during self critical training.') # Model settings parser.add_argument('--caption_model', type=str, default="show_tell", diff --git a/train.py b/train.py index f9fae8bb..08e30f09 100644 --- a/train.py +++ b/train.py @@ -18,7 +18,7 @@ from dataloader import * import eval_utils import misc.utils as utils -from misc.rewards import get_self_critical_reward +from misc.rewards import init_cider_scorer, get_self_critical_reward try: import tensorflow as tf @@ -101,6 +101,7 @@ def train(opt): # If start self critical training if opt.self_critical_after != -1 and epoch >= opt.self_critical_after: sc_flag = True + init_cider_scorer(opt.cached_tokens) else: sc_flag = False