Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fixing base lamb optimizer
Browse files Browse the repository at this point in the history
  • Loading branch information
Rohit Kumar Srivastava committed Nov 14, 2019
1 parent d8e397d commit e6ac0dc
Show file tree
Hide file tree
Showing 5 changed files with 162 additions and 59 deletions.
24 changes: 16 additions & 8 deletions python/mxnet/optimizer/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,15 @@
import os
import numpy
from ..base import py_str
from ..ndarray import (NDArray, zeros, clip, sqrt, cast, maximum, abs as NDabs, array, multiply,
from ..ndarray import (NDArray, zeros, clip, sqrt, cast, minimum, maximum, abs as NDabs, array, multiply,
multi_sum_sq, multi_lars, norm as NDnorm)
from ..ndarray import (sgd_update, sgd_mom_update, adam_update, rmsprop_update, rmspropalex_update,
mp_sgd_update, mp_sgd_mom_update, square, ftrl_update, ftml_update,
signsgd_update, signum_update, nag_mom_update, mp_nag_mom_update,
multi_sgd_update, multi_sgd_mom_update, multi_mp_sgd_update,
multi_mp_sgd_mom_update, preloaded_multi_sgd_update,
preloaded_multi_sgd_mom_update, preloaded_multi_mp_sgd_update,
preloaded_multi_mp_sgd_mom_update, lamb_update)
preloaded_multi_mp_sgd_mom_update, lamb_update, lamb_weight_update)
from ..ndarray import sparse
from ..random import normal
from ..util import is_np_array
Expand Down Expand Up @@ -1250,7 +1250,7 @@ class LAMB(Optimizer):
"""LAMB Optimizer.
"""
def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-6,
lower_bound=1e-3, upper_bound=10.0, bias_correction=False, **kwargs):
lower_bound=None, upper_bound=None, bias_correction=False, **kwargs):
super(LAMB, self).__init__(learning_rate=learning_rate, **kwargs)
self.beta1 = beta1
self.beta2 = beta2
Expand All @@ -1259,13 +1259,14 @@ def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-6,
self.upper_bound = upper_bound
self.bias_correction = bias_correction


def create_state(self, index, weight):
stype = weight.stype
dtype = weight.dtype
return (zeros(weight.shape, weight.context, dtype=dtype, stype=stype),
zeros(weight.shape, weight.context, dtype=dtype, stype=stype))

def update(self, index, weight,grad, state):
def update(self, index, weight, grad, state):
assert(isinstance(weight, NDArray))
assert(isinstance(grad, NDArray))
self._update_count(index)
Expand All @@ -1274,14 +1275,21 @@ def update(self, index, weight,grad, state):
t = self._index_update_count[index]

kwargs = {'beta1': self.beta1, 'beta2': self.beta2, 'epsilon': self.epsilon,
'lower_bound': self.lower_bound, 'upper_bound': self.upper_bound,
'bias_correction': self.bias_correction, 't': t,
'rescale_grad': self.rescale_grad}
mean, var = state
if self.clip_gradient:
kwargs['clip_gradient'] = self.clip_gradient

mean, var = state
lamb_update(weight, grad, mean, var, out=weight, lr=lr, wd=wd, **kwargs)
g = lamb_update(weight, grad, mean, var, wd=wd, **kwargs)

kwargs = {}
if self.lower_bound:
kwargs['lower_bound'] = self.lower_bound
if self.upper_bound:
kwargs['upper_bound'] = self.upper_bound
r_1 = weight.norm()
r_2 = g.norm()
lamb_weight_update(weight, g, r_1, r_2, lr = lr, out=weight, **kwargs)


# pylint: enable=line-too-long
Expand Down
142 changes: 104 additions & 38 deletions src/operator/optimizer_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1564,20 +1564,15 @@ inline void AdamUpdateEx(const nnvm::NodeAttrs& attrs,
}

