Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[BUGFIX] fix log_sigmoid bugs (#20372)
Browse files Browse the repository at this point in the history
* fix log_sigmoid bugs

* use forward interface

* forget updata rtc(backward_log_sigmoid)
  • Loading branch information
Adnios authored Jun 30, 2021
1 parent 0104e5d commit cb5bd4e
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 20 deletions.
2 changes: 1 addition & 1 deletion src/common/cuda/rtc/backward_functions-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ backward_sigmoid(const DTypeGrad grad, const DType val) {
template <typename DType, typename DTypeGrad>
__device__ inline mixed_type<DTypeGrad, DType>
backward_log_sigmoid(const DTypeGrad grad, const DType val) {
return grad * 1 / (1 + op::exp(val));
return grad * (1 - op::exp(val));
}
template <typename DType, typename DTypeGrad>
Expand Down
2 changes: 1 addition & 1 deletion src/operator/mshadow_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ MXNET_UNARY_MATH_OP(sigmoid_grad, math::id(a) * (1.0f - math::id(a)));

MXNET_UNARY_MATH_OP(log_sigmoid, math::log(1.0f / (1.0f + math::exp(-a))));

MXNET_UNARY_MATH_OP(log_sigmoid_grad, 1.0f / (1.0f + math::exp(a)));
MXNET_UNARY_MATH_OP(log_sigmoid_grad, 1.0f - math::exp(a));

struct mish : public mxnet_op::tunable {
template<typename DType>
Expand Down
13 changes: 8 additions & 5 deletions src/operator/nn/activation.cu
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,16 @@ void ActivationCompute<gpu>(const nnvm::NodeAttrs& attrs,
const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);
const int act_type = param.act_type;

// SoftReLU, kSoftSign and Mish are not supported by CUDNN yet
// SoftReLU, SoftSign, Log_Sigmoid and Mish are not supported by CUDNN yet
if (act_type == activation::kSoftReLU) {
ActivationForward<gpu, mshadow_op::softrelu, mshadow_op::softrelu_grad>(ctx,
inputs[0], req[0], outputs[0]);
} else if (act_type == activation::kSoftSign) {
ActivationForward<gpu, mshadow_op::softsign, mshadow_op::softsign_grad>(ctx,
inputs[0], req[0], outputs[0]);
} else if (act_type == activation::kLogSigmoid) {
ActivationForward<gpu, mshadow_op::log_sigmoid, mshadow_op::log_sigmoid_grad>(ctx,
inputs[0], req[0], outputs[0]);
} else if (act_type == activation::kMish) {
ActivationForward<gpu, mshadow_op::mish, mshadow_op::mish_grad>(ctx,
inputs[0], req[0], outputs[0]);
Expand All @@ -87,10 +90,13 @@ void ActivationGradCompute<gpu>(const nnvm::NodeAttrs& attrs,

bool do_memory_opt = dmlc::GetEnv("MXNET_MEMORY_OPT", 0);

// SoftReLU, SoftSign and Mish not supported by CUDNN yet
// SoftReLU, SoftSign, Log_Sigmoid and Mish not supported by CUDNN yet
if (act_type == activation::kSoftReLU) {
ActivationBackward<gpu, mshadow_op::softrelu, mshadow_op::softrelu_grad>(
ctx, inputs.at(0), inputs.at(1), req[0], outputs[0]);
} else if (act_type == activation::kLogSigmoid) {
ActivationBackward<gpu, mshadow_op::log_sigmoid, mshadow_op::log_sigmoid_grad>(
ctx, inputs.at(0), inputs.at(1), req[0], outputs[0]);
} else if (act_type == activation::kMish) {
ActivationBackward<gpu, mshadow_op::mish, mshadow_op::mish_grad>(
ctx, inputs.at(0), inputs.at(2), req[0], outputs[0]);
Expand Down Expand Up @@ -121,9 +127,6 @@ void ActivationGradCompute<gpu>(const nnvm::NodeAttrs& attrs,
} else if (act_type == activation::kSigmoid) {
ActivationBackward<gpu, mshadow_op::sigmoid, mshadow_op::sigmoid_grad>(
ctx, inputs.at(0), inputs.at(1), req[0], outputs[0]);
} else if (act_type == activation::kLogSigmoid) {
ActivationBackward<gpu, mshadow_op::log_sigmoid, mshadow_op::log_sigmoid_grad>(
ctx, inputs.at(0), inputs.at(1), req[0], outputs[0]);
} else {
LOG(FATAL) << "unknown activation type";
}
Expand Down
28 changes: 26 additions & 2 deletions src/operator/tensor/elemwise_unary_op_basic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -161,10 +161,34 @@ The storage type of ``log_sigmoid`` output is always dense
)code" ADD_FILELINE)
.set_attr<FCompute>("FCompute<cpu>", UnaryOp::Compute<cpu, mshadow_op::log_sigmoid>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_log_sigmoid"});
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseOut{"_backward_log_sigmoid"});

MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU(_backward_log_sigmoid,
unary_bwd<mshadow_op::log_sigmoid_grad>);
unary_bwd<mshadow_op::log_sigmoid_grad>)
.set_attr<nnvm::FGradient>("FGradient",
[](const nnvm::ObjectPtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
// n->inputs[0] : y_grad
// n->inputs[1] : f(x) = log_sigmoid(x)
// ograds[0] : head_grads
// f''(x) = f'(x) * (f'(x) - 1)
// NodeEntry{n} : y_grad * f'(x)
auto ones = MakeNode("ones_like", n->attrs.name + "_grad_ones", {n->inputs[1]}, nullptr, &n);
auto grad_minus_one = MakeNode("elemwise_sub", n->attrs.name + "_grad_sub",
{n->inputs[0], nnvm::NodeEntry{ones}}, nullptr, &n);
auto grad_grad_mid = MakeNode("elemwise_mul", n->attrs.name + "_grad_mul",
{n->inputs[0], nnvm::NodeEntry{grad_minus_one}}, nullptr, &n);
auto dydx = MakeNode("elemwise_div", n->attrs.name + "_grad_div",
{nnvm::NodeEntry{n}, n->inputs[0]}, nullptr, &n);

// when building gradient graph, the backward node of n->inputs[1] will be
// added to the graph again, therefore f`(x) will be multiplied
std::vector<nnvm::NodeEntry> ret;
ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "backward_grad_grad",
{ograds[0], nnvm::NodeEntry{dydx}}, nullptr, &n));
ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "backward_grad_grad_in",
{ograds[0], nnvm::NodeEntry{grad_grad_mid}}, nullptr, &n));
return ret;
});

