Skip to content

Commit

Permalink
Merge commit '510b0e02d1fcf43a5281bebb147ca1bce5db45f1' into self_cri…
Browse files Browse the repository at this point in the history
…tical_bottom_up

* commit '510b0e02d1fcf43a5281bebb147ca1bce5db45f1':
  Fix a typo in FCModel _sample
  Remove att2in dependency and fix typo
  Fix #26.
  fix bug when lang_stats not set to 1
  Only initialize cider_score at the first time.
  • Loading branch information
ruotianluo committed Feb 14, 2019
2 parents a258cac + 510b0e0 commit a99a760
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 9 deletions.
2 changes: 1 addition & 1 deletion misc/rewards.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,4 +75,4 @@ def get_self_critical_reward(model, fc_feats, att_feats, att_masks, data, gen_re

rewards = np.repeat(scores[:, np.newaxis], gen_result.shape[1], 1)

return rewards
return rewards
2 changes: 1 addition & 1 deletion models/FCModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def _sample(self, fc_feats, att_feats, att_masks=None, opt={}):
beam_size = opt.get('beam_size', 1)
temperature = opt.get('temperature', 1.0)
if beam_size > 1:
return self.sample_beam(fc_feats, att_feats, opt)
return self._sample_beam(fc_feats, att_feats, opt)

batch_size = fc_feats.size(0)
state = self.init_hidden(batch_size)
Expand Down
7 changes: 3 additions & 4 deletions models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,13 @@
from .ShowTellModel import ShowTellModel
from .FCModel import FCModel
from .OldModel import ShowAttendTellModel, AllImgModel
from .Att2inModel import Att2inModel
# from .Att2inModel import Att2inModel
from .AttModel import *

def setup(opt):

if opt.caption_model == 'fc':
model = FCModel(opt)
if opt.caption_model == 'show_tell':
elif opt.caption_model == 'show_tell':
model = ShowTellModel(opt)
# Att2in model in self-critical
elif opt.caption_model == 'att2in':
Expand Down Expand Up @@ -54,4 +53,4 @@ def setup(opt):
assert os.path.isfile(os.path.join(opt.start_from,"infos_"+opt.id+".pkl")),"infos.pkl file does not exist in path %s"%opt.start_from
model.load_state_dict(torch.load(os.path.join(opt.start_from, 'model.pth')))

return model
return model
7 changes: 4 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def train(opt):
opt.current_lr = opt.learning_rate * decay_factor
else:
opt.current_lr = opt.learning_rate
utils.set_lr(optimizer, opt.current_lr) # set the decayed rate
utils.set_lr(optimizer, opt.current_lr)
# Assign the scheduled sampling prob
if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
frac = (epoch - opt.scheduled_sampling_start) // opt.scheduled_sampling_increase_every
Expand Down Expand Up @@ -169,8 +169,9 @@ def train(opt):

# Write validation result into summary
add_summary_value(tb_summary_writer, 'validation loss', val_loss, iteration)
for k,v in lang_stats.items():
add_summary_value(tb_summary_writer, k, v, iteration)
if lang_stats is not None:
for k,v in lang_stats.items():
add_summary_value(tb_summary_writer, k, v, iteration)
val_result_history[iteration] = {'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions}

# Save model if is improving on validation result
Expand Down

0 comments on commit a99a760

Please sign in to comment.