Skip to content

Commit

Permalink
Added Mish activation
Browse files Browse the repository at this point in the history
  • Loading branch information
a-sidorova committed Aug 4, 2020
1 parent 36f650a commit c14d8bd
Show file tree
Hide file tree
Showing 16 changed files with 217 additions and 10 deletions.
1 change: 1 addition & 0 deletions include/mkldnn.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@ enum algorithm {
eltwise_clamp = mkldnn_eltwise_clamp,
eltwise_not = mkldnn_eltwise_not,
eltwise_swish = mkldnn_eltwise_swish,
eltwise_mish = mkldnn_eltwise_mish,
depthwise_scale_shift = mkldnn_depthwise_scale_shift,
depthwise_prelu = mkldnn_depthwise_prelu,
lrn_across_channels = mkldnn_lrn_across_channels,
Expand Down
2 changes: 2 additions & 0 deletions include/mkldnn_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,8 @@ typedef enum {
mkldnn_eltwise_not = 0xef,
/** Eltwise: swish */
mkldnn_eltwise_swish = 0xff,
/** Eltwise: mish */
mkldnn_eltwise_mish = 0x1f0,
/** Max pooling */
mkldnn_pooling_max = 0x1ff,
/** Average pooling include padding */
Expand Down
1 change: 1 addition & 0 deletions src/common/c_types_map.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ namespace alg_kind {
const alg_kind_t eltwise_clamp = mkldnn_eltwise_clamp;
const alg_kind_t eltwise_not = mkldnn_eltwise_not;
const alg_kind_t eltwise_swish = mkldnn_eltwise_swish;
const alg_kind_t eltwise_mish = mkldnn_eltwise_mish;
const alg_kind_t depthwise_scale_shift = mkldnn_depthwise_scale_shift;
const alg_kind_t depthwise_prelu = mkldnn_depthwise_prelu;
const alg_kind_t pooling_max = mkldnn_pooling_max;
Expand Down
3 changes: 2 additions & 1 deletion src/common/eltwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ status_t eltwise_desc_init(eltwise_desc_t *eltwise_desc, prop_kind_t prop_kind,
&& one_of(alg_kind, eltwise_relu, eltwise_tanh, eltwise_elu,
eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear,
eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic,
eltwise_exp, eltwise_gelu, eltwise_clamp, eltwise_not, eltwise_swish)
eltwise_exp, eltwise_gelu, eltwise_clamp, eltwise_not, eltwise_swish,
eltwise_mish)
&& IMPLICATION(prop_kind == backward_data, diff_data_desc != nullptr);
if (!args_ok) return invalid_arguments;

Expand Down
7 changes: 7 additions & 0 deletions src/common/math_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,13 @@ inline U swish_bwd(T dd, T s, A alpha) {
return dd * (v + s * alpha * v * (1 - v));
}

template <typename T,
typename U = typename utils::remove_reference<T>::type>
inline U mish_fwd(T s) {
float v = ::log1pf(::expf((float)s));
return (U)(s * ::tanhf(v));
}

template <typename T, typename A,
typename U = typename utils::remove_reference<T>::type>
inline U scale_shift_fwd(T s_val, A w_val, A b_val) {
Expand Down
1 change: 1 addition & 0 deletions src/common/mkldnn_debug.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,7 @@ const char *mkldnn_alg_kind2str(mkldnn_alg_kind_t v) {
if (v == mkldnn_eltwise_clamp) return "eltwise_clamp";
if (v == mkldnn_eltwise_not) return "eltwise_not";
if (v == mkldnn_eltwise_swish) return "eltwise_swish";
if (v == mkldnn_eltwise_mish) return "eltwise_mish";
if (v == mkldnn_pooling_max) return "pooling_max";
if (v == mkldnn_pooling_avg_include_padding) return "pooling_avg_include_padding";
if (v == mkldnn_pooling_avg_exclude_padding) return "pooling_avg_exclude_padding";
Expand Down
3 changes: 2 additions & 1 deletion src/common/primitive_attr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ status_t post_ops_t::append_eltwise(float scale, alg_kind_t alg, float alpha,
bool known_alg = one_of(alg, eltwise_relu, eltwise_tanh, eltwise_elu,
eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear,
eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic,
eltwise_exp, eltwise_gelu, eltwise_clamp, eltwise_not, eltwise_swish);
eltwise_exp, eltwise_gelu, eltwise_clamp, eltwise_not, eltwise_swish,
eltwise_mish);
if (!known_alg)
return invalid_arguments;

Expand Down
177 changes: 175 additions & 2 deletions src/cpu/jit_uni_eltwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,117 @@ void jit_uni_eltwise_injector_f32<isa>::swish_compute_vector(
h->uni_vmulps(vmm_src, vmm_src, vmm_aux0);
}

template <cpu_isa_t isa>
void jit_uni_eltwise_injector_f32<isa>::mish_compute_vector(
const Vmm &vmm_src) {
// Save src data on stack for later usage
h->sub(h->rsp, vlen);
h->uni_vmovups(h->ptr[h->rsp], vmm_src);

// soft_relu - ln(1+exp(x))
// duplicate src
h->uni_vmovups(vmm_aux2, vmm_src);

h->uni_vminps(vmm_src, vmm_src, table_val(25));
h->uni_vmaxps(vmm_src, vmm_src, table_val(26));
h->uni_vmovups(vmm_aux1, vmm_src);
// calculate exp(x)
// fx = x * log2ef + 0.5
h->uni_vmulps(vmm_src, vmm_src, table_val(2));
h->uni_vaddps(vmm_src, vmm_src, table_val(1));

// tmp = floorf(fx)
h->uni_vroundps(vmm_aux0, vmm_src, _op_floor);

// keep fx for further computations
h->uni_vmovups(vmm_src, vmm_aux0); //vmm_src = fx
// calculation fx * ln2
h->uni_vmulps(vmm_aux0, vmm_aux0, table_val(3));
// x = x - fx * ln2
h->uni_vsubps(vmm_aux1, vmm_aux1, vmm_aux0);
// y = p5
h->uni_vmovups(vmm_aux3, table_val(9));
// y = y * x + p4
h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(8));
// y = y * x + p3
h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(7));
// y = y * x + p2
h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(6));
// y = y * x + p1
h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(0));
// y = y * x + p0
h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(5));

// compute 2^(-n)
if (isa == avx512_common) {
h->vmulps(vmm_aux1, vmm_src, table_val(27));
h->vcvtps2dq(vmm_aux1, vmm_aux1);
} else {
h->uni_vcvtps2dq(vmm_aux1, vmm_src);
h->uni_vpsignd(vmm_aux1, vmm_aux1, table_val(27));
}

h->uni_vpaddd(vmm_aux1, vmm_aux1, table_val(4));
h->uni_vpslld(vmm_aux1, vmm_aux1, 23); //vmm_aux1 = 2^-fx
// calculate ln(1 + y)
h->uni_vaddps(vmm_aux3, vmm_aux3, vmm_aux1);
// x = y; y is free; keep x for further computations
h->uni_vmovups(vmm_src, vmm_aux3);
// frexp()
h->uni_vpsrld(vmm_src, vmm_src, 23);
h->uni_vcvtdq2ps(vmm_src, vmm_src);
// got n. where n is x = 2^n * y. y = 0.5 .. 1
h->uni_vsubps(vmm_src, vmm_src, table_val(28));

h->uni_vandps(vmm_aux3, vmm_aux3, table_val(29));
// got y. (mantisa) 0.5 < y < 1
h->uni_vorps(vmm_aux3, vmm_aux3, table_val(30));
// y = y - 1
h->uni_vsubps(vmm_aux3, vmm_aux3, table_val(0));
// y = p8
h->uni_vmovups(vmm_aux1, table_val(39));
// y = y * x + p7
h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(38));
// y = y * x + p6
h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(37));
// y = y * x + p5
h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(36));
// y = y * x + p4
h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(35));
// y = y * x + p3
h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(34));
// y = y * x + p2
h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(33));
// y = y * x + p1
h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(32));
// y = y * x + p0 ; p0 = 0
h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(31));
//calculate ln(2) * n
h->uni_vmulps(vmm_src, vmm_src, table_val(3));
h->uni_vaddps(vmm_aux1, vmm_aux1, vmm_src);
h->uni_vaddps(vmm_aux1, vmm_aux1, vmm_aux0);

