Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Backport #16715 and #16903 to 1.6 #17036

Merged
merged 2 commits into from
Dec 10, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 50 additions & 2 deletions python/mxnet/optimizer/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,14 @@
multi_sgd_update, multi_sgd_mom_update, multi_mp_sgd_update,
multi_mp_sgd_mom_update, preloaded_multi_sgd_update,
preloaded_multi_sgd_mom_update, preloaded_multi_mp_sgd_update,
preloaded_multi_mp_sgd_mom_update)
preloaded_multi_mp_sgd_mom_update, lamb_update_phase1, lamb_update_phase2)
from ..ndarray import sparse
from ..random import normal
from ..util import is_np_array

__all__ = [
'AdaDelta', 'AdaGrad', 'Adam', 'Adamax', 'DCASGD', 'FTML', 'Ftrl', 'LARS', 'LBSGD',
'NAG', 'NDabs', 'Nadam', 'Optimizer', 'RMSProp', 'SGD', 'SGLD', 'Signum',
'NAG', 'NDabs', 'Nadam', 'Optimizer', 'RMSProp', 'SGD', 'SGLD', 'Signum', 'LAMB',
'Test', 'Updater', 'ccSGD', 'create', 'get_updater', 'register'
]

Expand Down Expand Up @@ -1244,6 +1244,54 @@ def update(self, index, weight, grad, state):
kwargs = {}
sgd_update(weight, grad, out=weight, lr=lr, wd=wd, **kwargs)


@register
class LAMB(Optimizer):
"""LAMB Optimizer.
"""
def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-6,
lower_bound=None, upper_bound=None, bias_correction=True, **kwargs):
super(LAMB, self).__init__(learning_rate=learning_rate, **kwargs)
self.beta1 = beta1
self.beta2 = beta2
self.epsilon = epsilon
self.lower_bound = lower_bound
self.upper_bound = upper_bound
self.bias_correction = bias_correction


def create_state(self, index, weight):
stype = weight.stype
dtype = weight.dtype
return (zeros(weight.shape, weight.context, dtype=dtype, stype=stype),
zeros(weight.shape, weight.context, dtype=dtype, stype=stype))

def update(self, index, weight, grad, state):
assert(isinstance(weight, NDArray))
assert(isinstance(grad, NDArray))
self._update_count(index)
lr = self._get_lr(index)
wd = self._get_wd(index)
t = self._index_update_count[index]

kwargs = {'beta1': self.beta1, 'beta2': self.beta2, 'epsilon': self.epsilon,
'bias_correction': self.bias_correction, 't': t,
'rescale_grad': self.rescale_grad}
mean, var = state
if self.clip_gradient:
kwargs['clip_gradient'] = self.clip_gradient
g = lamb_update_phase1(weight, grad, mean, var, wd=wd, **kwargs)

kwargs = {}
if self.lower_bound:
kwargs['lower_bound'] = self.lower_bound
if self.upper_bound:
kwargs['upper_bound'] = self.upper_bound
r_1 = weight.norm()
r_2 = g.norm()
lamb_update_phase2(weight, g, r_1, r_2, lr=lr, out=weight, **kwargs)


# pylint: enable=line-too-long
@register
class DCASGD(Optimizer):
Expand Down
188 changes: 188 additions & 0 deletions src/operator/optimizer_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1563,6 +1563,194 @@ inline void AdamUpdateEx(const nnvm::NodeAttrs& attrs,
}
}

