diff --git a/misc/utils.py b/misc/utils.py index 3085810f..b2e208ad 100644 --- a/misc/utils.py +++ b/misc/utils.py @@ -112,6 +112,10 @@ def set_lr(optimizer, lr): for group in optimizer.param_groups: group['lr'] = lr +def get_lr(optimizer): + for group in optimizer.param_groups: + return group['lr'] + def clip_gradient(optimizer, grad_clip): for group in optimizer.param_groups: for param in group['params']: @@ -159,3 +163,38 @@ def length_average(length, logprobs, alpha=0.): Returns the average probability of tokens in a sequence. """ return logprobs / length + +class ReduceLROnPlateau(object): + "Optim wrapper that implements rate." + def __init__(self, optimizer, mode='min', factor=0.1, patience=10, verbose=False, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08): + self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode, factor, patience, verbose, threshold, threshold_mode, cooldown, min_lr, eps) + self.optimizer = optimizer + self.current_lr = get_lr(optimizer) + + def step(self): + "Update parameters and rate" + self.optimizer.step() + + def scheduler_step(self, val): + self.scheduler.step(val) + self.current_lr = get_lr(self.optimizer) + + def state_dict(self): + return {'current_lr':self.current_lr, + 'scheduler_state_dict': self.scheduler.state_dict(), + 'optimizer_state_dict': self.optimizer.state_dict()} + + def load_state_dict(self, state_dict): + if 'current_lr' not in state_dict: + # it's normal optimizer + self.optimizer.load_state_dict(state_dict) + set_lr(self.optimizer, self.current_lr) # use the lr fromt the option + else: + # it's a schduler + self.current_lr = state_dict['current_lr'] + self.scheduler.load_state_dict(state_dict['scheduler_state_dict']) + self.optimizer.load_state_dict(state_dict['optimizer_state_dict']) + # current_lr is actually useless in this case + + def __getattr__(self, name): + return getattr(self.optimizer, name) diff --git a/opts.py b/opts.py index b8a0907f..69c1f605 100644 --- a/opts.py +++ b/opts.py @@ -100,6 +100,8 @@ def parse_opt(): help='epsilon that goes into denominator for smoothing') parser.add_argument('--weight_decay', type=float, default=0, help='weight_decay') + parser.add_argument('--reduce_on_plateau', action='store_true', + help='') parser.add_argument('--scheduled_sampling_start', type=int, default=-1, help='at what iteration to start decay gt probability') diff --git a/train.py b/train.py index f27c6207..6125a455 100644 --- a/train.py +++ b/train.py @@ -79,7 +79,11 @@ def train(opt): crit = utils.LanguageModelCriterion() rl_crit = utils.RewardCriterion() - optimizer = utils.build_optimizer(model.parameters(), opt) + if opt.reduce_on_plateau: + optimizer = utils.build_optimizer(model.parameters(), opt) + optimizer = utils.ReduceLROnPlateau(optimizer, factor=0.5, patience=3) + else: + optimizer = utils.build_optimizer(model.parameters(), opt) # Load the optimizer if vars(opt).get('start_from', None) is not None and os.path.isfile(os.path.join(opt.start_from,"optimizer.pth")): optimizer.load_state_dict(torch.load(os.path.join(opt.start_from, 'optimizer.pth'))) @@ -87,13 +91,14 @@ def train(opt): while True: if update_lr_flag: # Assign the learning rate - if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0: - frac = (epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every - decay_factor = opt.learning_rate_decay_rate ** frac - opt.current_lr = opt.learning_rate * decay_factor - else: - opt.current_lr = opt.learning_rate - utils.set_lr(optimizer, opt.current_lr) + if not opt.reduce_on_plateau: + if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0: + frac = (epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every + decay_factor = opt.learning_rate_decay_rate ** frac + opt.current_lr = opt.learning_rate * decay_factor + else: + opt.current_lr = opt.learning_rate + utils.set_lr(optimizer, opt.current_lr) # set the decayed rate # Assign the scheduled sampling prob if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0: frac = (epoch - opt.scheduled_sampling_start) // opt.scheduled_sampling_increase_every @@ -151,6 +156,8 @@ def train(opt): # Write the training loss summary if (iteration % opt.losses_log_every == 0): add_summary_value(tb_summary_writer, 'train_loss', train_loss, iteration) + if opt.reduce_on_plateau: + opt.current_lr = optimizer.current_lr add_summary_value(tb_summary_writer, 'learning_rate', opt.current_lr, iteration) add_summary_value(tb_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration) if sc_flag: @@ -168,6 +175,11 @@ def train(opt): eval_kwargs.update(vars(opt)) val_loss, predictions, lang_stats = eval_utils.eval_split(dp_model, crit, loader, eval_kwargs) + if opt.reduce_on_plateau: + if 'CIDEr' in lang_stats: + optimizer.scheduler_step(-lang_stats['CIDEr']) + else: + optimizer.scheduler_step(val_loss) # Write validation result into summary add_summary_value(tb_summary_writer, 'validation loss', val_loss, iteration) if lang_stats is not None: