From 12532b6cc7d31b1b8fb1a6888e1928bf1cd92757 Mon Sep 17 00:00:00 2001 From: Rohit Kumar Srivastava Date: Fri, 15 Nov 2019 01:54:16 +0000 Subject: [PATCH] adding API doc for Lamb Phase 1 and 2 --- src/operator/optimizer_op.cc | 49 ++++++++++++++++++++++++++++++++++-- 1 file changed, 47 insertions(+), 2 deletions(-) diff --git a/src/operator/optimizer_op.cc b/src/operator/optimizer_op.cc index ff248861788a..9cf32778b15c 100644 --- a/src/operator/optimizer_op.cc +++ b/src/operator/optimizer_op.cc @@ -924,7 +924,32 @@ Note that non-zero values for the weight decay option are not supported. .add_arguments(AdagradParam::__FIELDS__()); NNVM_REGISTER_OP(lamb_update_phase1) -.describe(R"code(Update function for lamb optimizer. +.describe(R"code(Phase I of lamb update it performs the following operations and returns g:. + +Link to paper: https://arxiv.org/pdf/1904.00962.pdf + +.. math:: + \begin{gather*} + grad = grad * rescale_grad + if (grad < -clip_gradient) + then + grad = -clip_gradient + if (grad > clip_gradient) + then + grad = clip_gradient + + mean = beta1 * mean + (1 - beta1) * grad; + variance = beta2 * variance + (1. - beta2) * grad ^ 2; + + if (bias_correction) + then + mean_hat = mean / (1. - beta1^t); + var_hat = var / (1 - beta2^t); + g = mean_hat / (var_hat^(1/2) + epsilon) + wd * weight; + else + g = mean / (var_data^(1/2) + epsilon) + wd * weight_data[i]; + \end{gather*} + )code" ADD_FILELINE) .set_num_inputs(4) .set_num_outputs(1) @@ -943,7 +968,27 @@ NNVM_REGISTER_OP(lamb_update_phase1) .add_arguments(LambUpdatePhaseOneParam::__FIELDS__()); NNVM_REGISTER_OP(lamb_update_phase2) -.describe(R"code(Update function for lamb optimizer. +.describe(R"code(Phase II of lamb update it performs the following operations and updates grad. + +Link to paper: https://arxiv.org/pdf/1904.00962.pdf + +.. math:: + \begin{gather*} + if (lower_bound >= 0) + then + r1 = max(r1, lower_bound) + if (upper_bound >= 0) + then + r1 = max(r1, upper_bound) + + if (r1 == 0 or r2 == 0) + then + lr = lr + else + lr = lr * (r1/r2) + weight = weight - lr * g + \end{gather*} + )code" ADD_FILELINE) .set_num_inputs(4) .set_num_outputs(1)