struct LambUpdatePhaseOneParam : public dmlc::Parameter<LambUpdatePhaseOneParam> {
float beta1;
float beta2;
float epsilon;
int t;
bool bias_correction;
float wd;
float rescale_grad;
float clip_gradient;
DMLC_DECLARE_PARAMETER(LambUpdatePhaseOneParam) {
DMLC_DECLARE_FIELD(beta1)
.set_default(0.9f)
.describe("The decay rate for the 1st moment estimates.");
DMLC_DECLARE_FIELD(beta2)
.set_default(0.999f)
.describe("The decay rate for the 2nd moment estimates.");
DMLC_DECLARE_FIELD(epsilon)
.set_default(1e-6f)
.describe("A small constant for numerical stability.");
DMLC_DECLARE_FIELD(t)
.describe("Index update count.");
DMLC_DECLARE_FIELD(bias_correction)
.set_default(true)
.describe("Whether to use bias correction.");
DMLC_DECLARE_FIELD(wd)
.describe("Weight decay augments the objective function with a "
"regularization term that penalizes large weights. "
"The penalty scales with the square of the magnitude of each weight.");
DMLC_DECLARE_FIELD(rescale_grad)
.set_default(1.0f)
.describe("Rescale gradient to grad = rescale_grad*grad.");
DMLC_DECLARE_FIELD(clip_gradient)
.set_default(-1.0f)
.describe("Clip gradient to the range of [-clip_gradient, clip_gradient] "
"If clip_gradient <= 0, gradient clipping is turned off. "
"grad = max(min(grad, clip_gradient), -clip_gradient).");
}
};

struct LambUpdatePhaseTwoParam : public dmlc::Parameter<LambUpdatePhaseTwoParam> {
float lr;
float lower_bound;
float upper_bound;
DMLC_DECLARE_PARAMETER(LambUpdatePhaseTwoParam) {
DMLC_DECLARE_FIELD(lr)
.describe("Learning rate");
DMLC_DECLARE_FIELD(lower_bound)
.set_default(-1.0f)
.describe("Lower limit of norm of weight. If lower_bound <= 0, Lower limit is not set");
DMLC_DECLARE_FIELD(upper_bound)
.set_default(-1.0f)
.describe("Upper limit of norm of weight. If upper_bound <= 0, Upper limit is not set");
}
};

struct LambUpdatePhaseOneKernel {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType* out_data,
DType* mean_data, DType* var_data, const DType* weight_data, const DType* grad_data,
const DType clip_gradient, const DType rescale_grad,
const DType beta1, const DType beta1_t, const DType beta2, const DType beta2_t,
const DType wd, const DType epsilon, const int t,
bool bias_correction, const OpReqType req) {
using namespace mshadow_op;

DType grad_rescaled = grad_data[i] * rescale_grad;
if (clip_gradient >= 0.f) {
grad_rescaled = clip::Map(grad_rescaled, clip_gradient);
}

mean_data[i] = beta1 * mean_data[i] + (1.f - beta1) * grad_rescaled;
var_data[i] = beta2 * var_data[i] + (1.f - beta2) * grad_rescaled * grad_rescaled;

DType g = mean_data[i] / (square_root::Map(var_data[i]) + epsilon) + wd * weight_data[i];

if (bias_correction) {
DType mean_hat = mean_data[i] / (1. - beta1_t);
DType var_hat = var_data[i] / (1 - beta2_t);
g = mean_hat / (square_root::Map(var_hat) + epsilon) + wd * weight_data[i];
}
KERNEL_ASSIGN(out_data[i], req, g);
}
};

template<typename xpu>
inline void LambUpdatePhaseOne(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
using namespace mxnet_op;
const LambUpdatePhaseOneParam& param = nnvm::get<LambUpdatePhaseOneParam>(attrs.parsed);
Stream<xpu>* s = ctx.get_stream<xpu>();
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
DType beta1_t = std::pow(param.beta1, param.t);
DType beta2_t = std::pow(param.beta2, param.t);
Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> mean = inputs[2].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> var = inputs[3].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);

Kernel<LambUpdatePhaseOneKernel, xpu>::Launch(s, weight.shape_.Size(),
out.dptr_, mean.dptr_, var.dptr_, weight.dptr_, grad.dptr_,
static_cast<DType>(param.clip_gradient), static_cast<DType>(param.rescale_grad),
static_cast<DType>(param.beta1), beta1_t, static_cast<DType>(param.beta2), beta2_t,
static_cast<DType>(param.wd), static_cast<DType>(param.epsilon),
static_cast<int>(param.t), static_cast<bool>(param.bias_correction), req[0]);
});
}

inline bool LambUpdatePhaseTwoShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector* in_attrs,
mxnet::ShapeVector* out_attrs) {
CHECK_EQ(in_attrs->size(), 4U);
CHECK_EQ(out_attrs->size(), 1U);

mxnet::TShape expected_out(in_attrs->at(0).ndim(), -1);

mxnet::TShape& weight_shape = in_attrs->at(0);
mxnet::TShape& g_shape = in_attrs->at(1);
CHECK_EQ(weight_shape.ndim(), g_shape.ndim())
<< "total no. of dimensions for weights and g must match";
for (int i=0; i < weight_shape.ndim(); ++i) {
CHECK_EQ(weight_shape[i], g_shape[i])
<< "weight and g dimension size mismatch at " << i << "-th index";
}
mxnet::TShape& r1_shape = in_attrs->at(2);
mxnet::TShape& r2_shape = in_attrs->at(3);
CHECK_EQ(r1_shape[0], 1U) << "r1 shape incorrect";
CHECK_EQ(r2_shape[0], 1U) << "r2 shape incorrect";
for (int i=0; i < expected_out.ndim(); ++i) {
expected_out[i] = weight_shape[i];
}

SHAPE_ASSIGN_CHECK(*out_attrs, 0, expected_out);
return shape_is_known(expected_out);
}

struct LambUpdatePhaseTwoKernel {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType* out_data,
const DType* weight_data, const DType* g,
const DType* r1, const DType* r2,
DType lr, const DType lower_bound,
const DType upper_bound, const OpReqType req) {
using namespace mshadow_op;

DType new_r1 = r1[0];
if (lower_bound >= 0) {
new_r1 = maximum::Map(new_r1, lower_bound);
}
if (upper_bound >= 0) {
new_r1 = minimum::Map(new_r1, upper_bound);
}
if (new_r1 == 0.0f || r2[0] == 0.0f) {
lr = lr * 1.0f;
} else {
lr = lr * new_r1 / r2[0];
}

KERNEL_ASSIGN(out_data[i], req, weight_data[i] - lr * g[i]);
}
};

template<typename xpu>
inline void LambUpdatePhaseTwo(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
using namespace mxnet_op;
const LambUpdatePhaseTwoParam& param = nnvm::get<LambUpdatePhaseTwoParam>(attrs.parsed);
Stream<xpu>* s = ctx.get_stream<xpu>();
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> g = inputs[1].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> r1 = inputs[2].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> r2 = inputs[3].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);

Kernel<LambUpdatePhaseTwoKernel, xpu>::Launch(s, weight.shape_.Size(),
out.dptr_, weight.dptr_, g.dptr_, r1.dptr_, r2.dptr_,
static_cast<DType>(param.lr), static_cast<DType>(param.lower_bound),
static_cast<DType>(param.upper_bound), req[0]);
});
}

// This RMSProp code follows the version in
// http://arxiv.org/pdf/1308.0850v5.pdf Eq(38) - Eq(45)
// by Alex Graves, 2013.
Expand Down
81 changes: 81 additions & 0 deletions src/operator/optimizer_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ DMLC_REGISTER_PARAMETER(FtrlParam);
DMLC_REGISTER_PARAMETER(SignSGDParam);
DMLC_REGISTER_PARAMETER(SignumParam);
DMLC_REGISTER_PARAMETER(AdagradParam);
DMLC_REGISTER_PARAMETER(LambUpdatePhaseOneParam);
DMLC_REGISTER_PARAMETER(LambUpdatePhaseTwoParam);

