Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flash attention v2 forward #10484

Merged
merged 27 commits into from
Apr 18, 2024
Merged

Flash attention v2 forward #10484

merged 27 commits into from
Apr 18, 2024

Conversation

cccddd77
Copy link
Contributor

集成了flash attn v2 forward算子

const auto batch_size = query->shape()->At(0);
const auto seqlen_q = query->shape()->At(2);
const auto num_heads = query->shape()->At(1);
// const auto max_seqlen_batch_q = query->shape()->At(2);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove?

const int64_t& seed = 0) const {
const auto og_size = query->shape()->At(3);
const auto batch_size = query->shape()->At(0);
const auto seqlen_q = query->shape()->At(2);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里应该也要检查下 q k v三者shape在bsz,headsize,numhead一不一致

const auto num_heads = query->shape()->At(1);
// const auto max_seqlen_batch_q = query->shape()->At(2);
const auto max_seqlen_batch_k = key->shape()->At(2);
const auto max_seqlen_batch_v = value->shape()->At(2);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

k和v的seqlen是不是一定相等,如果是应该也要check

// Query (Batch x Num_heads x Q_seq_len x Dim_per_head)
// Key (Batch x Num_heads x KV_seq_len x Dim_per_head)
// Value (Batch x Num_heads x KV_seq_len x Dim_per_head)
auto q_padded = JUST(pad_last_dim<8>(query));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果last dim size满足被8整除,这个是不是可以去掉,应该需要if else

JUST(OpInterpUtil::Dispatch<one::Tensor>(*op_, {q_, k_, v_}, ctx));

auto output_padded = JUST(functional::Transpose(output_, {0, 2, 1, 3}));
return JUST(functional::Slice(output_padded, {0, 0, 0, 0},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

slice这个应该也要根据是否有pad处理下


} // namespace impl

ONEFLOW_FUNCTION_LIBRARY(m) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

整体我觉得应该可以放到nn_functor.cpp里

@MARD1NO MARD1NO mentioned this pull request Apr 12, 2024

} // namespace oneflow

#endif // ONEFLOW_USER_KERNELS_FLASH_ATTENTION_KERNEL_H_
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

加个换行

return Maybe<void>::Ok();
}

} // namespace oneflow
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

有这个标志的都加个换行

const Shape& q_shape = ctx->InputShape("query", 0);
const Shape& k_shape = ctx->InputShape("key", 0);
const Shape& v_shape = ctx->InputShape("value", 0);

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果有输入alibi,应该要检查下是否等于 numhead / batch * numhead shape

#endif
}

// void set_params_dgrad(Flash_bwd_params &params,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这部分删除,后续单独拉出来一个 scaled_dot_product_attention_kernel_grad.cu ,后向的东西都放这


void set_params_alibi(Flash_fwd_params& params, const Tensor* alibi_slopes_, int batch_size,
int num_heads) {
params.alibi_slopes_ptr = nullptr;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么这里直接nullptr了

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

一个小tips是如果为了临时Debug/跑通绕过的代码,最好加个 check, Log

TODO(Author): Need Support ALibi params
if (alibi_slopes != nullptr) {
    // throw / log...
}

如果后面有人以为这个函数已经支持,传入参数发现结果不对,这样提示更加明确。后续想要增加支持的时候,直接全局搜索 TODO(Author): 也很快能找到需要修改的代码处


const int arch = cuda_stream->cuda_arch() / 10;
const bool is_supported_arch = (arch == 80 || arch == 86 || arch == 89 || arch == 90);
CHECK(is_supported_arch);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

应该加点log信息,表明只支持80 86 89 90

}

