-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNET-560] Add temperature parameter in Softmax operator #11466
Changes from 9 commits
26a5623
bc2b9ff
38ac907
ce2af77
ab7fbbd
6dc4e42
c4e2de7
595130d
fd83576
8bef2df
0146650
cf1f38c
ec94e07
631a58a
e1332a7
a072d66
44c8b0a
5c7f8d5
bd6a868
affa46d
3a40247
f739a58
d20bfde
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -53,7 +53,7 @@ struct log_softmax_fwd { | |
|
||
template<typename OP, typename DType, int ndim> | ||
inline void Softmax(Stream<cpu> *s, DType *in, DType *out, | ||
Shape<ndim> shape, int axis) { | ||
Shape<ndim> shape, int axis, const float temperature) { | ||
index_t M = shape[axis]; | ||
index_t N = shape.Size()/M; | ||
Shape<ndim> stride = calc_stride(shape); | ||
|
@@ -71,12 +71,22 @@ inline void Softmax(Stream<cpu> *s, DType *in, DType *out, | |
} | ||
|
||
DType sum = DType(0); | ||
for (index_t j = 0; j < M; ++j) { | ||
sum += std::exp(in[base + j*sa] - mmax); | ||
} | ||
if (temperature == 1.0) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could you add a comment why you are branching for 1.0. And also the fact this is not useful for GPU. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. By default the value of temperature is 1.0. Users will use other values only during reinforcement training cases. For CPU, the compiler cannot optimize this "divide-by-1.0" computation at runtime. Therefore I added a branch here. The performance difference is calibrated using an example shown in the Description of this PR. This branch is not added in GPU kernel because branching will add extra overhead for GPU. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I meant in the code :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks to your suggestion, I have added two comments in the code to make it clear for other developers in the future. |
||
for (index_t j = 0; j < M; ++j) { | ||
sum += std::exp(in[base + j*sa] - mmax); | ||
} | ||
|
||
for (index_t j = 0; j < M; ++j) { | ||
out[base + j*sa] = OP::Map(in[base + j*sa] - mmax, sum); | ||
} | ||
} else { | ||
for (index_t j = 0; j < M; ++j) { | ||
sum += std::exp((in[base + j*sa] - mmax)/temperature); | ||
} | ||
|
||
for (index_t j = 0; j < M; ++j) { | ||
out[base + j*sa] = OP::Map(in[base + j*sa] - mmax, sum); | ||
for (index_t j = 0; j < M; ++j) { | ||
out[base + j*sa] = OP::Map((in[base + j*sa] - mmax)/temperature, sum); | ||
} | ||
} | ||
} | ||
} | ||
|
@@ -100,7 +110,8 @@ struct log_softmax_bwd { | |
|
||
template<typename OP1, typename OP2, typename DType, int ndim> | ||
inline void SoftmaxGrad(Stream<cpu> *s, DType *out, DType *ograd, | ||
DType *igrad, Shape<ndim> shape, int axis) { | ||
DType *igrad, Shape<ndim> shape, int axis, | ||
const float temperature) { | ||
index_t M = shape[axis]; | ||
index_t N = shape.Size()/M; | ||
Shape<ndim> stride = calc_stride(shape); | ||
|
@@ -117,8 +128,14 @@ inline void SoftmaxGrad(Stream<cpu> *s, DType *out, DType *ograd, | |
sum += OP1::Map(ograd[base + j*sa], out[base + j*sa]); | ||
} | ||
|
||
for (index_t j = 0; j < M; ++j) { | ||
igrad[base + j*sa] = OP2::Map(ograd[base + j*sa], out[base + j*sa], sum); | ||
if (temperature == 1.0) { | ||
for (index_t j = 0; j < M; ++j) { | ||
igrad[base + j*sa] = OP2::Map(ograd[base + j*sa], out[base + j*sa], sum); | ||
} | ||
} else { | ||
for (index_t j = 0; j < M; ++j) { | ||
igrad[base + j*sa] = OP2::Map(ograd[base + j*sa], out[base + j*sa], sum)/temperature; | ||
} | ||
} | ||
} | ||
} | ||
|
@@ -127,7 +144,8 @@ inline void SoftmaxGrad(Stream<cpu> *s, DType *out, DType *ograd, | |
#ifdef __CUDACC__ | ||
template<int x_bits, typename OP, typename DType, int ndim> | ||
__global__ void softmax_compute_kernel(DType *in, DType *out, index_t M, int axis, | ||
Shape<ndim> sshape, Shape<ndim> stride) { | ||
Shape<ndim> sshape, Shape<ndim> stride, | ||
const float temperature) { | ||
const unsigned x_size = 1 << x_bits; | ||
__shared__ DType smem[x_size]; | ||
index_t sa = stride[axis]; | ||
|
@@ -146,7 +164,7 @@ __global__ void softmax_compute_kernel(DType *in, DType *out, index_t M, int axi | |
|
||
red::sum::SetInitValue(smem[x]); | ||
for (index_t i = x; i < M; i += x_size) { | ||
red::sum::Reduce(smem[x], static_cast<DType>(expf(in[base + i*sa] - smax))); | ||
red::sum::Reduce(smem[x], static_cast<DType>(expf((in[base + i*sa] - smax)/temperature))); | ||
} | ||
__syncthreads(); | ||
cuda::Reduce1D<red::sum, x_bits>(smem); | ||
|
@@ -155,13 +173,13 @@ __global__ void softmax_compute_kernel(DType *in, DType *out, index_t M, int axi | |
__syncthreads(); | ||
|
||
for (index_t i = x; i < M; i += x_size) { | ||
out[base + i*sa] = OP::Map(in[base + i*sa] - smax, ssum); | ||
out[base + i*sa] = OP::Map((in[base + i*sa] - smax)/temperature, ssum); | ||
} | ||
} | ||
|
||
template<typename OP, typename DType, int ndim> | ||
inline void Softmax(Stream<gpu> *s, DType *in, DType *out, | ||
Shape<ndim> shape, int axis) { | ||
Shape<ndim> shape, int axis, const float temperature) { | ||
const int x_bits = 7; | ||
const int x_size = 1 << x_bits; | ||
index_t M = shape[axis]; | ||
|
@@ -172,15 +190,15 @@ inline void Softmax(Stream<gpu> *s, DType *in, DType *out, | |
|
||
softmax_compute_kernel<x_bits, OP, DType, ndim> | ||
<<<N, x_size, 0, mshadow::Stream<gpu>::GetStream(s)>>>( | ||
in, out, M, axis, sshape, stride); | ||
in, out, M, axis, sshape, stride, temperature); | ||
MSHADOW_CUDA_POST_KERNEL_CHECK(softmax_compute_kernel); | ||
} | ||
|
||
|
||
template<int x_bits, typename OP1, typename OP2, typename DType, int ndim> | ||
__global__ void softmax_gradient_kernel(DType *out, DType *ograd, DType *igrad, | ||
index_t M, int axis, Shape<ndim> sshape, | ||
Shape<ndim> stride) { | ||
Shape<ndim> stride, const float temperature) { | ||
const unsigned x_size = 1 << x_bits; | ||
__shared__ DType smem[x_size]; | ||
index_t sa = stride[axis]; | ||
|
@@ -198,14 +216,15 @@ __global__ void softmax_gradient_kernel(DType *out, DType *ograd, DType *igrad, | |
__syncthreads(); | ||
|
||
for (index_t i = x; i < M; i += x_size) { | ||
igrad[base + i*sa] = OP2::Map(ograd[base + i*sa], out[base + i*sa], ssum); | ||
igrad[base + i*sa] = OP2::Map(ograd[base + i*sa], out[base + i*sa], ssum)/temperature; | ||
} | ||
} | ||
|
||
|
||
template<typename OP1, typename OP2, typename DType, int ndim> | ||
inline void SoftmaxGrad(Stream<gpu> *s, DType *out, DType *ograd, | ||
DType *igrad, Shape<ndim> shape, int axis) { | ||
DType *igrad, Shape<ndim> shape, int axis, | ||
const float temperature) { | ||
const int x_bits = 7; | ||
const int x_size = 1 << x_bits; | ||
index_t M = shape[axis]; | ||
|
@@ -216,7 +235,7 @@ inline void SoftmaxGrad(Stream<gpu> *s, DType *out, DType *ograd, | |
|
||
softmax_gradient_kernel<x_bits, OP1, OP2, DType, ndim> | ||
<<<N, x_size, 0, mshadow::Stream<gpu>::GetStream(s)>>>( | ||
out, ograd, igrad, M, axis, sshape, stride); | ||
out, ograd, igrad, M, axis, sshape, stride, temperature); | ||
MSHADOW_CUDA_POST_KERNEL_CHECK(softmax_gradient_kernel); | ||
} | ||
#endif | ||
|
@@ -226,9 +245,12 @@ inline void SoftmaxGrad(Stream<gpu> *s, DType *out, DType *ograd, | |
|
||
struct SoftmaxParam : public dmlc::Parameter<SoftmaxParam> { | ||
int axis; | ||
float temperature; | ||
DMLC_DECLARE_PARAMETER(SoftmaxParam) { | ||
DMLC_DECLARE_FIELD(axis).set_default(-1) | ||
.describe("The axis along which to compute softmax."); | ||
DMLC_DECLARE_FIELD(temperature).set_default(1.0f) | ||
.describe("Temperature parameter in softmax"); | ||
} | ||
}; | ||
|
||
|
@@ -247,10 +269,10 @@ void SoftmaxCompute(const nnvm::NodeAttrs& attrs, | |
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { | ||
if (shape.ndim() == 2) { | ||
Softmax<OP>(ctx.get_stream<xpu>(), inputs[0].dptr<DType>(), | ||
outputs[0].dptr<DType>(), shape.get<2>(), axis); | ||
outputs[0].dptr<DType>(), shape.get<2>(), axis, param.temperature); | ||
} else { | ||
Softmax<OP>(ctx.get_stream<xpu>(), inputs[0].dptr<DType>(), | ||
outputs[0].dptr<DType>(), shape.get<3>(), axis); | ||
outputs[0].dptr<DType>(), shape.get<3>(), axis, param.temperature); | ||
} | ||
}); | ||
} | ||
|
@@ -272,11 +294,11 @@ void SoftmaxGradCompute(const nnvm::NodeAttrs& attrs, | |
if (shape.ndim() == 2) { | ||
SoftmaxGrad<OP1, OP2>(ctx.get_stream<xpu>(), inputs[1].dptr<DType>(), | ||
inputs[0].dptr<DType>(), outputs[0].dptr<DType>(), | ||
shape.get<2>(), axis); | ||
shape.get<2>(), axis, param.temperature); | ||
} else { | ||
SoftmaxGrad<OP1, OP2>(ctx.get_stream<xpu>(), inputs[1].dptr<DType>(), | ||
inputs[0].dptr<DType>(), outputs[0].dptr<DType>(), | ||
shape.get<3>(), axis); | ||
shape.get<3>(), axis, param.temperature); | ||
} | ||
}); | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -77,10 +77,13 @@ MXNET_OPERATOR_REGISTER_UNARY(softmax) | |
The resulting array contains elements in the range (0,1) and the elements along the given axis sum up to 1. | ||
|
||
.. math:: | ||
softmax(\mathbf{z/t})_j = \frac{e^{z_j/t}}{\sum_{k=1}^K e^{z_k/t}} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not remove line 82? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Removed |
||
softmax(\mathbf{z})_j = \frac{e^{z_j}}{\sum_{k=1}^K e^{z_k}} | ||
|
||
for :math:`j = 1, ..., K` | ||
|
||
t is the temperature parameter in softmax function. By default, t equals 1.0 | ||
|
||
Example:: | ||
|
||
x = [[ 1. 1. 1.] | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -258,11 +258,11 @@ def test_rnnrelu_dropout(): | |
out = exe.forward(is_train=True) | ||
out[0].wait_to_read() | ||
|
||
def np_softmax(x, axis=-1): | ||
def np_softmax(x, axis=-1, temperature=1.0): | ||
# fix for old numpy on Travis not supporting keepdims | ||
# x = x - np.max(x, axis=-1, keepdims=True) | ||
x = x - np.max(x, axis=axis, keepdims=True) | ||
x = np.exp(x) | ||
x = np.exp(x/temperature) | ||
# x /= np.sum(x, axis=-1, keepdims=True) | ||
x /= np.sum(x, axis=axis, keepdims=True) | ||
return x | ||
|
@@ -4226,6 +4226,17 @@ def test_new_softmax(): | |
check_symbolic_forward(sym, [data], [np_softmax(data, axis=axis)]) | ||
check_numeric_gradient(sym, [data], rtol=0.05, atol=1e-3) | ||
|
||
@with_seed() | ||
def test_softmax_with_temperature(): | ||
for ndim in range(1, 5): | ||
for temp in range(1, 11): | ||
shape = np.random.randint(1, 5, size=ndim) | ||
axis = np.random.randint(-ndim, ndim) | ||
data = np.random.uniform(-2, 2, size=shape) | ||
sym = mx.sym.softmax(axis=axis, temperature=temp) | ||
expected = np_softmax(data, axis=axis, temperature=temp) | ||
check_symbolic_forward(sym, [data], [expected]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any reason why check_symbolic_backward is not tested? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added check_symbolic_backward test |
||
check_numeric_gradient(sym, [data], rtol=0.05, atol=1e-3) | ||
|
||
@with_seed() | ||
def test_log_softmax(): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we support double precision and half precision too?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a parameter, so it should be of a specific type. I think to ensure the highest precision we should take this parameter in as a double type and cast it to the DType during runtime.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the suggestion. I have changed the data type from float to generic DType