diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index ec32adf90000..e44c9677e00b 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -172,5 +172,6 @@ List of Contributors * [Thomas Delteil](https://github.com/ThomasDelteil) * [Jesse Brizzi](https://github.com/jessebrizzi) * [Hang Zhang](http://hangzh.com) +* [Lin Yuan](https://github.com/apeforest) * [Kou Ding](https://github.com/chinakook) * [Istvan Fehervari](https://github.com/ifeherva) diff --git a/cpp-package/scripts/OpWrapperGenerator.py b/cpp-package/scripts/OpWrapperGenerator.py index 8facde168408..1b5f8b56b924 100644 --- a/cpp-package/scripts/OpWrapperGenerator.py +++ b/cpp-package/scripts/OpWrapperGenerator.py @@ -95,6 +95,7 @@ class Arg: 'int or None':'dmlc::optional',\ 'long':'int64_t',\ 'double':'double',\ + 'double or None':'dmlc::optional',\ 'Shape or None':'dmlc::optional',\ 'string':'const std::string&'} name = '' diff --git a/src/operator/contrib/ctc_loss-inl.h b/src/operator/contrib/ctc_loss-inl.h index ef58c519aa9c..0e7b63e58fb3 100644 --- a/src/operator/contrib/ctc_loss-inl.h +++ b/src/operator/contrib/ctc_loss-inl.h @@ -409,7 +409,7 @@ class CTCLossOp : public Operator { // since the input is activation before softmax and cudnn ctc takes softmax // apply softmax to inputs first. - mxnet_op::Softmax(s, data.dptr_, prob.dptr_, data.shape_, 2); + mxnet_op::Softmax(s, data.dptr_, prob.dptr_, data.shape_, 2, 1.0); CUDNN_CALL(cudnnCTCLoss(s->dnn_handle_, prob_desc_, @@ -427,7 +427,7 @@ class CTCLossOp : public Operator { if (req_grad) { mxnet_op::SoftmaxGrad(s, - prob.dptr_, grad.dptr_, grad.dptr_, data.shape_, 2); + prob.dptr_, grad.dptr_, grad.dptr_, data.shape_, 2, 1.0); Assign(grad, mxnet::kWriteInplace, grad * alphabet_size); } } diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h b/src/operator/nn/mkldnn/mkldnn_base-inl.h index f77d113dd1d7..bbfb873ee86a 100644 --- a/src/operator/nn/mkldnn/mkldnn_base-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h @@ -146,9 +146,11 @@ namespace op { struct ActivationParam; struct ConvolutionParam; struct DeconvolutionParam; +struct SoftmaxParam; bool SupportMKLDNNAct(const ActivationParam& param); bool SupportMKLDNNConv(const ConvolutionParam& params, const NDArray &input); bool SupportMKLDNNDeconv(const DeconvolutionParam& params, const NDArray &input); +bool SupportMKLDNNSoftmax(const SoftmaxParam& param); } static int GetTypeSize(int dtype) { diff --git a/src/operator/nn/mkldnn/mkldnn_softmax.cc b/src/operator/nn/mkldnn/mkldnn_softmax.cc index acfa358a796e..7268ed39339e 100644 --- a/src/operator/nn/mkldnn/mkldnn_softmax.cc +++ b/src/operator/nn/mkldnn/mkldnn_softmax.cc @@ -32,6 +32,15 @@ namespace mxnet { namespace op { +bool SupportMKLDNNSoftmax(const SoftmaxParam ¶m) { + // MKLDNN does not support temperature argument in their softmax function + // now. Need update this once they start to support it. + if (param.temperature.has_value()) { + return false; + } + return true; +} + void MKLDNNSoftmaxForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, const NDArray &in_data, const OpReqType &req, const NDArray &out_data) { diff --git a/src/operator/nn/softmax-inl.h b/src/operator/nn/softmax-inl.h index 080bc08852c3..64b436e7ea0f 100644 --- a/src/operator/nn/softmax-inl.h +++ b/src/operator/nn/softmax-inl.h @@ -53,7 +53,7 @@ struct log_softmax_fwd { template inline void Softmax(Stream *s, DType *in, DType *out, - Shape shape, int axis) { + Shape shape, int axis, const DType temperature) { index_t M = shape[axis]; index_t N = shape.Size()/M; Shape stride = calc_stride(shape); @@ -71,12 +71,25 @@ inline void Softmax(Stream *s, DType *in, DType *out, } DType sum = DType(0); - for (index_t j = 0; j < M; ++j) { - sum += std::exp(in[base + j*sa] - mmax); - } + // By default temperature is 1.0, and only in reinforcement training + // users would set it to other values. + // Adding a branch here to save the CPU 'divide-by-1' computation at runtime + if (temperature == 1.0) { + for (index_t j = 0; j < M; ++j) { + sum += std::exp(in[base + j*sa] - mmax); + } + + for (index_t j = 0; j < M; ++j) { + out[base + j*sa] = OP::Map(in[base + j*sa] - mmax, sum); + } + } else { + for (index_t j = 0; j < M; ++j) { + sum += std::exp((in[base + j*sa] - mmax)/temperature); + } - for (index_t j = 0; j < M; ++j) { - out[base + j*sa] = OP::Map(in[base + j*sa] - mmax, sum); + for (index_t j = 0; j < M; ++j) { + out[base + j*sa] = OP::Map((in[base + j*sa] - mmax)/temperature, sum); + } } } } @@ -100,7 +113,8 @@ struct log_softmax_bwd { template inline void SoftmaxGrad(Stream *s, DType *out, DType *ograd, - DType *igrad, Shape shape, int axis) { + DType *igrad, Shape shape, int axis, + const DType temperature) { index_t M = shape[axis]; index_t N = shape.Size()/M; Shape stride = calc_stride(shape); @@ -117,8 +131,17 @@ inline void SoftmaxGrad(Stream *s, DType *out, DType *ograd, sum += OP1::Map(ograd[base + j*sa], out[base + j*sa]); } - for (index_t j = 0; j < M; ++j) { - igrad[base + j*sa] = OP2::Map(ograd[base + j*sa], out[base + j*sa], sum); + // By default temperature is 1.0, and only in reinforcement training + // users would set it to other values. + // Adding a branch here to save the CPU 'divide-by-1' computation at runtime + if (temperature == 1.0) { + for (index_t j = 0; j < M; ++j) { + igrad[base + j*sa] = OP2::Map(ograd[base + j*sa], out[base + j*sa], sum); + } + } else { + for (index_t j = 0; j < M; ++j) { + igrad[base + j*sa] = OP2::Map(ograd[base + j*sa], out[base + j*sa], sum)/temperature; + } } } } @@ -127,7 +150,8 @@ inline void SoftmaxGrad(Stream *s, DType *out, DType *ograd, #ifdef __CUDACC__ template __global__ void softmax_compute_kernel(DType *in, DType *out, index_t M, int axis, - Shape sshape, Shape stride) { + Shape sshape, Shape stride, + const double temperature) { const unsigned x_size = 1 << x_bits; __shared__ DType smem[x_size]; index_t sa = stride[axis]; @@ -146,7 +170,8 @@ __global__ void softmax_compute_kernel(DType *in, DType *out, index_t M, int axi red::sum::SetInitValue(smem[x]); for (index_t i = x; i < M; i += x_size) { - red::sum::Reduce(smem[x], static_cast(expf(in[base + i*sa] - smax))); + red::sum::Reduce(smem[x], static_cast(expf((in[base + i*sa] - smax)/ + static_cast(temperature)))); } __syncthreads(); cuda::Reduce1D(smem); @@ -155,13 +180,13 @@ __global__ void softmax_compute_kernel(DType *in, DType *out, index_t M, int axi __syncthreads(); for (index_t i = x; i < M; i += x_size) { - out[base + i*sa] = OP::Map(in[base + i*sa] - smax, ssum); + out[base + i*sa] = OP::Map((in[base + i*sa] - smax)/static_cast(temperature), ssum); } } template inline void Softmax(Stream *s, DType *in, DType *out, - Shape shape, int axis) { + Shape shape, int axis, const double temperature) { const int x_bits = 7; const int x_size = 1 << x_bits; index_t M = shape[axis]; @@ -172,7 +197,7 @@ inline void Softmax(Stream *s, DType *in, DType *out, softmax_compute_kernel <<::GetStream(s)>>>( - in, out, M, axis, sshape, stride); + in, out, M, axis, sshape, stride, temperature); MSHADOW_CUDA_POST_KERNEL_CHECK(softmax_compute_kernel); } @@ -180,7 +205,7 @@ inline void Softmax(Stream *s, DType *in, DType *out, template __global__ void softmax_gradient_kernel(DType *out, DType *ograd, DType *igrad, index_t M, int axis, Shape sshape, - Shape stride) { + Shape stride, const double temperature) { const unsigned x_size = 1 << x_bits; __shared__ DType smem[x_size]; index_t sa = stride[axis]; @@ -198,14 +223,16 @@ __global__ void softmax_gradient_kernel(DType *out, DType *ograd, DType *igrad, __syncthreads(); for (index_t i = x; i < M; i += x_size) { - igrad[base + i*sa] = OP2::Map(ograd[base + i*sa], out[base + i*sa], ssum); + igrad[base + i*sa] = OP2::Map(ograd[base + i*sa], out[base + i*sa], ssum)/ + static_cast(temperature); } } template inline void SoftmaxGrad(Stream *s, DType *out, DType *ograd, - DType *igrad, Shape shape, int axis) { + DType *igrad, Shape shape, int axis, + const double temperature) { const int x_bits = 7; const int x_size = 1 << x_bits; index_t M = shape[axis]; @@ -216,7 +243,7 @@ inline void SoftmaxGrad(Stream *s, DType *out, DType *ograd, softmax_gradient_kernel <<::GetStream(s)>>>( - out, ograd, igrad, M, axis, sshape, stride); + out, ograd, igrad, M, axis, sshape, stride, temperature); MSHADOW_CUDA_POST_KERNEL_CHECK(softmax_gradient_kernel); } #endif @@ -226,9 +253,12 @@ inline void SoftmaxGrad(Stream *s, DType *out, DType *ograd, struct SoftmaxParam : public dmlc::Parameter { int axis; + dmlc::optional temperature; DMLC_DECLARE_PARAMETER(SoftmaxParam) { DMLC_DECLARE_FIELD(axis).set_default(-1) .describe("The axis along which to compute softmax."); + DMLC_DECLARE_FIELD(temperature).set_default(dmlc::optional()) + .describe("Temperature parameter in softmax"); } }; @@ -243,14 +273,18 @@ void SoftmaxCompute(const nnvm::NodeAttrs& attrs, CHECK_NE(req[0], kAddTo); const SoftmaxParam& param = nnvm::get(attrs.parsed); int axis = CheckAxis(param.axis, inputs[0].ndim()); + const double temperature = param.temperature.has_value() ? + param.temperature.value() : 1.0; TShape shape = AxisShapeCompact(inputs[0].shape_, &axis, true); MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { if (shape.ndim() == 2) { Softmax(ctx.get_stream(), inputs[0].dptr(), - outputs[0].dptr(), shape.get<2>(), axis); + outputs[0].dptr(), shape.get<2>(), axis, + static_cast(temperature)); } else { Softmax(ctx.get_stream(), inputs[0].dptr(), - outputs[0].dptr(), shape.get<3>(), axis); + outputs[0].dptr(), shape.get<3>(), axis, + static_cast(temperature)); } }); } @@ -267,16 +301,20 @@ void SoftmaxGradCompute(const nnvm::NodeAttrs& attrs, CHECK_NE(req[0], kAddTo); const SoftmaxParam& param = nnvm::get(attrs.parsed); int axis = CheckAxis(param.axis, inputs[0].ndim()); + const double temperature = param.temperature.has_value() ? + param.temperature.value() : 1.0; TShape shape = AxisShapeCompact(inputs[0].shape_, &axis, true); MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { if (shape.ndim() == 2) { SoftmaxGrad(ctx.get_stream(), inputs[1].dptr(), inputs[0].dptr(), outputs[0].dptr(), - shape.get<2>(), axis); + shape.get<2>(), axis, + static_cast(temperature)); } else { SoftmaxGrad(ctx.get_stream(), inputs[1].dptr(), inputs[0].dptr(), outputs[0].dptr(), - shape.get<3>(), axis); + shape.get<3>(), axis, + static_cast(temperature)); } }); } diff --git a/src/operator/nn/softmax.cc b/src/operator/nn/softmax.cc index e9b104f12868..e855608e7f28 100644 --- a/src/operator/nn/softmax.cc +++ b/src/operator/nn/softmax.cc @@ -39,7 +39,8 @@ static void SoftmaxComputeExCPU(const nnvm::NodeAttrs& attrs, const std::vector& req, const std::vector& outputs) { // It seems MKLDNN softmax doesn't support training. - if (SupportMKLDNN(inputs[0]) && !ctx.is_train) { + const SoftmaxParam& param = nnvm::get(attrs.parsed); + if (SupportMKLDNN(inputs[0]) && !ctx.is_train && SupportMKLDNNSoftmax(param)) { MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs); MKLDNNSoftmaxForward(attrs, ctx, inputs[0], req[0], outputs[0]); auto fn = SoftmaxCompute; @@ -77,10 +78,12 @@ MXNET_OPERATOR_REGISTER_UNARY(softmax) The resulting array contains elements in the range (0,1) and the elements along the given axis sum up to 1. .. math:: - softmax(\mathbf{z})_j = \frac{e^{z_j}}{\sum_{k=1}^K e^{z_k}} + softmax(\mathbf{z/t})_j = \frac{e^{z_j/t}}{\sum_{k=1}^K e^{z_k/t}} for :math:`j = 1, ..., K` +t is the temperature parameter in softmax function. By default, t equals 1.0 + Example:: x = [[ 1. 1. 1.] diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 44fffdd3b474..1c50d7238e9c 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -260,11 +260,11 @@ def test_rnnrelu_dropout(): out = exe.forward(is_train=True) out[0].wait_to_read() -def np_softmax(x, axis=-1): +def np_softmax(x, axis=-1, temperature=1.0): # fix for old numpy on Travis not supporting keepdims # x = x - np.max(x, axis=-1, keepdims=True) x = x - np.max(x, axis=axis, keepdims=True) - x = np.exp(x) + x = np.exp(x/temperature) # x /= np.sum(x, axis=-1, keepdims=True) x /= np.sum(x, axis=axis, keepdims=True) return x @@ -4319,6 +4319,18 @@ def test_new_softmax(): check_symbolic_forward(sym, [data], [np_softmax(data, axis=axis)]) check_numeric_gradient(sym, [data], rtol=0.05, atol=1e-3) +@with_seed() +def test_softmax_with_temperature(): + for ndim in range(1, 5): + shape = np.random.randint(1, 5, size=ndim) + data = np.random.uniform(-2, 2, size=shape) + for temp in range(1, 11): + sym = mx.sym.softmax(axis=0, temperature=temp) + expected_fwd = np_softmax(data, axis=0, temperature=temp) + expected_bwd = np.zeros(shape) + check_symbolic_forward(sym, [data], [expected_fwd], rtol=0.05, atol=1e-3) + check_symbolic_backward(sym, [data], [np.ones(shape)], [expected_bwd], rtol=0.05, atol=1e-3) + check_numeric_gradient(sym, [data], rtol=0.05, atol=1e-3) @with_seed() def test_log_softmax():