Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sparse Adam optimizer #164

Merged
merged 4 commits into from
Aug 14, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 141 additions & 0 deletions src/operator/optimizer_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -700,6 +700,147 @@ inline void AdamUpdate(const nnvm::NodeAttrs& attrs,
});
}

/*!
* Note: this kernel performs sparse adam update. For each row-slice in row_sparse
* gradient, it finds the corresponding elements in weight, mean and var and performs
* the update.
* The kernel assumes dense weight/mean/var, and row_sparse gradient
*/
template<int req>
struct AdamDnsRspDnsKernel {
template<typename DType, typename IType>
MSHADOW_XINLINE static void Map(int i, const nnvm::dim_t row_length, DType* out_data,
DType* mean_data, DType* var_data, const DType* weight_data, const IType* grad_idx,
const DType* grad_data, const DType clip_gradient, const DType beta1, const DType beta2,
const DType lr, const DType wd, const DType epsilon, const DType rescale_grad) {
using nnvm::dim_t;
using namespace mshadow_op;
const dim_t row_offset = grad_idx[i] * row_length;
for (dim_t j = 0; j < row_length; j++) {
// index in data/mean/var
const dim_t data_i = row_offset + j;
// index in grad
const dim_t grad_i = i * row_length + j;
const DType grad_rescaled = grad_data[grad_i] * rescale_grad + weight_data[data_i] * wd;
if (clip_gradient >= 0.0f) {
mean_data[data_i] = beta1 * mean_data[data_i] + (1.f - beta1) *
clip::Map(grad_rescaled, clip_gradient);
var_data[data_i] = beta2 * var_data[data_i] + (1.f - beta2) * square::Map(
clip::Map(grad_rescaled, clip_gradient));
} else {
mean_data[data_i] = beta1 * mean_data[data_i] + (1.f - beta1) * grad_rescaled;
var_data[data_i] = beta2 * var_data[data_i] +
(1.f - beta2) * grad_rescaled * grad_rescaled;
}
KERNEL_ASSIGN(out_data[data_i], req, weight_data[data_i] - lr * mean_data[data_i] /
square_root::Map(var_data[data_i]) + epsilon);
}
}
};


template<typename xpu>
inline void AdamUpdateDnsRspDnsImpl(const AdamParam& param,
const OpContext& ctx,
const TBlob& weight,
const NDArray& grad,
const TBlob& mean,
const TBlob& var,
const OpReqType& req,
TBlob *out) {
using namespace mxnet_op;
using namespace rowsparse;
Stream<xpu>* s = ctx.get_stream<xpu>();
if (!grad.storage_initialized() || req == kNullOp) return;
CHECK_EQ(req, kWriteInplace) << "kWriteInplace is expected for sparse adam_update";
CHECK_GT(weight.shape_.Size(), 0);
CHECK_GT(mean.shape_.Size(), 0);
CHECK_GT(var.shape_.Size(), 0);

MSHADOW_REAL_TYPE_SWITCH(weight.type_flag_, DType, {
MSHADOW_IDX_TYPE_SWITCH(grad.aux_type(kIdx), IType, {
MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
const DType* weight_data = weight.dptr<DType>();
const IType* grad_idx = grad.aux_data(kIdx).dptr<IType>();
const DType* grad_val = grad.data().dptr<DType>();
DType* mean_data = mean.dptr<DType>();
DType* var_data = var.dptr<DType>();
DType* out_data = out->dptr<DType>();
nnvm::dim_t num_rows = grad.aux_shape(kIdx)[0];
const auto row_length = weight.shape_.ProdShape(1, weight.ndim());
Kernel<AdamDnsRspDnsKernel<req_type>, xpu>::Launch(s, num_rows, row_length,
out_data, mean_data, var_data, weight_data, grad_idx, grad_val,
static_cast<DType>(param.clip_gradient), static_cast<DType>(param.beta1),
static_cast<DType>(param.beta2), static_cast<DType>(param.lr),
static_cast<DType>(param.wd), static_cast<DType>(param.epsilon),
static_cast<DType>(param.rescale_grad));
});
});
});
}