Maybe<void> ScaledDotProductFlashAttentionOp::GetSbp(user_op::SbpContext* ctx) {
ctx->NewBuilder()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. SBP应该要考虑Optional输入alibi
  2. 考虑模型并行情况下,应该是对numhead维度进行split,需要添加这个SBP

@MARD1NO
Copy link
Contributor

MARD1NO commented Apr 12, 2024

Oneflow编译应该会注入一个宏 CUDA_VERSION,可以根据这个来控制只有11.6以上才编译:

例如:#if CUDA_VERSION >= 11000

Copy link
Contributor

Code got formatted by CI. Please request CI again if you still want to have this PR merged. If the PR is from a forked repo, please download the patch files from the GitHub Actions web page and apply them locally.

@cccddd77 cccddd77 requested review from oneflow-ci-bot and removed request for oneflow-ci-bot April 15, 2024 09:55
Copy link
Contributor

Copy link
Contributor

Speed stats:
GPU Name: NVIDIA GeForce RTX 3080 Ti 

❌ OneFlow resnet50 time: 43.7ms (= 4368.3ms / 100, input_shape=[16, 3, 224, 224])
PyTorch resnet50 time: 57.6ms (= 5758.1ms / 100, input_shape=[16, 3, 224, 224])
✔️ Relative speed: 1.32 (= 57.6ms / 43.7ms)

OneFlow resnet50 time: 26.3ms (= 2626.2ms / 100, input_shape=[8, 3, 224, 224])
PyTorch resnet50 time: 37.1ms (= 3706.3ms / 100, input_shape=[8, 3, 224, 224])
✔️ Relative speed: 1.41 (= 37.1ms / 26.3ms)

OneFlow resnet50 time: 19.2ms (= 3835.8ms / 200, input_shape=[4, 3, 224, 224])
PyTorch resnet50 time: 35.4ms (= 7076.2ms / 200, input_shape=[4, 3, 224, 224])
✔️ Relative speed: 1.84 (= 35.4ms / 19.2ms)

OneFlow resnet50 time: 17.5ms (= 3509.5ms / 200, input_shape=[2, 3, 224, 224])
PyTorch resnet50 time: 32.1ms (= 6424.5ms / 200, input_shape=[2, 3, 224, 224])
✔️ Relative speed: 1.83 (= 32.1ms / 17.5ms)

OneFlow resnet50 time: 16.4ms (= 3285.1ms / 200, input_shape=[1, 3, 224, 224])
PyTorch resnet50 time: 30.6ms (= 6119.6ms / 200, input_shape=[1, 3, 224, 224])
✔️ Relative speed: 1.86 (= 30.6ms / 16.4ms)

OneFlow swin dataloader time: 0.200s (= 39.952s / 200, num_workers=1)
PyTorch swin dataloader time: 0.129s (= 25.734s / 200, num_workers=1)
Relative speed: 0.644 (= 0.129s / 0.200s)

OneFlow swin dataloader time: 0.055s (= 10.931s / 200, num_workers=4)
PyTorch swin dataloader time: 0.032s (= 6.489s / 200, num_workers=4)
Relative speed: 0.594 (= 0.032s / 0.055s)

OneFlow swin dataloader time: 0.030s (= 5.982s / 200, num_workers=8)
PyTorch swin dataloader time: 0.017s (= 3.382s / 200, num_workers=8)
Relative speed: 0.565 (= 0.017s / 0.030s)

❌ OneFlow resnet50 time: 49.4ms (= 4938.8ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 64.9ms (= 6491.6ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.31 (= 64.9ms / 49.4ms)

OneFlow resnet50 time: 36.0ms (= 3597.1ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 45.5ms (= 4547.9ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.26 (= 45.5ms / 36.0ms)

OneFlow resnet50 time: 27.6ms (= 5518.6ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 41.1ms (= 8216.5ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.49 (= 41.1ms / 27.6ms)

OneFlow resnet50 time: 25.2ms (= 5036.0ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 40.2ms (= 8043.2ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.60 (= 40.2ms / 25.2ms)

OneFlow resnet50 time: 24.7ms (= 4943.2ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 37.3ms (= 7467.4ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.51 (= 37.3ms / 24.7ms)

@cccddd77 cccddd77 requested review from oneflow-ci-bot and removed request for oneflow-ci-bot April 16, 2024 08:14
Copy link
Contributor

Code got formatted by CI. Please request CI again if you still want to have this PR merged. If the PR is from a forked repo, please download the patch files from the GitHub Actions web page and apply them locally.

@cccddd77 cccddd77 requested review from oneflow-ci-bot and removed request for oneflow-ci-bot April 17, 2024 07:15
Copy link
Contributor

Copy link
Contributor

Speed stats:
GPU Name: NVIDIA GeForce RTX 3080 Ti 

❌ OneFlow resnet50 time: 43.6ms (= 4355.4ms / 100, input_shape=[16, 3, 224, 224])
PyTorch resnet50 time: 57.7ms (= 5766.5ms / 100, input_shape=[16, 3, 224, 224])
✔️ Relative speed: 1.32 (= 57.7ms / 43.6ms)

OneFlow resnet50 time: 26.1ms (= 2611.3ms / 100, input_shape=[8, 3, 224, 224])
PyTorch resnet50 time: 37.6ms (= 3761.9ms / 100, input_shape=[8, 3, 224, 224])
✔️ Relative speed: 1.44 (= 37.6ms / 26.1ms)

OneFlow resnet50 time: 18.9ms (= 3787.8ms / 200, input_shape=[4, 3, 224, 224])
PyTorch resnet50 time: 37.3ms (= 7461.5ms / 200, input_shape=[4, 3, 224, 224])
✔️ Relative speed: 1.97 (= 37.3ms / 18.9ms)

OneFlow resnet50 time: 16.6ms (= 3329.3ms / 200, input_shape=[2, 3, 224, 224])
PyTorch resnet50 time: 33.7ms (= 6735.5ms / 200, input_shape=[2, 3, 224, 224])
✔️ Relative speed: 2.02 (= 33.7ms / 16.6ms)

OneFlow resnet50 time: 17.6ms (= 3513.1ms / 200, input_shape=[1, 3, 224, 224])
PyTorch resnet50 time: 29.1ms (= 5813.2ms / 200, input_shape=[1, 3, 224, 224])
✔️ Relative speed: 1.65 (= 29.1ms / 17.6ms)

OneFlow swin dataloader time: 0.201s (= 40.120s / 200, num_workers=1)
PyTorch swin dataloader time: 0.128s (= 25.580s / 200, num_workers=1)
Relative speed: 0.638 (= 0.128s / 0.201s)

OneFlow swin dataloader time: 0.055s (= 10.999s / 200, num_workers=4)
PyTorch swin dataloader time: 0.033s (= 6.693s / 200, num_workers=4)
Relative speed: 0.608 (= 0.033s / 0.055s)

OneFlow swin dataloader time: 0.030s (= 6.027s / 200, num_workers=8)
PyTorch swin dataloader time: 0.016s (= 3.255s / 200, num_workers=8)
Relative speed: 0.540 (= 0.016s / 0.030s)

❌ OneFlow resnet50 time: 49.0ms (= 4903.7ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 64.1ms (= 6410.7ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.31 (= 64.1ms / 49.0ms)

OneFlow resnet50 time: 36.7ms (= 3673.9ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 46.9ms (= 4687.0ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.28 (= 46.9ms / 36.7ms)

OneFlow resnet50 time: 27.5ms (= 5498.3ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 39.7ms (= 7946.6ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.45 (= 39.7ms / 27.5ms)

OneFlow resnet50 time: 25.3ms (= 5059.3ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 40.6ms (= 8120.7ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.61 (= 40.6ms / 25.3ms)

OneFlow resnet50 time: 24.6ms (= 4921.3ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 38.5ms (= 7699.6ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.56 (= 38.5ms / 24.6ms)

@MARD1NO MARD1NO merged commit 44ad994 into master Apr 18, 2024
20 checks passed
@MARD1NO MARD1NO deleted the flash_attention_v2_forward branch April 18, 2024 07:30
@@ -0,0 +1,39 @@
include(ExternalProject)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个第三方库 之后最好把它挪到external目录下去 原则上新添加的第三方依赖都放external下

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants