diff --git a/src/caffe/solvers/sgd_solver.cpp b/src/caffe/solvers/sgd_solver.cpp index dcc0103cc..a218552bd 100644 --- a/src/caffe/solvers/sgd_solver.cpp +++ b/src/caffe/solvers/sgd_solver.cpp @@ -125,9 +125,11 @@ Dtype SGDSolver::GetLearningRate() { rate = this->param_.base_lr() * pow(this->param_.gamma(), this->current_step_); } else if (lr_policy == "poly") { - rate = this->param_.base_lr() * pow(Dtype(1.) - - (Dtype(this->iter_) / Dtype(this->param_.max_iter())), - this->param_.power()); + float end_learning_rate = 0.0001; + rate = (this->param_.base_lr() - end_learning_rate) * pow(Dtype(1.) - + ((Dtype(this->iter_) - Dtype(this->param_.warmup_iter())) / + (Dtype(this->param_.max_iter()) - Dtype(this->param_.warmup_iter()) + Dtype(1.))), + this->param_.power()) + end_learning_rate; } else if (lr_policy == "sigmoid") { rate = this->param_.base_lr() * (Dtype(1.) / (Dtype(1.) + exp(-this->param_.gamma() * (Dtype(this->iter_) - @@ -741,7 +743,8 @@ Dtype SGDSolver::GetLocalRate(int param_id) const { float weight_decay = this->param_.weight_decay(); if (w_norm > 0.F && wgrad_norm > 0.F) { - rate = gw_ratio * w_norm / (wgrad_norm + weight_decay * w_norm); + float lars_epsilon = 0.00001; + rate = gw_ratio * w_norm / (wgrad_norm + weight_decay * w_norm + lars_epsilon); } if (local_lr > 0.F) { local_lr = rate;