template<typename xpu>
inline void AdamUpdateRspRspRspImpl(const AdamParam& param,
const OpContext& ctx,
const NDArray& weight,
const NDArray& grad,
const NDArray& mean,
const NDArray& var,
const OpReqType& req,
NDArray *out) {
using namespace mshadow;
using namespace mshadow::expr;
using namespace mxnet_op;
using namespace rowsparse;
CHECK_RSP_ALL_ROWS_NON_ZERO(weight, "AdamUpdate", "weights");
Stream<xpu>* s = ctx.get_stream<xpu>();
// fill mean and variance with zero values in order to reuse the sgd mom dns impl
if (!mean.storage_initialized()) {
NDArray mean_zeros = mean;
FillDnsZerosRspImpl(s, &mean_zeros);
}
if (!var.storage_initialized()) {
NDArray var_zeros = var;
FillDnsZerosRspImpl(s, &var_zeros);
}
TBlob out_blob = out->data();
// reuse dns rsp implementation when storage_shape == shape
AdamUpdateDnsRspDnsImpl<xpu>(param, ctx, weight.data(), grad, mean.data(),
var.data(), req, &out_blob);
}


template<typename xpu>
inline void AdamUpdateEx(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs) {
const AdamParam& param = nnvm::get<AdamParam>(attrs.parsed);
mshadow::Stream<xpu>* s = ctx.get_stream<xpu>();
const auto weight_stype = inputs[0].storage_type();
const auto grad_stype = inputs[1].storage_type();
const auto mean_stype = inputs[2].storage_type();
const auto var_stype = inputs[3].storage_type();

const auto out_stype = outputs[0].storage_type();
CHECK_EQ(mean_stype, weight_stype) << "Inconsistent storage type detected between "
<< " mean.stype = " << mean_stype << " and weight.stype = " << weight_stype;
CHECK_EQ(var_stype, weight_stype) << "Inconsistent storage type detected between "
<< " var.stype = " << var_stype << " and weight.stype = " << weight_stype;
if (weight_stype == kRowSparseStorage && mean_stype == kRowSparseStorage &&
var_stype == kRowSparseStorage && grad_stype == kRowSparseStorage &&
out_stype == kRowSparseStorage) {
NDArray out = outputs[0];
AdamUpdateRspRspRspImpl<xpu>(param, ctx, inputs[0], inputs[1], inputs[2],
inputs[3], req[0], &out);
} else {
LOG(FATAL) << "Unexpected storage types: weight.stype = " << weight_stype
<< ", var.stype = " << var_stype << ", mean.stype = " << mean_stype
<< ", grad.stype = " << grad_stype;
}
}

// This RMSProp code follows the version in
// http://arxiv.org/pdf/1308.0850v5.pdf Eq(38) - Eq(45)
// by Alex Graves, 2013.
Expand Down
1 change: 1 addition & 0 deletions src/operator/optimizer_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ It updates the weights using::
return std::vector<uint32_t>{2, 3};
})
.set_attr<FCompute>("FCompute<cpu>", AdamUpdate<cpu>)
.set_attr<FComputeEx>("FComputeEx<cpu>", AdamUpdateEx<cpu>)
.add_argument("weight", "NDArray-or-Symbol", "Weight")
.add_argument("grad", "NDArray-or-Symbol", "Gradient")
.add_argument("mean", "NDArray-or-Symbol", "Moving mean")
Expand Down
3 changes: 2 additions & 1 deletion src/operator/optimizer_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ NNVM_REGISTER_OP(mp_sgd_mom_update)
.set_attr<FCompute>("FCompute<gpu>", MP_SGDMomUpdate<gpu>);