// get vmm_mask = src > max logf
h->uni_vmovups(vmm_mask, vmm_aux2);
if (isa == avx512_common) {
// y = (x < max log f) ? soft_relu(x) : x
h->vcmpps(k_mask, vmm_mask, table_val(25), _cmp_nle_us);
h->vblendmps(vmm_aux1 | k_mask, vmm_aux1, vmm_aux2);
} else {
// y = (x < max log f) ? soft_relu(x) : x
h->uni_vcmpgtps(vmm_mask, vmm_mask, table_val(25));
h->uni_vblendvps(vmm_aux1, vmm_aux1, vmm_aux2, vmm_mask);
}
h->uni_vmovups(vmm_src, vmm_aux1);

// tanh(ln(1+exp(x)))
tanh_compute_vector(vmm_src);
// x*tanh(ln(1+exp(x)))
h->uni_vmovups(vmm_aux0, h->ptr[h->rsp]);
h->add(h->rsp, vlen);
h->uni_vmulps(vmm_src, vmm_src, vmm_aux0);
}

template <cpu_isa_t isa>
void jit_uni_eltwise_injector_f32<isa>::relu_prepare_table() {
for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(float2int(alpha_));
Expand Down Expand Up @@ -725,6 +836,63 @@ void jit_uni_eltwise_injector_f32<isa>::clamp_prepare_table() {
for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(float2int(beta_));
}

template <cpu_isa_t isa>
void jit_uni_eltwise_injector_f32<isa>::mish_prepare_table() {
const unsigned int cvals[] = {
0x3f800000, // [0] 1.0f
0x3f000000, // [1] 0.5f
0x3fb8aa3b, // [2] log2ef = 1.44269502f
0x3f317218, // [3] ln2f = 0.69314718f
0x0000007f, // [4] 0x7f
// exp(x) polynom
0x3f800001, // [5] p0 = 1.0000001f
0x3efffe85, // [6] p2 = 0.4999887f
0x3e2aaa3e, // [7] p3 = 0.16666505f
0x3d2bb1b1, // [8] p4 = 0.041917507f
0x3c091ec1, // [9] p5 = 0.008369149f
0x42b17218, //[10] logf(FLT_MAX)
0xc2aeac50, //[11] logf(FLT_MIN)
// tanh(x) constants,
0x80000000, //[12] mask to extract sign
0x39ddb3d7, //[13] arg below which tanh(x) = x
0x3f0c9f54, //[14] arg below which pol approx is valid
0x41102cb4, //[15] arg after which tanh(x) = 1
0xc0000000, //[16] -2.0f
0x7fffffff, //[17] mask to make positive
// tanh pol approx
0x3f7fffff, //[18] p0
0xbeaaa9cf, //[19] p1
0x3e085f1f, //[20] p2
0xbd572bda, //[21] p3
0x3c84fd08, //[22] p4
// gelu approx constants
0x3d372713, //[23] 0.044715
0x3f4c4229, //[24] sqrt(2/pi)
// TODO: update values [24] and [25] from comments as they are more precise
0x42b0c0a5, //[25] max logf = 88.3762589f //0x42b17218, //[24] logf(FLT_MAX)
0xc1766666, //[26] min logf = -14.5f //0xc2aeac50, //[25] logf(FLT_MIN)
//
0xbf800000, //[27] is required for sign changing
0x42fc0000, //[28] 126
0x807fffff, //[29] and with (to get 0.5 * mantissa)
0x3f000000, //[30] or with (to get 0.5 * mantissa)
// ln(1 + x) polynomial
0xb2b4637d, //[31] p0 = 0.0000000244f
0x3f7fff8e, //[32] p1 = 0.9999976971f
0xbf001759, //[33] p2 = -0.5002478215f
0x3ea70608, //[34] p3 = 0.3272714505f
0xbea3d7bf, //[35] p4 = -0.3153830071f
0xbe361d04, //[36] p5 = -0.1701777461f
0xbfa8f1e6, //[37] p6 = -1.3254635147f
0xbfe1e812, //[38] p7 = -1.7971917960f
0xbfc4d30e, //[39] p8 = -1.5652673123f
};

for (size_t i = 0; i < sizeof(cvals) / sizeof(cvals[0]); ++i) {
for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(cvals[i]);
}
}

template <cpu_isa_t isa>
int jit_uni_eltwise_injector_f32<isa>::aux_vecs_count(alg_kind_t alg_) {
switch (alg_) {
Expand All @@ -742,6 +910,7 @@ int jit_uni_eltwise_injector_f32<isa>::aux_vecs_count(alg_kind_t alg_) {
case alg_kind::eltwise_gelu: return 5;
case alg_kind::eltwise_clamp: return 0;
case alg_kind::eltwise_swish: return 4;
case alg_kind::eltwise_mish: return 5;
default: assert(!"unsupported eltwise algorithm");
}

Expand Down Expand Up @@ -771,6 +940,7 @@ void jit_uni_eltwise_injector_f32<isa>::compute_body(size_t start_idx,
case eltwise_gelu: gelu_compute_vector(Vmm(idx)); break;
case eltwise_clamp: clamp_compute_vector(Vmm(idx)); break;
case eltwise_swish: swish_compute_vector(Vmm(idx)); break;
case eltwise_mish: mish_compute_vector(Vmm(idx)); break;
default: assert(!"unsupported eltwise algorithm");
}
}
Expand Down Expand Up @@ -812,6 +982,7 @@ void jit_uni_eltwise_injector_f32<isa>::prepare_table(bool gen_table) {
case eltwise_bounded_relu: bounded_relu_prepare_table(); break;
case eltwise_square: break;
case eltwise_clamp: clamp_prepare_table(); break;
case eltwise_mish: mish_prepare_table(); break;
default: assert(!"unsupported eltwise algorithm");
}
}
Expand Down Expand Up @@ -1124,7 +1295,8 @@ struct jit_uni_kernel_fwd_f32: public jit_uni_eltwise_kernel_f32,
assert(utils::one_of(desc.alg_kind, eltwise_tanh, eltwise_elu,
eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear,
eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic,
eltwise_exp, eltwise_gelu, eltwise_clamp, eltwise_swish));
eltwise_exp, eltwise_gelu, eltwise_clamp, eltwise_swish,
eltwise_mish));