// mish
MXNET_OPERATOR_REGISTER_UNARY(mish)
Expand Down
74 changes: 63 additions & 11 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -3451,6 +3451,42 @@ def forward(self, a):
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)


@use_np
def test_npx_activation_log_sigmoid():
def np_log_sigmoid(x):
return _np.log(_np.divide(1.0, (1.0 + _np.exp(-x))))
def np_log_sigmoid_grad(x):
return _np.divide(1.0, _np.add(1.0, _np.exp(x)))

class TestLogSigmoid(HybridBlock):
def __init__(self):
super(TestLogSigmoid, self).__init__()

def forward(self, a):
return npx.activation(a, act_type='log_sigmoid')

shapes = [(), (2, 3, 4)]
for hybridize in [True, False]:
for shape in shapes:
test_log_sigmoid = TestLogSigmoid()
if hybridize:
test_log_sigmoid.hybridize()
x = rand_ndarray(shape).as_np_ndarray()
x.attach_grad()
np_out = np_log_sigmoid(x.asnumpy())
with mx.autograd.record():
mx_out = test_log_sigmoid(x)
assert mx_out.shape == np_out.shape
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
mx_out.backward()
np_backward = np_log_sigmoid_grad(x.asnumpy())
assert_almost_equal(x.grad.asnumpy(), np_backward, rtol=1e-3, atol=1e-5)

mx_out = npx.activation(x, act_type='log_sigmoid')
np_out = np_log_sigmoid(x.asnumpy())
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)


@use_np
def test_npx_activation_mish():
def np_mish(a):
Expand All @@ -3461,17 +3497,33 @@ def np_mish_grad(a):
sigmoid = _np.divide(1.0, (1.0 + _np.exp(-a)))
return tanh + a * sigmoid * (1.0 - tanh * tanh)

shape = (3, 4)
A = mx.np.random.uniform(low=-1.0, high=1.0, size=shape)
A.attach_grad()
np_out = np_mish(A.asnumpy())
with mx.autograd.record():
B = mx.npx.activation(A, act_type='mish')
assert B.shape == np_out.shape
assert_almost_equal(B.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
B.backward()
np_backward = np_mish_grad(A.asnumpy())
assert_almost_equal(A.grad.asnumpy(), np_backward, rtol=1e-3, atol=1e-5)
class TestMish(HybridBlock):
def __init__(self):
super(TestMish, self).__init__()

def forward(self, a):
return npx.activation(a, act_type='mish')

shapes = [(), (2, 3, 4)]
for hybridize in [True, False]:
for shape in shapes:
test_mish = TestMish()
if hybridize:
test_mish.hybridize()
x = rand_ndarray(shape).as_np_ndarray()
x.attach_grad()
np_out = np_mish(x.asnumpy())
with mx.autograd.record():
mx_out = test_mish(x)
assert mx_out.shape == np_out.shape
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
mx_out.backward()
np_backward = np_mish_grad(x.asnumpy())
assert_almost_equal(x.grad.asnumpy(), np_backward, rtol=1e-3, atol=1e-5)

mx_out = npx.activation(x, act_type='mish')
np_out = np_mish(x.asnumpy())
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)


@use_np
Expand Down

0 comments on commit cb5bd4e

Please sign in to comment.