Skip to content

Commit

Permalink
Make cached_token an option that can change; and lazy intialize the c…
Browse files Browse the repository at this point in the history
…ider scorer.
  • Loading branch information
ruotianluo committed Oct 3, 2017
1 parent 6f3c78c commit 52ec5c8
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 2 deletions.
6 changes: 5 additions & 1 deletion misc/rewards.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,13 @@
sys.path.append("cider")
from pyciderevalcap.ciderD.ciderD import CiderD

CiderD_scorer = CiderD(df='coco-train-idxs')
CiderD_scorer = None
#CiderD_scorer = CiderD(df='corpus')

def init_cider_scorer(cached_tokens):
global CiderD_scorer
CiderD_scorer = CiderD(df=cached_tokens)

def array_to_str(arr):
out = ''
for i in range(len(arr)):
Expand Down
2 changes: 2 additions & 0 deletions opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ def parse_opt():
Note: this file contains absolute paths, be careful when moving files around;
'model.ckpt-*' : file(s) with model definition (created by tf)
""")
parser.add_argument('--cached_tokens', type=str, default='coco-train-idxs',
help='Cached token file for calculating cider score during self critical training.')

# Model settings
parser.add_argument('--caption_model', type=str, default="show_tell",
Expand Down
3 changes: 2 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from dataloader import *
import eval_utils
import misc.utils as utils
from misc.rewards import get_self_critical_reward
from misc.rewards import init_cider_scorer, get_self_critical_reward

try:
import tensorflow as tf
Expand Down Expand Up @@ -101,6 +101,7 @@ def train(opt):
# If start self critical training
if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
sc_flag = True
init_cider_scorer(opt.cached_tokens)
else:
sc_flag = False

Expand Down

0 comments on commit 52ec5c8

Please sign in to comment.