NNVM_REGISTER_OP(signsgd_update)
.describe(R"code(Update function for SignSGD optimizer.
Expand Down Expand Up @@ -921,5 +923,84 @@ Note that non-zero values for the weight decay option are not supported.
.add_argument("history", "NDArray-or-Symbol", "History")
.add_arguments(AdagradParam::__FIELDS__());

NNVM_REGISTER_OP(lamb_update_phase1)
.describe(R"code(Phase I of lamb update it performs the following operations and returns g:.

Link to paper: https://arxiv.org/pdf/1904.00962.pdf

.. math::
\begin{gather*}
grad = grad * rescale_grad
if (grad < -clip_gradient)
then
grad = -clip_gradient
if (grad > clip_gradient)
then
grad = clip_gradient

mean = beta1 * mean + (1 - beta1) * grad;
variance = beta2 * variance + (1. - beta2) * grad ^ 2;

if (bias_correction)
then
mean_hat = mean / (1. - beta1^t);
var_hat = var / (1 - beta2^t);
g = mean_hat / (var_hat^(1/2) + epsilon) + wd * weight;
else
g = mean / (var_data^(1/2) + epsilon) + wd * weight_data[i];
\end{gather*}

)code" ADD_FILELINE)
.set_num_inputs(4)
.set_num_outputs(1)
.set_attr_parser(ParamParser<LambUpdatePhaseOneParam>)
.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<4, 1>)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<4, 1>)
.set_attr<FCompute>("FCompute<cpu>", LambUpdatePhaseOne<cpu>)
.set_attr<nnvm::FMutateInputs>("FMutateInputs",
[](const nnvm::NodeAttrs& attrs) {
return std::vector<uint32_t>{2, 3};
})
.add_argument("weight", "NDArray-or-Symbol", "Weight")
.add_argument("grad", "NDArray-or-Symbol", "Gradient")
.add_argument("mean", "NDArray-or-Symbol", "Moving mean")
.add_argument("var", "NDArray-or-Symbol", "Moving variance")
.add_arguments(LambUpdatePhaseOneParam::__FIELDS__());

NNVM_REGISTER_OP(lamb_update_phase2)
.describe(R"code(Phase II of lamb update it performs the following operations and updates grad.

Link to paper: https://arxiv.org/pdf/1904.00962.pdf

.. math::
\begin{gather*}
if (lower_bound >= 0)
then
r1 = max(r1, lower_bound)
if (upper_bound >= 0)
then
r1 = max(r1, upper_bound)

if (r1 == 0 or r2 == 0)
then
lr = lr
else
lr = lr * (r1/r2)
weight = weight - lr * g
\end{gather*}

)code" ADD_FILELINE)
.set_num_inputs(4)
.set_num_outputs(1)
.set_attr_parser(ParamParser<LambUpdatePhaseTwoParam>)
.set_attr<mxnet::FInferShape>("FInferShape", LambUpdatePhaseTwoShape)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<4, 1>)
.set_attr<FCompute>("FCompute<cpu>", LambUpdatePhaseTwo<cpu>)
.add_argument("weight", "NDArray-or-Symbol", "Weight")
.add_argument("g", "NDArray-or-Symbol", "Output of lamb_update_phase 1")
.add_argument("r1", "NDArray-or-Symbol", "r1")
.add_argument("r2", "NDArray-or-Symbol", "r2")
.add_arguments(LambUpdatePhaseTwoParam::__FIELDS__());

} // namespace op
} // namespace mxnet
7 changes: 7 additions & 0 deletions src/operator/optimizer_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -277,5 +277,12 @@ NNVM_REGISTER_OP(ftrl_update)
NNVM_REGISTER_OP(_sparse_adagrad_update)
.set_attr<FComputeEx>("FComputeEx<gpu>", AdagradUpdateEx<gpu>);

NNVM_REGISTER_OP(lamb_update_phase1)
.set_attr<FCompute>("FCompute<gpu>", LambUpdatePhaseOne<gpu>);

NNVM_REGISTER_OP(lamb_update_phase2)
.set_attr<FCompute>("FCompute<gpu>", LambUpdatePhaseTwo<gpu>);


} // namespace op
} // namespace mxnet
Loading