Skip to content

Commit

Permalink
Add reduce on plataeu.
Browse files Browse the repository at this point in the history
  • Loading branch information
ruotianluo committed Apr 10, 2019
1 parent 28bcbb3 commit a1d59f2
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 8 deletions.
39 changes: 39 additions & 0 deletions misc/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,10 @@ def set_lr(optimizer, lr):
for group in optimizer.param_groups:
group['lr'] = lr

def get_lr(optimizer):
for group in optimizer.param_groups:
return group['lr']

def clip_gradient(optimizer, grad_clip):
for group in optimizer.param_groups:
for param in group['params']:
Expand Down Expand Up @@ -159,3 +163,38 @@ def length_average(length, logprobs, alpha=0.):
Returns the average probability of tokens in a sequence.
"""
return logprobs / length

class ReduceLROnPlateau(object):
"Optim wrapper that implements rate."
def __init__(self, optimizer, mode='min', factor=0.1, patience=10, verbose=False, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08):
self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode, factor, patience, verbose, threshold, threshold_mode, cooldown, min_lr, eps)
self.optimizer = optimizer
self.current_lr = get_lr(optimizer)

def step(self):
"Update parameters and rate"
self.optimizer.step()

def scheduler_step(self, val):
self.scheduler.step(val)
self.current_lr = get_lr(self.optimizer)

def state_dict(self):
return {'current_lr':self.current_lr,
'scheduler_state_dict': self.scheduler.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict()}

def load_state_dict(self, state_dict):
if 'current_lr' not in state_dict:
# it's normal optimizer
self.optimizer.load_state_dict(state_dict)
set_lr(self.optimizer, self.current_lr) # use the lr fromt the option
else:
# it's a schduler
self.current_lr = state_dict['current_lr']
self.scheduler.load_state_dict(state_dict['scheduler_state_dict'])
self.optimizer.load_state_dict(state_dict['optimizer_state_dict'])
# current_lr is actually useless in this case

def __getattr__(self, name):
return getattr(self.optimizer, name)
2 changes: 2 additions & 0 deletions opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ def parse_opt():
help='epsilon that goes into denominator for smoothing')
parser.add_argument('--weight_decay', type=float, default=0,
help='weight_decay')
parser.add_argument('--reduce_on_plateau', action='store_true',
help='')

parser.add_argument('--scheduled_sampling_start', type=int, default=-1,
help='at what iteration to start decay gt probability')
Expand Down
28 changes: 20 additions & 8 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,21 +79,26 @@ def train(opt):
crit = utils.LanguageModelCriterion()
rl_crit = utils.RewardCriterion()

optimizer = utils.build_optimizer(model.parameters(), opt)
if opt.reduce_on_plateau:
optimizer = utils.build_optimizer(model.parameters(), opt)
optimizer = utils.ReduceLROnPlateau(optimizer, factor=0.5, patience=3)
else:
optimizer = utils.build_optimizer(model.parameters(), opt)
# Load the optimizer
if vars(opt).get('start_from', None) is not None and os.path.isfile(os.path.join(opt.start_from,"optimizer.pth")):
optimizer.load_state_dict(torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

while True:
if update_lr_flag:
# Assign the learning rate
if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
frac = (epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every
decay_factor = opt.learning_rate_decay_rate ** frac
opt.current_lr = opt.learning_rate * decay_factor
else:
opt.current_lr = opt.learning_rate
utils.set_lr(optimizer, opt.current_lr)
if not opt.reduce_on_plateau:
if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
frac = (epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every
decay_factor = opt.learning_rate_decay_rate ** frac
opt.current_lr = opt.learning_rate * decay_factor
else:
opt.current_lr = opt.learning_rate
utils.set_lr(optimizer, opt.current_lr) # set the decayed rate
# Assign the scheduled sampling prob
if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
frac = (epoch - opt.scheduled_sampling_start) // opt.scheduled_sampling_increase_every
Expand Down Expand Up @@ -151,6 +156,8 @@ def train(opt):
# Write the training loss summary
if (iteration % opt.losses_log_every == 0):
add_summary_value(tb_summary_writer, 'train_loss', train_loss, iteration)
if opt.reduce_on_plateau:
opt.current_lr = optimizer.current_lr
add_summary_value(tb_summary_writer, 'learning_rate', opt.current_lr, iteration)
add_summary_value(tb_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration)
if sc_flag:
Expand All @@ -168,6 +175,11 @@ def train(opt):
eval_kwargs.update(vars(opt))
val_loss, predictions, lang_stats = eval_utils.eval_split(dp_model, crit, loader, eval_kwargs)

if opt.reduce_on_plateau:
if 'CIDEr' in lang_stats:
optimizer.scheduler_step(-lang_stats['CIDEr'])
else:
optimizer.scheduler_step(val_loss)
# Write validation result into summary
add_summary_value(tb_summary_writer, 'validation loss', val_loss, iteration)
if lang_stats is not None:
Expand Down

0 comments on commit a1d59f2

Please sign in to comment.