preamble();

Expand Down Expand Up @@ -1289,7 +1461,8 @@ status_t jit_uni_eltwise_fwd_t<isa, d_type>::pd_t::init() {
&& utils::one_of(desc()->alg_kind, eltwise_relu, eltwise_tanh,
eltwise_elu, eltwise_square, eltwise_abs, eltwise_sqrt,
eltwise_linear, eltwise_bounded_relu, eltwise_soft_relu,
eltwise_logistic, eltwise_exp, eltwise_gelu, eltwise_clamp, eltwise_swish)
eltwise_logistic, eltwise_exp, eltwise_gelu, eltwise_clamp,
eltwise_swish, eltwise_mish)
&& memory_desc_wrapper(src_pd()).is_dense(true)
&& IMPLICATION(!memory_desc_wrapper(src_pd()).is_dense(false),
math::eltwise_fwd_preserves_zero(desc()->alg_kind, true))
Expand Down
5 changes: 4 additions & 1 deletion src/cpu/jit_uni_eltwise.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ struct jit_uni_eltwise_injector_f32 {
assert(utils::one_of(alg_, eltwise_relu, eltwise_tanh, eltwise_elu,
eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear,
eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic,
eltwise_exp, eltwise_gelu, eltwise_clamp, eltwise_swish));
eltwise_exp, eltwise_gelu, eltwise_clamp, eltwise_swish,
eltwise_mish));
}

