-
Notifications
You must be signed in to change notification settings - Fork 667
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Flash attention v2 forward #10484
Flash attention v2 forward #10484
Conversation
…nto flash_attn_v2 for updates.
const auto batch_size = query->shape()->At(0); | ||
const auto seqlen_q = query->shape()->At(2); | ||
const auto num_heads = query->shape()->At(1); | ||
// const auto max_seqlen_batch_q = query->shape()->At(2); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove?
const int64_t& seed = 0) const { | ||
const auto og_size = query->shape()->At(3); | ||
const auto batch_size = query->shape()->At(0); | ||
const auto seqlen_q = query->shape()->At(2); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里应该也要检查下 q k v三者shape在bsz,headsize,numhead一不一致
const auto num_heads = query->shape()->At(1); | ||
// const auto max_seqlen_batch_q = query->shape()->At(2); | ||
const auto max_seqlen_batch_k = key->shape()->At(2); | ||
const auto max_seqlen_batch_v = value->shape()->At(2); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
k和v的seqlen是不是一定相等,如果是应该也要check
// Query (Batch x Num_heads x Q_seq_len x Dim_per_head) | ||
// Key (Batch x Num_heads x KV_seq_len x Dim_per_head) | ||
// Value (Batch x Num_heads x KV_seq_len x Dim_per_head) | ||
auto q_padded = JUST(pad_last_dim<8>(query)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果last dim size满足被8整除,这个是不是可以去掉,应该需要if else
JUST(OpInterpUtil::Dispatch<one::Tensor>(*op_, {q_, k_, v_}, ctx)); | ||
|
||
auto output_padded = JUST(functional::Transpose(output_, {0, 2, 1, 3})); | ||
return JUST(functional::Slice(output_padded, {0, 0, 0, 0}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
slice这个应该也要根据是否有pad处理下
|
||
} // namespace impl | ||
|
||
ONEFLOW_FUNCTION_LIBRARY(m) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
整体我觉得应该可以放到nn_functor.cpp里
|
||
} // namespace oneflow | ||
|
||
#endif // ONEFLOW_USER_KERNELS_FLASH_ATTENTION_KERNEL_H_ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
加个换行
return Maybe<void>::Ok(); | ||
} | ||
|
||
} // namespace oneflow |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
有这个标志的都加个换行
const Shape& q_shape = ctx->InputShape("query", 0); | ||
const Shape& k_shape = ctx->InputShape("key", 0); | ||
const Shape& v_shape = ctx->InputShape("value", 0); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果有输入alibi,应该要检查下是否等于 numhead / batch * numhead shape
#endif | ||
} | ||
|
||
// void set_params_dgrad(Flash_bwd_params ¶ms, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这部分删除,后续单独拉出来一个 scaled_dot_product_attention_kernel_grad.cu ,后向的东西都放这
|
||
void set_params_alibi(Flash_fwd_params& params, const Tensor* alibi_slopes_, int batch_size, | ||
int num_heads) { | ||
params.alibi_slopes_ptr = nullptr; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为什么这里直接nullptr了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
一个小tips是如果为了临时Debug/跑通绕过的代码,最好加个 check, Log
TODO(Author): Need Support ALibi params
if (alibi_slopes != nullptr) {
// throw / log...
}
如果后面有人以为这个函数已经支持,传入参数发现结果不对,这样提示更加明确。后续想要增加支持的时候,直接全局搜索 TODO(Author): 也很快能找到需要修改的代码处
|
||
const int arch = cuda_stream->cuda_arch() / 10; | ||
const bool is_supported_arch = (arch == 80 || arch == 86 || arch == 89 || arch == 90); | ||
CHECK(is_supported_arch); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
应该加点log信息,表明只支持80 86 89 90
} | ||
|
||
Maybe<void> ScaledDotProductFlashAttentionOp::GetSbp(user_op::SbpContext* ctx) { | ||
ctx->NewBuilder() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- SBP应该要考虑Optional输入alibi
- 考虑模型并行情况下,应该是对numhead维度进行split,需要添加这个SBP
Oneflow编译应该会注入一个宏 CUDA_VERSION,可以根据这个来控制只有11.6以上才编译: 例如:#if CUDA_VERSION >= 11000 |
Code got formatted by CI. Please request CI again if you still want to have this PR merged. If the PR is from a forked repo, please download the patch files from the GitHub Actions web page and apply them locally. |
View latest API docs preview at: https://oneflow-staging.oss-cn-beijing.aliyuncs.com/docs/Oneflow-Inc/oneflow/pr/10484/ |
Speed stats:
|
3e028a3
to
ffe3f74
Compare
Code got formatted by CI. Please request CI again if you still want to have this PR merged. If the PR is from a forked repo, please download the patch files from the GitHub Actions web page and apply them locally. |
View latest API docs preview at: https://oneflow-staging.oss-cn-beijing.aliyuncs.com/docs/Oneflow-Inc/oneflow/pr/10484/ |
Speed stats:
|
@@ -0,0 +1,39 @@ | |||
include(ExternalProject) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个第三方库 之后最好把它挪到external目录下去 原则上新添加的第三方依赖都放external下
集成了flash attn v2 forward算子