Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add threshold for mish
Browse files Browse the repository at this point in the history
Signed-off-by: Adnios <[email protected]>
  • Loading branch information
Adnios committed Jun 10, 2021
1 parent 7740cca commit fe63505
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 9 deletions.
6 changes: 3 additions & 3 deletions src/common/cuda/rtc/backward_functions-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@ template <typename DType, typename DTypeGrad>
__device__ inline mixed_type<DTypeGrad, DType>
backward_mish(const DTypeGrad grad, const DType val) {
const mixed_type<DTypeGrad, DType> v = val;
const auto softrelu = op::log(1 + exp(v));
const auto tanh = op::tanh(softrelu);
return grad * (tanh + v * sigmoid(v) * (1 - tanh * tanh));
const auto softrelu = op::softrelu(v)
const auto tanh_sr = op::tanh(softrelu);
return grad * (tanh_sr + v * sigmoid(v) * (1 - tanh_sr * tanh_sr));
}
template <typename DType, typename DTypeGrad>
Expand Down
4 changes: 2 additions & 2 deletions src/common/cuda/rtc/forward_functions-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -697,9 +697,9 @@ __device__ inline DType log_sigmoid(const DType val) {
template <typename DType>
__device__ inline DType mish(const DType val) {
if (type_util::has_double_or_integral<DType>::value) {
return val * ::tanh(::log(1 + ::exp(val)));
return val * ::tanh(::softrelu(val));
} else {
return val * ::tanhf(logf(1 + expf(val)));
return val * ::tanhf(::softrelu(val));
}
}
Expand Down
28 changes: 24 additions & 4 deletions src/operator/mshadow_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -415,11 +415,31 @@ MXNET_UNARY_MATH_OP(log_sigmoid, math::log(1.0f / (1.0f + math::exp(-a))));

MXNET_UNARY_MATH_OP(log_sigmoid_grad, 1.0f / (1.0f + math::exp(a)));

MXNET_UNARY_MATH_OP(mish, a * math::tanh(math::log(1.0f + math::exp(a))));
struct mish : public mxnet_op::tunable {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
// reference softrelue
auto softrelu = math::log1p(math::exp(a));
if (a > DType(20.0f)) {
auto softrelu = a;
}
return DType(a * math::tanh(softrelu));
}
};

MXNET_UNARY_MATH_OP(mish_grad, math::tanh(math::log(1.0f + math::exp(a))) +
a * (1.0f / (1.0f + math::exp(-a))) *
(1.0f - math::sqr(math::tanh(math::log(1.0f + math::exp(a))))));
struct mish_grad : public mxnet_op::tunable {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
// Note: the input(a) is x(not y)
auto softrelu = math::log1p(math::exp(a));
if (a > DType(20.0f)) {
auto softrelu = a;
}
auto tanh_sr = math::tanh(softrelu);
auto sr_grad = 1.0f / (1.0f + math::exp(-a));
return DType(tanh_sr + a * sr_grad * (1.0f - tanh_sr * tanh_sr));
}
};

MXNET_UNARY_MATH_OP(softsign, a / (1.0f + math::fabs(a)));

Expand Down

0 comments on commit fe63505

Please sign in to comment.