struct LAMBParam : public dmlc::Parameter<LAMBParam> {
float lr;
float beta1;
float beta2;
float epsilon;
float lower_bound;
float upper_bound;
float t;
bool bias_correction;
float wd;
float rescale_grad;
float clip_gradient;
DMLC_DECLARE_PARAMETER(LAMBParam) {
DMLC_DECLARE_FIELD(lr)
.describe("Learning rate");
DMLC_DECLARE_FIELD(beta1)
.set_default(0.9f)
.describe("The decay rate for the 1st moment estimates.");
Expand All @@ -1587,19 +1582,12 @@ struct LAMBParam : public dmlc::Parameter<LAMBParam> {
DMLC_DECLARE_FIELD(epsilon)
.set_default(1e-6f)
.describe("A small constant for numerical stability.");
DMLC_DECLARE_FIELD(lower_bound)
.set_default(1e-3f)
.describe("Lower limit of norm of weight.");
DMLC_DECLARE_FIELD(upper_bound)
.set_default(10.0f)
.describe("Upper limit of norm of weight.");
DMLC_DECLARE_FIELD(t)
.describe("Index update count.");
DMLC_DECLARE_FIELD(bias_correction)
.set_default(false)
.describe("Whether to use bias correction.");
DMLC_DECLARE_FIELD(wd)
.set_default(0.0f)
.describe("Weight decay augments the objective function with a "
"regularization term that penalizes large weights. "
"The penalty scales with the square of the magnitude of each weight.");
Expand All @@ -1614,44 +1602,48 @@ struct LAMBParam : public dmlc::Parameter<LAMBParam> {
}
};

struct LAMBWeightParam : public dmlc::Parameter<LAMBWeightParam> {
float lr;
float lower_bound;
float upper_bound;
DMLC_DECLARE_PARAMETER(LAMBWeightParam) {
DMLC_DECLARE_FIELD(lr)
.describe("Learning rate");
DMLC_DECLARE_FIELD(lower_bound)
.set_default(1e-3f)
.describe("Lower limit of norm of weight.");
DMLC_DECLARE_FIELD(upper_bound)
.set_default(10.0f)
.describe("Upper limit of norm of weight.");
}
};

struct LAMBUpdateKernel {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType* out_data,
DType* mean_data, DType* var_data, const DType* weight_data, const DType* grad_data,
const DType clip_gradient, const DType rescale_grad,
const DType beta1, const DType beta2,
DType lr, const DType wd,
const DType epsilon, const DType lower_bound,
const DType upper_bound, const DType t,
const DType beta1, const DType beta2, const DType wd,
const DType epsilon, const DType t,
bool bias_correction, const OpReqType req) {
using namespace mshadow_op;

DType grad_rescaled = grad_data[i] * rescale_grad + weight_data[i] * wd;
DType grad_rescaled = grad_data[i] * rescale_grad;
if (clip_gradient >= 0.f) {
grad_rescaled = clip::Map(grad_rescaled, clip_gradient);
}

mean_data[i] = beta1 * mean_data[i] + (1.f - beta1) * grad_rescaled;
var_data[i] = beta2 * var_data[i] + (1.f - beta2) * grad_rescaled * grad_rescaled;

DType r1 = square_root::Map(square::Map(weight_data[i]));

r1 = minimum::Map(maximum::Map(r1, lower_bound), upper_bound);
DType g = mean_data[i] / square_root::Map(var_data[i] + epsilon) + wd * weight_data[i];
DType g = mean_data[i] / (square_root::Map(var_data[i]) + epsilon) + wd * weight_data[i];

if (bias_correction) {
DType mean_hat = mean_data[i] / (1. - power::Map(beta1, t));
DType var_hat = var_data[i] / (1 - power::Map(beta2, t));
g = mean_hat / square_root::Map(var_hat + epsilon) + wd * weight_data[i];
}
DType r2 = square_root::Map(square::Map(g));
if (r1 == 0.0f || r2 == 0.0f) {
lr = lr * 1.0f;
} else {
lr = lr * r1 / r2;
g = mean_hat / (square_root::Map(var_hat) + epsilon) + wd * weight_data[i];
}

KERNEL_ASSIGN(out_data[i], req, weight_data[i] - lr * g);
KERNEL_ASSIGN(out_data[i], req, g);
}
};

Expand All @@ -1661,9 +1653,9 @@ inline void LAMBUpdate(const nnvm::NodeAttrs& attrs,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
using namespace mxnet_op;
const LAMBParam& param = nnvm::get<LAMBParam>(attrs.parsed);
Stream<xpu>* s = ctx.get_stream<xpu>();
using namespace mxnet_op;
const LAMBParam& param = nnvm::get<LAMBParam>(attrs.parsed);
Stream<xpu>* s = ctx.get_stream<xpu>();
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s);
Expand All @@ -1675,13 +1667,87 @@ inline void LAMBUpdate(const nnvm::NodeAttrs& attrs,
out.dptr_, mean.dptr_, var.dptr_, weight.dptr_, grad.dptr_,
static_cast<DType>(param.clip_gradient), static_cast<DType>(param.rescale_grad),
static_cast<DType>(param.beta1), static_cast<DType>(param.beta2),
static_cast<DType>(param.lr), static_cast<DType>(param.wd),
static_cast<DType>(param.epsilon), static_cast<DType>(param.lower_bound),
static_cast<DType>(param.upper_bound), static_cast<DType>(param.t),
static_cast<bool>(param.bias_correction), req[0]);
});
static_cast<DType>(param.wd), static_cast<DType>(param.epsilon),
static_cast<DType>(param.t), static_cast<bool>(param.bias_correction), req[0]);
});
}

inline bool LambWeightShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector* in_attrs,
mxnet::ShapeVector* out_attrs) {
CHECK_EQ(in_attrs->size(), 4U);
CHECK_EQ(out_attrs->size(), 1U);

mxnet::TShape expected_out(in_attrs->at(0).ndim(), -1);

mxnet::TShape& weight_shape = in_attrs->at(0);
mxnet::TShape& g_shape = in_attrs->at(1);
CHECK_EQ(weight_shape.ndim(), g_shape.ndim())
<< "total no. of dimensions for weights and g must match";
for (int i=0; i < weight_shape.ndim(); ++i) {
CHECK_EQ(weight_shape[i], g_shape[i])
<< "weight and g dimension size mismatch at " << i << "-th index";
}
mxnet::TShape& r1_shape = in_attrs->at(2);
mxnet::TShape& r2_shape = in_attrs->at(3);
CHECK_EQ(r1_shape[0], 1U) << "r1 shape incorrect";
CHECK_EQ(r2_shape[0], 1U) << "r2 shape incorrect";
for (int i=0; i < expected_out.ndim(); ++i) {
expected_out[i] = weight_shape[i];
}

SHAPE_ASSIGN_CHECK(*out_attrs, 0, expected_out);
return shape_is_known(expected_out);
}

struct LAMBWeightUpdateKernel {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType* out_data,
const DType* weight_data, const DType* g,
const DType* r1, const DType* r2,
DType lr, const DType lower_bound,
const DType upper_bound, const OpReqType req) {
using namespace mshadow_op;

DType new_r1 = r1[0];
if (lower_bound >= 0) {
new_r1 = maximum::Map(new_r1, lower_bound);
}
if (upper_bound >= 0) {
new_r1 = minimum::Map(new_r1, upper_bound);
}
if (new_r1 == 0.0f || r2[0] == 0.0f) {
lr = lr * 1.0f;
} else {
lr = lr * new_r1 / r2[0];
}

KERNEL_ASSIGN(out_data[i], req, weight_data[i] - lr * g[i]);
}
};

template<typename xpu>
inline void LAMBWeightUpdate(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
using namespace mxnet_op;
const LAMBWeightParam& param = nnvm::get<LAMBWeightParam>(attrs.parsed);
Stream<xpu>* s = ctx.get_stream<xpu>();
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> g = inputs[1].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> r1 = inputs[2].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> r2 = inputs[3].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);

Kernel<LAMBWeightUpdateKernel, xpu>::Launch(s, weight.shape_.Size(),
out.dptr_, weight.dptr_, g.dptr_, r1.dptr_, r2.dptr_,
static_cast<DType>(param.lr), static_cast<DType>(param.lower_bound),
static_cast<DType>(param.upper_bound), req[0]);
});
}

// This RMSProp code follows the version in
// http://arxiv.org/pdf/1308.0850v5.pdf Eq(38) - Eq(45)
Expand Down
24 changes: 22 additions & 2 deletions src/operator/optimizer_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ DMLC_REGISTER_PARAMETER(SignSGDParam);
DMLC_REGISTER_PARAMETER(SignumParam);
DMLC_REGISTER_PARAMETER(AdagradParam);
DMLC_REGISTER_PARAMETER(LAMBParam);
DMLC_REGISTER_PARAMETER(LAMBWeightParam);

NNVM_REGISTER_OP(signsgd_update)
.describe(R"code(Update function for SignSGD optimizer.
Expand Down Expand Up @@ -928,14 +929,33 @@ NNVM_REGISTER_OP(lamb_update)
.set_num_inputs(4)
.set_num_outputs(1)
.set_attr_parser(ParamParser<LAMBParam>)
.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<4,1>)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<4,1>)
.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<4, 1>)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<4, 1>)
.set_attr<FCompute>("FCompute<cpu>", LAMBUpdate<cpu>)
.set_attr<nnvm::FMutateInputs>("FMutateInputs",
[](const nnvm::NodeAttrs& attrs) {
return std::vector<uint32_t>{2, 3};
})
.add_argument("weight", "NDArray-or-Symbol", "Weight")
.add_argument("grad", "NDArray-or-Symbol", "Gradient")
.add_argument("mean", "NDArray-or-Symbol", "Moving mean")
.add_argument("var", "NDArray-or-Symbol", "Moving variance")
.add_arguments(LAMBParam::__FIELDS__());

