-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add other 15 activation ops #32622
add other 15 activation ops #32622
Conversation
Thanks for your contribution! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
39644c4
to
92761cf
Compare
92761cf
to
85104d6
Compare
__device__ __forceinline__ T operator()(const T* args) const { | ||
MPType x = static_cast<MPType>(args[0]); | ||
MPType t = static_cast<MPType>(threshold); | ||
MPType temp = (x > -t && x < t) ? x : (x <= -t ? -t : t); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上,这种写法会需要2次、3次比较。可以看看:
MPType min = x < t ? x : t;
MPType max = min > -t ? min : -t;
不过可能性能也没什么差别。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
__device__ __forceinline__ T operator()(const T* args) const { | ||
MPType dout = static_cast<MPType>(args[0]); | ||
MPType out = static_cast<MPType>(args[1]); | ||
MPType t = static_cast<MPType>(threshold); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
所以这个算子也是要求threshold > 0
的吧?不知道OP实现里面有没有检查,或者加个注释说明下?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
加了注释,但是我看在api层面好像没有对threshold的正负做限制,感觉也需要加一下,不然为负值时,结果会不确定。
MPType temp_x_neg = static_cast<MPType>(x <= zero); | ||
return static_cast<T>( | ||
dout * (temp_a_pos * temp_x_pos + temp_a_pos * temp_x_neg * a * exp(x) + | ||
temp_a_neg * temp_x_pos * (one + a * exp(x)))); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
居然这么复杂了。。。
// dx = 0, when x <= -3 | ||
// dout , when x >= 3 | ||
// dout * (x / 3 + 0.5), otherwise | ||
// Inputs: args[0], the input dout |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
公式里面使用threshold
、offset
、scale
。貌似python端没把这个几个参数暴露出来,但是op用了这几个参数。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改公式,并注释几个变量的默认值
a25f611
to
1fdb9dc
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM and great work~
return {{"threshold", &threshold}}; | ||
} | ||
|
||
// hadrshrink(x) = (x > -threshold && x < threshold) ? 0 : x |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
有个typo,另外threshold也应大于0。
} | ||
|
||
// hard_sigmoid(x) = 0, when x <= -3 | ||
// 1, when x >= 3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
代码里面是和0、1比较,公式里面是和-3、3比较,是slope=6、offset=-3
?
PR types
Performance optimization
PR changes
OPs
Describe
Unify the implementation of activation operation
本次提交共修改激活算子16个,包括其前向和反向,其中
log 类算子 3 个: log1p, log2, log10
relu 类算子 5 个:brelu, soft_relu, relu6, thresholded_relu, elu
其他 8 个 :stanh, soft_plus, soft_sign, tanh_shrink, hard_shrink, hard_sigmoid, swish, hard_swish
每种类型算子的性能提升近似,因此选取每个类别中的一个算子作为示例进行描述,如下表:
case配置:[16, 128, 257, 257]