Skip to content

Commit

Permalink
Fixes NAG optimizer apache#15543 (apache#16053)
Browse files Browse the repository at this point in the history
* fix update rules

* readable updates in unit test

* mom update
  • Loading branch information
anirudhacharya authored and larroy committed Sep 28, 2019
1 parent 64bde04 commit 62dce19
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 37 deletions.
50 changes: 23 additions & 27 deletions src/operator/optimizer_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1066,21 +1066,18 @@ struct NAGMomKernel {
const DType param_lr, const DType param_wd,
const DType param_rescale_grad, const OpReqType req) {
if (param_clip_gradient >= 0.0f) {
mom_data[i] = param_momentum*mom_data[i]
+ mshadow_op::clip::Map(param_rescale_grad*grad_data[i],
param_clip_gradient)
+ (param_wd*weight_data[i]);
KERNEL_ASSIGN(out_data[i], req, weight_data[i]
- param_lr*(param_momentum*mom_data[i]
+ mshadow_op::clip::Map(param_rescale_grad*grad_data[i],
param_clip_gradient)));
mom_data[i] = param_momentum*mom_data[i];
KERNEL_ASSIGN(out_data[i], req, weight_data[i]-mom_data[i]+(param_momentum+1)
*(mom_data[i]-(param_lr*(mshadow_op::clip::Map(param_rescale_grad
*grad_data[i], param_clip_gradient)+(param_wd*weight_data[i])))));
mom_data[i] = mom_data[i] - (param_lr*((mshadow_op::clip::Map(param_rescale_grad*grad_data[i],
param_clip_gradient))+(param_wd*weight_data[i])));
} else {
mom_data[i] = param_momentum*mom_data[i]
+ param_rescale_grad*grad_data[i]
+ (param_wd*weight_data[i]);
KERNEL_ASSIGN(out_data[i], req, weight_data[i]
- param_lr*(param_momentum*mom_data[i]
+ param_rescale_grad*grad_data[i]));
mom_data[i] = param_momentum*mom_data[i];
KERNEL_ASSIGN(out_data[i], req, weight_data[i]-mom_data[i]+(param_momentum+1)
*(mom_data[i]-(param_lr*(param_rescale_grad*grad_data[i]+param_wd*weight_data[i]))));
mom_data[i] = mom_data[i] - param_lr*((param_rescale_grad*grad_data[i])
+(param_wd*weight_data[i]));
}
}
};
Expand Down Expand Up @@ -1119,22 +1116,21 @@ struct MP_NAGMomKernel {
const OpReqType req) {
float w = weight32[i];
if (param_clip_gradient >= 0.0f) {
mom_data[i] = param_momentum*mom_data[i]
+ mshadow_op::clip::Map(param_rescale_grad
*static_cast<float>(grad_data[i]), param_clip_gradient)
+ (param_wd*w);
w = w - param_lr*(param_momentum*mom_data[i]
+ mshadow_op::clip::Map(param_rescale_grad
*static_cast<float>(grad_data[i]),
param_clip_gradient));
mom_data[i] = param_momentum*mom_data[i];
w = w-mom_data[i]+(param_momentum+1)*(mom_data[i]-param_lr
*(mshadow_op::clip::Map(param_rescale_grad*static_cast<float>(grad_data[i]),
param_clip_gradient)+(param_wd*w)));
mom_data[i] = mom_data[i] - param_lr
*((mshadow_op::clip::Map(param_rescale_grad*static_cast<float>(grad_data[i]),
param_clip_gradient))+(param_wd*w));
weight32[i] = w;
KERNEL_ASSIGN(out_data[i], req, w);
} else {
mom_data[i] = param_momentum*mom_data[i]
+ param_rescale_grad*static_cast<float>(grad_data[i])
+ (param_wd*w);
w = w - param_lr*(param_momentum*mom_data[i]
+ param_rescale_grad*static_cast<float>(grad_data[i]));
mom_data[i] = param_momentum*mom_data[i];
w = w-mom_data[i]+(param_momentum+1)*(mom_data[i]-param_lr
*(param_rescale_grad*static_cast<float>(grad_data[i])+(param_wd*w)));
mom_data[i] = mom_data[i] - param_lr
*((param_rescale_grad*static_cast<float>(grad_data[i]))+(param_wd*w));
weight32[i] = w;
KERNEL_ASSIGN(out_data[i], req, w);
}
Expand Down
14 changes: 4 additions & 10 deletions tests/python/unittest/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,11 +384,8 @@ def update(self, index, weight, grad, state):
weight[:] += -lr * (grad + wd * weight)
else:
mom = state
mom[:] *= self.momentum
mom[:] += grad
mom[:] += wd * weight
grad[:] += self.momentum * mom
weight[:] -= lr * grad
weight[:] += (self.momentum**2 * mom) - lr*(self.momentum + 1)*(grad + wd*weight)
mom[:] = (self.momentum*mom) - lr*(grad + wd*weight)
else:
grad32 = array(grad, ctx=grad.context, dtype=np.float32)
grad32 = grad32 * self.rescale_grad
Expand All @@ -399,11 +396,8 @@ def update(self, index, weight, grad, state):
if self.momentum == 0.0:
weight32[:] += -lr * (grad32 + wd * weight32)
else:
mom[:] *= self.momentum
mom[:] += grad32
mom[:] += wd * weight32
grad32[:] += self.momentum * mom
weight32[:] -= lr * grad32
weight32[:] += (self.momentum**2 * mom) - lr*(self.momentum+1)*(grad32 + wd*weight32)
mom[:] = (self.momentum*mom) - lr*(grad32 + wd*weight32)
tmp = weight32.astype(weight.dtype)
tmp.copyto(weight)

Expand Down

0 comments on commit 62dce19

Please sign in to comment.