NNVM_REGISTER_OP(lamb_weight_update)
.describe(R"code(Update function for lamb optimizer.
)code" ADD_FILELINE)
.set_num_inputs(4)
.set_num_outputs(1)
.set_attr_parser(ParamParser<LAMBWeightParam>)
.set_attr<mxnet::FInferShape>("FInferShape", LambWeightShape)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<4, 1>)
.set_attr<FCompute>("FCompute<cpu>", LAMBWeightUpdate<cpu>)
.add_argument("weight", "NDArray-or-Symbol", "Weight")
.add_argument("g", "NDArray-or-Symbol", "g")
.add_argument("r1", "NDArray-or-Symbol", "r1")
.add_argument("r2", "NDArray-or-Symbol", "r2")
.add_arguments(LAMBWeightParam::__FIELDS__());

} // namespace op
} // namespace mxnet
6 changes: 5 additions & 1 deletion src/operator/optimizer_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,11 @@ NNVM_REGISTER_OP(_sparse_adagrad_update)
.set_attr<FComputeEx>("FComputeEx<gpu>", AdagradUpdateEx<gpu>);

NNVM_REGISTER_OP(lamb_update)
.set_attr<FCompute>("FCompute<gpu>", LambUpdate<gpu>);
.set_attr<FCompute>("FCompute<gpu>", LAMBUpdate<gpu>);

NNVM_REGISTER_OP(lamb_weight_update)
.set_attr<FCompute>("FCompute<gpu>", LAMBWeightUpdate<gpu>);


} // namespace op
} // namespace mxnet
25 changes: 15 additions & 10 deletions tests/python/unittest/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,31 +454,34 @@ def update(self, index, weight, grad, state):

grad *= self.rescale_grad
if self.clip_gradient is not None:
grad = clip(grad, -self.clip_gradient, self.clip_gradient)
grad = mx.nd.clip(grad, -self.clip_gradient, self.clip_gradient)

mean, var = state
mean[:] = self.beta1 * mean + (1. - self.beta1) * grad
var[:] = self.beta2 * var + (1. - self.beta2) * mx.nd.square(grad)

mean_hat = mean
var_hat = var
r1 = weight.norm()
if not self.bias_correction:
r1 = mx.nd.minimum(mx.nd.maximum(r1, self.lower_bound), self.upper_bound)
g = mean / (mx.nd.sqrt(var) + self.epsilon) + wd * weight

else:
if self.lower_bound:
r1 = mx.nd.maximum(r1, self.lower_bound)
if self.upper_bound:
r1 = mx.nd.minimum(r1, self.upper_bound)
if self.bias_correction:
mean_hat = mean / (1. - mx.nd.power(self.beta1, t))
var_hat = var / (1. - mx.nd.power(self.beta2, t))
g = mean_hat / mx.nd.sqrt(var_hat + self.epsilon) + wd * weight

g = mean_hat / (mx.nd.sqrt(var_hat) + self.epsilon) + wd * weight
r2 = g.norm()

# calculate lamb_trust_ratio
r = 1. if r1 == 0. or r2 == 0. else r1 / r2
lr *= r

# update weight
weight[:] -= lr * g

def update_multi_precision(self, index, weight, grad, state):
self.update(index, weight, grad, state)

@with_seed()
def test_lamb():
opt1 = PyLAMB
Expand All @@ -488,7 +491,9 @@ def test_lamb():
rg_options = [{}, {'rescale_grad': 0.14}, {'rescale_grad': 0.8}]
wd_options = [{}, {'wd': 0.03}, {'wd': 0.05}, {'wd': 0.07}]
bc_options = [{}, {'bias_correction': False}, {'bias_correction': True}]
for params in itertools.product(cg_options, rg_options, wd_options, bc_options):
lb_options = [{}, {'lower_bound': None}]
ub_options = [{}, {'upper_bound': None}]
for params in itertools.product(cg_options, rg_options, wd_options, bc_options, lb_options, ub_options):
kwarg = {k: v for param in params for k, v in param.items()}
compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, np.float32)

Expand Down

0 comments on commit e6ac0dc

Please sign in to comment.