// note that eltwise.scale is ignored
Expand Down Expand Up @@ -120,6 +121,7 @@ struct jit_uni_eltwise_injector_f32 {
void gelu_compute_vector(const Vmm &vmm_src);
void clamp_compute_vector(const Vmm &vmm_src);
void swish_compute_vector(const Vmm &vmm_src);
void mish_compute_vector(const Vmm &vmm_src);

void relu_prepare_table();
void elu_prepare_table();
Expand All @@ -129,6 +131,7 @@ struct jit_uni_eltwise_injector_f32 {
void linear_prepare_table();
void bounded_relu_prepare_table();
void clamp_prepare_table();
void mish_prepare_table();
};

struct jit_uni_eltwise_kernel_f32;
Expand Down
7 changes: 6 additions & 1 deletion src/cpu/ref_eltwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ ref_eltwise_scalar_fwd_t::ref_eltwise_scalar_fwd_t(alg_kind_t alg, float alpha,
assert(utils::one_of(alg_, eltwise_relu, eltwise_tanh, eltwise_elu,
eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear,
eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic,
eltwise_exp, eltwise_gelu, eltwise_clamp, eltwise_not, eltwise_swish));
eltwise_exp, eltwise_gelu, eltwise_clamp, eltwise_not, eltwise_swish,
eltwise_mish));
}