NNVM_REGISTER_OP(adam_update)
.set_attr<FCompute>("FCompute<gpu>", AdamUpdate<gpu>);
.set_attr<FCompute>("FCompute<gpu>", AdamUpdate<gpu>)
.set_attr<FComputeEx>("FComputeEx<gpu>", AdamUpdateEx<gpu>);

NNVM_REGISTER_OP(rmsprop_update)
.set_attr<FCompute>("FCompute<gpu>", RMSPropUpdate<gpu>);
Expand Down
2 changes: 1 addition & 1 deletion src/operator/tensor/init_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ inline void FillDnsZerosRspImpl(mshadow::Stream<xpu> *s, NDArray *dst) {
auto idx = dst->aux_data(kIdx).FlatTo1D<xpu, IType>(s);
auto val = dst->data();
Kernel<set_zero, xpu>::Launch(s, val.Size(), val.dptr<DType>());
ASSIGN_DISPATCH(idx, kWriteTo, range<IType>(0, num_rows, 1, 1))
ASSIGN_DISPATCH(idx, kWriteTo, range<IType>(0, num_rows, 1, 1));
});
});
}
Expand Down
40 changes: 23 additions & 17 deletions tests/python/unittest/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,12 +312,13 @@ def test_sparse_sgd():
class PyAdam(mx.optimizer.Optimizer):
"""python reference implemenation of adam"""
def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8,
decay_factor=(1 - 1e-8), **kwargs):
decay_factor=(1 - 1e-8), sparse_update=False, **kwargs):
super(PyAdam, self).__init__(learning_rate=learning_rate, **kwargs)
self.beta1 = beta1
self.beta2 = beta2
self.epsilon = epsilon
self.decay_factor = decay_factor
self.sparse_update = sparse_update

def create_state(self, index, weight):
"""Create additional optimizer state: mean, variance
Expand Down Expand Up @@ -355,21 +356,28 @@ def update(self, index, weight, grad, state):
mean, variance = state

wd = self._get_wd(index)
grad = grad * self.rescale_grad + wd * weight
if self.clip_gradient is not None:
mx.nd.clip(grad, -self.clip_gradient, self.clip_gradient, out=grad)

mean *= self.beta1
mean += grad * (1. - self.beta1)

variance *= self.beta2
variance += (1 - self.beta2) * mx.nd.square(grad, out=grad)

num_rows = weight.shape[0]
coef1 = 1. - self.beta1**t
coef2 = 1. - self.beta2**t
lr *= math.sqrt(coef2)/coef1

weight -= lr*mean/(mx.nd.sqrt(variance) + self.epsilon)
for row in range(num_rows):
# check row slices of all zeros
all_zeros = mx.test_utils.almost_equal(grad[row].asnumpy(), np.zeros_like(grad[row].asnumpy()))
# skip zeros during sparse update
if all_zeros and self.sparse_update:
continue
grad[row] = grad[row] * self.rescale_grad + wd * weight[row]
# clip gradients
if self.clip_gradient is not None:
mx.nd.clip(grad[row], -self.clip_gradient, self.clip_gradient, out=grad[row])
# update mean
mean[row] *= self.beta1
mean[row] += grad[row] * (1. - self.beta1)
# update variance
variance[row] *= self.beta2
variance[row] += (1 - self.beta2) * mx.nd.square(grad[row], out=grad[row])
# update weight
weight[row] -= lr*mean[row]/(mx.nd.sqrt(variance[row]) + self.epsilon)


def test_adam():
Expand All @@ -386,10 +394,8 @@ def test_adam():
{'rescale_grad': 0.8, 'wd': 0.05}]
for kwarg in kwargs:
compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, np.float32)
# test operator fallback on cpu
if (default_context() == mx.cpu()):
compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape,
np.float32, g_stype='row_sparse')
compare_optimizer(opt1(sparse_update=True, **kwarg), opt2(**kwarg), shape,
np.float32, w_stype='row_sparse', g_stype='row_sparse')

# RMSProp
class PyRMSProp(mx.optimizer.Optimizer):
Expand Down