-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[CINN]Add qkv unpack attn #64439
[CINN]Add qkv unpack attn #64439
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
… add_qkv_unpack_attn
… add_qkv_unpack_attn
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@@ -2514,6 +2514,16 @@ | |||
backward : put_along_axis_grad | |||
interfaces : paddle::dialect::InferSymbolicShapeInterface | |||
|
|||
- op : qkv_unpack_mha |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
看起来这个是一个fusion op,应该放到fused_ops.yaml里?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO:该算子移到fusion_ops.yaml中
* add fast infer attention * remove usless code * fix rocm compile bug * polish code * fix conflict * remove depulicate
* add fast infer attention * remove usless code * fix rocm compile bug * polish code * fix conflict * remove depulicate
PR Category
CINN
PR Types
Others
Description
pcard-76996
添加q k v unpack的attention
q = [1, 1, head * head_dim]
k = [1, seq_len, head * head_dim]
v = [1, seq_len, head * head_dim]
当前仅支持bs = 1的情况,后续会逐步完善