From 26a56232d6bf9d31a7e5554b05bb10aa63e17c64 Mon Sep 17 00:00:00 2001 From: Lin Yuan Date: Thu, 28 Jun 2018 10:14:26 -0700 Subject: [PATCH 01/19] Add temperature parameter in softmax operator and add a unit test --- src/operator/contrib/ctc_loss-inl.h | 4 +-- src/operator/nn/softmax-inl.h | 42 +++++++++++++++----------- src/operator/nn/softmax.cc | 3 ++ tests/python/unittest/test_operator.py | 15 +++++++-- 4 files changed, 42 insertions(+), 22 deletions(-) diff --git a/src/operator/contrib/ctc_loss-inl.h b/src/operator/contrib/ctc_loss-inl.h index ef58c519aa9c..f96118c8646a 100644 --- a/src/operator/contrib/ctc_loss-inl.h +++ b/src/operator/contrib/ctc_loss-inl.h @@ -409,7 +409,7 @@ class CTCLossOp : public Operator { // since the input is activation before softmax and cudnn ctc takes softmax // apply softmax to inputs first. - mxnet_op::Softmax(s, data.dptr_, prob.dptr_, data.shape_, 2); + mxnet_op::Softmax(s, data.dptr_, prob.dptr_, data.shape_, 2, 1.0f); CUDNN_CALL(cudnnCTCLoss(s->dnn_handle_, prob_desc_, @@ -427,7 +427,7 @@ class CTCLossOp : public Operator { if (req_grad) { mxnet_op::SoftmaxGrad(s, - prob.dptr_, grad.dptr_, grad.dptr_, data.shape_, 2); + prob.dptr_, grad.dptr_, grad.dptr_, data.shape_, 2, 1.0f); Assign(grad, mxnet::kWriteInplace, grad * alphabet_size); } } diff --git a/src/operator/nn/softmax-inl.h b/src/operator/nn/softmax-inl.h index 080bc08852c3..70bf301b0815 100644 --- a/src/operator/nn/softmax-inl.h +++ b/src/operator/nn/softmax-inl.h @@ -53,7 +53,7 @@ struct log_softmax_fwd { template inline void Softmax(Stream *s, DType *in, DType *out, - Shape shape, int axis) { + Shape shape, int axis, const float temperature) { index_t M = shape[axis]; index_t N = shape.Size()/M; Shape stride = calc_stride(shape); @@ -72,11 +72,11 @@ inline void Softmax(Stream *s, DType *in, DType *out, DType sum = DType(0); for (index_t j = 0; j < M; ++j) { - sum += std::exp(in[base + j*sa] - mmax); + sum += std::exp((in[base + j*sa] - mmax)/temperature); } for (index_t j = 0; j < M; ++j) { - out[base + j*sa] = OP::Map(in[base + j*sa] - mmax, sum); + out[base + j*sa] = OP::Map((in[base + j*sa] - mmax)/temperature, sum); } } } @@ -100,7 +100,8 @@ struct log_softmax_bwd { template inline void SoftmaxGrad(Stream *s, DType *out, DType *ograd, - DType *igrad, Shape shape, int axis) { + DType *igrad, Shape shape, int axis, + const float temperature) { index_t M = shape[axis]; index_t N = shape.Size()/M; Shape stride = calc_stride(shape); @@ -118,7 +119,7 @@ inline void SoftmaxGrad(Stream *s, DType *out, DType *ograd, } for (index_t j = 0; j < M; ++j) { - igrad[base + j*sa] = OP2::Map(ograd[base + j*sa], out[base + j*sa], sum); + igrad[base + j*sa] = OP2::Map(ograd[base + j*sa], out[base + j*sa], sum)/temperature; } } } @@ -127,7 +128,8 @@ inline void SoftmaxGrad(Stream *s, DType *out, DType *ograd, #ifdef __CUDACC__ template __global__ void softmax_compute_kernel(DType *in, DType *out, index_t M, int axis, - Shape sshape, Shape stride) { + Shape sshape, Shape stride, + const float temperature) { const unsigned x_size = 1 << x_bits; __shared__ DType smem[x_size]; index_t sa = stride[axis]; @@ -146,7 +148,7 @@ __global__ void softmax_compute_kernel(DType *in, DType *out, index_t M, int axi red::sum::SetInitValue(smem[x]); for (index_t i = x; i < M; i += x_size) { - red::sum::Reduce(smem[x], static_cast(expf(in[base + i*sa] - smax))); + red::sum::Reduce(smem[x], static_cast(expf((in[base + i*sa] - smax)/temperature)); } __syncthreads(); cuda::Reduce1D(smem); @@ -155,13 +157,13 @@ __global__ void softmax_compute_kernel(DType *in, DType *out, index_t M, int axi __syncthreads(); for (index_t i = x; i < M; i += x_size) { - out[base + i*sa] = OP::Map(in[base + i*sa] - smax, ssum); + out[base + i*sa] = OP::Map((in[base + i*sa] - smax)/temperature, ssum); } } template inline void Softmax(Stream *s, DType *in, DType *out, - Shape shape, int axis) { + Shape shape, int axis, const float temperature) { const int x_bits = 7; const int x_size = 1 << x_bits; index_t M = shape[axis]; @@ -172,7 +174,7 @@ inline void Softmax(Stream *s, DType *in, DType *out, softmax_compute_kernel <<::GetStream(s)>>>( - in, out, M, axis, sshape, stride); + in, out, M, axis, sshape, stride, temperature); MSHADOW_CUDA_POST_KERNEL_CHECK(softmax_compute_kernel); } @@ -180,7 +182,7 @@ inline void Softmax(Stream *s, DType *in, DType *out, template __global__ void softmax_gradient_kernel(DType *out, DType *ograd, DType *igrad, index_t M, int axis, Shape sshape, - Shape stride) { + Shape stride, const float temperature) { const unsigned x_size = 1 << x_bits; __shared__ DType smem[x_size]; index_t sa = stride[axis]; @@ -198,14 +200,15 @@ __global__ void softmax_gradient_kernel(DType *out, DType *ograd, DType *igrad, __syncthreads(); for (index_t i = x; i < M; i += x_size) { - igrad[base + i*sa] = OP2::Map(ograd[base + i*sa], out[base + i*sa], ssum); + igrad[base + i*sa] = OP2::Map(ograd[base + i*sa], out[base + i*sa], ssum)/temperature; } } template inline void SoftmaxGrad(Stream *s, DType *out, DType *ograd, - DType *igrad, Shape shape, int axis) { + DType *igrad, Shape shape, int axis, + const float temperature) { const int x_bits = 7; const int x_size = 1 << x_bits; index_t M = shape[axis]; @@ -216,7 +219,7 @@ inline void SoftmaxGrad(Stream *s, DType *out, DType *ograd, softmax_gradient_kernel <<::GetStream(s)>>>( - out, ograd, igrad, M, axis, sshape, stride); + out, ograd, igrad, M, axis, sshape, stride, temperature); MSHADOW_CUDA_POST_KERNEL_CHECK(softmax_gradient_kernel); } #endif @@ -226,9 +229,12 @@ inline void SoftmaxGrad(Stream *s, DType *out, DType *ograd, struct SoftmaxParam : public dmlc::Parameter { int axis; + float temperature; DMLC_DECLARE_PARAMETER(SoftmaxParam) { DMLC_DECLARE_FIELD(axis).set_default(-1) .describe("The axis along which to compute softmax."); + DMLC_DECLARE_FIELD(temperature).set_default(1.0f) + .describe("Temperature parameter in softmax"); } }; @@ -247,10 +253,10 @@ void SoftmaxCompute(const nnvm::NodeAttrs& attrs, MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { if (shape.ndim() == 2) { Softmax(ctx.get_stream(), inputs[0].dptr(), - outputs[0].dptr(), shape.get<2>(), axis); + outputs[0].dptr(), shape.get<2>(), axis, param.temperature); } else { Softmax(ctx.get_stream(), inputs[0].dptr(), - outputs[0].dptr(), shape.get<3>(), axis); + outputs[0].dptr(), shape.get<3>(), axis, param.temperature); } }); } @@ -272,11 +278,11 @@ void SoftmaxGradCompute(const nnvm::NodeAttrs& attrs, if (shape.ndim() == 2) { SoftmaxGrad(ctx.get_stream(), inputs[1].dptr(), inputs[0].dptr(), outputs[0].dptr(), - shape.get<2>(), axis); + shape.get<2>(), axis, param.temperature); } else { SoftmaxGrad(ctx.get_stream(), inputs[1].dptr(), inputs[0].dptr(), outputs[0].dptr(), - shape.get<3>(), axis); + shape.get<3>(), axis, param.temperature); } }); } diff --git a/src/operator/nn/softmax.cc b/src/operator/nn/softmax.cc index e9b104f12868..5500511ca2db 100644 --- a/src/operator/nn/softmax.cc +++ b/src/operator/nn/softmax.cc @@ -77,10 +77,13 @@ MXNET_OPERATOR_REGISTER_UNARY(softmax) The resulting array contains elements in the range (0,1) and the elements along the given axis sum up to 1. .. math:: + softmax(\mathbf{z/t})_j = \frac{e^{z_j/t}}{\sum_{k=1}^K e^{z_k/t}} softmax(\mathbf{z})_j = \frac{e^{z_j}}{\sum_{k=1}^K e^{z_k}} for :math:`j = 1, ..., K` +t is the temperature parameter in softmax function. By default, t equals 1.0 + Example:: x = [[ 1. 1. 1.] diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 1a1c548c595b..30869fbc4a06 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -258,11 +258,11 @@ def test_rnnrelu_dropout(): out = exe.forward(is_train=True) out[0].wait_to_read() -def np_softmax(x, axis=-1): +def np_softmax(x, axis=-1, temperature=1.0): # fix for old numpy on Travis not supporting keepdims # x = x - np.max(x, axis=-1, keepdims=True) x = x - np.max(x, axis=axis, keepdims=True) - x = np.exp(x) + x = np.exp(x/temperature) # x /= np.sum(x, axis=-1, keepdims=True) x /= np.sum(x, axis=axis, keepdims=True) return x @@ -4226,6 +4226,17 @@ def test_new_softmax(): check_symbolic_forward(sym, [data], [np_softmax(data, axis=axis)]) check_numeric_gradient(sym, [data], rtol=0.05, atol=1e-3) +@with_seed() +def test_softmax_with_temperature(): + for ndim in range(1, 5): + for temp in range(1, 11): + shape = np.random.randint(1, 5, size=ndim) + axis = np.random.randint(-ndim, ndim) + data = np.random.uniform(-2, 2, size=shape) + sym = mx.sym.softmax(axis=axis, temperature=temp) + expected = np_softmax(data, axis=axis, temperature=temp) + check_symbolic_forward(sym, [data], [expected]) + check_numeric_gradient(sym, [data], rtol=0.05, atol=1e-3) @with_seed() def test_log_softmax(): From bc2b9ffc4222df9c7402fec2089629c903e97091 Mon Sep 17 00:00:00 2001 From: Lin Yuan Date: Thu, 28 Jun 2018 11:04:40 -0700 Subject: [PATCH 02/19] Optimize runtime when temperature is set to default 1.0 --- src/operator/nn/softmax-inl.h | 30 +++++++++++++++++++++++------- 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/src/operator/nn/softmax-inl.h b/src/operator/nn/softmax-inl.h index 70bf301b0815..b4fcd794c163 100644 --- a/src/operator/nn/softmax-inl.h +++ b/src/operator/nn/softmax-inl.h @@ -71,12 +71,22 @@ inline void Softmax(Stream *s, DType *in, DType *out, } DType sum = DType(0); - for (index_t j = 0; j < M; ++j) { - sum += std::exp((in[base + j*sa] - mmax)/temperature); - } + if (temperature == 1.0) { + for (index_t j = 0; j < M; ++j) { + sum += std::exp(in[base + j*sa] - mmax); + } + + for (index_t j = 0; j < M; ++j) { + out[base + j*sa] = OP::Map(in[base + j*sa] - mmax, sum); + } + } else { + for (index_t j = 0; j < M; ++j) { + sum += std::exp((in[base + j*sa] - mmax)/temperature); + } - for (index_t j = 0; j < M; ++j) { - out[base + j*sa] = OP::Map((in[base + j*sa] - mmax)/temperature, sum); + for (index_t j = 0; j < M; ++j) { + out[base + j*sa] = OP::Map((in[base + j*sa] - mmax)/temperature, sum); + } } } } @@ -118,8 +128,14 @@ inline void SoftmaxGrad(Stream *s, DType *out, DType *ograd, sum += OP1::Map(ograd[base + j*sa], out[base + j*sa]); } - for (index_t j = 0; j < M; ++j) { - igrad[base + j*sa] = OP2::Map(ograd[base + j*sa], out[base + j*sa], sum)/temperature; + if (temperature == 1.0) { + for (index_t j = 0; j < M; ++j) { + igrad[base + j*sa] = OP2::Map(ograd[base + j*sa], out[base + j*sa], sum); + } + } else { + for (index_t j = 0; j < M; ++j) { + igrad[base + j*sa] = OP2::Map(ograd[base + j*sa], out[base + j*sa], sum)/temperature; + } } } } From 38ac9070fcdc28d99143185496338b4b8fc5c305 Mon Sep 17 00:00:00 2001 From: Lin Yuan Date: Thu, 28 Jun 2018 10:14:26 -0700 Subject: [PATCH 03/19] Add temperature parameter in softmax operator and add a unit test --- src/operator/contrib/ctc_loss-inl.h | 4 +-- src/operator/nn/softmax-inl.h | 42 +++++++++++++++----------- src/operator/nn/softmax.cc | 3 ++ tests/python/unittest/test_operator.py | 15 +++++++-- 4 files changed, 42 insertions(+), 22 deletions(-) diff --git a/src/operator/contrib/ctc_loss-inl.h b/src/operator/contrib/ctc_loss-inl.h index ef58c519aa9c..f96118c8646a 100644 --- a/src/operator/contrib/ctc_loss-inl.h +++ b/src/operator/contrib/ctc_loss-inl.h @@ -409,7 +409,7 @@ class CTCLossOp : public Operator { // since the input is activation before softmax and cudnn ctc takes softmax // apply softmax to inputs first. - mxnet_op::Softmax(s, data.dptr_, prob.dptr_, data.shape_, 2); + mxnet_op::Softmax(s, data.dptr_, prob.dptr_, data.shape_, 2, 1.0f); CUDNN_CALL(cudnnCTCLoss(s->dnn_handle_, prob_desc_, @@ -427,7 +427,7 @@ class CTCLossOp : public Operator { if (req_grad) { mxnet_op::SoftmaxGrad(s, - prob.dptr_, grad.dptr_, grad.dptr_, data.shape_, 2); + prob.dptr_, grad.dptr_, grad.dptr_, data.shape_, 2, 1.0f); Assign(grad, mxnet::kWriteInplace, grad * alphabet_size); } } diff --git a/src/operator/nn/softmax-inl.h b/src/operator/nn/softmax-inl.h index 080bc08852c3..70bf301b0815 100644 --- a/src/operator/nn/softmax-inl.h +++ b/src/operator/nn/softmax-inl.h @@ -53,7 +53,7 @@ struct log_softmax_fwd { template inline void Softmax(Stream *s, DType *in, DType *out, - Shape shape, int axis) { + Shape shape, int axis, const float temperature) { index_t M = shape[axis]; index_t N = shape.Size()/M; Shape stride = calc_stride(shape); @@ -72,11 +72,11 @@ inline void Softmax(Stream *s, DType *in, DType *out, DType sum = DType(0); for (index_t j = 0; j < M; ++j) { - sum += std::exp(in[base + j*sa] - mmax); + sum += std::exp((in[base + j*sa] - mmax)/temperature); } for (index_t j = 0; j < M; ++j) { - out[base + j*sa] = OP::Map(in[base + j*sa] - mmax, sum); + out[base + j*sa] = OP::Map((in[base + j*sa] - mmax)/temperature, sum); } } } @@ -100,7 +100,8 @@ struct log_softmax_bwd { template inline void SoftmaxGrad(Stream *s, DType *out, DType *ograd, - DType *igrad, Shape shape, int axis) { + DType *igrad, Shape shape, int axis, + const float temperature) { index_t M = shape[axis]; index_t N = shape.Size()/M; Shape stride = calc_stride(shape); @@ -118,7 +119,7 @@ inline void SoftmaxGrad(Stream *s, DType *out, DType *ograd, } for (index_t j = 0; j < M; ++j) { - igrad[base + j*sa] = OP2::Map(ograd[base + j*sa], out[base + j*sa], sum); + igrad[base + j*sa] = OP2::Map(ograd[base + j*sa], out[base + j*sa], sum)/temperature; } } } @@ -127,7 +128,8 @@ inline void SoftmaxGrad(Stream *s, DType *out, DType *ograd, #ifdef __CUDACC__ template __global__ void softmax_compute_kernel(DType *in, DType *out, index_t M, int axis, - Shape sshape, Shape stride) { + Shape sshape, Shape stride, + const float temperature) { const unsigned x_size = 1 << x_bits; __shared__ DType smem[x_size]; index_t sa = stride[axis]; @@ -146,7 +148,7 @@ __global__ void softmax_compute_kernel(DType *in, DType *out, index_t M, int axi red::sum::SetInitValue(smem[x]); for (index_t i = x; i < M; i += x_size) { - red::sum::Reduce(smem[x], static_cast(expf(in[base + i*sa] - smax))); + red::sum::Reduce(smem[x], static_cast(expf((in[base + i*sa] - smax)/temperature)); } __syncthreads(); cuda::Reduce1D(smem); @@ -155,13 +157,13 @@ __global__ void softmax_compute_kernel(DType *in, DType *out, index_t M, int axi __syncthreads(); for (index_t i = x; i < M; i += x_size) { - out[base + i*sa] = OP::Map(in[base + i*sa] - smax, ssum); + out[base + i*sa] = OP::Map((in[base + i*sa] - smax)/temperature, ssum); } } template inline void Softmax(Stream *s, DType *in, DType *out, - Shape shape, int axis) { + Shape shape, int axis, const float temperature) { const int x_bits = 7; const int x_size = 1 << x_bits; index_t M = shape[axis]; @@ -172,7 +174,7 @@ inline void Softmax(Stream *s, DType *in, DType *out, softmax_compute_kernel <<::GetStream(s)>>>( - in, out, M, axis, sshape, stride); + in, out, M, axis, sshape, stride, temperature); MSHADOW_CUDA_POST_KERNEL_CHECK(softmax_compute_kernel); } @@ -180,7 +182,7 @@ inline void Softmax(Stream *s, DType *in, DType *out, template __global__ void softmax_gradient_kernel(DType *out, DType *ograd, DType *igrad, index_t M, int axis, Shape sshape, - Shape stride) { + Shape stride, const float temperature) { const unsigned x_size = 1 << x_bits; __shared__ DType smem[x_size]; index_t sa = stride[axis]; @@ -198,14 +200,15 @@ __global__ void softmax_gradient_kernel(DType *out, DType *ograd, DType *igrad, __syncthreads(); for (index_t i = x; i < M; i += x_size) { - igrad[base + i*sa] = OP2::Map(ograd[base + i*sa], out[base + i*sa], ssum); + igrad[base + i*sa] = OP2::Map(ograd[base + i*sa], out[base + i*sa], ssum)/temperature; } } template inline void SoftmaxGrad(Stream *s, DType *out, DType *ograd, - DType *igrad, Shape shape, int axis) { + DType *igrad, Shape shape, int axis, + const float temperature) { const int x_bits = 7; const int x_size = 1 << x_bits; index_t M = shape[axis]; @@ -216,7 +219,7 @@ inline void SoftmaxGrad(Stream *s, DType *out, DType *ograd, softmax_gradient_kernel <<::GetStream(s)>>>( - out, ograd, igrad, M, axis, sshape, stride); + out, ograd, igrad, M, axis, sshape, stride, temperature); MSHADOW_CUDA_POST_KERNEL_CHECK(softmax_gradient_kernel); } #endif @@ -226,9 +229,12 @@ inline void SoftmaxGrad(Stream *s, DType *out, DType *ograd, struct SoftmaxParam : public dmlc::Parameter { int axis; + float temperature; DMLC_DECLARE_PARAMETER(SoftmaxParam) { DMLC_DECLARE_FIELD(axis).set_default(-1) .describe("The axis along which to compute softmax."); + DMLC_DECLARE_FIELD(temperature).set_default(1.0f) + .describe("Temperature parameter in softmax"); } }; @@ -247,10 +253,10 @@ void SoftmaxCompute(const nnvm::NodeAttrs& attrs, MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { if (shape.ndim() == 2) { Softmax(ctx.get_stream(), inputs[0].dptr(), - outputs[0].dptr(), shape.get<2>(), axis); + outputs[0].dptr(), shape.get<2>(), axis, param.temperature); } else { Softmax(ctx.get_stream(), inputs[0].dptr(), - outputs[0].dptr(), shape.get<3>(), axis); + outputs[0].dptr(), shape.get<3>(), axis, param.temperature); } }); } @@ -272,11 +278,11 @@ void SoftmaxGradCompute(const nnvm::NodeAttrs& attrs, if (shape.ndim() == 2) { SoftmaxGrad(ctx.get_stream(), inputs[1].dptr(), inputs[0].dptr(), outputs[0].dptr(), - shape.get<2>(), axis); + shape.get<2>(), axis, param.temperature); } else { SoftmaxGrad(ctx.get_stream(), inputs[1].dptr(), inputs[0].dptr(), outputs[0].dptr(), - shape.get<3>(), axis); + shape.get<3>(), axis, param.temperature); } }); } diff --git a/src/operator/nn/softmax.cc b/src/operator/nn/softmax.cc index e9b104f12868..5500511ca2db 100644 --- a/src/operator/nn/softmax.cc +++ b/src/operator/nn/softmax.cc @@ -77,10 +77,13 @@ MXNET_OPERATOR_REGISTER_UNARY(softmax) The resulting array contains elements in the range (0,1) and the elements along the given axis sum up to 1. .. math:: + softmax(\mathbf{z/t})_j = \frac{e^{z_j/t}}{\sum_{k=1}^K e^{z_k/t}} softmax(\mathbf{z})_j = \frac{e^{z_j}}{\sum_{k=1}^K e^{z_k}} for :math:`j = 1, ..., K` +t is the temperature parameter in softmax function. By default, t equals 1.0 + Example:: x = [[ 1. 1. 1.] diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 1a1c548c595b..30869fbc4a06 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -258,11 +258,11 @@ def test_rnnrelu_dropout(): out = exe.forward(is_train=True) out[0].wait_to_read() -def np_softmax(x, axis=-1): +def np_softmax(x, axis=-1, temperature=1.0): # fix for old numpy on Travis not supporting keepdims # x = x - np.max(x, axis=-1, keepdims=True) x = x - np.max(x, axis=axis, keepdims=True) - x = np.exp(x) + x = np.exp(x/temperature) # x /= np.sum(x, axis=-1, keepdims=True) x /= np.sum(x, axis=axis, keepdims=True) return x @@ -4226,6 +4226,17 @@ def test_new_softmax(): check_symbolic_forward(sym, [data], [np_softmax(data, axis=axis)]) check_numeric_gradient(sym, [data], rtol=0.05, atol=1e-3) +@with_seed() +def test_softmax_with_temperature(): + for ndim in range(1, 5): + for temp in range(1, 11): + shape = np.random.randint(1, 5, size=ndim) + axis = np.random.randint(-ndim, ndim) + data = np.random.uniform(-2, 2, size=shape) + sym = mx.sym.softmax(axis=axis, temperature=temp) + expected = np_softmax(data, axis=axis, temperature=temp) + check_symbolic_forward(sym, [data], [expected]) + check_numeric_gradient(sym, [data], rtol=0.05, atol=1e-3) @with_seed() def test_log_softmax(): From ce2af7755f379deaee539ec3e66700e86d2f6694 Mon Sep 17 00:00:00 2001 From: Lin Yuan Date: Thu, 28 Jun 2018 11:04:40 -0700 Subject: [PATCH 04/19] Optimize runtime when temperature is set to default 1.0 --- src/operator/nn/softmax-inl.h | 30 +++++++++++++++++++++++------- 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/src/operator/nn/softmax-inl.h b/src/operator/nn/softmax-inl.h index 70bf301b0815..b4fcd794c163 100644 --- a/src/operator/nn/softmax-inl.h +++ b/src/operator/nn/softmax-inl.h @@ -71,12 +71,22 @@ inline void Softmax(Stream *s, DType *in, DType *out, } DType sum = DType(0); - for (index_t j = 0; j < M; ++j) { - sum += std::exp((in[base + j*sa] - mmax)/temperature); - } + if (temperature == 1.0) { + for (index_t j = 0; j < M; ++j) { + sum += std::exp(in[base + j*sa] - mmax); + } + + for (index_t j = 0; j < M; ++j) { + out[base + j*sa] = OP::Map(in[base + j*sa] - mmax, sum); + } + } else { + for (index_t j = 0; j < M; ++j) { + sum += std::exp((in[base + j*sa] - mmax)/temperature); + } - for (index_t j = 0; j < M; ++j) { - out[base + j*sa] = OP::Map((in[base + j*sa] - mmax)/temperature, sum); + for (index_t j = 0; j < M; ++j) { + out[base + j*sa] = OP::Map((in[base + j*sa] - mmax)/temperature, sum); + } } } } @@ -118,8 +128,14 @@ inline void SoftmaxGrad(Stream *s, DType *out, DType *ograd, sum += OP1::Map(ograd[base + j*sa], out[base + j*sa]); } - for (index_t j = 0; j < M; ++j) { - igrad[base + j*sa] = OP2::Map(ograd[base + j*sa], out[base + j*sa], sum)/temperature; + if (temperature == 1.0) { + for (index_t j = 0; j < M; ++j) { + igrad[base + j*sa] = OP2::Map(ograd[base + j*sa], out[base + j*sa], sum); + } + } else { + for (index_t j = 0; j < M; ++j) { + igrad[base + j*sa] = OP2::Map(ograd[base + j*sa], out[base + j*sa], sum)/temperature; + } } } } From 6dc4e42cd7b8d17d2fb178a97db7f1e5c9df41c3 Mon Sep 17 00:00:00 2001 From: Lin Yuan Date: Thu, 28 Jun 2018 11:27:28 -0700 Subject: [PATCH 05/19] Fix lint error --- src/operator/nn/softmax-inl.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/operator/nn/softmax-inl.h b/src/operator/nn/softmax-inl.h index b4fcd794c163..3d69c0d0b9c1 100644 --- a/src/operator/nn/softmax-inl.h +++ b/src/operator/nn/softmax-inl.h @@ -135,7 +135,7 @@ inline void SoftmaxGrad(Stream *s, DType *out, DType *ograd, } else { for (index_t j = 0; j < M; ++j) { igrad[base + j*sa] = OP2::Map(ograd[base + j*sa], out[base + j*sa], sum)/temperature; - } + } } } } From c4e2de7846c83bfb3456b0ebd5a6a2266e6874ca Mon Sep 17 00:00:00 2001 From: Lin Yuan Date: Thu, 28 Jun 2018 11:33:40 -0700 Subject: [PATCH 06/19] Add my name to contributor list --- CONTRIBUTORS.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index f1ab129288a1..87a629f7798d 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -172,3 +172,4 @@ List of Contributors * [Thomas Delteil](https://github.com/ThomasDelteil) * [Jesse Brizzi](https://github.com/jessebrizzi) * [Hang Zhang](http://hangzh.com) +* [Lin Yuan](https://github.com/apeforest) From 595130d65f9aa87f0aa560cf81fa8f1b691f49d2 Mon Sep 17 00:00:00 2001 From: Lin Yuan Date: Fri, 29 Jun 2018 09:34:41 -0700 Subject: [PATCH 07/19] Fix build error in CUDA --- src/operator/contrib/ctc_loss-inl.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/operator/contrib/ctc_loss-inl.h b/src/operator/contrib/ctc_loss-inl.h index f96118c8646a..0e7b63e58fb3 100644 --- a/src/operator/contrib/ctc_loss-inl.h +++ b/src/operator/contrib/ctc_loss-inl.h @@ -409,7 +409,7 @@ class CTCLossOp : public Operator { // since the input is activation before softmax and cudnn ctc takes softmax // apply softmax to inputs first. - mxnet_op::Softmax(s, data.dptr_, prob.dptr_, data.shape_, 2, 1.0f); + mxnet_op::Softmax(s, data.dptr_, prob.dptr_, data.shape_, 2, 1.0); CUDNN_CALL(cudnnCTCLoss(s->dnn_handle_, prob_desc_, @@ -427,7 +427,7 @@ class CTCLossOp : public Operator { if (req_grad) { mxnet_op::SoftmaxGrad(s, - prob.dptr_, grad.dptr_, grad.dptr_, data.shape_, 2, 1.0f); + prob.dptr_, grad.dptr_, grad.dptr_, data.shape_, 2, 1.0); Assign(grad, mxnet::kWriteInplace, grad * alphabet_size); } } From fd83576269a952430396c6038461aa611de45d6a Mon Sep 17 00:00:00 2001 From: Lin Yuan Date: Fri, 29 Jun 2018 09:59:20 -0700 Subject: [PATCH 08/19] Fix build error in CUDA --- src/operator/contrib/ctc_loss-inl.h | 4 ++-- src/operator/nn/softmax-inl.h | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/operator/contrib/ctc_loss-inl.h b/src/operator/contrib/ctc_loss-inl.h index 0e7b63e58fb3..f96118c8646a 100644 --- a/src/operator/contrib/ctc_loss-inl.h +++ b/src/operator/contrib/ctc_loss-inl.h @@ -409,7 +409,7 @@ class CTCLossOp : public Operator { // since the input is activation before softmax and cudnn ctc takes softmax // apply softmax to inputs first. - mxnet_op::Softmax(s, data.dptr_, prob.dptr_, data.shape_, 2, 1.0); + mxnet_op::Softmax(s, data.dptr_, prob.dptr_, data.shape_, 2, 1.0f); CUDNN_CALL(cudnnCTCLoss(s->dnn_handle_, prob_desc_, @@ -427,7 +427,7 @@ class CTCLossOp : public Operator { if (req_grad) { mxnet_op::SoftmaxGrad(s, - prob.dptr_, grad.dptr_, grad.dptr_, data.shape_, 2, 1.0); + prob.dptr_, grad.dptr_, grad.dptr_, data.shape_, 2, 1.0f); Assign(grad, mxnet::kWriteInplace, grad * alphabet_size); } } diff --git a/src/operator/nn/softmax-inl.h b/src/operator/nn/softmax-inl.h index 3d69c0d0b9c1..bc61a03d66e7 100644 --- a/src/operator/nn/softmax-inl.h +++ b/src/operator/nn/softmax-inl.h @@ -164,7 +164,7 @@ __global__ void softmax_compute_kernel(DType *in, DType *out, index_t M, int axi red::sum::SetInitValue(smem[x]); for (index_t i = x; i < M; i += x_size) { - red::sum::Reduce(smem[x], static_cast(expf((in[base + i*sa] - smax)/temperature)); + red::sum::Reduce(smem[x], static_cast(expf((in[base + i*sa] - smax)/temperature))); } __syncthreads(); cuda::Reduce1D(smem); From 8bef2dfb05ebcc1726d83c5975db05179e361318 Mon Sep 17 00:00:00 2001 From: Lin Yuan Date: Fri, 29 Jun 2018 15:15:17 -0700 Subject: [PATCH 09/19] Fall back to regular CPU when setting temperature parameter in MKLDNN --- src/operator/nn/mkldnn/mkldnn_base-inl.h | 2 ++ src/operator/nn/mkldnn/mkldnn_softmax.cc | 7 +++++++ src/operator/nn/softmax.cc | 3 ++- 3 files changed, 11 insertions(+), 1 deletion(-) diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h b/src/operator/nn/mkldnn/mkldnn_base-inl.h index c6e7f9bdefdc..d602c250d89b 100644 --- a/src/operator/nn/mkldnn/mkldnn_base-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h @@ -146,9 +146,11 @@ namespace op { struct ActivationParam; struct ConvolutionParam; struct DeconvolutionParam; +struct SoftmaxParam; bool SupportMKLDNNAct(const ActivationParam& param); bool SupportMKLDNNConv(const ConvolutionParam& params, const NDArray &input); bool SupportMKLDNNDeconv(const DeconvolutionParam& params, const NDArray &input); +bool SupportMKLDNNSoftmax(const SoftmaxParam& param); } static int GetTypeSize(int dtype) { diff --git a/src/operator/nn/mkldnn/mkldnn_softmax.cc b/src/operator/nn/mkldnn/mkldnn_softmax.cc index acfa358a796e..2943c84670e6 100644 --- a/src/operator/nn/mkldnn/mkldnn_softmax.cc +++ b/src/operator/nn/mkldnn/mkldnn_softmax.cc @@ -32,6 +32,13 @@ namespace mxnet { namespace op { +bool SupportMKLDNNSoftmax(const SoftmaxParam ¶m) { + if (param.temperature != 1.0) { + return false; + } + return true; +} + void MKLDNNSoftmaxForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, const NDArray &in_data, const OpReqType &req, const NDArray &out_data) { diff --git a/src/operator/nn/softmax.cc b/src/operator/nn/softmax.cc index 5500511ca2db..3fbd20d42a03 100644 --- a/src/operator/nn/softmax.cc +++ b/src/operator/nn/softmax.cc @@ -39,7 +39,8 @@ static void SoftmaxComputeExCPU(const nnvm::NodeAttrs& attrs, const std::vector& req, const std::vector& outputs) { // It seems MKLDNN softmax doesn't support training. - if (SupportMKLDNN(inputs[0]) && !ctx.is_train) { + const SoftmaxParam& param = nnvm::get(attrs.parsed); + if (SupportMKLDNN(inputs[0]) && !ctx.is_train && SupportMKLDNNSoftmax(param)) { MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs); MKLDNNSoftmaxForward(attrs, ctx, inputs[0], req[0], outputs[0]); auto fn = SoftmaxCompute; From 01466506d2dbe25eef2dc3be91584f91f8c7f757 Mon Sep 17 00:00:00 2001 From: Lin Yuan Date: Fri, 29 Jun 2018 15:44:23 -0700 Subject: [PATCH 10/19] Change temperature param type to generic DType --- src/operator/contrib/ctc_loss-inl.h | 4 ++-- src/operator/nn/softmax-inl.h | 28 ++++++++++++++++------------ 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/src/operator/contrib/ctc_loss-inl.h b/src/operator/contrib/ctc_loss-inl.h index f96118c8646a..0e7b63e58fb3 100644 --- a/src/operator/contrib/ctc_loss-inl.h +++ b/src/operator/contrib/ctc_loss-inl.h @@ -409,7 +409,7 @@ class CTCLossOp : public Operator { // since the input is activation before softmax and cudnn ctc takes softmax // apply softmax to inputs first. - mxnet_op::Softmax(s, data.dptr_, prob.dptr_, data.shape_, 2, 1.0f); + mxnet_op::Softmax(s, data.dptr_, prob.dptr_, data.shape_, 2, 1.0); CUDNN_CALL(cudnnCTCLoss(s->dnn_handle_, prob_desc_, @@ -427,7 +427,7 @@ class CTCLossOp : public Operator { if (req_grad) { mxnet_op::SoftmaxGrad(s, - prob.dptr_, grad.dptr_, grad.dptr_, data.shape_, 2, 1.0f); + prob.dptr_, grad.dptr_, grad.dptr_, data.shape_, 2, 1.0); Assign(grad, mxnet::kWriteInplace, grad * alphabet_size); } } diff --git a/src/operator/nn/softmax-inl.h b/src/operator/nn/softmax-inl.h index bc61a03d66e7..cad36e209fd7 100644 --- a/src/operator/nn/softmax-inl.h +++ b/src/operator/nn/softmax-inl.h @@ -53,7 +53,7 @@ struct log_softmax_fwd { template inline void Softmax(Stream *s, DType *in, DType *out, - Shape shape, int axis, const float temperature) { + Shape shape, int axis, const DType temperature) { index_t M = shape[axis]; index_t N = shape.Size()/M; Shape stride = calc_stride(shape); @@ -111,7 +111,7 @@ struct log_softmax_bwd { template inline void SoftmaxGrad(Stream *s, DType *out, DType *ograd, DType *igrad, Shape shape, int axis, - const float temperature) { + const DType temperature) { index_t M = shape[axis]; index_t N = shape.Size()/M; Shape stride = calc_stride(shape); @@ -145,7 +145,7 @@ inline void SoftmaxGrad(Stream *s, DType *out, DType *ograd, template __global__ void softmax_compute_kernel(DType *in, DType *out, index_t M, int axis, Shape sshape, Shape stride, - const float temperature) { + const DType temperature) { const unsigned x_size = 1 << x_bits; __shared__ DType smem[x_size]; index_t sa = stride[axis]; @@ -179,7 +179,7 @@ __global__ void softmax_compute_kernel(DType *in, DType *out, index_t M, int axi template inline void Softmax(Stream *s, DType *in, DType *out, - Shape shape, int axis, const float temperature) { + Shape shape, int axis, const DType temperature) { const int x_bits = 7; const int x_size = 1 << x_bits; index_t M = shape[axis]; @@ -198,7 +198,7 @@ inline void Softmax(Stream *s, DType *in, DType *out, template __global__ void softmax_gradient_kernel(DType *out, DType *ograd, DType *igrad, index_t M, int axis, Shape sshape, - Shape stride, const float temperature) { + Shape stride, const DType temperature) { const unsigned x_size = 1 << x_bits; __shared__ DType smem[x_size]; index_t sa = stride[axis]; @@ -224,7 +224,7 @@ __global__ void softmax_gradient_kernel(DType *out, DType *ograd, DType *igrad, template inline void SoftmaxGrad(Stream *s, DType *out, DType *ograd, DType *igrad, Shape shape, int axis, - const float temperature) { + const DType temperature) { const int x_bits = 7; const int x_size = 1 << x_bits; index_t M = shape[axis]; @@ -245,11 +245,11 @@ inline void SoftmaxGrad(Stream *s, DType *out, DType *ograd, struct SoftmaxParam : public dmlc::Parameter { int axis; - float temperature; + double temperature; DMLC_DECLARE_PARAMETER(SoftmaxParam) { DMLC_DECLARE_FIELD(axis).set_default(-1) .describe("The axis along which to compute softmax."); - DMLC_DECLARE_FIELD(temperature).set_default(1.0f) + DMLC_DECLARE_FIELD(temperature).set_default(1.0) .describe("Temperature parameter in softmax"); } }; @@ -269,10 +269,12 @@ void SoftmaxCompute(const nnvm::NodeAttrs& attrs, MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { if (shape.ndim() == 2) { Softmax(ctx.get_stream(), inputs[0].dptr(), - outputs[0].dptr(), shape.get<2>(), axis, param.temperature); + outputs[0].dptr(), shape.get<2>(), axis, + static_cast(param.temperature)); } else { Softmax(ctx.get_stream(), inputs[0].dptr(), - outputs[0].dptr(), shape.get<3>(), axis, param.temperature); + outputs[0].dptr(), shape.get<3>(), axis, + static_cast(param.temperature)); } }); } @@ -294,11 +296,13 @@ void SoftmaxGradCompute(const nnvm::NodeAttrs& attrs, if (shape.ndim() == 2) { SoftmaxGrad(ctx.get_stream(), inputs[1].dptr(), inputs[0].dptr(), outputs[0].dptr(), - shape.get<2>(), axis, param.temperature); + shape.get<2>(), axis, + static_cast(param.temperature)); } else { SoftmaxGrad(ctx.get_stream(), inputs[1].dptr(), inputs[0].dptr(), outputs[0].dptr(), - shape.get<3>(), axis, param.temperature); + shape.get<3>(), axis, + static_cast(param.temperature)); } }); } From ec94e0712aa4971e8dc4c43ffab2bd31250ec451 Mon Sep 17 00:00:00 2001 From: Lin Yuan Date: Fri, 29 Jun 2018 17:12:15 -0700 Subject: [PATCH 11/19] Fix build error in CUDA --- src/operator/nn/softmax-inl.h | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/operator/nn/softmax-inl.h b/src/operator/nn/softmax-inl.h index cad36e209fd7..04c53c19f6cf 100644 --- a/src/operator/nn/softmax-inl.h +++ b/src/operator/nn/softmax-inl.h @@ -145,7 +145,7 @@ inline void SoftmaxGrad(Stream *s, DType *out, DType *ograd, template __global__ void softmax_compute_kernel(DType *in, DType *out, index_t M, int axis, Shape sshape, Shape stride, - const DType temperature) { + const double temperature) { const unsigned x_size = 1 << x_bits; __shared__ DType smem[x_size]; index_t sa = stride[axis]; @@ -164,7 +164,7 @@ __global__ void softmax_compute_kernel(DType *in, DType *out, index_t M, int axi red::sum::SetInitValue(smem[x]); for (index_t i = x; i < M; i += x_size) { - red::sum::Reduce(smem[x], static_cast(expf((in[base + i*sa] - smax)/temperature))); + red::sum::Reduce(smem[x], static_cast(expf((in[base + i*sa] - smax)/static_cast(temperature)))); } __syncthreads(); cuda::Reduce1D(smem); @@ -179,7 +179,7 @@ __global__ void softmax_compute_kernel(DType *in, DType *out, index_t M, int axi template inline void Softmax(Stream *s, DType *in, DType *out, - Shape shape, int axis, const DType temperature) { + Shape shape, int axis, const double temperature) { const int x_bits = 7; const int x_size = 1 << x_bits; index_t M = shape[axis]; @@ -198,7 +198,7 @@ inline void Softmax(Stream *s, DType *in, DType *out, template __global__ void softmax_gradient_kernel(DType *out, DType *ograd, DType *igrad, index_t M, int axis, Shape sshape, - Shape stride, const DType temperature) { + Shape stride, const double temperature) { const unsigned x_size = 1 << x_bits; __shared__ DType smem[x_size]; index_t sa = stride[axis]; @@ -216,7 +216,7 @@ __global__ void softmax_gradient_kernel(DType *out, DType *ograd, DType *igrad, __syncthreads(); for (index_t i = x; i < M; i += x_size) { - igrad[base + i*sa] = OP2::Map(ograd[base + i*sa], out[base + i*sa], ssum)/temperature; + igrad[base + i*sa] = OP2::Map(ograd[base + i*sa], out[base + i*sa], ssum)/static_cast(temperature); } } @@ -224,7 +224,7 @@ __global__ void softmax_gradient_kernel(DType *out, DType *ograd, DType *igrad, template inline void SoftmaxGrad(Stream *s, DType *out, DType *ograd, DType *igrad, Shape shape, int axis, - const DType temperature) { + const double temperature) { const int x_bits = 7; const int x_size = 1 << x_bits; index_t M = shape[axis]; From 631a58a06a103d2f7355d574fbe65b0da702b0d5 Mon Sep 17 00:00:00 2001 From: Lin Yuan Date: Fri, 29 Jun 2018 21:47:08 -0700 Subject: [PATCH 12/19] Fix lint error --- src/operator/nn/softmax-inl.h | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/operator/nn/softmax-inl.h b/src/operator/nn/softmax-inl.h index 04c53c19f6cf..9ce78b317b1c 100644 --- a/src/operator/nn/softmax-inl.h +++ b/src/operator/nn/softmax-inl.h @@ -164,7 +164,8 @@ __global__ void softmax_compute_kernel(DType *in, DType *out, index_t M, int axi red::sum::SetInitValue(smem[x]); for (index_t i = x; i < M; i += x_size) { - red::sum::Reduce(smem[x], static_cast(expf((in[base + i*sa] - smax)/static_cast(temperature)))); + red::sum::Reduce(smem[x], static_cast(expf((in[base + i*sa] - smax)/ + static_cast(temperature)))); } __syncthreads(); cuda::Reduce1D(smem); @@ -216,7 +217,8 @@ __global__ void softmax_gradient_kernel(DType *out, DType *ograd, DType *igrad, __syncthreads(); for (index_t i = x; i < M; i += x_size) { - igrad[base + i*sa] = OP2::Map(ograd[base + i*sa], out[base + i*sa], ssum)/static_cast(temperature); + igrad[base + i*sa] = OP2::Map(ograd[base + i*sa], out[base + i*sa], ssum)/ + static_cast(temperature); } } From a072d662708b1d91e9c6596723e802283ad96445 Mon Sep 17 00:00:00 2001 From: Lin Yuan Date: Mon, 2 Jul 2018 11:30:39 -0700 Subject: [PATCH 13/19] Fix build error in GPU --- src/operator/nn/softmax-inl.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/operator/nn/softmax-inl.h b/src/operator/nn/softmax-inl.h index 9ce78b317b1c..462a1d0d1c4d 100644 --- a/src/operator/nn/softmax-inl.h +++ b/src/operator/nn/softmax-inl.h @@ -174,7 +174,7 @@ __global__ void softmax_compute_kernel(DType *in, DType *out, index_t M, int axi __syncthreads(); for (index_t i = x; i < M; i += x_size) { - out[base + i*sa] = OP::Map((in[base + i*sa] - smax)/temperature, ssum); + out[base + i*sa] = OP::Map((in[base + i*sa] - smax)/static_cast(temperature), ssum); } } From 44c8b0a1761ea156cc2cc505c2a6db051cf5ae32 Mon Sep 17 00:00:00 2001 From: Lin Yuan Date: Mon, 2 Jul 2018 11:33:57 -0700 Subject: [PATCH 14/19] Remove redundant line in description --- src/operator/nn/softmax.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/src/operator/nn/softmax.cc b/src/operator/nn/softmax.cc index 3fbd20d42a03..e855608e7f28 100644 --- a/src/operator/nn/softmax.cc +++ b/src/operator/nn/softmax.cc @@ -79,7 +79,6 @@ The resulting array contains elements in the range (0,1) and the elements along .. math:: softmax(\mathbf{z/t})_j = \frac{e^{z_j/t}}{\sum_{k=1}^K e^{z_k/t}} - softmax(\mathbf{z})_j = \frac{e^{z_j}}{\sum_{k=1}^K e^{z_k}} for :math:`j = 1, ..., K` From 5c7f8d59f01ebc6fd0cfb454bdd3d69e202abb3f Mon Sep 17 00:00:00 2001 From: Lin Yuan Date: Tue, 3 Jul 2018 15:58:20 -0700 Subject: [PATCH 15/19] Add check_symbolic_backward in unittest --- tests/python/unittest/test_operator.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index ecda2348b6a5..b85217d05aea 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -4293,13 +4293,14 @@ def test_new_softmax(): @with_seed() def test_softmax_with_temperature(): for ndim in range(1, 5): + shape = np.random.randint(1, 5, size=ndim) + data = np.random.uniform(-2, 2, size=shape) for temp in range(1, 11): - shape = np.random.randint(1, 5, size=ndim) - axis = np.random.randint(-ndim, ndim) - data = np.random.uniform(-2, 2, size=shape) - sym = mx.sym.softmax(axis=axis, temperature=temp) - expected = np_softmax(data, axis=axis, temperature=temp) - check_symbolic_forward(sym, [data], [expected]) + sym = mx.sym.softmax(axis=0, temperature=temp) + expected_fwd = np_softmax(data, axis=0, temperature=temp) + expected_bwd = np.zeros(shape) + check_symbolic_forward(sym, [data], [expected_fwd], rtol=0.05, atol=1e-3) + check_symbolic_backward(sym, [data], [np.ones(shape)], [expected_bwd], rtol=0.05, atol=1e-3) check_numeric_gradient(sym, [data], rtol=0.05, atol=1e-3) @with_seed() From affa46d7d45efee9879f5cbec0848d6fb1438a34 Mon Sep 17 00:00:00 2001 From: Lin Yuan Date: Thu, 12 Jul 2018 11:46:57 -0700 Subject: [PATCH 16/19] Make temperature argument optional for backward compatibility:w --- src/operator/nn/mkldnn/mkldnn_softmax.cc | 4 +++- src/operator/nn/softmax-inl.h | 16 ++++++++++------ 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/src/operator/nn/mkldnn/mkldnn_softmax.cc b/src/operator/nn/mkldnn/mkldnn_softmax.cc index 2943c84670e6..7268ed39339e 100644 --- a/src/operator/nn/mkldnn/mkldnn_softmax.cc +++ b/src/operator/nn/mkldnn/mkldnn_softmax.cc @@ -33,7 +33,9 @@ namespace mxnet { namespace op { bool SupportMKLDNNSoftmax(const SoftmaxParam ¶m) { - if (param.temperature != 1.0) { + // MKLDNN does not support temperature argument in their softmax function + // now. Need update this once they start to support it. + if (param.temperature.has_value()) { return false; } return true; diff --git a/src/operator/nn/softmax-inl.h b/src/operator/nn/softmax-inl.h index 462a1d0d1c4d..6d8cf9bb0b1e 100644 --- a/src/operator/nn/softmax-inl.h +++ b/src/operator/nn/softmax-inl.h @@ -247,11 +247,11 @@ inline void SoftmaxGrad(Stream *s, DType *out, DType *ograd, struct SoftmaxParam : public dmlc::Parameter { int axis; - double temperature; + dmlc::optional temperature; DMLC_DECLARE_PARAMETER(SoftmaxParam) { DMLC_DECLARE_FIELD(axis).set_default(-1) .describe("The axis along which to compute softmax."); - DMLC_DECLARE_FIELD(temperature).set_default(1.0) + DMLC_DECLARE_FIELD(temperature).set_default(dmlc::optional()) .describe("Temperature parameter in softmax"); } }; @@ -267,16 +267,18 @@ void SoftmaxCompute(const nnvm::NodeAttrs& attrs, CHECK_NE(req[0], kAddTo); const SoftmaxParam& param = nnvm::get(attrs.parsed); int axis = CheckAxis(param.axis, inputs[0].ndim()); + const double temperature = param.temperature.has_value() ? + param.temperature.value() : 1.0; TShape shape = AxisShapeCompact(inputs[0].shape_, &axis, true); MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { if (shape.ndim() == 2) { Softmax(ctx.get_stream(), inputs[0].dptr(), outputs[0].dptr(), shape.get<2>(), axis, - static_cast(param.temperature)); + static_cast(temperature)); } else { Softmax(ctx.get_stream(), inputs[0].dptr(), outputs[0].dptr(), shape.get<3>(), axis, - static_cast(param.temperature)); + static_cast(temperature)); } }); } @@ -293,18 +295,20 @@ void SoftmaxGradCompute(const nnvm::NodeAttrs& attrs, CHECK_NE(req[0], kAddTo); const SoftmaxParam& param = nnvm::get(attrs.parsed); int axis = CheckAxis(param.axis, inputs[0].ndim()); + const double temperature = param.temperature.has_value() ? + param.temperature.value() : 1.0; TShape shape = AxisShapeCompact(inputs[0].shape_, &axis, true); MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { if (shape.ndim() == 2) { SoftmaxGrad(ctx.get_stream(), inputs[1].dptr(), inputs[0].dptr(), outputs[0].dptr(), shape.get<2>(), axis, - static_cast(param.temperature)); + static_cast(temperature)); } else { SoftmaxGrad(ctx.get_stream(), inputs[1].dptr(), inputs[0].dptr(), outputs[0].dptr(), shape.get<3>(), axis, - static_cast(param.temperature)); + static_cast(temperature)); } }); } From 3a40247129f1dfcc39531e8c31df72f1861b6e96 Mon Sep 17 00:00:00 2001 From: Lin Yuan Date: Thu, 12 Jul 2018 14:06:03 -0700 Subject: [PATCH 17/19] Fix build error in cpp package due to dmlc::optional --- cpp-package/scripts/OpWrapperGenerator.py | 1 + 1 file changed, 1 insertion(+) diff --git a/cpp-package/scripts/OpWrapperGenerator.py b/cpp-package/scripts/OpWrapperGenerator.py index 8facde168408..1b5f8b56b924 100644 --- a/cpp-package/scripts/OpWrapperGenerator.py +++ b/cpp-package/scripts/OpWrapperGenerator.py @@ -95,6 +95,7 @@ class Arg: 'int or None':'dmlc::optional',\ 'long':'int64_t',\ 'double':'double',\ + 'double or None':'dmlc::optional',\ 'Shape or None':'dmlc::optional',\ 'string':'const std::string&'} name = '' From f739a58d00e72842d1221bd6a117f36ca825efc4 Mon Sep 17 00:00:00 2001 From: Lin Yuan Date: Thu, 19 Jul 2018 09:03:57 -0700 Subject: [PATCH 18/19] Add a comment in code --- src/operator/nn/softmax-inl.h | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/operator/nn/softmax-inl.h b/src/operator/nn/softmax-inl.h index 6d8cf9bb0b1e..6d2937b75773 100644 --- a/src/operator/nn/softmax-inl.h +++ b/src/operator/nn/softmax-inl.h @@ -71,6 +71,9 @@ inline void Softmax(Stream *s, DType *in, DType *out, } DType sum = DType(0); + // By default temperature is 1.0, and only in reinforcement training + // users would set it to other values. + // Adding a branch here to save the CPU 'divide-by-1' computation at runtime if (temperature == 1.0) { for (index_t j = 0; j < M; ++j) { sum += std::exp(in[base + j*sa] - mmax); @@ -128,6 +131,9 @@ inline void SoftmaxGrad(Stream *s, DType *out, DType *ograd, sum += OP1::Map(ograd[base + j*sa], out[base + j*sa]); } + // By default temperature is 1.0, and only in reinforcement training + // users would set it to other values. + // Adding a branch here to save the CPU 'divide-by-1' computation at runtime if (temperature == 1.0) { for (index_t j = 0; j < M; ++j) { igrad[base + j*sa] = OP2::Map(ograd[base + j*sa], out[base + j*sa], sum); From d20bfde73ec8fa71603dfee63c3cbf15f6786183 Mon Sep 17 00:00:00 2001 From: Lin Yuan Date: Thu, 19 Jul 2018 09:11:42 -0700 Subject: [PATCH 19/19] Fix lint error --- src/operator/nn/softmax-inl.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/operator/nn/softmax-inl.h b/src/operator/nn/softmax-inl.h index 6d2937b75773..64b436e7ea0f 100644 --- a/src/operator/nn/softmax-inl.h +++ b/src/operator/nn/softmax-inl.h @@ -72,7 +72,7 @@ inline void Softmax(Stream *s, DType *in, DType *out, DType sum = DType(0); // By default temperature is 1.0, and only in reinforcement training - // users would set it to other values. + // users would set it to other values. // Adding a branch here to save the CPU 'divide-by-1' computation at runtime if (temperature == 1.0) { for (index_t j = 0; j < M; ++j) { @@ -132,7 +132,7 @@ inline void SoftmaxGrad(Stream *s, DType *out, DType *ograd, } // By default temperature is 1.0, and only in reinforcement training - // users would set it to other values. + // users would set it to other values. // Adding a branch here to save the CPU 'divide-by-1' computation at runtime if (temperature == 1.0) { for (index_t j = 0; j < M; ++j) {