diff --git a/models/__init__.py b/models/__init__.py index fd917c06..4cff8d30 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -53,6 +53,10 @@ def setup(opt, reverse=False): # check if all necessary files exist assert os.path.isdir(opt.start_from)," %s must be a a path" % opt.start_from assert os.path.isfile(os.path.join(opt.start_from,"infos_"+opt.id+".pkl")),"infos.pkl file does not exist in path %s"%opt.start_from - model.load_state_dict(torch.load(os.path.join(opt.start_from, 'model.pth'))) + forward_dict, backward_dict = torch.load(os.path.join(opt.start_from, 'model.pth')) + if reverse: + model.load_state_dict(backward_dict) + else: + model.load_state_dict(forward_dict) return model diff --git a/opts.py b/opts.py index a2addcdc..c3a6409d 100644 --- a/opts.py +++ b/opts.py @@ -100,6 +100,7 @@ def parse_opt(): help='an id identifying this run/job. used in cross-val and appended when writing progress files') parser.add_argument('--train_only', type=int, default=0, help='if true then use 80k, else use 110k') + parser.add_argument('--param-l2', type=float, default=0.) args = parser.parse_args() @@ -117,4 +118,4 @@ def parse_opt(): assert args.load_best_score == 0 or args.load_best_score == 1, "language_eval should be 0 or 1" assert args.train_only == 0 or args.train_only == 1, "language_eval should be 0 or 1" - return args \ No newline at end of file + return args diff --git a/train_twinnet_para.py b/train_twinnet_para.py index b9a25735..ffbb4ef2 100644 --- a/train_twinnet_para.py +++ b/train_twinnet_para.py @@ -45,7 +45,7 @@ def train(opt): with open(os.path.join(opt.start_from, 'infos_'+opt.id+'.pkl')) as f: infos = cPickle.load(f) saved_model_opt = infos['opt'] - need_be_same=["caption_model", "rnn_type", "rnn_size", "num_layers"] + need_be_same = ["caption_model", "rnn_type", "rnn_size", "num_layers"] for checkme in need_be_same: assert vars(saved_model_opt)[checkme] == vars(opt)[checkme], "Command line argument and saved model disagree on '%s' " % checkme @@ -147,7 +147,9 @@ def train(opt): l2_loss = ((affine_states - invert_backstates)** 2).sum(dim=1)[:, 0, :] l2_loss = (l2_loss / masks[:, 1:-1].sum(dim=1).expand_as(l2_loss)).mean() - all_loss = loss + 3.0 * l2_loss + back_loss + if opt.param_l2 == 0: + back_loss = 0. * back_loss + all_loss = loss + opt.param_l2 * l2_loss + back_loss all_loss.backward() #back_loss.backward() @@ -211,7 +213,7 @@ def train(opt): best_val_score = current_score best_flag = True checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth') - torch.save(model.state_dict(), checkpoint_path) + torch.save([model.state_dict(), back_model.state_dict()], checkpoint_path) print("model saved to {}".format(checkpoint_path)) optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth') torch.save(optimizer.state_dict(), optimizer_path)