Skip to content

Commit

Permalink
Merge pull request ruotianluo#4 from dmitriy-serdyuk/fix-checkpoint
Browse files Browse the repository at this point in the history
Fix checkpoint
  • Loading branch information
nke001 authored Oct 25, 2017
2 parents 73685cd + b691e40 commit c251be8
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 5 deletions.
6 changes: 5 additions & 1 deletion models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ def setup(opt, reverse=False):
# check if all necessary files exist
assert os.path.isdir(opt.start_from)," %s must be a a path" % opt.start_from
assert os.path.isfile(os.path.join(opt.start_from,"infos_"+opt.id+".pkl")),"infos.pkl file does not exist in path %s"%opt.start_from
model.load_state_dict(torch.load(os.path.join(opt.start_from, 'model.pth')))
forward_dict, backward_dict = torch.load(os.path.join(opt.start_from, 'model.pth'))
if reverse:
model.load_state_dict(backward_dict)
else:
model.load_state_dict(forward_dict)

return model
3 changes: 2 additions & 1 deletion opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def parse_opt():
help='an id identifying this run/job. used in cross-val and appended when writing progress files')
parser.add_argument('--train_only', type=int, default=0,
help='if true then use 80k, else use 110k')
parser.add_argument('--param-l2', type=float, default=0.)

args = parser.parse_args()

Expand All @@ -117,4 +118,4 @@ def parse_opt():
assert args.load_best_score == 0 or args.load_best_score == 1, "language_eval should be 0 or 1"
assert args.train_only == 0 or args.train_only == 1, "language_eval should be 0 or 1"

return args
return args
8 changes: 5 additions & 3 deletions train_twinnet_para.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def train(opt):
with open(os.path.join(opt.start_from, 'infos_'+opt.id+'.pkl')) as f:
infos = cPickle.load(f)
saved_model_opt = infos['opt']
need_be_same=["caption_model", "rnn_type", "rnn_size", "num_layers"]
need_be_same = ["caption_model", "rnn_type", "rnn_size", "num_layers"]
for checkme in need_be_same:
assert vars(saved_model_opt)[checkme] == vars(opt)[checkme], "Command line argument and saved model disagree on '%s' " % checkme

Expand Down Expand Up @@ -147,7 +147,9 @@ def train(opt):
l2_loss = ((affine_states - invert_backstates)** 2).sum(dim=1)[:, 0, :]
l2_loss = (l2_loss / masks[:, 1:-1].sum(dim=1).expand_as(l2_loss)).mean()

all_loss = loss + 3.0 * l2_loss + back_loss
if opt.param_l2 == 0:
back_loss = 0. * back_loss
all_loss = loss + opt.param_l2 * l2_loss + back_loss

all_loss.backward()
#back_loss.backward()
Expand Down Expand Up @@ -211,7 +213,7 @@ def train(opt):
best_val_score = current_score
best_flag = True
checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth')
torch.save(model.state_dict(), checkpoint_path)
torch.save([model.state_dict(), back_model.state_dict()], checkpoint_path)
print("model saved to {}".format(checkpoint_path))
optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth')
torch.save(optimizer.state_dict(), optimizer_path)
Expand Down

0 comments on commit c251be8

Please sign in to comment.