Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add other 15 activation ops #32622

Merged
merged 8 commits into from
May 7, 2021
Merged

Conversation

ZzSean
Copy link
Contributor

@ZzSean ZzSean commented Apr 27, 2021

PR types

Performance optimization

PR changes

OPs

Describe

Unify the implementation of activation operation
本次提交共修改激活算子16个,包括其前向和反向,其中
log 类算子 3 个: log1p, log2, log10
relu 类算子 5 个:brelu, soft_relu, relu6, thresholded_relu, elu
其他 8 个 :stanh, soft_plus, soft_sign, tanh_shrink, hard_shrink, hard_sigmoid, swish, hard_swish

每种类型算子的性能提升近似,因此选取每个类别中的一个算子作为示例进行描述,如下表:
case配置:[16, 128, 257, 257]

OP Name FP32 old FP32 new pro FP16 old FP16 new pro
log1p fwd 1.4057ms 1.3126ms 7.1%
log1p bwd 2.0137ms 1.9063ms 5.6%
--- --- --- --- --- --- ---
brelu fwd 1.4010ms 1.3113ms 6.8% 890.51us 670.06us 32.9
brelu bwd 2.0610ms 1.9057ms 8.1% 1.2630ms 963.01us 31.1
--- --- --- --- --- --- ---
softplus fwd 1.6655ms 1.3196ms 26.2% 1.1606ms 827.52us 40.25%
softplus bwd 2.1733ms 1.9072ms 14.0% 1.6444ms 973.62us 68.9%
softsign fwd 1.4066ms 1.3132ms 7.1% 979.56us 675.96us 44.9%
softsign bwd 2.0145ms 1.9064ms 5.7% 1.3245ms 968.44us 36.8%
stanh fwd 1.6389ms 1.3141ms 24.7% 1.0028ms 678.49us 47.8%
stanh bwd 2.1097ms 1.9065ms 10.7% 1.4248ms 971.33us 46.7%
tanh_shrink fwd 1.6618ms 1.3141ms 26.5% 991.72us 678.24us 46.22%
tanh_shrink bwd 2.1485ms 1.9066ms 12.7% 1.4042ms 970.67us 44.7%
hard_shrink fwd 1.4996ms 1.3114ms 14.4% 885.03us 671.11us 31.9%
hard_shrink bwd 2.1071ms 1.9058ms 10.6% 1.3594ms 964.05us 41.0%
hard_sigmoid fwd 1.3994ms 1.3115ms 6.7% 922.11us 670.52us 37.5%
hard_sigmoid bwd 2.0665ms 1.9060ms 8.4% 1.2720ms 961.98us 32.2%
swish fwd 1.4200ms 1.3148ms 8.0% 1.0397ms 680.82us 52.7%
swish bwd 2.1792ms 1.9072ms 14.3% 1.9386ms 972.83us 99.3%
hard_swish fwd 1.4099ms 1.3140ms 7.3% 954.43us 678.90us 40.6%
hard_swish bwd 2.0657ms 1.9070ms 8.3% 1.3536ms 971.23us 39.4%

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

zhupengyang
zhupengyang previously approved these changes Apr 27, 2021
Copy link
Contributor

@zhupengyang zhupengyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

paddle/fluid/operators/activation_op.cu Outdated Show resolved Hide resolved
paddle/fluid/operators/activation_op.cu Outdated Show resolved Hide resolved
paddle/fluid/operators/activation_op.cu Outdated Show resolved Hide resolved
__device__ __forceinline__ T operator()(const T* args) const {
MPType x = static_cast<MPType>(args[0]);
MPType t = static_cast<MPType>(threshold);
MPType temp = (x > -t && x < t) ? x : (x <= -t ? -t : t);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上,这种写法会需要2次、3次比较。可以看看:

MPType min = x < t ? x : t;
MPType max = min > -t ? min : -t;

不过可能性能也没什么差别。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

__device__ __forceinline__ T operator()(const T* args) const {
MPType dout = static_cast<MPType>(args[0]);
MPType out = static_cast<MPType>(args[1]);
MPType t = static_cast<MPType>(threshold);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

所以这个算子也是要求threshold > 0的吧?不知道OP实现里面有没有检查,或者加个注释说明下?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

加了注释,但是我看在api层面好像没有对threshold的正负做限制,感觉也需要加一下,不然为负值时,结果会不确定。

paddle/fluid/operators/activation_op.cu Outdated Show resolved Hide resolved
paddle/fluid/operators/activation_op.cu Outdated Show resolved Hide resolved
paddle/fluid/operators/activation_op.cu Show resolved Hide resolved
MPType temp_x_neg = static_cast<MPType>(x <= zero);
return static_cast<T>(
dout * (temp_a_pos * temp_x_pos + temp_a_pos * temp_x_neg * a * exp(x) +
temp_a_neg * temp_x_pos * (one + a * exp(x))));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

居然这么复杂了。。。

// dx = 0, when x <= -3
// dout , when x >= 3
// dout * (x / 3 + 0.5), otherwise
// Inputs: args[0], the input dout
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

公式里面使用thresholdoffsetscale。貌似python端没把这个几个参数暴露出来,但是op用了这几个参数。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改公式,并注释几个变量的默认值

@ZzSean ZzSean force-pushed the activation_op_impl branch from a25f611 to 1fdb9dc Compare April 30, 2021 08:50
Copy link
Contributor

@Xreki Xreki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM and great work~

return {{"threshold", &threshold}};
}

// hadrshrink(x) = (x > -threshold && x < threshold) ? 0 : x
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

有个typo,另外threshold也应大于0。

}

// hard_sigmoid(x) = 0, when x <= -3
// 1, when x >= 3
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

代码里面是和0、1比较,公式里面是和-3、3比较,是slope=6、offset=-3

@Xreki Xreki merged commit b2160e7 into PaddlePaddle:develop May 7, 2021
@ZzSean ZzSean deleted the activation_op_impl branch May 7, 2021 05:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants