From cb5bd4ea8b6fc9b568c13747bdb006ac047b72b5 Mon Sep 17 00:00:00 2001 From: Adnios <41060790+Adnios@users.noreply.github.com> Date: Wed, 30 Jun 2021 09:38:33 +0800 Subject: [PATCH] [BUGFIX] fix log_sigmoid bugs (#20372) * fix log_sigmoid bugs * use forward interface * forget updata rtc(backward_log_sigmoid) --- src/common/cuda/rtc/backward_functions-inl.h | 2 +- src/operator/mshadow_op.h | 2 +- src/operator/nn/activation.cu | 13 ++-- .../tensor/elemwise_unary_op_basic.cc | 28 ++++++- tests/python/unittest/test_numpy_op.py | 74 ++++++++++++++++--- 5 files changed, 99 insertions(+), 20 deletions(-) diff --git a/src/common/cuda/rtc/backward_functions-inl.h b/src/common/cuda/rtc/backward_functions-inl.h index 85135ae6e888..50d8469571d8 100644 --- a/src/common/cuda/rtc/backward_functions-inl.h +++ b/src/common/cuda/rtc/backward_functions-inl.h @@ -47,7 +47,7 @@ backward_sigmoid(const DTypeGrad grad, const DType val) { template __device__ inline mixed_type backward_log_sigmoid(const DTypeGrad grad, const DType val) { - return grad * 1 / (1 + op::exp(val)); + return grad * (1 - op::exp(val)); } template diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h index 611ddbcad472..c3ce7332cc2c 100644 --- a/src/operator/mshadow_op.h +++ b/src/operator/mshadow_op.h @@ -413,7 +413,7 @@ MXNET_UNARY_MATH_OP(sigmoid_grad, math::id(a) * (1.0f - math::id(a))); MXNET_UNARY_MATH_OP(log_sigmoid, math::log(1.0f / (1.0f + math::exp(-a)))); -MXNET_UNARY_MATH_OP(log_sigmoid_grad, 1.0f / (1.0f + math::exp(a))); +MXNET_UNARY_MATH_OP(log_sigmoid_grad, 1.0f - math::exp(a)); struct mish : public mxnet_op::tunable { template diff --git a/src/operator/nn/activation.cu b/src/operator/nn/activation.cu index bb166243da50..1737f1ba412b 100644 --- a/src/operator/nn/activation.cu +++ b/src/operator/nn/activation.cu @@ -56,13 +56,16 @@ void ActivationCompute(const nnvm::NodeAttrs& attrs, const ActivationParam& param = nnvm::get(attrs.parsed); const int act_type = param.act_type; - // SoftReLU, kSoftSign and Mish are not supported by CUDNN yet + // SoftReLU, SoftSign, Log_Sigmoid and Mish are not supported by CUDNN yet if (act_type == activation::kSoftReLU) { ActivationForward(ctx, inputs[0], req[0], outputs[0]); } else if (act_type == activation::kSoftSign) { ActivationForward(ctx, inputs[0], req[0], outputs[0]); + } else if (act_type == activation::kLogSigmoid) { + ActivationForward(ctx, + inputs[0], req[0], outputs[0]); } else if (act_type == activation::kMish) { ActivationForward(ctx, inputs[0], req[0], outputs[0]); @@ -87,10 +90,13 @@ void ActivationGradCompute(const nnvm::NodeAttrs& attrs, bool do_memory_opt = dmlc::GetEnv("MXNET_MEMORY_OPT", 0); - // SoftReLU, SoftSign and Mish not supported by CUDNN yet + // SoftReLU, SoftSign, Log_Sigmoid and Mish not supported by CUDNN yet if (act_type == activation::kSoftReLU) { ActivationBackward( ctx, inputs.at(0), inputs.at(1), req[0], outputs[0]); + } else if (act_type == activation::kLogSigmoid) { + ActivationBackward( + ctx, inputs.at(0), inputs.at(1), req[0], outputs[0]); } else if (act_type == activation::kMish) { ActivationBackward( ctx, inputs.at(0), inputs.at(2), req[0], outputs[0]); @@ -121,9 +127,6 @@ void ActivationGradCompute(const nnvm::NodeAttrs& attrs, } else if (act_type == activation::kSigmoid) { ActivationBackward( ctx, inputs.at(0), inputs.at(1), req[0], outputs[0]); - } else if (act_type == activation::kLogSigmoid) { - ActivationBackward( - ctx, inputs.at(0), inputs.at(1), req[0], outputs[0]); } else { LOG(FATAL) << "unknown activation type"; } diff --git a/src/operator/tensor/elemwise_unary_op_basic.cc b/src/operator/tensor/elemwise_unary_op_basic.cc index 064c828210c7..079b9873d434 100644 --- a/src/operator/tensor/elemwise_unary_op_basic.cc +++ b/src/operator/tensor/elemwise_unary_op_basic.cc @@ -161,10 +161,34 @@ The storage type of ``log_sigmoid`` output is always dense )code" ADD_FILELINE) .set_attr("FCompute", UnaryOp::Compute) -.set_attr("FGradient", ElemwiseGradUseIn{"_backward_log_sigmoid"}); +.set_attr("FGradient", ElemwiseGradUseOut{"_backward_log_sigmoid"}); MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU(_backward_log_sigmoid, - unary_bwd); + unary_bwd) +.set_attr("FGradient", + [](const nnvm::ObjectPtr& n, const std::vector& ograds) { + // n->inputs[0] : y_grad + // n->inputs[1] : f(x) = log_sigmoid(x) + // ograds[0] : head_grads + // f''(x) = f'(x) * (f'(x) - 1) + // NodeEntry{n} : y_grad * f'(x) + auto ones = MakeNode("ones_like", n->attrs.name + "_grad_ones", {n->inputs[1]}, nullptr, &n); + auto grad_minus_one = MakeNode("elemwise_sub", n->attrs.name + "_grad_sub", + {n->inputs[0], nnvm::NodeEntry{ones}}, nullptr, &n); + auto grad_grad_mid = MakeNode("elemwise_mul", n->attrs.name + "_grad_mul", + {n->inputs[0], nnvm::NodeEntry{grad_minus_one}}, nullptr, &n); + auto dydx = MakeNode("elemwise_div", n->attrs.name + "_grad_div", + {nnvm::NodeEntry{n}, n->inputs[0]}, nullptr, &n); + + // when building gradient graph, the backward node of n->inputs[1] will be + // added to the graph again, therefore f`(x) will be multiplied + std::vector ret; + ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "backward_grad_grad", + {ograds[0], nnvm::NodeEntry{dydx}}, nullptr, &n)); + ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "backward_grad_grad_in", + {ograds[0], nnvm::NodeEntry{grad_grad_mid}}, nullptr, &n)); + return ret; + }); // mish MXNET_OPERATOR_REGISTER_UNARY(mish) diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 572735f84e2b..49558d1e58cd 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -3451,6 +3451,42 @@ def forward(self, a): assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5) +@use_np +def test_npx_activation_log_sigmoid(): + def np_log_sigmoid(x): + return _np.log(_np.divide(1.0, (1.0 + _np.exp(-x)))) + def np_log_sigmoid_grad(x): + return _np.divide(1.0, _np.add(1.0, _np.exp(x))) + + class TestLogSigmoid(HybridBlock): + def __init__(self): + super(TestLogSigmoid, self).__init__() + + def forward(self, a): + return npx.activation(a, act_type='log_sigmoid') + + shapes = [(), (2, 3, 4)] + for hybridize in [True, False]: + for shape in shapes: + test_log_sigmoid = TestLogSigmoid() + if hybridize: + test_log_sigmoid.hybridize() + x = rand_ndarray(shape).as_np_ndarray() + x.attach_grad() + np_out = np_log_sigmoid(x.asnumpy()) + with mx.autograd.record(): + mx_out = test_log_sigmoid(x) + assert mx_out.shape == np_out.shape + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5) + mx_out.backward() + np_backward = np_log_sigmoid_grad(x.asnumpy()) + assert_almost_equal(x.grad.asnumpy(), np_backward, rtol=1e-3, atol=1e-5) + + mx_out = npx.activation(x, act_type='log_sigmoid') + np_out = np_log_sigmoid(x.asnumpy()) + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5) + + @use_np def test_npx_activation_mish(): def np_mish(a): @@ -3461,17 +3497,33 @@ def np_mish_grad(a): sigmoid = _np.divide(1.0, (1.0 + _np.exp(-a))) return tanh + a * sigmoid * (1.0 - tanh * tanh) - shape = (3, 4) - A = mx.np.random.uniform(low=-1.0, high=1.0, size=shape) - A.attach_grad() - np_out = np_mish(A.asnumpy()) - with mx.autograd.record(): - B = mx.npx.activation(A, act_type='mish') - assert B.shape == np_out.shape - assert_almost_equal(B.asnumpy(), np_out, rtol=1e-3, atol=1e-5) - B.backward() - np_backward = np_mish_grad(A.asnumpy()) - assert_almost_equal(A.grad.asnumpy(), np_backward, rtol=1e-3, atol=1e-5) + class TestMish(HybridBlock): + def __init__(self): + super(TestMish, self).__init__() + + def forward(self, a): + return npx.activation(a, act_type='mish') + + shapes = [(), (2, 3, 4)] + for hybridize in [True, False]: + for shape in shapes: + test_mish = TestMish() + if hybridize: + test_mish.hybridize() + x = rand_ndarray(shape).as_np_ndarray() + x.attach_grad() + np_out = np_mish(x.asnumpy()) + with mx.autograd.record(): + mx_out = test_mish(x) + assert mx_out.shape == np_out.shape + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5) + mx_out.backward() + np_backward = np_mish_grad(x.asnumpy()) + assert_almost_equal(x.grad.asnumpy(), np_backward, rtol=1e-3, atol=1e-5) + + mx_out = npx.activation(x, act_type='mish') + np_out = np_mish(x.asnumpy()) + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5) @use_np