ref_eltwise_scalar_fwd_t::ref_eltwise_scalar_fwd_t(
Expand All @@ -61,6 +62,7 @@ float ref_eltwise_scalar_fwd_t::compute_scalar(float s) {
case eltwise_clamp: return clamp_fwd(s, alpha_, beta_);
case eltwise_not: return not_fwd(s);
case eltwise_swish: return swish_fwd(s, alpha_);
case eltwise_mish: return mish_fwd(s);
default: assert(!"unknown eltwise alg_kind");
}

Expand Down Expand Up @@ -96,6 +98,7 @@ void ref_eltwise_fwd_t<data_type>::execute_forward_nCspBc_padded() const {
case eltwise_clamp: d = clamp_fwd(s, alpha, beta); break;
case eltwise_not: d = not_fwd(s); break;
case eltwise_swish: d = swish_fwd(s, alpha); break;
case eltwise_mish: d = mish_fwd(s); break;
default: assert(!"unknown eltwise alg_kind");
}
};
Expand Down Expand Up @@ -205,6 +208,7 @@ void ref_eltwise_fwd_t<data_type>::execute_forward_generic() const {
case eltwise_clamp: d = clamp_fwd(s, alpha, beta); break;
case eltwise_not: d = not_fwd(s); break;
case eltwise_swish: d = swish_fwd(s, alpha); break;
case eltwise_mish: d = mish_fwd(s); break;
default: assert(!"unknown eltwise alg_kind");
}
});
Expand Down Expand Up @@ -299,6 +303,7 @@ void ref_eltwise_fwd_t<data_type>::execute_forward_dense() const {
case eltwise_clamp: d = clamp_fwd(s, alpha, beta); break;
case eltwise_not: d = not_fwd(s); break;
case eltwise_swish: d = swish_fwd(s, alpha); break;
case eltwise_mish: d = mish_fwd(s); break;
default: assert(!"unknown eltwise alg_kind");
}
});
Expand Down
3 changes: 2 additions & 1 deletion tests/gtests/test_convolution_eltwise_forward_bf16.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ TEST_P(convolution_test, TestConvolutionEltwise)
EXPAND_ARGS(PARAMS_CONV(eltwise_soft_relu, __VA_ARGS__)), \
EXPAND_ARGS(PARAMS_CONV(eltwise_logistic, __VA_ARGS__)), \
EXPAND_ARGS(PARAMS_CONV(eltwise_exp, __VA_ARGS__)), \
EXPAND_ARGS(PARAMS_CONV(eltwise_swish, __VA_ARGS__))
EXPAND_ARGS(PARAMS_CONV(eltwise_swish, __VA_ARGS__)), \
EXPAND_ARGS(PARAMS_CONV(eltwise_mish, __VA_ARGS__))

#define ELTWISE_ALPHA 0.5f
#define ELTWISE_BETA 1.5f
Expand Down
1 change: 1 addition & 0 deletions tests/gtests/test_convolution_eltwise_forward_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ void compute_ref_conv_eltwise_fwd(const test_convolution_sizes_t &c,
case eltwise_gelu: d = gelu_fwd(d); break;
case eltwise_clamp: d = clamp_fwd(d, elt_alpha, elt_beta); break;
case eltwise_swish: d = swish_fwd(d, elt_alpha); break;
case eltwise_mish: d = mish_fwd(d); break;
default: assert(!"unknown alg_kind");
}
}
Expand Down
3 changes: 2 additions & 1 deletion tests/gtests/test_convolution_eltwise_forward_f32.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ TEST_P(convolution_test, TestConvolutionEltwise)
EXPAND_ARGS(PARAMS_CONV(eltwise_bounded_relu, __VA_ARGS__)), \
EXPAND_ARGS(PARAMS_CONV(eltwise_soft_relu, __VA_ARGS__)), \
EXPAND_ARGS(PARAMS_CONV(eltwise_logistic, __VA_ARGS__)), \
EXPAND_ARGS(PARAMS_CONV(eltwise_swish, __VA_ARGS__))
EXPAND_ARGS(PARAMS_CONV(eltwise_swish, __VA_ARGS__)), \
EXPAND_ARGS(PARAMS_CONV(eltwise_mish, __VA_ARGS__))

#define ELTWISE_ALPHA 0.5f
#define ELTWISE_BETA 1.5f
Expand Down
3 changes: 2 additions & 1 deletion tests/gtests/test_convolution_eltwise_forward_x8s8f32s32.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ using convolution_test_s8s8s32f32 =
EXPAND_ARGS(PARAMS_CONV(eltwise_soft_relu, __VA_ARGS__)), \
EXPAND_ARGS(PARAMS_CONV(eltwise_logistic, __VA_ARGS__)), \
EXPAND_ARGS(PARAMS_CONV(eltwise_clamp, __VA_ARGS__)), \
EXPAND_ARGS(PARAMS_CONV(eltwise_swish, __VA_ARGS__))
EXPAND_ARGS(PARAMS_CONV(eltwise_swish, __VA_ARGS__)), \
EXPAND_ARGS(PARAMS_CONV(eltwise_mish, __VA_ARGS__))
// EXPAND_ARGS(PARAMS_CONV(eltwise_exp, __VA_ARGS__))

#define ELTWISE_ALPHA 0.5f
Expand Down
Loading

0 comments on commit c14d8bd

Please sign in to comment.