From 44b0d075db02affa7eb006c2980fa509d821f91a Mon Sep 17 00:00:00 2001 From: Qianfeng Date: Mon, 4 Mar 2024 18:58:33 +0800 Subject: [PATCH] Fast Multi-ahead Attention support on AMD ROCM (#978) * add option to build a standalone runner for splitk decoder; debugging numerics in reduction * fix a few bugs * fix an indexing bug * stash changes * Add benchmark_mem_eff_attn_mqa_gqa_ck_tiled.py to benchmark mqa/gqa performance on ck-tiled fmha * Synchronize with latest update in composable_kernel_tiled feature/fmha-pad-support branch * Tiny fix in benchmark_mem_eff_attn_mqa_gqa_ck_tiled.py * Synchronize with latest update in composable_kernel_tiled and make all unit_tests passed * Swith to new branch for composable_kernel_tiled submodule * Add bfp16 instances for ck-tiled inference * Update to test and benchmark scripts to include bfloat16 * Tiny update to ck_tiled kernel * Change to benchmark_mem_eff_attn_mqa_gqa_ck_tiled benchmark cases * stash changes * Use Async pipeline for no M/N0K1 padding cases * Add CF_FMHA_FWD_FAST_EXP2 to buiding * Add Triton FA2 forward op * Add Triton Flash Attention 2 to benchmarks * Synchronize with latest third_party/composable_kernel and remove the inner_product bhalf_t overloading in ck_attention_forward_decoder.h * stash split attention testing wip * Synchronize with latest third_party/composable_kernel again * Synchronize with latest third_party/composable_kernel_tiled * Change to make ck decoder buildable with both ck tiled or non-tiled fmha kernel * Change to make ck decoder buildable with both ck tiled or non-tiled fmha kernel * fix gqa for split-k=1 * Skip backward tests, fix import * fix the mask for decoding; row max and lse are computed correctly; debugging must go on * make libtorch split-1 decoder implementation pass numerical correctness * Disable CK kernel for large shapes, better catch OOMs * Actually remove submodule composable_kernel_tiled from the branch * Change the domain for the repo of composable_kernel submodule to ROCm * Update to validate_inputs() in common.py to support 4d mqa/gqa * synchronize test_mem_eff_attention_ck.py with test_mem_eff_attention.py * Tiny update in benchmark_mem_eff_attn_decoder_ck.py * Synchronize benchmark_mem_eff_attention_ck.py with benchmark_mem_eff_attention.py * Remove benchmark_mem_eff_attn_decoder_ck_tiled.py * Support for Generic Attention Mask Coordinate * Add ck.FwOp and ck.BwOp to dispatched operations * Add ck.FwOp and ck.BwOp to ALL_FW_OPS and ALL_BW_OPS * Update in tests/readme_test_on_rocm.txt * Add ckF and ck_decoder to benchmark_mem_eff_attn_decoder.py * Synchronize with the latest ck-tiled commits * Add is_ck_tiled_used() c++ extension interface for judging if ck-tiled is used * Remove composable_kernel_tiled submodule * inner_product removed from splitk kernel code * remove some commented out debug code * comment out debug code calling libtorch instead of hip implementation * remove commented out old and incorrect code fragments * add python version override to cmakelists * add conversion from Argument struct to string; fix split1 test crash -- fyi device guard needs to be declared to avoid segfaults in the kernel * add f32 support in the python op * refactor out input generation in cpp standalone * set loop unrolls to 1 in order to avoid index errors (will need to be fixed later for perf) * fix output splits allocation * fix bug in split attention: sumexp needs timestep bounds in each split * clang-format-10 * Enable support of attn-bias types with LocalAttention * Enable support of attn-bias types with LocalAttention * Synchronize submodule composable_kernel to the latest commits * Make the efficient_attention_forward_ck() C++ interface consistent with the updating of xformers/ops/fmha API * Tiny fix in ck.py to make test_backward pass * some refactorings for standalone tests * cleanup testing * Make the efficient_attention_forward_ck() C++ interface consistent with the updating of xformers/ops/fmha API * Tiny fix in ck.py to make test_backward pass * fix split1 attention csrc test * Enable support of flexible head-dim size (but <= 128) for ck-tiled fmha forward * Use Async pipeline when no any padding used * implement general split-k split-attention in libtorch, use for testing * fix split-max and split-sumexp shapes for split attention in libtorch * implement generic reduce split attention with libtorch * implement testing split reduce hip vs libtorch; tbd debug split-k=2 numerical mismatch in this test * refactor repetitive testing code * address code review: rearrange loops * address code review: add comment about number of iterations per split * address code review: remove comments * address code review: possibly eliminate a bug by using correct timestep range for scaling sumexp in smem * address code review: add todo * address code review: shift LDS access by tt_low to avoid smem overbooking * address code review: simplify reduction loops in split attention * Tiny update in ck-tiled forward kernel * address code review: merge for loops * address code review: simplify coefficient pick * fix runtime error message in testing code * fix split reduce test * address code review: fix smem offsets * remove redundant comment * address code review: initialize split attention workspace as empty * address code review: rename local vars * address code review: remove unused _rand_seqlens * address code review: cleanup python tests * remove redundant new_max local var * address code review: rename seq_acc * re-enable loop unroll; adjust tests to handle splits with size divisible by block size; handle empty splits correctly * test a wider range of split-k in cpp tests; fix torch implementation one more time to handle empty splits * Synchronize with ck-tiled update to support head-dim-256 and LSE storing * Add definition of FMHA_FWD_HEADDIM_SWITCH * Split the ck-tiled inference instances based on head-dim sizes to improve compiling * Setting k0n1_need_padding according to pipeline kQLoadOnce implementation * Add fmha forward c++ extension for ck-tiled * Set SUPPORTED_MAX_K=256 in ck.py * fix index in split-k attention * fix index in softmax reduce and complete fixing wavefronts per block optimization * clang-format-10 * Fix v_dram_transposed transpose transform in the kernel * Skipe trition_splitk for test_forward in test_mem_eff_attention.py * cleanup commented dead code * enable ck split-k in benchmark_attn_decoding * add rocm_ci workflow * move scipy import from file level under function similar to _vec_binom_test saves a few keystrokes when setting up environment * Add including of math_v2.hpp to ck_attention_forward_decoder_splitk.h * move forward_splitk to ck_splitk; make dispatch aware of ck_splitk and ck_decoder * Synchronize to latest ck-tiled and update accordingly * fix benchmark_attn_decoding * Remove third_party/composable_kernel_tiled * [Fix] use kK0BlockLength for HeadDim256 padding judging * Tiny type change for custom_mask_type in param class * Change to use ROCm repo for ck-tiled submodule * Remove tests/test_forward_ck_tiled.py * Update to test_mqa_forward_ck_tiled.py to use common create_attn_bias method * Add ck-tiled checking in test_mqa_forward_ck_tiled.py * rearrange smem access in softmax reduction * Add test_decoder and test_splitk_decoder for ROCM into test_mem_eff_attention.py * Add ref_attention_splitk and its test to tests/test_mem_eff_attention.py * Rename test_mem_eff_attention_ck.py as discarded * Add test_mqa_forward and ref_attention_mqa (for BMHK format mqa/gqa verification) into test_mem_eff_attention.py * Rename test_mqa_forward_ck_tiled.py as discarded * Remove CK specific script benchmark_mem_eff_attn_decoder_ck.py * Refine benchmark_mem_eff_attn_mqa_gqa_ck_tiled.py * Rename benchmark_mem_eff_attn_mqa_gqa_ck_tiled.py to benchmark_mem_eff_attention_mqa.py * Remove the runtime_error with using logsumexp in attention_forward_generic_ck_tiled.cpp * Add ck-tiled checking in ck.py * Remove CK-specific benchmark scripts * Don't require is_cpu_tensor for seqstart_q/seqstart_k/seqlen_k in attention_forward_generic_ck_tiled * Remove seqlen_cpu from _PaddedSeqLenInfo in attn_bias.py * Change the branch for composable_kernel_tiled submodule and update to latest * Remove the using of seqlen_cpu in BwOp of ck.py * Remove the using of seqlen_cpu in BwOp of ck.py * Align .clang_format with main branch and re-format c++ files * Synchronize to latest ck-tiled commit * Add checking of IS_CK_TILED into some testing scripts * Update to test_mem_eff_attention.py and ck.py * Building xformers using ck-tiled as default * ensure ck_decoder does not dispatch * Add disable_on_rocm on some test scripts * Update to test_mem_eff_attention.py * apply isort * apply black * fix flake8 suggestions * add license headers and reapply black * Tiny update to rocm_ci.yml * Add conditional compiling for cuda-depending codes in ROCM * Update to benchmark scripts * Rename the one script file * Revert "Add conditional compiling for cuda-depending codes in ROCM" This reverts commit 12fb41c2460909285102426ca9ab52162725d64b. * Update to scripts * Change and add readme for tests and benchmarks * Remove the stuffs for supporting old ck * Remove old composable_kernel from submodule list * Remove folder third_party/composable_kernel * Rename the folder * Remove unused script file * apply black * pacify mypy * fix clang-format * reapply black * fix lints * make test_splitk_reference run on cpu * add ck modules to docs * try fixing nvidia build by re-including sparse24 cpp folder into extension sources * update cutlass to upstream commit * update flash-attention to upstream commit * simplify setup.py * remove duplicate run_batched_infer_causalmask_attnbias_dispatched * add hip version and pytorch hip arch list to xformers build info * fix build * patch around the unhappy path in get_hip_version * skip test_grad_checkpointing for triton_splitk since it doesn't have bwop * re-enable test_mqa_forward since ck tiled is the current implementation * make skip test_wrong_alignment more generic * reapply black * simplify test_decoder * put python version check inside triton_splitk op * fix logic * cleanup python3.9 checks in tests * cleanup test_attentions * cleanup test_checkpoint as test running on cpu does not depend on gpu platform * fix lints * try fixing win build by conditional import of triton in triton op * re-enable test_triton_layernorm as it passes * re-enable test_triton_blocksparse as it passes * cleanup test_sparse_tensors * cleanup test_custom_ops * reapply black * cleanup test_core_attention * benchmark ck ops on rocm only * fix mypy * fix lint: black * fix lints: mypy * Rename HDim/headdim to MaxK/maxk * Move some headers files to ck examples for later reusing * Replace using qs_ks_vs pipeline by qr_ks_vs pipeline while HeadDim is 256 for better performance * rm test_ck_7 * fix lints * unskip test_unsupported_alignment * move out test_splitk_reference * add license header to file created in prev commit * roll back fmha/common.py ... so users are forced to provide rank-5 inputs for mqa/gqa * fix lint * remove unused ref_attention_mqa * resolve error in triton_splitk on rocm > Triton does not support if expressions (ternary operators) with dynamic conditions, use if statements instead * disable partial attention tests on rocm --------- Co-authored-by: Max Podkorytov <4273004+tenpercent@users.noreply.github.com> Co-authored-by: Grigory Sizov --- .github/workflows/rocm_ci.yml | 71 + .gitignore | 10 + .gitmodules | 4 + docs/source/components/ops.rst | 14 +- requirements-test.txt | 2 +- setup.py | 82 ++ tests/readme_test_on_rocm.txt | 13 + tests/test_attentions.py | 7 + tests/test_checkpoint.py | 15 +- tests/test_core_attention.py | 5 +- tests/test_custom_ops.py | 9 +- tests/test_mem_eff_attention.py | 110 +- tests/test_sparse_tensors.py | 4 +- tests/test_splitk_reference.py | 223 ++++ tests/test_swiglu.py | 5 + third_party/composable_kernel_tiled | 1 + xformers/_cpp_lib.py | 4 + xformers/benchmarks/LRA/run_tasks.py | 16 +- .../benchmarks/benchmark_attn_decoding.py | 46 +- .../benchmark_blocksparse_transformers.py | 4 +- xformers/benchmarks/benchmark_core.py | 11 +- xformers/benchmarks/benchmark_indexing.py | 2 +- .../benchmarks/benchmark_mem_eff_attention.py | 19 +- .../benchmark_mem_eff_attention_mqa.py | 262 ++++ .../benchmark_mem_eff_attn_decoder.py | 8 +- .../benchmarks/benchmark_nystrom_utils.py | 4 +- xformers/benchmarks/benchmark_sddmm.py | 15 +- xformers/benchmarks/benchmark_swiglu.py | 9 +- xformers/benchmarks/benchmark_transformer.py | 7 +- .../benchmarks/readme_benchmark_on_rocm.txt | 17 + xformers/benchmarks/utils.py | 27 +- xformers/csrc/attention/attention.cpp | 40 +- .../csrc/attention/hip_fmha/CMakeLists.txt | 120 ++ .../hip_fmha/attention_forward_decoder.cpp | 333 +++++ .../attention_forward_generic_ck_tiled.cpp | 421 ++++++ .../hip_fmha/attention_forward_splitk.cpp | 1165 +++++++++++++++++ .../hip_fmha/ck_attention_forward_decoder.h | 493 +++++++ .../ck_attention_forward_decoder_splitk.h | 693 ++++++++++ .../csrc/attention/hip_fmha/ck_fmha_test.cpp | 27 + .../csrc/attention/hip_fmha/ck_fmha_util.h | 158 +++ .../attention/hip_fmha/ck_tiled_bool_switch.h | 9 + .../hip_fmha/ck_tiled_fmha_batched_forward.h | 232 ++++ .../ck_tiled_fmha_batched_forward_bp16.cpp | 71 + .../ck_tiled_fmha_batched_forward_fp16.cpp | 71 + .../hip_fmha/ck_tiled_fmha_batched_infer.h | 232 ++++ .../ck_tiled_fmha_batched_infer_bp16.cpp | 71 + .../ck_tiled_fmha_batched_infer_fp16.cpp | 71 + .../hip_fmha/ck_tiled_fmha_definitions.h | 115 ++ .../hip_fmha/ck_tiled_fmha_grouped_forward.h | 199 +++ .../ck_tiled_fmha_grouped_forward_bp16.cpp | 71 + .../ck_tiled_fmha_grouped_forward_fp16.cpp | 71 + .../hip_fmha/ck_tiled_fmha_grouped_infer.h | 198 +++ .../ck_tiled_fmha_grouped_infer_bp16.cpp | 71 + .../ck_tiled_fmha_grouped_infer_fp16.cpp | 71 + .../attention/hip_fmha/ck_tiled_fmha_params.h | 215 +++ .../hip_fmha/ck_tiled_headdim_switch.h | 28 + ...p16_no_causalmask_no_attnbias_maxk_128.cpp | 15 + ...p16_no_causalmask_no_attnbias_maxk_256.cpp | 15 + ...bp16_no_causalmask_no_attnbias_maxk_32.cpp | 15 + ...bp16_no_causalmask_no_attnbias_maxk_64.cpp | 15 + ...6_no_causalmask_with_attnbias_maxk_128.cpp | 15 + ...6_no_causalmask_with_attnbias_maxk_256.cpp | 15 + ...16_no_causalmask_with_attnbias_maxk_32.cpp | 15 + ...16_no_causalmask_with_attnbias_maxk_64.cpp | 15 + ...6_with_causalmask_no_attnbias_maxk_128.cpp | 15 + ...6_with_causalmask_no_attnbias_maxk_256.cpp | 15 + ...16_with_causalmask_no_attnbias_maxk_32.cpp | 15 + ...16_with_causalmask_no_attnbias_maxk_64.cpp | 15 + ...with_causalmask_with_attnbias_maxk_128.cpp | 15 + ...with_causalmask_with_attnbias_maxk_256.cpp | 15 + ..._with_causalmask_with_attnbias_maxk_32.cpp | 15 + ..._with_causalmask_with_attnbias_maxk_64.cpp | 15 + ...p16_no_causalmask_no_attnbias_maxk_128.cpp | 15 + ...p16_no_causalmask_no_attnbias_maxk_256.cpp | 15 + ...fp16_no_causalmask_no_attnbias_maxk_32.cpp | 15 + ...fp16_no_causalmask_no_attnbias_maxk_64.cpp | 15 + ...6_no_causalmask_with_attnbias_maxk_128.cpp | 15 + ...6_no_causalmask_with_attnbias_maxk_256.cpp | 15 + ...16_no_causalmask_with_attnbias_maxk_32.cpp | 15 + ...16_no_causalmask_with_attnbias_maxk_64.cpp | 15 + ...6_with_causalmask_no_attnbias_maxk_128.cpp | 15 + ...6_with_causalmask_no_attnbias_maxk_256.cpp | 15 + ...16_with_causalmask_no_attnbias_maxk_32.cpp | 15 + ...16_with_causalmask_no_attnbias_maxk_64.cpp | 15 + ...with_causalmask_with_attnbias_maxk_128.cpp | 15 + ...with_causalmask_with_attnbias_maxk_256.cpp | 15 + ..._with_causalmask_with_attnbias_maxk_32.cpp | 15 + ..._with_causalmask_with_attnbias_maxk_64.cpp | 15 + ...p16_no_causalmask_no_attnbias_maxk_128.cpp | 15 + ...p16_no_causalmask_no_attnbias_maxk_256.cpp | 15 + ...bp16_no_causalmask_no_attnbias_maxk_32.cpp | 15 + ...bp16_no_causalmask_no_attnbias_maxk_64.cpp | 15 + ...6_no_causalmask_with_attnbias_maxk_128.cpp | 15 + ...6_no_causalmask_with_attnbias_maxk_256.cpp | 15 + ...16_no_causalmask_with_attnbias_maxk_32.cpp | 15 + ...16_no_causalmask_with_attnbias_maxk_64.cpp | 15 + ...6_with_causalmask_no_attnbias_maxk_128.cpp | 15 + ...6_with_causalmask_no_attnbias_maxk_256.cpp | 15 + ...16_with_causalmask_no_attnbias_maxk_32.cpp | 15 + ...16_with_causalmask_no_attnbias_maxk_64.cpp | 15 + ...with_causalmask_with_attnbias_maxk_128.cpp | 15 + ...with_causalmask_with_attnbias_maxk_256.cpp | 15 + ..._with_causalmask_with_attnbias_maxk_32.cpp | 15 + ..._with_causalmask_with_attnbias_maxk_64.cpp | 15 + ...p16_no_causalmask_no_attnbias_maxk_128.cpp | 15 + ...p16_no_causalmask_no_attnbias_maxk_256.cpp | 15 + ...fp16_no_causalmask_no_attnbias_maxk_32.cpp | 15 + ...fp16_no_causalmask_no_attnbias_maxk_64.cpp | 15 + ...6_no_causalmask_with_attnbias_maxk_128.cpp | 15 + ...6_no_causalmask_with_attnbias_maxk_256.cpp | 15 + ...16_no_causalmask_with_attnbias_maxk_32.cpp | 15 + ...16_no_causalmask_with_attnbias_maxk_64.cpp | 15 + ...6_with_causalmask_no_attnbias_maxk_128.cpp | 15 + ...6_with_causalmask_no_attnbias_maxk_256.cpp | 15 + ...16_with_causalmask_no_attnbias_maxk_32.cpp | 15 + ...16_with_causalmask_no_attnbias_maxk_64.cpp | 15 + ...with_causalmask_with_attnbias_maxk_128.cpp | 15 + ...with_causalmask_with_attnbias_maxk_256.cpp | 15 + ..._with_causalmask_with_attnbias_maxk_32.cpp | 15 + ..._with_causalmask_with_attnbias_maxk_64.cpp | 15 + ...p16_no_causalmask_no_attnbias_maxk_128.cpp | 15 + ...p16_no_causalmask_no_attnbias_maxk_256.cpp | 15 + ...bp16_no_causalmask_no_attnbias_maxk_32.cpp | 15 + ...bp16_no_causalmask_no_attnbias_maxk_64.cpp | 15 + ...6_no_causalmask_with_attnbias_maxk_128.cpp | 15 + ...6_no_causalmask_with_attnbias_maxk_256.cpp | 15 + ...16_no_causalmask_with_attnbias_maxk_32.cpp | 15 + ...16_no_causalmask_with_attnbias_maxk_64.cpp | 15 + ...6_with_causalmask_no_attnbias_maxk_128.cpp | 15 + ...6_with_causalmask_no_attnbias_maxk_256.cpp | 15 + ...16_with_causalmask_no_attnbias_maxk_32.cpp | 15 + ...16_with_causalmask_no_attnbias_maxk_64.cpp | 15 + ...with_causalmask_with_attnbias_maxk_128.cpp | 15 + ...with_causalmask_with_attnbias_maxk_256.cpp | 15 + ..._with_causalmask_with_attnbias_maxk_32.cpp | 15 + ..._with_causalmask_with_attnbias_maxk_64.cpp | 15 + ...p16_no_causalmask_no_attnbias_maxk_128.cpp | 15 + ...p16_no_causalmask_no_attnbias_maxk_256.cpp | 15 + ...fp16_no_causalmask_no_attnbias_maxk_32.cpp | 15 + ...fp16_no_causalmask_no_attnbias_maxk_64.cpp | 15 + ...6_no_causalmask_with_attnbias_maxk_128.cpp | 15 + ...6_no_causalmask_with_attnbias_maxk_256.cpp | 15 + ...16_no_causalmask_with_attnbias_maxk_32.cpp | 15 + ...16_no_causalmask_with_attnbias_maxk_64.cpp | 15 + ...6_with_causalmask_no_attnbias_maxk_128.cpp | 15 + ...6_with_causalmask_no_attnbias_maxk_256.cpp | 15 + ...16_with_causalmask_no_attnbias_maxk_32.cpp | 15 + ...16_with_causalmask_no_attnbias_maxk_64.cpp | 15 + ...with_causalmask_with_attnbias_maxk_128.cpp | 15 + ...with_causalmask_with_attnbias_maxk_256.cpp | 15 + ..._with_causalmask_with_attnbias_maxk_32.cpp | 15 + ..._with_causalmask_with_attnbias_maxk_64.cpp | 15 + ...p16_no_causalmask_no_attnbias_maxk_128.cpp | 15 + ...p16_no_causalmask_no_attnbias_maxk_256.cpp | 15 + ...bp16_no_causalmask_no_attnbias_maxk_32.cpp | 15 + ...bp16_no_causalmask_no_attnbias_maxk_64.cpp | 15 + ...6_no_causalmask_with_attnbias_maxk_128.cpp | 15 + ...6_no_causalmask_with_attnbias_maxk_256.cpp | 15 + ...16_no_causalmask_with_attnbias_maxk_32.cpp | 15 + ...16_no_causalmask_with_attnbias_maxk_64.cpp | 15 + ...6_with_causalmask_no_attnbias_maxk_128.cpp | 15 + ...6_with_causalmask_no_attnbias_maxk_256.cpp | 15 + ...16_with_causalmask_no_attnbias_maxk_32.cpp | 15 + ...16_with_causalmask_no_attnbias_maxk_64.cpp | 15 + ...with_causalmask_with_attnbias_maxk_128.cpp | 15 + ...with_causalmask_with_attnbias_maxk_256.cpp | 15 + ..._with_causalmask_with_attnbias_maxk_32.cpp | 15 + ..._with_causalmask_with_attnbias_maxk_64.cpp | 15 + ...p16_no_causalmask_no_attnbias_maxk_128.cpp | 15 + ...p16_no_causalmask_no_attnbias_maxk_256.cpp | 15 + ...fp16_no_causalmask_no_attnbias_maxk_32.cpp | 15 + ...fp16_no_causalmask_no_attnbias_maxk_64.cpp | 15 + ...6_no_causalmask_with_attnbias_maxk_128.cpp | 15 + ...6_no_causalmask_with_attnbias_maxk_256.cpp | 15 + ...16_no_causalmask_with_attnbias_maxk_32.cpp | 15 + ...16_no_causalmask_with_attnbias_maxk_64.cpp | 15 + ...6_with_causalmask_no_attnbias_maxk_128.cpp | 15 + ...6_with_causalmask_no_attnbias_maxk_256.cpp | 15 + ...16_with_causalmask_no_attnbias_maxk_32.cpp | 15 + ...16_with_causalmask_no_attnbias_maxk_64.cpp | 15 + ...with_causalmask_with_attnbias_maxk_128.cpp | 15 + ...with_causalmask_with_attnbias_maxk_256.cpp | 15 + ..._with_causalmask_with_attnbias_maxk_32.cpp | 15 + ..._with_causalmask_with_attnbias_maxk_64.cpp | 15 + xformers/info.py | 1 + xformers/ops/__init__.py | 4 + xformers/ops/common.py | 6 +- xformers/ops/fmha/__init__.py | 25 +- xformers/ops/fmha/attn_bias.py | 3 +- xformers/ops/fmha/ck.py | 514 ++++++++ xformers/ops/fmha/ck_decoder.py | 139 ++ xformers/ops/fmha/ck_splitk.py | 208 +++ xformers/ops/fmha/common.py | 8 +- xformers/ops/fmha/dispatch.py | 50 +- xformers/ops/fmha/triton.py | 703 ++++++++-- xformers/ops/fmha/triton_splitk.py | 23 +- 196 files changed, 9652 insertions(+), 224 deletions(-) create mode 100644 .github/workflows/rocm_ci.yml create mode 100644 tests/readme_test_on_rocm.txt create mode 100644 tests/test_splitk_reference.py create mode 160000 third_party/composable_kernel_tiled create mode 100644 xformers/benchmarks/benchmark_mem_eff_attention_mqa.py create mode 100644 xformers/benchmarks/readme_benchmark_on_rocm.txt create mode 100644 xformers/csrc/attention/hip_fmha/CMakeLists.txt create mode 100644 xformers/csrc/attention/hip_fmha/attention_forward_decoder.cpp create mode 100644 xformers/csrc/attention/hip_fmha/attention_forward_generic_ck_tiled.cpp create mode 100644 xformers/csrc/attention/hip_fmha/attention_forward_splitk.cpp create mode 100644 xformers/csrc/attention/hip_fmha/ck_attention_forward_decoder.h create mode 100644 xformers/csrc/attention/hip_fmha/ck_attention_forward_decoder_splitk.h create mode 100644 xformers/csrc/attention/hip_fmha/ck_fmha_test.cpp create mode 100644 xformers/csrc/attention/hip_fmha/ck_fmha_util.h create mode 100644 xformers/csrc/attention/hip_fmha/ck_tiled_bool_switch.h create mode 100644 xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_forward.h create mode 100644 xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_forward_bp16.cpp create mode 100644 xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_forward_fp16.cpp create mode 100644 xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_infer.h create mode 100644 xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_infer_bp16.cpp create mode 100644 xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_infer_fp16.cpp create mode 100644 xformers/csrc/attention/hip_fmha/ck_tiled_fmha_definitions.h create mode 100644 xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_forward.h create mode 100644 xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_forward_bp16.cpp create mode 100644 xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_forward_fp16.cpp create mode 100644 xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_infer.h create mode 100644 xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_infer_bp16.cpp create mode 100644 xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_infer_fp16.cpp create mode 100644 xformers/csrc/attention/hip_fmha/ck_tiled_fmha_params.h create mode 100644 xformers/csrc/attention/hip_fmha/ck_tiled_headdim_switch.h create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_bp16_no_causalmask_no_attnbias_maxk_128.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_bp16_no_causalmask_no_attnbias_maxk_256.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_bp16_no_causalmask_no_attnbias_maxk_32.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_bp16_no_causalmask_no_attnbias_maxk_64.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_bp16_no_causalmask_with_attnbias_maxk_128.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_bp16_no_causalmask_with_attnbias_maxk_256.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_bp16_no_causalmask_with_attnbias_maxk_32.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_bp16_no_causalmask_with_attnbias_maxk_64.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_bp16_with_causalmask_no_attnbias_maxk_128.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_bp16_with_causalmask_no_attnbias_maxk_256.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_bp16_with_causalmask_no_attnbias_maxk_32.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_bp16_with_causalmask_no_attnbias_maxk_64.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_bp16_with_causalmask_with_attnbias_maxk_128.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_bp16_with_causalmask_with_attnbias_maxk_256.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_bp16_with_causalmask_with_attnbias_maxk_32.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_bp16_with_causalmask_with_attnbias_maxk_64.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_fp16_no_causalmask_no_attnbias_maxk_128.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_fp16_no_causalmask_no_attnbias_maxk_256.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_fp16_no_causalmask_no_attnbias_maxk_32.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_fp16_no_causalmask_no_attnbias_maxk_64.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_fp16_no_causalmask_with_attnbias_maxk_128.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_fp16_no_causalmask_with_attnbias_maxk_256.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_fp16_no_causalmask_with_attnbias_maxk_32.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_fp16_no_causalmask_with_attnbias_maxk_64.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_fp16_with_causalmask_no_attnbias_maxk_128.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_fp16_with_causalmask_no_attnbias_maxk_256.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_fp16_with_causalmask_no_attnbias_maxk_32.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_fp16_with_causalmask_no_attnbias_maxk_64.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_fp16_with_causalmask_with_attnbias_maxk_128.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_fp16_with_causalmask_with_attnbias_maxk_256.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_fp16_with_causalmask_with_attnbias_maxk_32.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_fp16_with_causalmask_with_attnbias_maxk_64.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_bp16_no_causalmask_no_attnbias_maxk_128.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_bp16_no_causalmask_no_attnbias_maxk_256.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_bp16_no_causalmask_no_attnbias_maxk_32.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_bp16_no_causalmask_no_attnbias_maxk_64.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_bp16_no_causalmask_with_attnbias_maxk_128.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_bp16_no_causalmask_with_attnbias_maxk_256.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_bp16_no_causalmask_with_attnbias_maxk_32.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_bp16_no_causalmask_with_attnbias_maxk_64.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_bp16_with_causalmask_no_attnbias_maxk_128.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_bp16_with_causalmask_no_attnbias_maxk_256.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_bp16_with_causalmask_no_attnbias_maxk_32.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_bp16_with_causalmask_no_attnbias_maxk_64.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_bp16_with_causalmask_with_attnbias_maxk_128.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_bp16_with_causalmask_with_attnbias_maxk_256.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_bp16_with_causalmask_with_attnbias_maxk_32.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_bp16_with_causalmask_with_attnbias_maxk_64.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_fp16_no_causalmask_no_attnbias_maxk_128.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_fp16_no_causalmask_no_attnbias_maxk_256.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_fp16_no_causalmask_no_attnbias_maxk_32.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_fp16_no_causalmask_no_attnbias_maxk_64.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_fp16_no_causalmask_with_attnbias_maxk_128.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_fp16_no_causalmask_with_attnbias_maxk_256.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_fp16_no_causalmask_with_attnbias_maxk_32.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_fp16_no_causalmask_with_attnbias_maxk_64.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_fp16_with_causalmask_no_attnbias_maxk_128.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_fp16_with_causalmask_no_attnbias_maxk_256.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_fp16_with_causalmask_no_attnbias_maxk_32.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_fp16_with_causalmask_no_attnbias_maxk_64.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_fp16_with_causalmask_with_attnbias_maxk_128.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_fp16_with_causalmask_with_attnbias_maxk_256.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_fp16_with_causalmask_with_attnbias_maxk_32.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_fp16_with_causalmask_with_attnbias_maxk_64.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_bp16_no_causalmask_no_attnbias_maxk_128.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_bp16_no_causalmask_no_attnbias_maxk_256.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_bp16_no_causalmask_no_attnbias_maxk_32.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_bp16_no_causalmask_no_attnbias_maxk_64.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_bp16_no_causalmask_with_attnbias_maxk_128.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_bp16_no_causalmask_with_attnbias_maxk_256.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_bp16_no_causalmask_with_attnbias_maxk_32.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_bp16_no_causalmask_with_attnbias_maxk_64.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_bp16_with_causalmask_no_attnbias_maxk_128.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_bp16_with_causalmask_no_attnbias_maxk_256.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_bp16_with_causalmask_no_attnbias_maxk_32.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_bp16_with_causalmask_no_attnbias_maxk_64.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_bp16_with_causalmask_with_attnbias_maxk_128.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_bp16_with_causalmask_with_attnbias_maxk_256.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_bp16_with_causalmask_with_attnbias_maxk_32.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_bp16_with_causalmask_with_attnbias_maxk_64.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_fp16_no_causalmask_no_attnbias_maxk_128.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_fp16_no_causalmask_no_attnbias_maxk_256.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_fp16_no_causalmask_no_attnbias_maxk_32.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_fp16_no_causalmask_no_attnbias_maxk_64.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_fp16_no_causalmask_with_attnbias_maxk_128.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_fp16_no_causalmask_with_attnbias_maxk_256.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_fp16_no_causalmask_with_attnbias_maxk_32.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_fp16_no_causalmask_with_attnbias_maxk_64.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_fp16_with_causalmask_no_attnbias_maxk_128.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_fp16_with_causalmask_no_attnbias_maxk_256.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_fp16_with_causalmask_no_attnbias_maxk_32.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_fp16_with_causalmask_no_attnbias_maxk_64.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_fp16_with_causalmask_with_attnbias_maxk_128.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_fp16_with_causalmask_with_attnbias_maxk_256.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_fp16_with_causalmask_with_attnbias_maxk_32.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_fp16_with_causalmask_with_attnbias_maxk_64.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_bp16_no_causalmask_no_attnbias_maxk_128.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_bp16_no_causalmask_no_attnbias_maxk_256.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_bp16_no_causalmask_no_attnbias_maxk_32.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_bp16_no_causalmask_no_attnbias_maxk_64.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_bp16_no_causalmask_with_attnbias_maxk_128.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_bp16_no_causalmask_with_attnbias_maxk_256.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_bp16_no_causalmask_with_attnbias_maxk_32.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_bp16_no_causalmask_with_attnbias_maxk_64.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_bp16_with_causalmask_no_attnbias_maxk_128.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_bp16_with_causalmask_no_attnbias_maxk_256.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_bp16_with_causalmask_no_attnbias_maxk_32.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_bp16_with_causalmask_no_attnbias_maxk_64.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_bp16_with_causalmask_with_attnbias_maxk_128.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_bp16_with_causalmask_with_attnbias_maxk_256.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_bp16_with_causalmask_with_attnbias_maxk_32.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_bp16_with_causalmask_with_attnbias_maxk_64.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_fp16_no_causalmask_no_attnbias_maxk_128.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_fp16_no_causalmask_no_attnbias_maxk_256.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_fp16_no_causalmask_no_attnbias_maxk_32.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_fp16_no_causalmask_no_attnbias_maxk_64.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_fp16_no_causalmask_with_attnbias_maxk_128.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_fp16_no_causalmask_with_attnbias_maxk_256.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_fp16_no_causalmask_with_attnbias_maxk_32.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_fp16_no_causalmask_with_attnbias_maxk_64.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_fp16_with_causalmask_no_attnbias_maxk_128.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_fp16_with_causalmask_no_attnbias_maxk_256.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_fp16_with_causalmask_no_attnbias_maxk_32.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_fp16_with_causalmask_no_attnbias_maxk_64.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_fp16_with_causalmask_with_attnbias_maxk_128.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_fp16_with_causalmask_with_attnbias_maxk_256.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_fp16_with_causalmask_with_attnbias_maxk_32.cpp create mode 100644 xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_fp16_with_causalmask_with_attnbias_maxk_64.cpp create mode 100644 xformers/ops/fmha/ck.py create mode 100644 xformers/ops/fmha/ck_decoder.py create mode 100644 xformers/ops/fmha/ck_splitk.py diff --git a/.github/workflows/rocm_ci.yml b/.github/workflows/rocm_ci.yml new file mode 100644 index 0000000000..f2593d53af --- /dev/null +++ b/.github/workflows/rocm_ci.yml @@ -0,0 +1,71 @@ +name: ROCM_CI + +on: + pull_request: + types: [labeled, synchronize, reopened] + +jobs: + build: + if: contains(github.event.label.name, 'rocm') + runs-on: rocm + + steps: + - uses: actions/checkout@v2 + - name: Get CPU info on Ubuntu + if: contains(runner.os, 'linux') + run: | + cat /proc/cpuinfo + - name: Get env vars + run: | + echo GITHUB_WORKFLOW = $GITHUB_WORKFLOW + echo HOME = $HOME + echo PWD = $PWD + echo GITHUB_ACTION = $GITHUB_ACTION + echo GITHUB_ACTIONS = $GITHUB_ACTIONS + echo GITHUB_REPOSITORY = $GITHUB_REPOSITORY + echo GITHUB_EVENT_NAME = $GITHUB_EVENT_NAME + echo GITHUB_EVENT_PATH = $GITHUB_EVENT_PATH + echo GITHUB_WORKSPACE = $GITHUB_WORKSPACE + echo GITHUB_SHA = $GITHUB_SHA + echo GITHUB_REF = $GITHUB_REF + + export GIT_BRANCH=${GITHUB_BASE_REF:-${GITHUB_REF#refs/heads/}} + echo GIT_BRANCH = $GIT_BRANCH + + export ROCM_PATH=/opt/rocm + echo ROCM_PATH = $ROCM_PATH + + export MAX_JOBS=64 + echo MAX_JOBS = $MAX_JOBS + + hipcc --version + rocm-smi + rocminfo | grep "gfx" + + - name: Build XFormers + run: | + git clone --recursive -b $GIT_BRANCH $GITHUB_REPOSITORY + docker run -it --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --device=/dev/kfd --device=/dev/dri --group-add video --ipc=host --shm-size 8G -v $PWD/xformers:/xformers rocm/pytorch-nightly:latest + + pip3 install --upgrade pip + pip3 uninstall -y xformers + MAX_JOBS=$MAX_JOBS pip3 install -e /xformers --verbose + pip3 install scipy==1.10 + + python3 -c "import torch; print(torch.__version__)" + python3 -m xformers.info + + - name: Run python tests + run: | + pytest -rpfs /xformers/tests/test_mem_eff_attention.py | tee test_mem_eff_attention.log + + - name: Archive logs + uses: actions/upload-artifact@v3 + with: + name: test results + path: test_mem_eff_attention_ck.log + + - name: Process test results + run: | + echo "Processing test results TBD" + diff --git a/.gitignore b/.gitignore index 38b453363b..8c6455c1b7 100644 --- a/.gitignore +++ b/.gitignore @@ -60,3 +60,13 @@ outputs xformers/_flash_attn xformers/version.py xformers/cpp_lib.json + +## temporary files +xformers/csrc/attention/hip_fmha/*.cu +xformers/csrc/attention/hip_fmha/*.hip +xformers/csrc/attention/hip_fmha/*_hip.h +xformers/csrc/attention/hip_fmha/instances/*.cu +xformers/csrc/attention/hip_fmha/instances/*.hip +xformers/csrc/attention/hip_fmha/instances_tiled/*.cu +xformers/csrc/attention/hip_fmha/instances_tiled/*.hip + diff --git a/.gitmodules b/.gitmodules index b15bd78f63..6358114101 100644 --- a/.gitmodules +++ b/.gitmodules @@ -4,3 +4,7 @@ [submodule "third_party/flash-attention"] path = third_party/flash-attention url = https://github.com/Dao-AILab/flash-attention.git +[submodule "third_party/composable_kernel_tiled"] + path = third_party/composable_kernel_tiled + url = https://github.com/ROCm/composable_kernel.git + branch = ck_tile/dev diff --git a/docs/source/components/ops.rst b/docs/source/components/ops.rst index 5f98fdcb52..09dc0d25cd 100644 --- a/docs/source/components/ops.rst +++ b/docs/source/components/ops.rst @@ -22,13 +22,25 @@ Available implementations :member-order: bysource .. automodule:: xformers.ops.fmha.triton - :members: FwOp, BwOp + :members: FwOp :member-order: bysource .. automodule:: xformers.ops.fmha.small_k :members: FwOp, BwOp :member-order: bysource +.. automodule:: xformers.ops.fmha.ck + :members: FwOp, BwOp + :member-order: bysource + +.. automodule:: xformers.ops.fmha.ck_decoder + :members: FwOp + :member-order: bysource + +.. automodule:: xformers.ops.fmha.ck_splitk + :members: FwOp + :member-order: bysource + Attention biases ~~~~~~~~~~~~~~~~~~~~ diff --git a/requirements-test.txt b/requirements-test.txt index 0f460733f7..50bbd93a41 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -25,7 +25,7 @@ hydra-core >= 1.1 # Dependency for Mixture of Experts fairscale >= 0.4.5 -scipy +scipy >= 1.7 # Dependency for fused layers, optional cmake diff --git a/setup.py b/setup.py index 5ce78b9f34..ed5fa52b61 100644 --- a/setup.py +++ b/setup.py @@ -125,6 +125,23 @@ def get_cuda_version(cuda_dir) -> int: return bare_metal_major * 100 + bare_metal_minor +def get_hip_version(rocm_dir) -> str: + hipcc_bin = "hipcc" if rocm_dir is None else os.path.join(rocm_dir, "bin", "hipcc") + try: + raw_output = subprocess.check_output( + [hipcc_bin, "--version"], universal_newlines=True + ) + except Exception as e: + print( + f"hip installation not found: {e} ROCM_PATH={os.environ.get('ROCM_PATH')}" + ) + return None + for line in raw_output.split("\n"): + if "HIP version" in line: + return line.split()[-1] + return None + + def get_flash_attention_extensions(cuda_version: int, extra_compile_args): # XXX: Not supported on windows for cuda<12 # https://github.com/Dao-AILab/flash-attention/issues/345 @@ -223,11 +240,27 @@ def get_flash_attention_extensions(cuda_version: int, extra_compile_args): ] +def rename_cpp_cu(cpp_files): + for entry in cpp_files: + shutil.copy(entry, os.path.splitext(entry)[0] + ".cu") + + def get_extensions(): extensions_dir = os.path.join("xformers", "csrc") sources = glob.glob(os.path.join(extensions_dir, "**", "*.cpp"), recursive=True) source_cuda = glob.glob(os.path.join(extensions_dir, "**", "*.cu"), recursive=True) + source_hip = glob.glob( + os.path.join(extensions_dir, "attention", "hip_fmha", "**", "*.cpp"), + recursive=True, + ) + source_hip_generated = glob.glob( + os.path.join(extensions_dir, "attention", "hip_fmha", "**", "*.cu"), + recursive=True, + ) + # avoid the temporary .cu files generated under xformers/csrc/attention/hip_fmha + source_cuda = list(set(source_cuda) - set(source_hip_generated)) + sources = list(set(sources) - set(source_hip)) sputnik_dir = os.path.join(this_dir, "third_party", "sputnik") cutlass_dir = os.path.join(this_dir, "third_party", "cutlass", "include") @@ -253,6 +286,7 @@ def get_extensions(): include_dirs = [extensions_dir] ext_modules = [] cuda_version = None + hip_version = None flash_version = "0.0.0" if ( @@ -294,6 +328,7 @@ def get_extensions(): flash_extensions = get_flash_attention_extensions( cuda_version=cuda_version, extra_compile_args=extra_compile_args ) + if flash_extensions: flash_version = get_flash_version() ext_modules += flash_extensions @@ -306,6 +341,51 @@ def get_extensions(): "--ptxas-options=-O2", "--ptxas-options=-allow-expensive-optimizations=true", ] + elif torch.cuda.is_available() and torch.version.hip: + rename_cpp_cu(source_hip) + rocm_home = os.getenv("ROCM_PATH") + hip_version = get_hip_version(rocm_home) + + source_hip_cu = [] + for ff in source_hip: + source_hip_cu += [ff.replace(".cpp", ".cu")] + + extension = CUDAExtension + sources += source_hip_cu + include_dirs += [ + Path(this_dir) / "xformers" / "csrc" / "attention" / "hip_fmha" + ] + + include_dirs += [ + Path(this_dir) + / "third_party" + / "composable_kernel_tiled" + / "example" + / "91_tile_program" + / "xformers_fmha" + ] + + include_dirs += [ + Path(this_dir) / "third_party" / "composable_kernel_tiled" / "include" + ] + + generator_flag = [] + + cc_flag = ["-DBUILD_PYTHON_PACKAGE"] + extra_compile_args = { + "cxx": ["-O3", "-std=c++17"] + generator_flag, + "nvcc": [ + "-O3", + "-std=c++17", + f"--offload-arch={os.getenv('HIP_ARCHITECTURES', 'native')}", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-DCK_FMHA_FWD_FAST_EXP2=1", + "-fgpu-flush-denormals-to-zero", + ] + + generator_flag + + cc_flag, + } ext_modules.append( extension( @@ -320,6 +400,7 @@ def get_extensions(): return ext_modules, { "version": { "cuda": cuda_version, + "hip": hip_version, "torch": torch.__version__, "python": platform.python_version(), "flash": flash_version, @@ -328,6 +409,7 @@ def get_extensions(): k: os.environ.get(k) for k in [ "TORCH_CUDA_ARCH_LIST", + "PYTORCH_ROCM_ARCH", "XFORMERS_BUILD_TYPE", "XFORMERS_ENABLE_DEBUG_ASSERTIONS", "NVCC_FLAGS", diff --git a/tests/readme_test_on_rocm.txt b/tests/readme_test_on_rocm.txt new file mode 100644 index 0000000000..c21fd0d587 --- /dev/null +++ b/tests/readme_test_on_rocm.txt @@ -0,0 +1,13 @@ + + 1. #> pip install -e ./ + + 2. verify testing for generic fmha inference on ROCM + + #> pytest tests/test_mem_eff_attention.py::test_forward + + 3. verify testing for decoder fmha inference on ROCM + + #> pytest tests/test_mem_eff_attention.py::test_decoder + #> pytest tests/test_mem_eff_attention.py::test_splitk_decoder + + diff --git a/tests/test_attentions.py b/tests/test_attentions.py index cf70bbea74..2bdbb2d1ff 100644 --- a/tests/test_attentions.py +++ b/tests/test_attentions.py @@ -107,6 +107,13 @@ def test_order_invariance( causal: bool, device: torch.device, ): + if ( + torch.version.hip + and device == torch.device("cuda") + and attention_name == "local" + ): + # Backend calls into Sputnik library which isn't built on ROCm + device = torch.device("cpu") torch.manual_seed(42) torch.cuda.manual_seed_all(42) diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py index d648bc8d39..d3a831ce48 100644 --- a/tests/test_checkpoint.py +++ b/tests/test_checkpoint.py @@ -111,7 +111,11 @@ def test_checkpoint_with_grad(policy_fn, input_requires_grad, grad_mode): "op", [ xformers.ops.MemoryEfficientAttentionFlashAttentionOp, - xformers.ops.MemoryEfficientAttentionCutlassOp, + ( + xformers.ops.MemoryEfficientAttentionCutlassOp + if torch.version.cuda + else xformers.ops.MemoryEfficientAttentionCkOp + ), ], ) def test_checkpoint_attention(policy_fn, input_requires_grad, device, autocast, op): @@ -121,6 +125,15 @@ def test_checkpoint_attention(policy_fn, input_requires_grad, device, autocast, ): pytest.skip("skipping operator not supported in this arch") + if ( + op is xformers.ops.MemoryEfficientAttentionFlashAttentionOp + and torch.version.hip + ): + pytest.skip("FlashAttentionOp is not supported on ROCM!") + + if op is xformers.ops.MemoryEfficientAttentionCkOp: + pytest.skip("Gradience is currently not supported by ck-tiled!") + class Attn(nn.Module): def forward(self, x): out = xformers.ops.memory_efficient_attention(x, x, x, op=op) diff --git a/tests/test_core_attention.py b/tests/test_core_attention.py index 7c3673bf1e..9fd7464930 100644 --- a/tests/test_core_attention.py +++ b/tests/test_core_attention.py @@ -31,7 +31,9 @@ def fn_and_catch_oor(*args, **kwargs): return fn_and_catch_oor -_devices = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"] +_devices = ( + ["cpu", "cuda"] if torch.cuda.is_available() and torch.version.cuda else ["cpu"] +) def test_core_attention(): @@ -144,6 +146,7 @@ def test_amp_attention_sparsecs(device): @pytest.mark.skipif( not _is_blocksparse_available, reason="Blocksparse is not available" ) +@pytest.mark.skipif(not torch.version.cuda, reason="Sparse ops not supported on ROCm") @pytest.mark.parametrize("device", ["cuda"]) @pytest.mark.parametrize("data_type", [torch.float16, torch.float32]) @catch_oor diff --git a/tests/test_custom_ops.py b/tests/test_custom_ops.py index bef8b41021..7e8a78593e 100644 --- a/tests/test_custom_ops.py +++ b/tests/test_custom_ops.py @@ -16,8 +16,13 @@ _sparse_bmm, ) -cuda_only = pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA") -_devices = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"] +cuda_only = pytest.mark.skipif( + not torch.cuda.is_available() or not torch.version.cuda, reason="requires CUDA" +) + +_devices = ( + ["cpu", "cuda"] if torch.cuda.is_available() and torch.version.cuda else ["cpu"] +) def _baseline_matmul_with_sparse_mask( diff --git a/tests/test_mem_eff_attention.py b/tests/test_mem_eff_attention.py index 8cf38eb86c..0a06217a77 100644 --- a/tests/test_mem_eff_attention.py +++ b/tests/test_mem_eff_attention.py @@ -26,6 +26,13 @@ torch.backends.cuda.matmul.allow_tf32 = False cuda_only = pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA") +rocm_only = pytest.mark.skipif( + not torch.cuda.is_available() or not torch.version.hip, reason="requires ROCM" +) +disable_on_rocm = pytest.mark.skipif( + not not torch.version.hip, reason="could not be done on ROCM" +) + compute_capability = (0, 0) if torch.cuda.is_available(): compute_capability = torch.cuda.get_device_capability("cuda") @@ -460,6 +467,7 @@ def test_forward(opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv, packed, fmt, **kwargs) k, kv, ) = opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv + if packed and not (k == kv and q_len == kv_len): pytest.skip( f"packed incompatible with `k ({k}) != kv ({kv})` or `q_len ({q_len}) != kv_len ({kv_len})`" @@ -565,6 +573,10 @@ def test_logsumexp(opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv): k, kv, ) = opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv + + if op is fmha.ck.FwOp: + pytest.skip("logsumexp is not yet supported by ck-tiled fmha!") + query, key, value, attn_bias = create_tensors( *opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv, fmt="BMK" ) @@ -967,6 +979,7 @@ def test_dropout_backward_cutlass(dt, q_len, kv_len, batch_size, k, p): @cuda_only +@disable_on_rocm @pytest.mark.parametrize("k_len", [32]) @pytest.mark.parametrize("batch_size", [1]) @pytest.mark.parametrize("kv_len", [3 * 32]) @@ -1063,6 +1076,7 @@ def test_cuda_streams( ) = opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv if device != "cuda": pytest.skip("Not CUDA") + bias_type = None opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv = [ op, @@ -1202,6 +1216,13 @@ def test_grad_checkpointing( k, kv, ) = opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv + if op is fmha.triton.FwOp: + pytest.skip("Triton Flash Attention 2 doesn't support backward pass yet") + if op is fmha.triton_splitk.FwOp: + pytest.skip("Triton Flash Decoding doesn't support backward pass yet") + if op is fmha.ck.FwOp: + pytest.skip("ck-tiled FMHA doesn't supported backward pass yet") + bias_type = None opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv = ( op, @@ -1273,6 +1294,7 @@ def test_unsupported_stride_lastdim(op: Type[fmha.AttentionFwOpBase]): q = torch.empty([1, 1, 32, 4], device="cuda", dtype=torch.float16).permute( 0, 3, 1, 2 ) + try: fmha.memory_efficient_attention(q, q, q, op=(op, None)) except ValueError as e: @@ -1288,6 +1310,7 @@ def test_unsupported_stride_lastdim(op: Type[fmha.AttentionFwOpBase]): ) def test_unsupported_stride_alignment(op: Type[fmha.AttentionFwOpBase]): q = torch.empty([1, 2, 1, 33], device="cuda", dtype=torch.float16)[:, :, :, :32] + try: fmha.memory_efficient_attention(q, q, q, op=(op, None)) except ValueError as e: @@ -1547,7 +1570,7 @@ def _kv_heads_label(kv_heads: Optional[int]) -> str: @pytest.mark.parametrize( "op", [ - fmha.decoder.FwOp, + fmha.decoder.FwOp if torch.version.cuda else fmha.ck_decoder.FwOp, ], ) @pytest.mark.parametrize("kv_heads", [None, 1, 2], ids=_kv_heads_label) @@ -1563,6 +1586,7 @@ def test_decoder( dtype: str, dequant: bool = False, num_queries: int = 1, + d: int = 128, ) -> None: # kv_heads = 1: multiquery # kv_heads = None: neither MQA nor GQA @@ -1575,7 +1599,6 @@ def test_decoder( raise pytest.skip("dequant needs triton updates") dtype_ = {"f16": torch.float16, "bf16": torch.bfloat16, "f32": torch.float32}[dtype] torch.manual_seed(1) - d = 128 if kv_heads is not None and kv_heads > 1: k_shape: Tuple[int, ...] = (1, bsz * padding, kv_heads, n_heads, d) q_shape: Tuple[int, ...] = ( @@ -1608,6 +1631,9 @@ def test_decoder( k = k[..., :1, :].expand(k_shape) v = v[..., :1, :].expand(k_shape) + if skip_reasons := op.not_supported_reasons(fmha.Inputs(q, k, v)): + pytest.skip("; ".join(skip_reasons)) + attn_bias = fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens( q_seqlen=[num_queries] * bsz, kv_seqlen=k_seqlen, @@ -1632,14 +1658,13 @@ def dequant_cache(x): k = dequant_cache(k) v = dequant_cache(v) - cutlass_output = fmha.memory_efficient_attention_forward( - q, k, v, attn_bias, op=fmha.cutlass.FwOp - ) + ref_output = ref_attention(q, k, v, attn_bias) + assert_allclose( - decoder_output, - cutlass_output, - atol=fmha.cutlass.FwOp.ERROR_ATOL[dtype_] * 4, - rtol=fmha.cutlass.FwOp.ERROR_RTOL[dtype_], + decoder_output.to(ref_output.dtype), + ref_output, + atol=op.ERROR_ATOL[dtype_] * 4, + rtol=op.ERROR_RTOL[dtype_], ) @@ -1683,6 +1708,36 @@ def test_triton_splitk_decoder( ) +@rocm_only +@pytest.mark.parametrize( + "op", [fmha.ck_splitk.FwOp_S1, fmha.ck_splitk.FwOp_S2, fmha.ck_splitk.FwOp_S4] +) +@pytest.mark.parametrize("dtype", ["f32"]) +@pytest.mark.parametrize("kv_heads", [None, 1, 2], ids=_kv_heads_label) +@pytest.mark.parametrize("n_heads", [16]) +@pytest.mark.parametrize("d", [128, 256]) +@pytest.mark.parametrize("padding, bsz", [(32, 8), (4096, 1), (32, 1), (4096, 8)]) +def test_ck_splitk_decoder( + op, + kv_heads: Optional[int], + n_heads: int, + padding: int, + bsz: int, + dtype: str, + d: int, +) -> None: + # no quantized impl compared to cuda + test_decoder( + op, + kv_heads=kv_heads, + n_heads=n_heads, + padding=padding, + bsz=bsz, + dtype=dtype, + d=d, + ) + + @sm80_or_better_only @pytest.mark.parametrize( "op", @@ -1736,6 +1791,9 @@ def test_attn_bias_blockdiag_doc() -> None: from xformers.ops import fmha + if torch.version.hip: + pytest.skip("backward pass/gradience is not yet supported by ck-tiled fmha!") + K = 16 dtype = torch.float16 device = "cuda" @@ -1788,6 +1846,7 @@ def test_f16_biasf32(self) -> None: with pytest.raises((ValueError, RuntimeError)): fmha.memory_efficient_attention(q, k, v, attn_bias=bias) + @disable_on_rocm def test_f32_biasf16(self) -> None: q, k, v, bias = self.create_tensors(torch.float32) fmha.memory_efficient_attention(q, k, v, attn_bias=bias) @@ -1797,7 +1856,12 @@ def test_f32_biasf16(self) -> None: @pytest.mark.parametrize("dtype", [torch.float32, torch.float16]) def test_wrong_alignment(self, dtype) -> None: - op = fmha.cutlass.FwOp + op = fmha.cutlass.FwOp if torch.version.cuda else fmha.ck.FwOp + if dtype not in op.SUPPORTED_DTYPES: + pytest.skip( + f"{dtype=} is not supported by {op.__module__}.{op.__qualname__}" + ) + q, k, v, bias = self.create_tensors(dtype, Mq=7, Mkv=5) try: fmha.memory_efficient_attention(q, k, v, attn_bias=bias, op=(op, None)) @@ -1849,6 +1913,7 @@ def test_permuted_attn_bias(self) -> None: @cuda_only +@disable_on_rocm @pytest.mark.parametrize("dtype_str", ["f32", "f16", "bf16"]) @pytest.mark.parametrize( "sm_shmem", @@ -2004,8 +2069,10 @@ def test_forward_gqa_one_group(opFW): @sm80_or_better_only +@disable_on_rocm def test_flash_gqa_wrong_strides() -> None: op = (fmha.flash.FwOp, None) + device = "cuda" B, Mq, Mkv, G, H, K = 3, 1, 512, 2, 8, 128 q = torch.empty((B, Mq, G, H, K), dtype=torch.float16, device=device) @@ -2043,6 +2110,7 @@ def _dispatches_to_flash_decoding(q, kv): ) +@disable_on_rocm def test_dispatch_decoding_bmhk() -> None: assert not _dispatches_to_splitK( torch.empty([1, 8, 1, 128]), torch.empty([1, 2048, 1, 128]) @@ -2065,6 +2133,7 @@ def test_dispatch_decoding_bmhk() -> None: ), "Should not use SplitK if B is big" +@disable_on_rocm def test_dispatch_decoding_bmghk() -> None: assert not _dispatches_to_splitK( torch.empty([1, 8, 1, 1, 128]), torch.empty([1, 2048, 1, 1, 128]) @@ -2132,7 +2201,9 @@ def test_forward_splitk( @cuda_only @pytest.mark.parametrize( - "op", [fmha.triton_splitk.FwOp, fmha.flash.FwOp], ids=lambda op: op.NAME + "op", + [fmha.triton_splitk.FwOp, fmha.flash.FwOp, fmha.ck.FwOp], + ids=lambda op: op.NAME, ) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=str) @pytest.mark.parametrize( @@ -2150,8 +2221,8 @@ def test_mqa_decoding(op: Type[fmha.AttentionFwOpBase], dtype, B_Mkv_H_K): k = k.expand(-1, -1, H, -1) v = v.expand(-1, -1, H, -1) - if not op.supports(fmha.Inputs(q, k, v)): - pytest.skip("not supported") + if skip_reasons := op.not_supported_reasons(fmha.Inputs(q, k, v)): + pytest.skip("; ".join(skip_reasons)) out = fmha.memory_efficient_attention_forward(q, k, v, op=op) ref = ref_attention(q, k, v) assert_allclose( @@ -2172,6 +2243,9 @@ def test_empty_tensors_empty_query( ) opFW = opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv[0] + if torch.version.hip: + pytest.skip("backward pass/gradience is not yet supported by ck-tiled fmha!") + query = query[:, :0] query.requires_grad_(True) key.requires_grad_(True) @@ -2194,6 +2268,9 @@ def test_empty_tensors_empty_kv( ) opFW = opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv[0] + if torch.version.hip: + pytest.skip("backward pass/gradience is not yet supported by ck-tiled fmha!") + key = key[:, :0] value = value[:, :0] query.requires_grad_(True) @@ -2216,6 +2293,9 @@ def test_empty_tensors_empty_b( ) opFW = opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv[0] + if torch.version.hip: + pytest.skip("backward pass/gradience is not yet supported by ck-tiled fmha!") + query, key, value = query[:0], key[:0], value[:0] query.requires_grad_(True) key.requires_grad_(True) @@ -2238,6 +2318,7 @@ def test_local_attn_bias() -> None: @cuda_only +@disable_on_rocm @pytest.mark.parametrize("cc", [60, 70, 80]) @pytest.mark.parametrize("maxK", [32, 64, 128, 256]) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16]) @@ -2607,6 +2688,7 @@ def paged_attention_run_inner( torch.testing.assert_close(y_swapped, y_packed) +@disable_on_rocm @sm80_or_better_only @pytest.mark.parametrize( "op", @@ -2667,7 +2749,7 @@ def test_merge_attentions_nobias( assert lse is None -@sm80_or_better_only +@disable_on_rocm @sm80_or_better_only @pytest.mark.parametrize( "dtype,op", diff --git a/tests/test_sparse_tensors.py b/tests/test_sparse_tensors.py index 2834987385..641f2ffc70 100644 --- a/tests/test_sparse_tensors.py +++ b/tests/test_sparse_tensors.py @@ -12,7 +12,9 @@ from xformers.sparse import BlockSparseTensor, SparseCSRTensor cuda_only = pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA") -_devices = ["cpu", "cuda:0"] if torch.cuda.is_available() else ["cpu"] +_devices = ( + ["cpu", "cuda:0"] if torch.cuda.is_available() and torch.version.cuda else ["cpu"] +) _tensor_types = [BlockSparseTensor, SparseCSRTensor] diff --git a/tests/test_splitk_reference.py b/tests/test_splitk_reference.py new file mode 100644 index 0000000000..b62379cd4c --- /dev/null +++ b/tests/test_splitk_reference.py @@ -0,0 +1,223 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional, Tuple + +import pytest +import torch + +import xformers.ops +from xformers.ops import fmha + +from .test_mem_eff_attention import ref_attention +from .utils import assert_allclose + +torch.backends.cuda.matmul.allow_tf32 = False + + +def ref_attention_splitk_bmhk( + q, k, v, attn_bias, scale=None, split_k=None, dtype=None +) -> torch.Tensor: + assert q.ndim == 4 + + def T(t): + return t.permute((0, 2, 1, 3)).reshape( + [t.shape[0] * t.shape[2], t.shape[1], t.shape[3]] + ) + + if isinstance(attn_bias, xformers.ops.AttentionBias): + attn_bias = attn_bias.materialize( + (q.shape[0], q.shape[2], q.shape[1], k.shape[1]), + device=q.device, + dtype=torch.float32, + ).reshape([q.shape[0] * q.shape[2], q.shape[1], k.shape[1]]) + out = ref_attention_splitk( + T(q), T(k), T(v), attn_bias, scale=scale, split_k=split_k, dtype=dtype + ) + out = out.reshape([q.shape[0], q.shape[2], q.shape[1], v.shape[3]]) + return out.permute((0, 2, 1, 3)) + + +def ref_attention_splitk( + q, k, v, attn_bias, scale=None, split_k=2, dtype=None +) -> torch.Tensor: + if q.ndim == 5: + + def attn_bias_group(group: int): + if isinstance(attn_bias, torch.Tensor): + return attn_bias[:, group] + if isinstance(attn_bias, fmha.attn_bias.LowerTriangularMaskWithTensorBias): + return fmha.attn_bias.LowerTriangularMaskWithTensorBias( + attn_bias._bias[:, group] + ) + return attn_bias + + return torch.stack( + [ + ref_attention_splitk_bmhk( + q[:, :, g], + k[:, :, g], + v[:, :, g], + attn_bias=attn_bias_group(g), + split_k=split_k, + dtype=dtype, + ) + for g in range(q.shape[2]) + ], + dim=2, + ) + + if q.ndim == 4: + return ref_attention_splitk_bmhk( + q, k, v, attn_bias=attn_bias, split_k=split_k, dtype=dtype + ) + assert q.ndim == 3 + if dtype is None: + dtype = torch.float32 + q = q.to(dtype=dtype) + k = k.to(dtype=dtype) + v = v.to(dtype=dtype) + + if scale is None: + scale = q.shape[-1] ** -0.5 + assert not q.isnan().any() + q = q * scale + assert not q.isnan().any() + + if attn_bias is not None: + if isinstance(attn_bias, xformers.ops.AttentionBias): + # Always create in B,H,Mq,Mk format + attn_bias_tensor = attn_bias.materialize( + (q.shape[0], 1, q.shape[1], k.shape[1]), + device=q.device, + dtype=torch.float32, + ) + else: + attn_bias_tensor = attn_bias + if attn_bias_tensor.ndim == 4: + assert q.shape[0] == attn_bias_tensor.shape[0] * attn_bias_tensor.shape[1] + attn_bias_tensor = attn_bias_tensor.reshape( + [-1, *attn_bias_tensor.shape[2:]] + ) + + split_size = k.size(-2) // split_k + split_config = {"dim": -2, "split_size_or_sections": split_size} + k_split = torch.split(k, **split_config) + v_split = torch.split(v, **split_config) + attn_bias_split = torch.split( + attn_bias_tensor, dim=-1, split_size_or_sections=split_size + ) + + def compute_attention_split(q_whole, k_slice, v_slice, attn_bias_slice): + p_slice = q_whole @ k_slice.transpose(-2, -1) + p_slice += attn_bias_slice + row_max = torch.max(p_slice, dim=-1, keepdim=True).values + p_slice_scaled = p_slice - row_max + p_slice_scaled[p_slice_scaled.isnan()] = float("-inf") + s = torch.exp(p_slice_scaled) + row_sumexp = torch.sum(s, dim=-1, keepdim=True) + attn_slice = s @ v_slice + return { + "attn_slice": attn_slice, + "row_max": row_max, + "row_sumexp": row_sumexp, + } + + splits = list(zip(k_split, v_split, attn_bias_split)) + + slices = list(map(lambda s: compute_attention_split(q, s[0], s[1], s[2]), splits)) + out = torch.zeros_like(q) + + # reduce out over split-k slices + + global_max = torch.zeros_like(slices[0]["row_max"]).fill_(float("-inf")) + global_sumexp = torch.zeros_like(slices[0]["row_sumexp"]) + + for s in slices: + local_out = s["attn_slice"] + local_max = s["row_max"] + local_sumexp = s["row_sumexp"] + + log_alpha = -torch.abs(local_max - global_max) + alpha = torch.exp(log_alpha) + alpha.nan_to_num_(1.0) + + pick_new = local_max < global_max + new_coef = torch.where(pick_new, alpha, 1.0) + curr_coef = torch.where(pick_new, 1.0, alpha) + + out = out * curr_coef + local_out * new_coef + global_sumexp = global_sumexp * curr_coef + local_sumexp * new_coef + global_max = torch.max(local_max, global_max) + out /= global_sumexp + return out + + +def _kv_heads_label(kv_heads: Optional[int]) -> str: + if kv_heads is None: + return "" + if kv_heads == 1: + return "mq" + return f"gqa{kv_heads}" + + +@pytest.mark.parametrize("dtype", ["f32"]) +@pytest.mark.parametrize("kv_heads", [None, 1, 2], ids=_kv_heads_label) +@pytest.mark.parametrize("n_heads", [16]) +@pytest.mark.parametrize("padding, bsz", [(32, 8), (4096, 1)]) +@pytest.mark.parametrize("split_k", [1, 2, 4]) +@pytest.mark.parametrize("device", ["cpu"]) +def test_splitk_reference( + kv_heads: int, + n_heads: int, + padding: int, + bsz: int, + dtype: str, + device: str, + split_k: int, +): + dtype_ = {"f16": torch.float16, "bf16": torch.bfloat16, "f32": torch.float32}[dtype] + torch.manual_seed(1) + d = 256 + num_queries = 1 + if kv_heads is not None and kv_heads > 1: + k_shape: Tuple[int, ...] = (1, bsz * padding, kv_heads, n_heads, d) + q_shape: Tuple[int, ...] = ( + 1, + bsz * num_queries, + kv_heads, + n_heads, + d, + ) + else: + k_shape = (1, bsz * padding, n_heads, d) + q_shape = (1, bsz * num_queries, n_heads, d) + + k = torch.rand(k_shape, dtype=dtype_, device=device) + k_seqlen = torch.randint(1, padding + 1, (bsz,)).tolist() + v = torch.rand_like(k) + q = torch.rand(q_shape, dtype=dtype_, device=device) + causal_diagonal = torch.tensor( # TODO: make unnecessary + [i - 1 for i in k_seqlen], dtype=torch.int32, device=device + ) + + if kv_heads is not None: + k = k[..., :1, :].expand(k_shape) + v = v[..., :1, :].expand(k_shape) + + attn_bias = fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens( + q_seqlen=[1] * bsz, + kv_seqlen=k_seqlen, + causal_diagonal=causal_diagonal, + kv_padding=padding, + ) + ref_out = ref_attention(q, k, v, attn_bias) + splitk_out = ref_attention_splitk(q, k, v, attn_bias, None, split_k=split_k) + assert_allclose( + ref_out, + splitk_out, + atol=fmha.ck.FwOp.ERROR_ATOL[dtype_], + rtol=fmha.ck.FwOp.ERROR_RTOL[dtype_], + ) diff --git a/tests/test_swiglu.py b/tests/test_swiglu.py index d5e7795361..2dd79152c2 100644 --- a/tests/test_swiglu.py +++ b/tests/test_swiglu.py @@ -29,6 +29,10 @@ torch.__version__ < "2.2.0.dev20231122", reason="requires PyTorch 2.2+" ) +disable_on_rocm = pytest.mark.skipif( + not not torch.version.hip, reason="could not be done on ROCM" +) + def assert_allclose( # The output of the tested function @@ -135,6 +139,7 @@ def create_module_cached(**kwargs) -> xsw.SwiGLU: return xsw.SwiGLU(**kwargs) +@disable_on_rocm @pytest.mark.parametrize("autocast", [False, True], ids=["regular", "autocast"]) @pytest.mark.parametrize("op", _ops, ids=[x.NAME for x in _ops]) @pytest.mark.parametrize("dtype", _dtypes, ids=[str(x) for x in _dtypes]) diff --git a/third_party/composable_kernel_tiled b/third_party/composable_kernel_tiled new file mode 160000 index 0000000000..b344343273 --- /dev/null +++ b/third_party/composable_kernel_tiled @@ -0,0 +1 @@ +Subproject commit b344343273cf6731ba0a47e061629890a8014af5 diff --git a/xformers/_cpp_lib.py b/xformers/_cpp_lib.py index 4eb6fd9814..d5d0117005 100644 --- a/xformers/_cpp_lib.py +++ b/xformers/_cpp_lib.py @@ -27,6 +27,10 @@ class _BuildInfo: def cuda_version(self) -> Optional[int]: return self.metadata["version"]["cuda"] + @property + def hip_version(self) -> Optional[int]: + return self.metadata["version"]["hip"] + @property def torch_version(self) -> str: return self.metadata["version"]["torch"] diff --git a/xformers/benchmarks/LRA/run_tasks.py b/xformers/benchmarks/LRA/run_tasks.py index e9d1f72843..41c5fbe55e 100644 --- a/xformers/benchmarks/LRA/run_tasks.py +++ b/xformers/benchmarks/LRA/run_tasks.py @@ -53,9 +53,11 @@ def build_model(args: argparse.Namespace, config: Dict) -> nn.Module: model = cast( pl.LightningModule, - ModelForSCDual(config[f"{task}"], attention_name) - if task == Task.Retrieval - else ModelForSC(config[f"{task}"], attention_name), + ( + ModelForSCDual(config[f"{task}"], attention_name) + if task == Task.Retrieval + else ModelForSC(config[f"{task}"], attention_name) + ), ) logging.info(model) @@ -252,9 +254,11 @@ def benchmark(args): trainer = pl.Trainer( accelerator="gpu", - strategy=DDPStrategy(find_unused_parameters=args.debug) - if not args.skip_train - else None, + strategy=( + DDPStrategy(find_unused_parameters=args.debug) + if not args.skip_train + else None + ), accumulate_grad_batches=config_training["gradient_accumulation"], callbacks=[progress_bar, checkpoint_callback], detect_anomaly=args.debug, diff --git a/xformers/benchmarks/benchmark_attn_decoding.py b/xformers/benchmarks/benchmark_attn_decoding.py index 2300e233fe..ed457757fd 100644 --- a/xformers/benchmarks/benchmark_attn_decoding.py +++ b/xformers/benchmarks/benchmark_attn_decoding.py @@ -3,8 +3,8 @@ # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. - -from typing import Any +import sys +from typing import Any, Dict, Type import torch from torch.utils import benchmark @@ -17,11 +17,9 @@ CASES = [ - dict(B=max(1, 2 ** (16 - i)), Mq=1, Mkv=2**i, Hq=16, Hkv=1, K=128) - for i in range(8, 18) -] + [ - dict(B=max(1, 2 ** (16 - i)), Mq=1, Mkv=2**i, Hq=16, Hkv=2, K=128) + dict(B=max(1, 2 ** (16 - i)), Mq=1, Mkv=2**i, Hq=16, Hkv=hkv, K=128) for i in range(8, 18) + for hkv in (1, 2) ] @@ -96,13 +94,28 @@ def __init__( self.v = self.v[:, :, 0] def fw(self) -> None: - xops.memory_efficient_attention_forward(self.q, self.k, self.v, op=self.OP) + try: + xops.memory_efficient_attention_forward(self.q, self.k, self.v, op=self.OP) + except (RuntimeError, ValueError) as e: + print(f"Runtime error: {e}") + + +class AttentionDecodingCK(AttentionDecodingFlashDecoding): + OP = xops.fmha.ck.FwOp + + +class AttentionDecodingCKDecoder(AttentionDecodingFlashDecoding): + OP = xops.fmha.ck_decoder.FwOp class AttentionDecodingSplitKV(AttentionDecodingFlashDecoding): OP = xops.fmha.triton_splitk.FwOp +class AttentionDecodingCKSplitKV(AttentionDecodingFlashDecoding): + OP = xops.fmha.ck_splitk.FwOp + + class AttentionDecodingPyTorchRepeat(AttentionDecodingFlashDecoding): def fw(self) -> None: B, Mq, Mkv, Hq, Hkv, K = self.shapes @@ -114,12 +127,25 @@ def fw(self) -> None: return attn @ v -BENCHMARKS = { +BENCHMARKS: Dict[str, Type[AttentionDecodingFlashDecoding]] = { "pytorch": AttentionDecodingPyTorchRepeat, - "flash-decoding": AttentionDecodingFlashDecoding, - "triton_splitK": AttentionDecodingSplitKV, } +if torch.version.cuda: + BENCHMARKS["flash-decoding"] = AttentionDecodingFlashDecoding + +if torch.version.hip: + BENCHMARKS.update( + { + "ck": AttentionDecodingCK, + "ck-decoder": AttentionDecodingCKDecoder, + "ck_splitK": AttentionDecodingCKSplitKV, + } + ) + + +if (sys.version_info.major, sys.version_info.minor) >= (3, 9): + BENCHMARKS["triton_splitK"] = AttentionDecodingSplitKV try: import flash_attn diff --git a/xformers/benchmarks/benchmark_blocksparse_transformers.py b/xformers/benchmarks/benchmark_blocksparse_transformers.py index f9cb72a15c..3cdd9a3692 100644 --- a/xformers/benchmarks/benchmark_blocksparse_transformers.py +++ b/xformers/benchmarks/benchmark_blocksparse_transformers.py @@ -60,7 +60,7 @@ def get_mask(MaskGenType, config, config_setter=[]): # Get the mask mask_generator = MaskGenType(mask_config) - for (key, value) in config_setter: + for key, value in config_setter: mask_generator.set_config_attr(key, value) if not mask_generator.is_valid_config(): return None @@ -73,7 +73,7 @@ def densify_mask(mask, config): seq_length = config.seq_length block_size = config.block_size dense_mask = torch.zeros(num_heads, seq_length, seq_length) - for (h, i, j) in zip(*mask.nonzero(as_tuple=True)): + for h, i, j in zip(*mask.nonzero(as_tuple=True)): dense_mask[ h, i * block_size : (i + 1) * block_size, diff --git a/xformers/benchmarks/benchmark_core.py b/xformers/benchmarks/benchmark_core.py index 97cdefa09a..2a4d675605 100644 --- a/xformers/benchmarks/benchmark_core.py +++ b/xformers/benchmarks/benchmark_core.py @@ -252,7 +252,10 @@ def bench_bmm(): compare.print() -bench_sddmm() -bench_matmul_with_mask() -bench_softmax() -bench_bmm() +if torch.version.hip: + print("This benchmark could not be done on ROCM!") +else: + bench_sddmm() + bench_matmul_with_mask() + bench_softmax() + bench_bmm() diff --git a/xformers/benchmarks/benchmark_indexing.py b/xformers/benchmarks/benchmark_indexing.py index ed1e71001f..353b9dba7d 100644 --- a/xformers/benchmarks/benchmark_indexing.py +++ b/xformers/benchmarks/benchmark_indexing.py @@ -111,7 +111,7 @@ def __init__(self, dtype, batches, D, keep_ratio, bw: bool) -> None: indices = [] sources = [] - for (B, seqlen) in batches: + for B, seqlen in batches: index = [i for i in range(B)] random.Random(B).shuffle(index) indices.append( diff --git a/xformers/benchmarks/benchmark_mem_eff_attention.py b/xformers/benchmarks/benchmark_mem_eff_attention.py index 1f8affe51b..bbeb222648 100644 --- a/xformers/benchmarks/benchmark_mem_eff_attention.py +++ b/xformers/benchmarks/benchmark_mem_eff_attention.py @@ -10,11 +10,11 @@ import torch from torch.utils import benchmark -from utils import benchmark_main_helper import xformers.ops import xformers.ops.fmha as fmha from xformers.attn_bias_utils import create_attn_bias +from xformers.benchmarks.utils import benchmark_main_helper torch.backends.cuda.matmul.allow_tf32 = False @@ -102,12 +102,23 @@ def T(t): ), ] + +class TritonFlashAttentionFwAutotuned(xformers.ops.fmha.triton.FwOp): + AUTOTUNE = True + + OPS = [ (xformers.ops.fmha.cutlass.FwOp, xformers.ops.fmha.cutlass.BwOp), (xformers.ops.fmha.flash.FwOp, xformers.ops.fmha.flash.BwOp), - # TODO: Triton is not stable: it can trigger Illegal Memory Accesses - # and its performance varies a lot between runs. - # (xformers.ops.fmha.triton.FwOp, xformers.ops.fmha.triton.BwOp), + (xformers.ops.fmha.ck.FwOp, xformers.ops.fmha.ck.BwOp), + ( + TritonFlashAttentionFwAutotuned, + ( + xformers.ops.fmha.cutlass.BwOp + if torch.version.cuda + else xformers.ops.fmha.ck.BwOp + ), + ), ] diff --git a/xformers/benchmarks/benchmark_mem_eff_attention_mqa.py b/xformers/benchmarks/benchmark_mem_eff_attention_mqa.py new file mode 100644 index 0000000000..4e4c47e380 --- /dev/null +++ b/xformers/benchmarks/benchmark_mem_eff_attention_mqa.py @@ -0,0 +1,262 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import itertools +import random +from functools import partial + +import torch +from torch.utils import benchmark + +import xformers.ops +import xformers.ops.fmha as fmha +from xformers.attn_bias_utils import create_attn_bias +from xformers.benchmarks.utils import benchmark_main_helper + +torch.backends.cuda.matmul.allow_tf32 = False + + +# this interface assumes the tensor is in BMHK, but q and k/v might has different number of heads +def ref_attention_mqa( + q, k, v, attn_bias=None, drop_mask=None, p=0.0, scale=None, dtype=None +): + if q.ndim == 4: + B, M, Hq, K = q.shape + _, N, Hkv, Kv = v.shape + nhead_ratio_qk = Hq // Hkv + + def attn_bias_head(head: int): + if isinstance(attn_bias, torch.Tensor): + assert attn_bias.ndim == 4 + _, H, _, _ = attn_bias.shape + assert H == Hq + bias_bghmn = attn_bias.reshape(B, Hkv, nhead_ratio_qk, M, N) + return bias_bghmn[:, :, head] + if isinstance(attn_bias, fmha.attn_bias.LowerTriangularMaskWithTensorBias): + assert attn_bias._bias.ndim == 4 + _, H, _, _ = attn_bias._bias.shape + assert H == Hq + bias_bghmn = attn_bias._bias.reshape(B, Hkv, nhead_ratio_qk, M, N) + + return fmha.attn_bias.LowerTriangularMaskWithTensorBias( + bias_bghmn[:, :, head] + ) + return attn_bias + + q_bmghk = q.reshape((B, M, Hkv, nhead_ratio_qk, K)) + + return torch.stack( + [ + ref_attention_bmhk( + q_bmghk[:, :, :, h], k, v, attn_bias=attn_bias_head(h), dtype=dtype + ) + for h in range(q_bmghk.shape[3]) + ], + dim=3, + ).reshape((B, M, Hq, Kv)) + + assert q.ndim == 3 + if dtype is None: + dtype = torch.float32 + q = q.to(dtype=dtype) + k = k.to(dtype=dtype) + v = v.to(dtype=dtype) + + scale = scale if scale is not None else (q.shape[-1] ** -0.5) + q = q * scale + + attn = q @ k.transpose(-2, -1) + if attn_bias is not None: + if isinstance(attn_bias, xformers.ops.AttentionBias): + # Always create in B,H,Mq,Mk format + attn_bias_tensor = attn_bias.materialize( + (q.shape[0], 1, q.shape[1], k.shape[1]), + device=q.device, + dtype=dtype, + ) + else: + attn_bias_tensor = attn_bias.to(dtype=dtype) + if attn_bias_tensor.ndim == 4: + assert q.shape[0] == attn_bias_tensor.shape[0] * attn_bias_tensor.shape[1] + attn_bias_tensor = attn_bias_tensor.reshape( + [-1, *attn_bias_tensor.shape[2:]] + ) + attn = attn + attn_bias_tensor + attn = attn.softmax(-1) + if drop_mask is not None: + attn = attn * (drop_mask / (1 - p)) + return attn @ v + + +# ref_attention_bmhk is completely the same as used by test_forward_ck_tiled.py +def ref_attention_bmhk(q, k, v, attn_bias, scale=None, dtype=None) -> torch.Tensor: + assert q.ndim == 4 + + def T(t): + return t.permute((0, 2, 1, 3)).reshape( + [t.shape[0] * t.shape[2], t.shape[1], t.shape[3]] + ) + + if isinstance(attn_bias, xformers.ops.AttentionBias): + attn_bias = attn_bias.materialize( + (q.shape[0], q.shape[2], q.shape[1], k.shape[1]), + device=q.device, + dtype=torch.float32, + ).reshape([q.shape[0] * q.shape[2], q.shape[1], k.shape[1]]) + out = ref_attention_mqa(T(q), T(k), T(v), attn_bias, scale=scale, dtype=dtype) + out = out.reshape([q.shape[0], q.shape[2], q.shape[1], v.shape[3]]) + return out.permute((0, 2, 1, 3)) + + +min_run_time = 0.5 +device = torch.device("cuda") + +NUM_THREADS = [1] if device.type == "cuda" else [1, 40] +SHAPES = [ + (1, 512, 512, 64, 8, 128), + (1, 1024, 1024, 64, 8, 128), + (1, 2048, 2048, 64, 8, 128), + (1, 4096, 4096, 64, 8, 128), + (1, 8192, 8192, 64, 8, 128), + (1, 16384, 16384, 64, 8, 128), + (1, 1024, 1024, 64, 8, 64), + (1, 1024, 1024, 8, 1, 64), + (1, 1024, 1024, 4, 4, 64), + # *sorted(itertools.product([1, 2], [2048, 4096], [2048, 4096], [4, 8], [1, 2], [128])), + # *sorted( + # itertools.product([16], [128, 512], [512, 1024], [16], [2, 4], [64, 128]) + # ), +] + +OPS = [ + xformers.ops.fmha.ck.FwOp, + xformers.ops.fmha.flash.FwOp, + # TODO: Triton is not stable: it can trigger Illegal Memory Accesses + # and its performance varies a lot between runs. + # + # xformers.ops.fmha.triton.FwOp, +] + + +def product_dict(**kwargs): + keys = kwargs.keys() + vals = kwargs.values() + for instance in itertools.product(*vals): + yield dict(zip(keys, instance)) + + +CASES = list( + product_dict( + shape=SHAPES, + num_threads=NUM_THREADS, + dropout_p=[0.0], + attn_bias_type=[type(None)], + dtype=[torch.half, torch.bfloat16], + ) +) + +# Add more cases with some variations +for c in CASES.copy(): + c = c.copy() + c.update( + random.Random(str(c["shape"])).choice( + [ + {"attn_bias_type": torch.Tensor}, + {"attn_bias_type": xformers.ops.LowerTriangularMask}, + ] + ) + ) + CASES.append(c) + + +def create_tensors(shape, dtype, requires_grad=False): + B, M, N, Hq, Hkv, K = shape + q = torch.rand( + [B, M, Hq, K], device=device, dtype=dtype, requires_grad=requires_grad + ) + k = torch.rand( + [B, N, Hkv, K], device=device, dtype=dtype, requires_grad=requires_grad + ) + v = torch.rand( + [B, N, Hkv, K], device=device, dtype=dtype, requires_grad=requires_grad + ) + return q, k, v + + +def mem_eff_attention_fw(shape, num_threads: int, attn_bias_type, dropout_p, dtype): + B, M, N, Hq, Hkv, K = shape + nhead_ratio_qk = Hq // Hkv + q, k, v = create_tensors(shape, dtype) + bias = create_attn_bias( + attn_bias_type, + batch_size=B, + num_heads=Hq, + num_heads_groups=nhead_ratio_qk, + q_len=M, + kv_len=N, + device=device, + dtype=dtype, + requires_grad=False, + fmt="BMHK", + op=fmha.ck.FwOp, # only required as a refer op by create_attn_bias + ) + inp = fmha.Inputs(query=q, key=k, value=v, attn_bias=bias, p=dropout_p) + + dtype_str = { + torch.bfloat16: "b16", + torch.half: "f16", + torch.float: "f32", + }[dtype] + sub_label = ( + f"{dtype_str} {B}-{M}-{N}-{Hq}-{Hkv}-{K}, p={dropout_p}, " + f"BiasT={attn_bias_type.__name__}" + ) + + has_run = False + for fw_op in OPS: + if not fw_op.supports(inp): + continue + + yield benchmark.Timer( + stmt="fn(q, k, v, attn_bias, p)", + globals={ + "q": q, + "k": k, + "v": v, + "attn_bias": inp.attn_bias, + "p": dropout_p, + "fn": partial( + xformers.ops.memory_efficient_attention_forward, op=fw_op + ), + }, + label=f"attention (attn_bias={attn_bias_type})", + description=fw_op.NAME, + sub_label=sub_label, + num_threads=num_threads, + ) + has_run = True + + if not has_run: + return + + yield benchmark.Timer( + stmt="fn(q, k, v, attn_bias, p)", + globals={ + "q": q, + "k": k, + "v": v, + "attn_bias": inp.attn_bias, + "p": dropout_p, + "fn": ref_attention_mqa, + }, + label=f"attention (attn_bias={attn_bias_type})", + description="eager", + sub_label=sub_label, + num_threads=num_threads, + ) + + +benchmark_main_helper(mem_eff_attention_fw, CASES, min_run_time=min_run_time) diff --git a/xformers/benchmarks/benchmark_mem_eff_attn_decoder.py b/xformers/benchmarks/benchmark_mem_eff_attn_decoder.py index 7f1b4ceaa4..67698c87c4 100644 --- a/xformers/benchmarks/benchmark_mem_eff_attn_decoder.py +++ b/xformers/benchmarks/benchmark_mem_eff_attn_decoder.py @@ -59,8 +59,12 @@ def T(t): NUM_THREADS = [1] if device.type == "cuda" else [1, 40] OPS = [ - xformers.ops.fmha.cutlass.FwOp, - xformers.ops.fmha.decoder.FwOp, + xformers.ops.fmha.cutlass.FwOp if torch.version.cuda else xformers.ops.fmha.ck.FwOp, + ( + xformers.ops.fmha.decoder.FwOp + if torch.version.cuda + else xformers.ops.fmha.ck_decoder.FwOp + ), ] KV_SHAPES = [ diff --git a/xformers/benchmarks/benchmark_nystrom_utils.py b/xformers/benchmarks/benchmark_nystrom_utils.py index 6f4b38c846..c85b034568 100644 --- a/xformers/benchmarks/benchmark_nystrom_utils.py +++ b/xformers/benchmarks/benchmark_nystrom_utils.py @@ -93,7 +93,9 @@ def iterative_pinv_analysis( break -if __name__ == "__main__": +if torch.version.hip: + print("This benchmark could not be done on ROCM!") +else: iterative_pinv_analysis() bench_inverse(iterative_pinv) bench_inverse(torch.linalg.pinv) diff --git a/xformers/benchmarks/benchmark_sddmm.py b/xformers/benchmarks/benchmark_sddmm.py index 693e4a6236..536fc5ef8e 100644 --- a/xformers/benchmarks/benchmark_sddmm.py +++ b/xformers/benchmarks/benchmark_sddmm.py @@ -109,9 +109,12 @@ def bench_sddmm(configs): results = [] -print("Swin Transformer") -results += bench_sddmm(swin_t_config) -print("ViT") -results += bench_sddmm(vit_config) -print("Basic cases") -results += bench_sddmm(basic_config) +if torch.version.hip: + print("This benchmark could not be done on ROCM!") +else: + print("Swin Transformer") + results += bench_sddmm(swin_t_config) + print("ViT") + results += bench_sddmm(vit_config) + print("Basic cases") + results += bench_sddmm(basic_config) diff --git a/xformers/benchmarks/benchmark_swiglu.py b/xformers/benchmarks/benchmark_swiglu.py index ffa413a954..b283673347 100644 --- a/xformers/benchmarks/benchmark_swiglu.py +++ b/xformers/benchmarks/benchmark_swiglu.py @@ -11,9 +11,9 @@ import torch from torch.utils import benchmark -from utils import benchmark_main_helper import xformers.ops.swiglu_op as xsw +from xformers.benchmarks.utils import benchmark_main_helper min_run_time = 0.5 device = torch.device("cuda") @@ -156,5 +156,8 @@ def benchmark_swiglu_bw(shape, dtype, bias: bool): ) -benchmark_main_helper(benchmark_swiglu, CASES, min_run_time=min_run_time) -benchmark_main_helper(benchmark_swiglu_bw, CASES, min_run_time=min_run_time) +if torch.version.hip: + print("This benchmark could not be done on ROCM!") +else: + benchmark_main_helper(benchmark_swiglu, CASES, min_run_time=min_run_time) + benchmark_main_helper(benchmark_swiglu_bw, CASES, min_run_time=min_run_time) diff --git a/xformers/benchmarks/benchmark_transformer.py b/xformers/benchmarks/benchmark_transformer.py index 3f9930cbb5..4346af9c19 100644 --- a/xformers/benchmarks/benchmark_transformer.py +++ b/xformers/benchmarks/benchmark_transformer.py @@ -15,9 +15,9 @@ from timm.models.vision_transformer import Attention as TimmAttention from timm.models.vision_transformer import Block as TimmBlock from torch.utils import benchmark -from utils import benchmark_main_helper import xformers.ops as xops +from xformers.benchmarks.utils import benchmark_main_helper def replace_module(module: nn.Module, replace_class, factory): @@ -153,4 +153,7 @@ def benchmark_transformer(model_info, dtype) -> Iterator[benchmark.Timer]: ) -benchmark_main_helper(benchmark_transformer, CASES) +if torch.version.hip: + print("This benchmark could not be done on ROCM!") +else: + benchmark_main_helper(benchmark_transformer, CASES) diff --git a/xformers/benchmarks/readme_benchmark_on_rocm.txt b/xformers/benchmarks/readme_benchmark_on_rocm.txt new file mode 100644 index 0000000000..9ae61f5294 --- /dev/null +++ b/xformers/benchmarks/readme_benchmark_on_rocm.txt @@ -0,0 +1,17 @@ + + + 1. #> pip install -e ./ + + 2. Benchmark for generic fmha inference on ROCM + + #> python xformers/benchmarks/benchmark_mem_eff_attention.py + + 3. Benchmark for decoder fmha inference on ROCM + + #> python xformers/benchmarks/benchmark_mem_eff_attn_decoder.py + + 4. Other Benchmarks for fmha inference on ROCM + + #> python xformers/benchmarks/benchmark_attn_decoding.py + #> python xformers/benchmarks/benchmark_mem_eff_attention_mqa.py + diff --git a/xformers/benchmarks/utils.py b/xformers/benchmarks/utils.py index c3834ad5ab..ef508661ac 100644 --- a/xformers/benchmarks/utils.py +++ b/xformers/benchmarks/utils.py @@ -263,9 +263,9 @@ def _benchmark_results_from_csv(filename: str) -> List[Tuple[Dict[str, Any], Any data.append( ( { - META_ALGORITHM: row["algorithm"] - if row["algorithm"] != "" - else None, + META_ALGORITHM: ( + row["algorithm"] if row["algorithm"] != "" else None + ), }, measurement, ) @@ -282,9 +282,11 @@ def _benchmark_results_to_csv( "label": r.task_spec.label, "num_threads": r.task_spec.num_threads, "algorithm": metadata.get(META_ALGORITHM, ""), - "description": r.task_spec.description - if r.task_spec.description in BASELINE_DESCRIPTIONS - else "", + "description": ( + r.task_spec.description + if r.task_spec.description in BASELINE_DESCRIPTIONS + else "" + ), "runtime_us": int(1000 * 1000 * r.mean), "mem_use_mb": r.mem_use, } @@ -478,6 +480,7 @@ def benchmark_run_and_compare( .replace(" ", "_") .replace("-", "_") .replace(".", "_") + .replace("/", "_") ) except (RuntimeError, AssertionError): # No GPU env = "cpu" @@ -519,7 +522,7 @@ def benchmark_run_and_compare( # pbar.write(f"Skipped (NotImplementedError)") continue except RuntimeError as e: - if "CUDA out of memory" not in str(e): + if not _is_oom_error(e): raise if not quiet: pbar.write("Skipped (OOM)") @@ -564,7 +567,7 @@ def benchmark_run_and_compare( memory = torch.cuda.max_memory_allocated() / 2**20 - mem_begin measurement.mem_use = memory except RuntimeError as e: - if "CUDA out of memory" not in str(e): + if not _is_oom_error(e): raise if not quiet: pbar.write("Skipped (OOM)") @@ -573,7 +576,7 @@ def benchmark_run_and_compare( if not quiet: pbar.write(f"{name}: memory used: {memory} MB") except RuntimeError as e: - if "CUDA out of memory" not in str(e): + if not _is_oom_error(e): raise if not quiet: pbar.write("Skipped (OOM)") @@ -615,6 +618,12 @@ def matches_current(r): ) +def _is_oom_error(e): + return isinstance( + e, (torch.cuda.OutOfMemoryError, triton.runtime.autotuner.OutOfResources) + ) + + def _fail_if_regressions( results: List[Any], reference: List[Any], atol_s: float, rtol: float ) -> None: diff --git a/xformers/csrc/attention/attention.cpp b/xformers/csrc/attention/attention.cpp index 6b025d6763..36a9675e72 100644 --- a/xformers/csrc/attention/attention.cpp +++ b/xformers/csrc/attention/attention.cpp @@ -8,18 +8,48 @@ #include TORCH_LIBRARY_FRAGMENT(xformers, m) { +#if !defined(USE_ROCM) m.def(TORCH_SELECTIVE_SCHEMA( - "xformers::efficient_attention_forward_small_k(Tensor query, Tensor key, Tensor value, bool compute_logsumexp, Tensor? attn_bias, float p) -> (Tensor, Tensor, int, int)")); + "xformers::efficient_attention_forward_small_k(Tensor query, Tensor key, Tensor value, " + "bool compute_logsumexp, Tensor? attn_bias, float p) -> (Tensor, Tensor, int, int)")); m.def(TORCH_SELECTIVE_SCHEMA( - "xformers::efficient_attention_forward_cutlass(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, Tensor? seqstart_q, Tensor? seqstart_k, int? max_seqlen_q, float dropout_p, bool compute_logsumexp, int custom_mask_type, float? scale, Tensor? seqlen_k, int? window_size) -> (Tensor, Tensor, int, int)")); + "xformers::efficient_attention_forward_cutlass(Tensor query, Tensor key, Tensor value, " + "Tensor? attn_bias, Tensor? seqstart_q, Tensor? seqstart_k, int? max_seqlen_q, float " + "dropout_p, bool compute_logsumexp, int custom_mask_type, float? scale, Tensor? seqlen_k, " + "int? window_size) -> (Tensor, Tensor, int, int)")); m.def(TORCH_SELECTIVE_SCHEMA( - "xformers::efficient_attention_forward_decoder(Tensor query, Tensor key, Tensor value, Tensor seq_positions, float scale) -> Tensor")); + "xformers::efficient_attention_forward_decoder(Tensor query, Tensor " + "key, Tensor value, Tensor seq_positions, float scale) -> Tensor")); m.def(TORCH_SELECTIVE_SCHEMA( - "xformers::efficient_attention_backward_small_k(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor logsumexp, Tensor output, Tensor? attn_bias, float p, int rng_seed, int rng_offset) -> (Tensor, Tensor, Tensor)")); + "xformers::efficient_attention_backward_small_k(Tensor grad_out, Tensor query, Tensor key, " + "Tensor value, Tensor logsumexp, Tensor output, Tensor? attn_bias, float p, int rng_seed, " + "int rng_offset) -> (Tensor, Tensor, Tensor)")); m.def(TORCH_SELECTIVE_SCHEMA( - "xformers::efficient_attention_backward_cutlass(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int max_seqlen_q, int max_seqlen_k, Tensor logsumexp, Tensor output, float dropout_p, int rng_seed, int rng_offset, int custom_mask_type, float? scale, int num_splits_key, int? window_size) -> (Tensor, Tensor, Tensor, Tensor)")); + "xformers::efficient_attention_backward_cutlass(Tensor grad_out, Tensor query, Tensor key, " + "Tensor value, Tensor? bias, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int max_seqlen_q, " + "int max_seqlen_k, Tensor logsumexp, Tensor output, float dropout_p, int rng_seed, int " + "rng_offset, int custom_mask_type, float? scale, int num_splits_key, int? window_size) -> " + "(Tensor, Tensor, Tensor, Tensor)")); m.def(TORCH_SELECTIVE_SCHEMA( "xformers::_temp_dropout(Tensor out, float p) -> Tensor")); m.def(TORCH_SELECTIVE_SCHEMA( "xformers::_cutlass_rand_uniform(float p, Tensor out) -> Tensor")); +#endif +#if defined(USE_ROCM) + m.def(TORCH_SELECTIVE_SCHEMA( + "xformers::efficient_attention_forward_ck(Tensor query, " + "Tensor key, Tensor value, Tensor? attn_bias, Tensor? seqstart_q, " + "Tensor? seqstart_k, int? max_seqlen_q, float dropout_p, " + "bool compute_logsumexp, int custom_mask_type, float? scale, Tensor? seqlen_k, int? window_size) -> (Tensor, Tensor, int, int)")); + m.def(TORCH_SELECTIVE_SCHEMA( + "xformers::efficient_attention_forward_decoder_ck(Tensor query, " + "Tensor key, Tensor value, Tensor? seq_positions, float scale) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA( + "xformers::efficient_attention_forward_decoder_splitk_ck(Tensor query, Tensor key, " + " Tensor value, Tensor? seq_positions, float scale, int split_k) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA( + "xformers::efficient_attention_backward_ck(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor? attn_bias, Tensor? seqstart_q, Tensor? seqstart_k, int? max_seqlen_q, Tensor? seqlen_k, Tensor logsumexp, Tensor output, float dropout_p, int rng_seed, int rng_offset, int custom_mask_type, float? scale) -> (Tensor, Tensor, Tensor, Tensor)")); + m.def(TORCH_SELECTIVE_SCHEMA( + "xformers::_ck_rand_uniform(float p, Tensor out) -> Tensor")); +#endif } diff --git a/xformers/csrc/attention/hip_fmha/CMakeLists.txt b/xformers/csrc/attention/hip_fmha/CMakeLists.txt new file mode 100644 index 0000000000..2bf65f305b --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/CMakeLists.txt @@ -0,0 +1,120 @@ +cmake_minimum_required(VERSION 3.26) + +project(FMHADecoderMain LANGUAGES CXX HIP) + +message("CMAKE_CXX_COMPILER: ${CMAKE_CXX_COMPILER} (need hipcc)") + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_CXX_EXTENSIONS OFF) +set(CMAKE_CXX_FLAGS "-Wall") +set(CMAKE_CXX_FLAGS_DEBUG "-g -O0") +set(CMAKE_VERBOSE_MAKEFILE on) + +set(py_version 3.9) + +set(exe_name attention_forward_decoder_main) +set(splitk_exe_name attention_forward_splitk_decoder_main) +set(project_root_dir /xformers) +set(xformers_csrc ${project_root_dir}/xformers/csrc) +set(sources ${xformers_csrc}/attention/hip_fmha/attention_forward_decoder.hip) +set(splitk_sources ${xformers_csrc}/attention/hip_fmha/attention_forward_splitk.hip) +set(ck_include ${project_root_dir}/third_party/composable_kernel/include/) +set(torch_include /opt/conda/envs/py_${py_version}/lib/python${py_version}/site-packages/torch/include) + +set_source_files_properties(${sources} ${splitk_sources} PROPERTIES LANGUAGE HIP) +add_executable(${exe_name} ${sources}) +add_executable(${splitk_exe_name} ${splitk_sources}) + +find_package(HIP REQUIRED) +find_package(ROCM REQUIRED PATHS /opt/rocm) +include(ROCMInstallTargets) + +message("HIP_VERSION: ${HIP_VERSION_MAJOR}.${HIP_VERSION_MINOR}.${HIP_VERSION_PATCH}") + +set_target_properties(${exe_name} ${splitk_exe_name} PROPERTIES LINKER_LANGUAGE CXX) +set_target_properties(${exe_name} ${splitk_exe_name} PROPERTIES POSITION_INDEPENDENT_CODE ON) +set_target_properties(${exe_name} ${splitk_exe_name} PROPERTIES HIP_ARCHITECTURES ${GPU_TARGETS}) + +target_compile_options(${exe_name} PUBLIC + -fno-gpu-rdc + $<$: + --save-temps + > +) + +target_compile_options(${splitk_exe_name} PUBLIC + -fno-gpu-rdc + $<$: + --save-temps + -g + -O0 + > +) + +target_include_directories(${exe_name} PUBLIC + ${ck_include} # ck includes + ${torch_include} # aten includes + ${torch_include}/torch/csrc/api/include # torch includes +) + +target_include_directories(${splitk_exe_name} PUBLIC + ${ck_include} # ck includes + ${torch_include} # aten includes + ${torch_include}/torch/csrc/api/include # torch includes +) + +target_link_directories(${exe_name} PUBLIC + /opt/conda/envs/py_${py_version}/lib/python${py_version}/site-packages/torch/lib # c10, torch + /opt/rocm/hip/lib +) + +target_link_directories(${splitk_exe_name} PUBLIC + /opt/conda/envs/py_${py_version}/lib/python${py_version}/site-packages/torch/lib # c10, torch + /opt/rocm/hip/lib +) + +target_link_libraries(${exe_name} PUBLIC + c10 + c10_hip + torch + torch_hip + torch_cpu + amdhip64 +) + + +target_link_libraries(${splitk_exe_name} PUBLIC + c10 + c10_hip + torch + torch_hip + torch_cpu + amdhip64 +) + +target_compile_definitions(${exe_name} PUBLIC + ATTN_FWD_DECODER_MAIN=1 + GLIBCXX_USE_CXX11_ABI=1 + __HIP_PLATFORM_HCC__=1 + USE_ROCM=1 +) + +target_compile_definitions(${splitk_exe_name} PUBLIC + ATTN_FWD_SPLITK_DECODER_MAIN=1 + GLIBCXX_USE_CXX11_ABI=1 + __HIP_PLATFORM_HCC__=1 + USE_ROCM=1 +) + +include(CMakePrintHelpers) +cmake_print_properties(TARGETS ${exe_name} ${splitk_exe_name} PROPERTIES + LINK_LIBRARIES + LINK_DIRECTORIES + INCLUDE_DIRECTORIES + COMPILE_DEFINITIONS + COMPILE_OPTIONS + SOURCES + HIP_ARCHITECTURES) + +rocm_install(TARGETS ${exe_name} ${splitk_exe_name}) \ No newline at end of file diff --git a/xformers/csrc/attention/hip_fmha/attention_forward_decoder.cpp b/xformers/csrc/attention/hip_fmha/attention_forward_decoder.cpp new file mode 100644 index 0000000000..786dfec0b5 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/attention_forward_decoder.cpp @@ -0,0 +1,333 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include + +#include "ck_attention_forward_decoder.h" + +namespace { +constexpr int32_t kThreadsPerWavefront = 64; +constexpr int32_t kWavefrontsPerBlock = 16; +constexpr int32_t K_MAX = 4 * kThreadsPerWavefront; +} // namespace + +namespace { + +template +struct c10_to_data_t; +template <> +struct c10_to_data_t { + using type = float; +}; + +template <> +struct c10_to_data_t { + using type = ck::half_t; +}; + +template <> +struct c10_to_data_t { + using type = ck::bhalf_t; +}; +} // namespace + +namespace { + +#define AT_DISPATCH_CASE_3(SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, ...) \ + AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) + +#define AT_DISPATCH_SWITCH_3( \ + SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_DISPATCH_CASE_3(SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, __VA_ARGS__)) + +template < + int32_t ThreadsPerWavefront, + int32_t WavefrontsPerBlock, + int32_t KV_M_MAX = 8192, + int32_t K_MAX = 256> +at::Tensor& efficient_attention_forward_decoder_ck_out_impl( + const at::Tensor& XQ, // [B, 1, G, H, D] + const at::Tensor& cache_K, // [B, KV_M_MAX, G, H or 1, D] + const at::Tensor& cache_V, // [B, KV_M_MAX, G, H or 1, D] + at::optional seq_kv_lens, // [B] + double qk_scale, + at::Tensor& O) { + static_assert(4 * ThreadsPerWavefront == K_MAX, ""); + static_assert(WavefrontsPerBlock <= ThreadsPerWavefront, ""); + + at::OptionalDeviceGuard guard(XQ.device()); + TORCH_CHECK(XQ.is_cuda()); + TORCH_CHECK(cache_K.is_cuda()); + TORCH_CHECK(cache_V.is_cuda()); + + TORCH_CHECK(!seq_kv_lens || seq_kv_lens->is_cuda()); + + TORCH_CHECK(cache_K.size(1) <= KV_M_MAX); + TORCH_CHECK(cache_K.size(4) <= K_MAX); + + constexpr auto rank = 5; + + auto B = XQ.size(0); + auto M = XQ.size(1); + auto G = XQ.size(2); + auto H = XQ.size(3); + + TORCH_CHECK(B <= 1024); + TORCH_CHECK(M <= 1024); + TORCH_CHECK(H <= 1024); + + dim3 blocks(B * H * M * G); + dim3 threads(ThreadsPerWavefront, WavefrontsPerBlock); + + int32_t smem_softmax = KV_M_MAX * sizeof(float) + threads.y * sizeof(float); + int32_t smem_output = K_MAX * sizeof(float) * + threads.y; // 4 * threadsPerBlock * sizeof(float) == sizeof(O[b][0][h][:]) + const size_t lds_bytes = max(smem_softmax, smem_output); + auto stream = at::cuda::getCurrentHIPStream().stream(); + + AT_DISPATCH_SWITCH_3( + at::ScalarType::Half, + at::ScalarType::BFloat16, + at::ScalarType::Float, + XQ.scalar_type(), + "efficient_attention_forward_decoder_ck", + [&] { + using ck_data_t = c10_to_data_t::type; + using device_op_t = + ck::tensor_operation::device::FMHADecoderSeqlen1DeviceOp; + auto op = device_op_t{}; + + auto XQ_acc = + XQ.packed_accessor32(); + auto K_acc = + cache_K.packed_accessor64(); + auto V_acc = + cache_V.packed_accessor64(); + auto O_acc = + O.packed_accessor32(); + auto seq_acc = seq_kv_lens + ? seq_kv_lens + ->packed_accessor32() + .data() + : nullptr; + auto arg = device_op_t::Argument( + reinterpret_cast(XQ_acc.data()), + reinterpret_cast(K_acc.data()), + reinterpret_cast(V_acc.data()), + reinterpret_cast(O_acc.data()), + seq_acc, + XQ_acc.stride(0), + XQ_acc.stride(1), + XQ_acc.stride(2), + XQ_acc.stride(3), + K_acc.stride(0), + K_acc.stride(1), + K_acc.stride(2), + K_acc.stride(3), + XQ_acc.size(1), + XQ_acc.size(2), + XQ_acc.size(3), + XQ_acc.size(4), + K_acc.size(1), + K_acc.size(3) == 1, + qk_scale, + blocks, + threads, + lds_bytes); + + auto invoker = device_op_t::Invoker{}; + (void)invoker.Run(arg, {stream}); + }); + + return O; +} + +#undef AT_DISPATCH_CASE_3 +#undef AT_DISPATCH_SWITCH_3 + +template +at::Tensor efficient_attention_forward_decoder_ck_impl( + const at::Tensor& XQ, // [B, 1, G, H, D] + const at::Tensor& cache_K, // [B, KV_M_MAX, G, H or 1, D] + const at::Tensor& cache_V, // [B, KV_M_MAX, H or 1, D] + at::optional seq_kv_lens, // [B] + double qk_scale) { + auto O = at::empty_like(XQ); + efficient_attention_forward_decoder_ck_out_impl< + ThreadsPerWavefront, + WavefrontsPerBlock>(XQ, cache_K, cache_V, seq_kv_lens, qk_scale, O); + return O; +} + +at::Tensor efficient_attention_forward_decoder_ck( + const at::Tensor& XQ, // [B, 1, G, H, D] + const at::Tensor& cache_K, // [B, KV_M_MAX, G, H or 1, D] + const at::Tensor& cache_V, // [B, KV_M_MAX, G, H or 1, D] + at::optional seq_kv_lens, // [B] + double qk_scale) { + return efficient_attention_forward_decoder_ck_impl< + kThreadsPerWavefront, + kWavefrontsPerBlock>(XQ, cache_K, cache_V, seq_kv_lens, qk_scale); +} +} // namespace + +TORCH_LIBRARY_IMPL(xformers, CUDA, m) { + m.impl( + TORCH_SELECTIVE_NAME("xformers::efficient_attention_forward_decoder_ck"), + TORCH_FN(efficient_attention_forward_decoder_ck)); +} + +#ifdef ATTN_FWD_DECODER_MAIN + +#include + +// clang-format off + +/* + +(1) hipify + > pip install -e /xformers + + For obtaining all the library paths needed for compilation below, add `--verbose`. + For efficient utilization of CPU cores for compilation use MAX_JOBS env variable. + +(2) compile + > mkdir build + > cd build + > cmake /xformers/xformers/csrc/attention/hip_fmha/ \ + -DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ + -D CMAKE_PREFIX_PATH=/opt/rocm \ + -D CMAKE_BUILD_TYPE=Debug \ + -D GPU_TARGETS="native" + > make + +(3a) run correctness check + > ./attention_forward_decoder_main + +(3b) run specific input shape + > ./attention_forward_decoder_main n_keys padding batch_size n_heads is_multiquery dtype n_wavefronts_per_block +*/ + +// clang-format on + +static void do_correctness_check() { + const int32_t D = 4 * kThreadsPerWavefront; + const int32_t B = 1; + const int32_t H = 4; + const int32_t G = 1; + auto options = torch::TensorOptions() + .dtype(torch::kFloat32) + .layout(torch::kStrided) + .device(torch::kCUDA, 1) + .requires_grad(false); + auto int_options = options.dtype(torch::kInt); + auto XQ = at::randn({B, 1, G, H, D}, options); + auto K = at::randn({B, 4096, G, H, D}, options); + auto V = at::randn({B, 4096, G, H, D}, options); + auto seq = at::randint(63, 128, {B}, int_options); + double qk_scale = 1. / sqrt(D); + + auto result = efficient_attention_forward_decoder_ck_impl<64, 1>( + XQ, K, V, seq, qk_scale); + auto gold_result = efficient_attention_forward_decoder_ck_impl<64, 2>( + XQ, K, V, seq, qk_scale); + auto mask = at::isclose( + result, gold_result, /*atol*/ 1e-3, /*rtol*/ 1e-5, /*equal_nan*/ false); + auto percent_match = at::sum(mask.to(torch::kFloat32)) / mask.numel(); + printf( + "Mismatched elements percentage: %.2f\n", + 1. - percent_match.item()); +} + +int main(int argc, char** argv) { + if (argc == 1) { + do_correctness_check(); + } else { + const auto args = std::vector(argv + 1, argv + argc); + if (args.size() != 7) { + std::cout + << "Usage: ./a.out n_keys padding batch_size n_heads is_multiquery dtype " + "n_wavefronts_per_block" + << std::endl; + return 0; + } + const int32_t n_keys = std::stoi(args[0]); + const int32_t padding = std::stoi(args[1]); + const int32_t batch_size = std::stoi(args[2]); + const int32_t n_heads = std::stoi(args[3]); + const int32_t n_groups = 1; + const int32_t multiquery = (args[4] == "mq"); + const auto dtype = (args[5] == "f32") ? torch::kFloat32 + : (args[5] == "f16") ? torch::kFloat16 + : torch::kBFloat16; + const int32_t n_wavefronts_per_block = std::stoi(args[6]); + + const int32_t dim_per_head = 4 * kThreadsPerWavefront; + + const auto options = torch::TensorOptions() + .dtype(dtype) + .layout(torch::kStrided) + .device(torch::kCUDA, 1) + .requires_grad(false); + + const auto int_options = options.dtype(torch::kInt); + const auto Q = + at::rand({batch_size, 1, n_groups, n_heads, dim_per_head}, options); + const auto K = multiquery + ? at::rand({batch_size, padding, n_groups, 1, dim_per_head}, options) + .expand({batch_size, padding, n_groups, n_heads, dim_per_head}) + : at::rand( + {batch_size, padding, n_groups, n_heads, dim_per_head}, options); + const auto V = at::rand_like(K); + auto O = at::empty_like(Q); + + const auto seq = at::randint(1, n_keys, {batch_size}, int_options); + const double qk_scale = 1. / sqrt(dim_per_head); + auto call_ptr = decltype(&efficient_attention_forward_decoder_ck_out_impl< + kThreadsPerWavefront, + kWavefrontsPerBlock>){}; + +#define SWITCH_CASE_SET_CALLPTR(n) \ + case (n): \ + call_ptr = &efficient_attention_forward_decoder_ck_out_impl< \ + kThreadsPerWavefront, \ + (n)>; \ + break; + + switch (n_wavefronts_per_block) { + SWITCH_CASE_SET_CALLPTR(1); + SWITCH_CASE_SET_CALLPTR(2); + SWITCH_CASE_SET_CALLPTR(4); + SWITCH_CASE_SET_CALLPTR(8); + SWITCH_CASE_SET_CALLPTR(16); + + default: + call_ptr = nullptr; + break; + } +#undef SWITCH_CASE_SET_CALLPTR + + if (call_ptr) { + call_ptr(Q, K, V, seq, qk_scale, O); + } else { + std::cout << "Warning: no kernel was found for wavefronts_per_block=" + << n_wavefronts_per_block << std::endl; + } + } + return 0; +} + +#endif // MAIN \ No newline at end of file diff --git a/xformers/csrc/attention/hip_fmha/attention_forward_generic_ck_tiled.cpp b/xformers/csrc/attention/hip_fmha/attention_forward_generic_ck_tiled.cpp new file mode 100644 index 0000000000..a56b87f737 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/attention_forward_generic_ck_tiled.cpp @@ -0,0 +1,421 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ck_fmha_util.h" +#include "ck_tiled_fmha_params.h" + +extern void batched_forward_fp16( + BatchedForwardParams& param, + hipStream_t stream); +extern void batched_forward_bp16( + BatchedForwardParams& param, + hipStream_t stream); +extern void grouped_forward_fp16( + GroupedForwardParams& param, + hipStream_t stream); +extern void grouped_forward_bp16( + GroupedForwardParams& param, + hipStream_t stream); + +extern void batched_infer_fp16(BatchedForwardParams& param, hipStream_t stream); +extern void batched_infer_bp16(BatchedForwardParams& param, hipStream_t stream); +extern void grouped_infer_fp16(GroupedForwardParams& param, hipStream_t stream); +extern void grouped_infer_bp16(GroupedForwardParams& param, hipStream_t stream); + +namespace { + +/* + There are 2 modes for using this function. + (Mode BMHK) With all the heads having the same seqlen + (Mode 1MHK) `batch=1` with all tokens across batches concatenated +*/ +std::tuple +efficient_attention_forward_ck( + const at::Tensor& query, // [b, seqlen, num_heads_q, K] + const at::Tensor& key, // [b, seqlen, num_heads_kv, K] + const at::Tensor& value, // [b, seqlen, num_heads_kv, Kv] + const c10::optional& bias, // [b, num_heads_q, seqlen, seqlen] + // (Mode 1MHK only) [b+1]: cu_seqlens_q[b] contains the + // position of the first query token for batch $b + const c10::optional& seqstart_q, + // (Mode 1MHK only) [b+1]: cu_seqlen_k[b] contains the + // position of the first key token for batch $b + const c10::optional& seqstart_k, + // (Mode 1MHK only) Maximum sequence length across batches + const c10::optional max_seqlen_q_, + double dropout_p, // attention matrix dropout probability + bool compute_logsumexp, + int64_t custom_mask_type, + c10::optional scale, + const c10::optional& seqlen_k, + const c10::optional window_size) { + TORCH_CHECK(query.dim() == 4); + TORCH_CHECK(key.dim() == 4); + TORCH_CHECK(value.dim() == 4); + + // Batch sizes + TORCH_CHECK(query.size(0) == key.size(0)); + TORCH_CHECK(query.size(0) == value.size(0)); + + // Sequence length + TORCH_CHECK(key.size(1) == value.size(1)); + + // Num heads + TORCH_CHECK(query.size(2) % key.size(2) == 0); + TORCH_CHECK(key.size(2) == value.size(2)); + + // Embedding per head + TORCH_CHECK(query.size(3) == key.size(3)); + + TORCH_CHECK(query.scalar_type() == key.scalar_type()); + TORCH_CHECK(query.scalar_type() == value.scalar_type()); + + TORCH_CHECK(seqstart_q.has_value() == seqstart_k.has_value()); + if (seqstart_q.has_value()) { + TORCH_CHECK(seqstart_q->scalar_type() == at::ScalarType::Int); + TORCH_CHECK(seqstart_k->scalar_type() == at::ScalarType::Int); + TORCH_CHECK(seqstart_q->dim() == 1 && seqstart_k->dim() == 1); + TORCH_CHECK(seqstart_q->size(0) == seqstart_k->size(0)); + TORCH_CHECK(query.size(0) == 1, "cu_seqlen only supports batch_size=1"); + TORCH_CHECK(max_seqlen_q_.has_value()); + }; + + // last dim is contiguous, device is kCUDA + CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(query); + CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(key); + CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(value); + + // at::cuda::CUDAGuard device_guard(query.device()); + hipStream_t stream = at::cuda::getCurrentHIPStream().stream(); + + int64_t B = query.size(0); + int64_t M = query.size(1); + int64_t N = key.size(1); + int64_t Hq = query.size(-2); + int64_t Hkv = key.size(-2); + int64_t K = query.size(-1); + int64_t Kv = value.size(-1); + + auto opts = query.options(); + + at::Tensor logsumexp; + + at::Tensor out = at::empty({B, M, Hq, Kv}, opts); + + const bool use_dropout = std::fpclassify(dropout_p) != FP_ZERO; + int64_t philox_seed; + int64_t philox_offset; + + if (use_dropout) { + /* + at::PhiloxCudaState rng_engine_inputs; + at::CUDAGeneratorImpl* gen = + at::get_generator_or_default( + c10::nullopt, at::cuda::detail::getDefaultCUDAGenerator()); + + std::lock_guard lock(gen->mutex_); + // if using dropout, we produce 1 random number for each element of the + // attention tensor + rng_engine_inputs = gen->philox_cuda_state(B * Hq * M * N); + + const auto seeds = at::cuda::philox::unpack(rng_engine_inputs); + + philox_seed = std::get<0>(seeds); + philox_offset = std::get<1>(seeds); + */ + throw std::runtime_error( + "drop-out is currently not implemented by ck-tiled!"); + } + + auto set_batched_forward_params = [&](BatchedForwardParams& p) { + p.B = B; + p.M = M; + p.N = N; + p.Hq = Hq; + p.Hkv = Hkv; + p.K = K; + p.Kv = Kv; + + if (scale.has_value()) { + p.scale = float(*scale); + } else { + p.scale = float(1.0 / std::sqrt(float(K))); + } + + p.q_ptr = query.data_ptr(); + p.k_ptr = key.data_ptr(); + p.v_ptr = value.data_ptr(); + p.out_ptr = out.data_ptr(); + + p.q_strides = { + static_cast(query.stride(0)), + static_cast(query.stride(1)), + static_cast(query.stride(2)), + static_cast(query.stride(3))}; + p.k_strides = { + static_cast(key.stride(0)), + static_cast(key.stride(1)), + static_cast(key.stride(2)), + static_cast(key.stride(3))}; + p.v_strides = { + static_cast(value.stride(0)), + static_cast(value.stride(1)), + static_cast(value.stride(2)), + static_cast(value.stride(3))}; + p.out_strides = { + static_cast(out.stride(0)), + static_cast(out.stride(1)), + static_cast(out.stride(2)), + static_cast(out.stride(3))}; + + if (bias.has_value()) { + CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA((*bias)); + TORCH_CHECK(bias->scalar_type() == query.scalar_type()); + + p.has_attn_bias = true; + p.attn_bias_ptr = bias->data_ptr(); + + const at::Tensor bias_4d_view = get_bias_4d_view(*bias, B, Hq, M, N); + p.attn_bias_strides = { + static_cast(bias_4d_view.stride(0)), + static_cast(bias_4d_view.stride(1)), + static_cast(bias_4d_view.stride(2)), + static_cast(bias_4d_view.stride(3))}; + } else + p.has_attn_bias = false; + + p.custom_mask_type = custom_mask_type; + p.window_size = + window_size.has_value() ? (*window_size > 0 ? *window_size : 0) : 0; + + p.use_dropout = use_dropout; + p.philox_seed = philox_seed; + p.philox_offset = philox_offset; + p.compute_logsumexp = compute_logsumexp; + + // the following parameters are only used by training forward + if (p.use_dropout) { + // p.dropout_prob = static_cast(dropout_p); + throw std::runtime_error( + "drop-out is currently not implemented by ck-tiled!"); + } else + p.dropout_prob = 0.0f; + + if (p.compute_logsumexp) { + logsumexp = at::empty({B, Hq, M}, opts.dtype(at::kFloat)); + p.logsumexp_ptr = logsumexp.data_ptr(); + } else + p.logsumexp_ptr = nullptr; + }; + + auto set_grouped_forward_params = [&](GroupedForwardParams& p) { + p.num_batches = seqstart_q->size(0) - 1; + p.M = M; + p.N = N; + p.Hq = Hq; + p.Hkv = Hkv; + p.K = K; + p.Kv = Kv; + + if (scale.has_value()) { + p.scale = float(*scale); + } else { + p.scale = float(1.0 / std::sqrt(float(K))); + } + + p.q_ptr = query.data_ptr(); + p.k_ptr = key.data_ptr(); + p.v_ptr = value.data_ptr(); + p.out_ptr = out.data_ptr(); + + p.q_strides = { + static_cast(query.stride(1)), + static_cast(query.stride(2)), + static_cast(query.stride(3))}; + p.k_strides = { + static_cast(key.stride(1)), + static_cast(key.stride(2)), + static_cast(key.stride(3))}; + p.v_strides = { + static_cast(value.stride(1)), + static_cast(value.stride(2)), + static_cast(value.stride(3))}; + p.out_strides = { + static_cast(out.stride(1)), + static_cast(out.stride(2)), + static_cast(out.stride(3))}; + + if (bias.has_value()) { + CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA((*bias)); + TORCH_CHECK(bias->scalar_type() == query.scalar_type()); + + p.has_attn_bias = true; + p.attn_bias_ptr = bias->data_ptr(); + + const at::Tensor bias_4d_view = get_bias_4d_view(*bias, B, Hq, M, N); + p.attn_bias_strides = { + static_cast(bias_4d_view.stride(0)), + static_cast(bias_4d_view.stride(1)), + static_cast(bias_4d_view.stride(2)), + static_cast(bias_4d_view.stride(3))}; + } else + p.has_attn_bias = false; + + p.custom_mask_type = custom_mask_type; + p.window_size = + window_size.has_value() ? (*window_size > 0 ? *window_size : 0) : 0; + + // max_seqlen_q is used to create logsumexp tensor + p.max_seqlen_q = *max_seqlen_q_; + + // interesting: the tensors have to be defined here, moving to more local + // scope will cause issue + at::Tensor dev_seqstart_q; + at::Tensor dev_seqstart_k; + at::Tensor dev_seqlen_k; + + if (seqstart_q->is_cpu()) { + dev_seqstart_q = at::empty({p.num_batches + 1}, opts.dtype(at::kInt)); + p.seqstart_q_dev_ptr = dev_seqstart_q.data_ptr(); + HIP_CALL_CHECK(hipMemcpyAsync( + p.seqstart_q_dev_ptr, + seqstart_q->data_ptr(), + (p.num_batches + 1) * sizeof(int), + hipMemcpyHostToDevice, + stream)); + } else + p.seqstart_q_dev_ptr = seqstart_q->data_ptr(); + + if (seqstart_k->is_cpu()) { + dev_seqstart_k = at::empty({p.num_batches + 1}, opts.dtype(at::kInt)); + + p.seqstart_k_dev_ptr = dev_seqstart_k.data_ptr(); + HIP_CALL_CHECK(hipMemcpyAsync( + p.seqstart_k_dev_ptr, + seqstart_k->data_ptr(), + (p.num_batches + 1) * sizeof(int), + hipMemcpyHostToDevice, + stream)); + } else + p.seqstart_k_dev_ptr = seqstart_k->data_ptr(); + + if (seqlen_k.has_value()) { + TORCH_CHECK(seqlen_k->scalar_type() == at::ScalarType::Int); + TORCH_CHECK(seqlen_k->dim() == 1); + TORCH_CHECK(seqlen_k->size(0) == p.num_batches) + + if (seqlen_k->is_cpu()) { + dev_seqlen_k = at::empty({p.num_batches}, opts.dtype(at::kInt)); + + p.seqlen_k_dev_ptr = dev_seqlen_k.data_ptr(); + HIP_CALL_CHECK(hipMemcpyAsync( + p.seqlen_k_dev_ptr, + seqlen_k->data_ptr(), + p.num_batches * sizeof(int), + hipMemcpyHostToDevice, + stream)); + } else + p.seqlen_k_dev_ptr = seqlen_k->data_ptr(); + } else + p.seqlen_k_dev_ptr = nullptr; + + p.use_dropout = use_dropout; + p.philox_seed = philox_seed; + p.philox_offset = philox_offset; + p.compute_logsumexp = compute_logsumexp; + + // the following parameters are only used by training forward + if (p.use_dropout) { + // p.dropout_prob = static_cast(dropout_p); + throw std::runtime_error( + "drop-out is currently not implemented by ck-tiled!"); + } else + p.dropout_prob = 0.0f; + + if (p.compute_logsumexp) { + logsumexp = at::empty( + {p.num_batches, Hq, p.max_seqlen_q}, opts.dtype(at::kFloat)); + p.logsumexp_ptr = logsumexp.data_ptr(); + } else + p.logsumexp_ptr = nullptr; + }; + + auto inDataType = query.scalar_type(); + + if (!seqstart_q.has_value()) { // input is batched + BatchedForwardParams batched_forward_params; + + set_batched_forward_params(batched_forward_params); + + if (!batched_forward_params.use_dropout && + !batched_forward_params.compute_logsumexp) { + if (inDataType == at::ScalarType::Half) { + batched_infer_fp16(batched_forward_params, stream); + } else if (inDataType == at::ScalarType::BFloat16) { + batched_infer_bp16(batched_forward_params, stream); + } else + throw std::runtime_error("input data-type is not supported!"); + } else { + if (inDataType == at::ScalarType::Half) { + batched_forward_fp16(batched_forward_params, stream); + } else if (inDataType == at::ScalarType::BFloat16) { + batched_forward_bp16(batched_forward_params, stream); + } else + throw std::runtime_error("input data-type is not supported!"); + + throw std::runtime_error( + "drop-out and compuate logsumexp currently not implemented by ck-tiled!"); + }; + } else { // input is grouped + GroupedForwardParams grouped_forward_params; + + set_grouped_forward_params(grouped_forward_params); + + if (!grouped_forward_params.use_dropout && + !grouped_forward_params.compute_logsumexp) { + if (inDataType == at::ScalarType::Half) { + grouped_infer_fp16(grouped_forward_params, stream); + } else if (inDataType == at::ScalarType::BFloat16) { + grouped_infer_bp16(grouped_forward_params, stream); + } else + throw std::runtime_error("input data-type is not supported!"); + } else { + if (inDataType == at::ScalarType::Half) { + grouped_forward_fp16(grouped_forward_params, stream); + } else if (inDataType == at::ScalarType::BFloat16) { + grouped_forward_bp16(grouped_forward_params, stream); + } else + throw std::runtime_error("input data-type is not supported!"); + + throw std::runtime_error( + "drop-out and compuate logsumexp currently not implemented by ck-tiled!"); + }; + }; + + return std::make_tuple(out, logsumexp, philox_seed, philox_offset); +} + +} // namespace + +TORCH_LIBRARY_IMPL(xformers, CUDA, m) { + m.impl( + TORCH_SELECTIVE_NAME("xformers::efficient_attention_forward_ck"), + TORCH_FN(efficient_attention_forward_ck)); +} diff --git a/xformers/csrc/attention/hip_fmha/attention_forward_splitk.cpp b/xformers/csrc/attention/hip_fmha/attention_forward_splitk.cpp new file mode 100644 index 0000000000..06fbbe0f69 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/attention_forward_splitk.cpp @@ -0,0 +1,1165 @@ +#include +#include +#include +#include +#include + +#include "ck_attention_forward_decoder_splitk.h" + +namespace { +constexpr int32_t kThreadsPerWavefront = 64; +constexpr int32_t kWavefrontsPerBlock = 16; +constexpr int32_t K_MAX = 4 * kThreadsPerWavefront; +} // namespace + +namespace { + +template +struct c10_to_data_t; +template <> +struct c10_to_data_t { + using type = float; +}; + +template <> +struct c10_to_data_t { + using type = ck::half_t; +}; + +template <> +struct c10_to_data_t { + using type = ck::bhalf_t; +}; +} // namespace + +#define AT_DISPATCH_CASE_3(SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, ...) \ + AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) + +#define AT_DISPATCH_SWITCH_3( \ + SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_DISPATCH_CASE_3(SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, __VA_ARGS__)) + +namespace { + +template < + int32_t ThreadsPerWavefront, + int32_t WavefrontsPerBlock, + int32_t KV_M_MAX = 8192, + int32_t K_MAX = 256> +at::Tensor& efficient_attention_forward_decoder_splitk_ck_out_impl( + const at::Tensor& XQ, // [B, 1, G, H, D] + const at::Tensor& cache_K, // [B, KV_M_MAX, G, H or 1, D] + const at::Tensor& cache_V, // [B, KV_M_MAX, G, H or 1, D] + at::optional seq_kv_lens, // [B] + double qk_scale, + int64_t split_k, + at::Tensor& split_max, + at::Tensor& split_sumexp, + at::Tensor& split_O, + at::Tensor& O) { + static_assert(4 * ThreadsPerWavefront == K_MAX, ""); + static_assert(WavefrontsPerBlock <= ThreadsPerWavefront, ""); + + at::OptionalDeviceGuard guard(XQ.device()); + TORCH_CHECK(XQ.is_cuda()); + TORCH_CHECK(cache_K.is_cuda()); + TORCH_CHECK(cache_V.is_cuda()); + + TORCH_CHECK(!seq_kv_lens || seq_kv_lens->is_cuda()); + + TORCH_CHECK(cache_K.size(1) / split_k <= KV_M_MAX); + TORCH_CHECK(cache_K.size(4) <= K_MAX); + + constexpr auto rank = 5; + + auto B = XQ.size(0); + auto M = XQ.size(1); + auto G = XQ.size(2); + auto H = XQ.size(3); + + TORCH_CHECK(B <= 1024); + TORCH_CHECK(M <= 1024); + TORCH_CHECK(H <= 1024); + + dim3 blocks(B * H * M * G, split_k); + dim3 threads(ThreadsPerWavefront, WavefrontsPerBlock); + + int32_t smem_softmax = KV_M_MAX * sizeof(float) + threads.y * sizeof(float); + int32_t smem_output = K_MAX * sizeof(float) * + threads.y; // 4 * threadsPerBlock * sizeof(float) == sizeof(O[b][0][h][:]) + const size_t lds_bytes = max(smem_softmax, smem_output); + auto stream = at::cuda::getCurrentHIPStream().stream(); + + AT_DISPATCH_SWITCH_3( + at::ScalarType::Half, + at::ScalarType::BFloat16, + at::ScalarType::Float, + XQ.scalar_type(), + "efficient_attention_forward_decoder_splitk_ck", + [&] { + using ck_data_t = c10_to_data_t::type; + using device_op_t = + ck::tensor_operation::device::FMHADecoderSplitKDeviceOp; + auto op = device_op_t{}; + + auto XQ_acc = + XQ.packed_accessor32(); + auto K_acc = + cache_K.packed_accessor64(); + auto V_acc = + cache_V.packed_accessor64(); + auto split_O_acc = + split_O + .packed_accessor32(); + auto O_acc = + O.packed_accessor32(); + auto seq_acc_ptr = seq_kv_lens + ? seq_kv_lens + ->packed_accessor32() + .data() + : nullptr; + auto split_max_acc = + split_max.packed_accessor32(); + auto split_sumexp_acc = + split_sumexp + .packed_accessor32(); + auto arg = device_op_t::Argument( + reinterpret_cast(XQ_acc.data()), + reinterpret_cast(K_acc.data()), + reinterpret_cast(V_acc.data()), + reinterpret_cast(O_acc.data()), + reinterpret_cast(split_O_acc.data()), + split_max_acc.data(), + split_sumexp_acc.data(), + seq_acc_ptr, + XQ_acc.stride(0), + XQ_acc.stride(1), + XQ_acc.stride(2), + XQ_acc.stride(3), + K_acc.stride(0), + K_acc.stride(1), + K_acc.stride(2), + K_acc.stride(3), + split_O_acc.stride(0), + XQ_acc.size(1), + XQ_acc.size(2), + XQ_acc.size(3), + XQ_acc.size(4), + K_acc.size(1), + K_acc.size(3) == 1, + qk_scale, + split_k, + blocks, + threads, + lds_bytes); + + auto invoker = device_op_t::Invoker{}; + (void)invoker.Run(arg, {stream}); + }); + + return O; +} + +template +at::Tensor efficient_attention_forward_decoder_splitk_ck_impl( + const at::Tensor& XQ, // [B, 1, G, H, D] + const at::Tensor& cache_K, // [B, KV_M_MAX, G, H or 1, D] + const at::Tensor& cache_V, // [B, KV_M_MAX, H or 1, D] + at::optional seq_kv_lens, // [B] + double qk_scale, + int64_t split_k) { + auto O = at::empty_like(XQ); + constexpr auto rank = 5; + + TORCH_CHECK(XQ.dim() == rank); + TORCH_CHECK(cache_K.dim() == rank); + TORCH_CHECK(cache_V.dim() == rank); + + auto B = XQ.size(0); + auto M = XQ.size(1); + auto G = XQ.size(2); + auto H = XQ.size(3); + auto K = XQ.size(4); + + auto O_splits = at::empty({split_k, B, M, G, H, K}, XQ.options()); + auto split_max = + at::empty({B, M, G, H, split_k}, XQ.options().dtype(at::kFloat)); + auto split_sumexp = at::empty_like(split_max); + + efficient_attention_forward_decoder_splitk_ck_out_impl< + ThreadsPerWavefront, + WavefrontsPerBlock>( + XQ, + cache_K, + cache_V, + seq_kv_lens, + qk_scale, + split_k, + split_max, + split_sumexp, + O_splits, + O); + + return O; +} + +at::Tensor efficient_attention_forward_decoder_splitk_ck( + const at::Tensor& XQ, // [B, 1, G, H, D] + const at::Tensor& cache_K, // [B, KV_M_MAX, G, H or 1, D] + const at::Tensor& cache_V, // [B, KV_M_MAX, G, H or 1, D] + at::optional seq_kv_lens, // [B] + double qk_scale, + int64_t split_k) { + return efficient_attention_forward_decoder_splitk_ck_impl< + kThreadsPerWavefront, + kWavefrontsPerBlock>( + XQ, cache_K, cache_V, seq_kv_lens, qk_scale, split_k); +} +} // namespace + +TORCH_LIBRARY_IMPL(xformers, CUDA, m) { + m.impl( + TORCH_SELECTIVE_NAME( + "xformers::efficient_attention_forward_decoder_splitk_ck"), + TORCH_FN(efficient_attention_forward_decoder_splitk_ck)); +} + +#ifdef ATTN_FWD_SPLITK_DECODER_MAIN + +#include + +// clang-format off + +/* + +(1) hipify + > pip install -e /xformers + + For obtaining the executed build commands, add `--verbose`. + For efficient utilization of CPU cores for compilation use MAX_JOBS env variable. + +(2) compile + > mkdir build + > cd build + > cmake /xformers/xformers/csrc/attention/hip_fmha/ \ + -DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ + -D CMAKE_PREFIX_PATH=/opt/rocm \ + -D CMAKE_BUILD_TYPE=Debug \ + -D GPU_TARGETS="native" + > make + +(3a) run correctness check + > ./attention_forward_splitk_decoder_main + +(3b) run specific input shape + > ./attention_forward_splitk_decoder_main n_keys padding batch_size n_heads is_multiquery dtype n_wavefronts_per_block +*/ + +// clang-format on + +static std::tuple split_attention_torch( + const at::Tensor& Q, + const at::Tensor& K, + const at::Tensor& V, + const at::Tensor& k_seqlens, + const int32_t split_k, + const int32_t block_size) { + auto Q_scaled = at::div(Q, sqrt(Q.size(-1))); + + std::vector O_splits; + std::vector m_splits; + std::vector l_splits; + + for (int32_t split_idx = 0; split_idx < split_k; ++split_idx) { + std::vector O_batch; + std::vector m_batch; + std::vector l_batch; + + for (size_t b = 0; b < k_seqlens.numel(); ++b) { + auto seqlen = k_seqlens[b].item(); + const int64_t t_low = + split_idx * (seqlen / split_k / block_size) * block_size; + const int64_t t_high = (split_idx + 1 < split_k) + ? (1 + split_idx) * (seqlen / split_k / block_size) * block_size + : seqlen; + + const bool empty = t_low == t_high; + + auto S = at::einsum( + "mghk, nghk -> mghn", + {Q_scaled[b], + at::slice(K[b], /*dim*/ 0, /*start*/ t_low, /*end*/ t_high)}, + /* einsum eval path */ at::nullopt); + auto m = empty + ? at::empty_like(S) + : std::get<0>(at::max(S, /* dim */ -1, /* keepdim */ true)); + auto s = at::exp(at::sub(S, m)); + auto l = at::sum(s, /* dim */ -1, /* keepdim */ true); + auto O = at::einsum( + "mghn, nghk -> mghk", + {s, at::slice(V[b], /*dim*/ 0, /*start*/ t_low, /*end*/ t_high)}, + /* einsum eval path */ at::nullopt); + if (empty) { + m = at::empty_like(at::slice(O, -1, 0, 1)); + l = at::zeros_like(m); + m.fill_(ck::NumericLimits::Lowest()); + } + O_batch.push_back(O); + m_batch.push_back(m); + l_batch.push_back(l); + } + + auto O_cat = at::stack(O_batch); + auto m_cat = at::stack(m_batch); + auto l_cat = at::stack(l_batch); + + O_splits.push_back(O_cat); + m_splits.push_back(m_cat); + l_splits.push_back(l_cat); + } + + auto O_cat = at::stack(O_splits); + auto m_cat = at::transpose(at::stack(m_splits), 0, -1); + auto l_cat = at::transpose(at::stack(l_splits), 0, -1); + + return std::make_tuple(O_cat, m_cat, l_cat); +} + +static at::Tensor split_reduce_torch( + const at::Tensor& O_splits, + const at::Tensor& m_splits, + const at::Tensor& l_splits, + int32_t split_k) { + auto O = at::zeros_like(at::slice(O_splits, 0, 0, 1)); + auto global_max = + at::empty_like(at::slice(m_splits, -1, 0, 1)).fill_(-65535.); + auto global_sumexp = at::zeros_like(global_max); + + for (int32_t split_idx = 0; split_idx < split_k; ++split_idx) { + auto local_O = at::slice(O_splits, 0, split_idx, split_idx + 1); + auto local_max = at::slice(m_splits, -1, split_idx, split_idx + 1); + auto local_sumexp = at::slice(l_splits, -1, split_idx, split_idx + 1); + + auto log_alpha = at::neg(at::abs(at::sub(local_max, global_max))); + auto alpha = at::exp(log_alpha); + alpha.nan_to_num_(1.); + + auto pick_new = at::less(local_max, global_max); + auto pick_current_coef = at::where(pick_new, 1., alpha); + auto pick_new_coef = at::where(pick_new, alpha, 1.); + + O = at::add(at::mul(pick_current_coef, O), at::mul(pick_new_coef, local_O)); + global_sumexp = at::add( + at::mul(pick_current_coef, global_sumexp), + at::mul(pick_new_coef, local_sumexp)); + global_max = at::max(local_max, global_max); + } + + return at::div(O, global_sumexp); +} + +static at::Tensor efficient_attention_forward_decoder_splitk_torch( + const at::Tensor& XQ, // [B, 1, G, H, D] + const at::Tensor& cache_K, // [B, KV_M_MAX, G, H or 1, D] + const at::Tensor& cache_V, // [B, KV_M_MAX, G, H or 1, D] + at::optional seq_kv_lens, // [B] + double qk_scale, + int32_t split_k, + int32_t block_size) { + auto [O_split, m, l] = split_attention_torch( + XQ, cache_K, cache_V, *seq_kv_lens, split_k, block_size); + auto O = split_reduce_torch(O_split, m, l, split_k); + return O.reshape_as(XQ); +} + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct FMHADecoderSplitAttentionDeviceOp : public BaseOperator { + using DeviceOp = FMHADecoderSplitAttentionDeviceOp; + struct Argument : public BaseArgument { + const scalar_t* __restrict__ XQ; + const scalar_t* __restrict__ cache_K; + const scalar_t* __restrict__ cache_V; + scalar_t* __restrict__ O; + scalar_t* __restrict__ split_O; + compute_t* __restrict__ split_max; + compute_t* __restrict__ split_sumexp; + const int32_t* __restrict__ seq_kv_lens; + const ptrdiff_t XQ_stride_b; + const ptrdiff_t XQ_stride_m; + const ptrdiff_t XQ_stride_g; + const ptrdiff_t XQ_stride_h; + const ptrdiff_t K_stride_b; + const ptrdiff_t K_stride_m; + const ptrdiff_t K_stride_g; + const ptrdiff_t K_stride_h; + const ptrdiff_t O_stride_split; + const int32_t Q_size_m; + const int32_t Q_size_g; + const int32_t Q_size_h; + const int32_t Q_size_k; + const int32_t K_size_m; + const bool multiquery; + const float qk_scale; + const int32_t split_k; + + const dim3 grid_dim; + const dim3 block_dim; + const size_t lds_bytes; + + Argument( + const scalar_t* __restrict__ XQ, + const scalar_t* __restrict__ cache_K, + const scalar_t* __restrict__ cache_V, + scalar_t* __restrict__ O, + scalar_t* __restrict__ split_O, + compute_t* __restrict__ split_max, + compute_t* __restrict__ split_sumexp, + const int32_t* __restrict__ seq_kv_lens, + const ptrdiff_t XQ_stride_b, + const ptrdiff_t XQ_stride_m, + const ptrdiff_t XQ_stride_g, + const ptrdiff_t XQ_stride_h, + const ptrdiff_t K_stride_b, + const ptrdiff_t K_stride_m, + const ptrdiff_t K_stride_g, + const ptrdiff_t K_stride_h, + const ptrdiff_t O_stride_split, + const int32_t Q_size_m, + const int32_t Q_size_g, + const int32_t Q_size_h, + const int32_t Q_size_k, + const int32_t K_size_m, + const bool multiquery, + const float qk_scale, + const int32_t split_k, + // launch params + const dim3 grid_dim, + const dim3 block_dim, + const size_t lds_bytes) + : XQ(XQ), + cache_K(cache_K), + cache_V(cache_V), + O(O), + split_O(split_O), + split_max(split_max), + split_sumexp(split_sumexp), + seq_kv_lens(seq_kv_lens), + XQ_stride_b(XQ_stride_b), + XQ_stride_m(XQ_stride_m), + XQ_stride_g(XQ_stride_g), + XQ_stride_h(XQ_stride_h), + K_stride_b(K_stride_b), + K_stride_m(K_stride_m), + K_stride_g(K_stride_g), + K_stride_h(K_stride_h), + O_stride_split(O_stride_split), + Q_size_m(Q_size_m), + Q_size_g(Q_size_g), + Q_size_h(Q_size_h), + Q_size_k(Q_size_k), + K_size_m(K_size_m), + multiquery(multiquery), + qk_scale(qk_scale), + split_k(split_k), + // launch params + grid_dim(grid_dim), + block_dim(block_dim), + lds_bytes(lds_bytes) {} + + std::string str() const { + std::ostringstream oss; + oss << "Argument { " << std::endl + << " XQ: " << XQ << std::endl + << " cache_K: " << cache_K << std::endl + << " cache_V: " << cache_V << std::endl + << " O: " << O << std::endl + << " split_O: " << split_O << std::endl + << " split_max: " << split_max << std::endl + << " split_sumexp: " << split_sumexp << std::endl + << " seq_kv_lens: " << seq_kv_lens << std::endl + << " XQ_stride_b: " << XQ_stride_b << std::endl + << " XQ_stride_m: " << XQ_stride_m << std::endl + << " XQ_stride_g: " << XQ_stride_g << std::endl + << " XQ_stride_h: " << XQ_stride_h << std::endl + << " K_stride_b: " << K_stride_b << std::endl + << " K_stride_m: " << K_stride_m << std::endl + << " K_stride_g: " << K_stride_g << std::endl + << " K_stride_h: " << K_stride_h << std::endl + << " O_stride_split: " << O_stride_split << std::endl + << " Q_size_m: " << Q_size_m << std::endl + << " Q_size_g: " << Q_size_g << std::endl + << " Q_size_h: " << Q_size_h << std::endl + << " Q_size_k: " << Q_size_k << std::endl + << " K_size_m: " << K_size_m << std::endl + << " multiquery: " << multiquery << std::endl + << " qk_scale: " << qk_scale << std::endl + << " split_k: " << split_k << std::endl + << std::endl + << " grid_dim: " << grid_dim.x << "." << grid_dim.y << "." + << grid_dim.z << std::endl + << " block_dim: " << block_dim.x << "." << block_dim.y << "." + << block_dim.z << std::endl + << " lds_bytes: " << lds_bytes << std::endl + << "}"; + return oss.str(); + } + }; + + struct Invoker : public BaseInvoker { + using Argument = DeviceOp::Argument; + float Run( + const Argument& arg, + const StreamConfig& stream_config = StreamConfig{}) { + auto threads_per_wavefront = arg.block_dim.x; + auto Q_size_k_alignment_necessary = 0; + + for (auto vec_size : {4, 2, 1}) { + if (arg.Q_size_k <= vec_size * threads_per_wavefront) { + Q_size_k_alignment_necessary = vec_size; + } + } + + if (!Q_size_k_alignment_necessary) { + throw std::runtime_error("Unsupported Q_size_k"); + } + + if (arg.Q_size_k % Q_size_k_alignment_necessary) { + throw std::runtime_error("Unsupported alignment for Q_size_k"); + } + + float split_attention_result = launch_and_time_kernel( + stream_config, + Q_size_k_alignment_necessary == 4 + ? efficient_attention_forward_decoder_splitk_ck_kernel< + scalar_t, + 4> + : Q_size_k_alignment_necessary == 2 + ? efficient_attention_forward_decoder_splitk_ck_kernel< + scalar_t, + 2> + : Q_size_k_alignment_necessary == 1 + ? efficient_attention_forward_decoder_splitk_ck_kernel< + scalar_t, + 1> + : nullptr, + arg.grid_dim, + arg.block_dim, + arg.lds_bytes, + arg.XQ, + arg.cache_K, + arg.cache_V, + arg.split_O, + arg.split_max, + arg.split_sumexp, + arg.seq_kv_lens, + arg.XQ_stride_b, + arg.XQ_stride_m, + arg.XQ_stride_g, + arg.XQ_stride_h, + arg.K_stride_b, + arg.K_stride_m, + arg.K_stride_g, + arg.K_stride_h, + arg.O_stride_split, + arg.Q_size_m, + arg.Q_size_g, + arg.Q_size_h, + arg.Q_size_k, + arg.K_size_m, + arg.multiquery, + arg.qk_scale, + arg.split_k); + + return split_attention_result; + } + }; +}; + +template +struct FMHADecoderSplitReduceDeviceOp : public BaseOperator { + using DeviceOp = FMHADecoderSplitReduceDeviceOp; + struct Argument : public BaseArgument { + const scalar_t* __restrict__ split_O; + const compute_t* __restrict__ split_max; + const compute_t* __restrict__ split_sumexp; + scalar_t* __restrict__ O; + + const int32_t O_size_m; + const int32_t O_size_g; + const int32_t O_size_h; + const int32_t O_size_k; + + const ptrdiff_t O_stride_split; + const ptrdiff_t O_stride_b; + const ptrdiff_t O_stride_m; + const ptrdiff_t O_stride_g; + const ptrdiff_t O_stride_h; + + const int32_t split_k; + + const dim3 grid_dim; + const dim3 block_dim; + const size_t lds_bytes; + + Argument( + const scalar_t* __restrict__ split_O, + const compute_t* __restrict__ split_max, + const compute_t* __restrict__ split_sumexp, + scalar_t* __restrict__ O, + const int32_t O_size_m, + const int32_t O_size_g, + const int32_t O_size_h, + const int32_t O_size_k, + const ptrdiff_t O_stride_split, + const ptrdiff_t O_stride_b, + const ptrdiff_t O_stride_m, + const ptrdiff_t O_stride_g, + const ptrdiff_t O_stride_h, + const int32_t split_k, + // launch params + const dim3 grid_dim, + const dim3 block_dim, + const size_t lds_bytes) + : split_O(split_O), + split_max(split_max), + split_sumexp(split_sumexp), + O(O), + O_size_m(O_size_m), + O_size_g(O_size_g), + O_size_h(O_size_h), + O_size_k(O_size_k), + O_stride_split(O_stride_split), + O_stride_b(O_stride_b), + O_stride_m(O_stride_m), + O_stride_g(O_stride_g), + O_stride_h(O_stride_h), + split_k(split_k), + // launch params + grid_dim(grid_dim), + block_dim(block_dim), + lds_bytes(lds_bytes) {} + + std::string str() const { + std::ostringstream oss; + oss << "Argument { " << std::endl + << " O: " << O << std::endl + << " split_O: " << split_O << std::endl + << " split_max: " << split_max << std::endl + << " split_sumexp: " << split_sumexp << std::endl + << " O_stride_b: " << O_stride_b << std::endl + << " O_stride_m: " << O_stride_m << std::endl + << " O_stride_g: " << O_stride_g << std::endl + << " O_stride_h: " << O_stride_h << std::endl + << " O_stride_split: " << O_stride_split << std::endl + << " O_size_m: " << O_size_m << std::endl + << " O_size_g: " << O_size_g << std::endl + << " O_size_h: " << O_size_h << std::endl + << " O_size_k: " << O_size_k << std::endl + << " split_k: " << split_k << std::endl + << std::endl + << " grid_dim: " << grid_dim.x << "." << grid_dim.y << "." + << grid_dim.z << std::endl + << " block_dim: " << block_dim.x << "." << block_dim.y << "." + << block_dim.z << std::endl + << " lds_bytes: " << lds_bytes << std::endl + << "}"; + return oss.str(); + } + }; + + struct Invoker : public BaseInvoker { + using Argument = DeviceOp::Argument; + float Run( + const Argument& arg, + const StreamConfig& stream_config = StreamConfig{}) { + auto threads_per_wavefront = arg.block_dim.x; + auto O_size_k_alignment_necessary = 0; + + for (auto vec_size : {4, 2, 1}) { + if (arg.O_size_k <= vec_size * threads_per_wavefront) { + O_size_k_alignment_necessary = vec_size; + } + } + + if (!O_size_k_alignment_necessary) { + throw std::runtime_error("Unsupported O_size_k"); + } + + if (arg.O_size_k % O_size_k_alignment_necessary) { + throw std::runtime_error("Unsupported alignment for O_size_k"); + } + + const dim3 reduce_gridsize = {arg.grid_dim.x}; + const dim3 reduce_blocksize = {arg.block_dim.x}; + constexpr int32_t reduce_lds_bytes = 0; + float reduce_result = launch_and_time_kernel( + stream_config, + O_size_k_alignment_necessary == 4 + ? efficient_attention_forward_decoder_splitk_reduce_ck_kernel< + scalar_t, + 4> + : O_size_k_alignment_necessary == 2 + ? efficient_attention_forward_decoder_splitk_reduce_ck_kernel< + scalar_t, + 2> + : O_size_k_alignment_necessary == 1 + ? efficient_attention_forward_decoder_splitk_reduce_ck_kernel< + scalar_t, + 1> + : nullptr, + reduce_gridsize, + reduce_blocksize, + reduce_lds_bytes, + arg.split_O, + arg.split_max, + arg.split_sumexp, + arg.O, + arg.O_size_m, + arg.O_size_g, + arg.O_size_h, + arg.O_size_k, + arg.O_stride_split, + arg.O_stride_b, + arg.O_stride_m, + arg.O_stride_g, + arg.O_stride_h, + arg.split_k); + return reduce_result; + } + }; +}; +} // namespace device +} // namespace tensor_operation +} // namespace ck + +static std::tuple split_attention_hip( + const at::Tensor& XQ, + const at::Tensor& K, + const at::Tensor& V, + const at::Tensor& seqlen, + const int32_t split_k, + const int32_t wavefronts_per_block) { + at::OptionalDeviceGuard guard(XQ.device()); + + auto B = XQ.size(0); + auto M = XQ.size(1); + auto G = XQ.size(2); + auto H = XQ.size(3); + auto D = XQ.size(4); + + double qk_scale = 1. / sqrt(D); + + auto O = at::empty_like(XQ); + constexpr auto rank = 5; + auto split_O = at::zeros({split_k, B, M, G, H, D}, XQ.options()); + auto split_max = + at::empty({B, M, G, H, split_k}, XQ.options().dtype(at::kFloat)) + .fill_(ck::NumericLimits::Lowest()); + auto split_sumexp = at::zeros_like(split_max); + + dim3 blocks(B * H * M * G, split_k); + dim3 threads(kThreadsPerWavefront, wavefronts_per_block); + + constexpr int32_t KV_M_MAX = 8192; + constexpr int32_t K_MAX = 4 * kThreadsPerWavefront; + + int32_t smem_softmax = KV_M_MAX * sizeof(float) + threads.y * sizeof(float); + int32_t smem_output = K_MAX * sizeof(float) * + threads.y; // 4 * threadsPerBlock * sizeof(float) == sizeof(O[b][0][h][:]) + const size_t lds_bytes = max(smem_softmax, smem_output); + auto stream = at::cuda::getCurrentHIPStream().stream(); + + AT_DISPATCH_SWITCH_3( + at::ScalarType::Half, + at::ScalarType::BFloat16, + at::ScalarType::Float, + XQ.scalar_type(), + "efficient_attention_forward_decoder_split_attention_ck_test", + [&] { + using ck_data_t = c10_to_data_t::type; + using device_op_t = + ck::tensor_operation::device::FMHADecoderSplitAttentionDeviceOp< + ck_data_t>; + auto op = device_op_t{}; + + auto XQ_acc = + XQ.packed_accessor32(); + auto K_acc = + K.packed_accessor64(); + auto V_acc = + V.packed_accessor64(); + auto split_O_acc = + split_O + .packed_accessor32(); + auto O_acc = + O.packed_accessor32(); + auto seq_acc = + seqlen.packed_accessor32(); + auto split_max_acc = + split_max.packed_accessor32(); + auto split_sumexp_acc = + split_sumexp + .packed_accessor32(); + auto arg = device_op_t::Argument( + reinterpret_cast(XQ_acc.data()), + reinterpret_cast(K_acc.data()), + reinterpret_cast(V_acc.data()), + reinterpret_cast(O_acc.data()), + reinterpret_cast(split_O_acc.data()), + split_max_acc.data(), + split_sumexp_acc.data(), + seq_acc.data(), + XQ_acc.stride(0), + XQ_acc.stride(1), + XQ_acc.stride(2), + XQ_acc.stride(3), + K_acc.stride(0), + K_acc.stride(1), + K_acc.stride(2), + K_acc.stride(3), + split_O_acc.stride(0), + XQ_acc.size(1), + XQ_acc.size(2), + XQ_acc.size(3), + XQ_acc.size(4), + K_acc.size(1), + K_acc.size(3) == 1, + qk_scale, + split_k, + blocks, + threads, + lds_bytes); + + auto invoker = device_op_t::Invoker{}; + (void)invoker.Run(arg, {stream}); + }); + return std::make_tuple(split_O, split_max, split_sumexp); +} + +static at::Tensor split_reduce_hip( + const at::Tensor& split_O, + const at::Tensor& split_max, + const at::Tensor& split_sumexp, + const int32_t split_k) { + at::OptionalDeviceGuard guard(split_O.device()); + + auto B = split_O.size(1); + auto M = split_O.size(2); + auto G = split_O.size(3); + auto H = split_O.size(4); + auto D = split_O.size(5); + + TORCH_CHECK_EQ(split_k, split_O.size(0)); + TORCH_CHECK_EQ(split_k, split_max.size(-1)); + TORCH_CHECK_EQ(split_k, split_sumexp.size(-1)); + + constexpr auto rank = 5; + + TORCH_CHECK_EQ(split_O.dim(), 1 + rank); + TORCH_CHECK_EQ(split_max.dim(), rank); + TORCH_CHECK_EQ(split_sumexp.dim(), rank); + + auto O = at::zeros({B, M, G, H, D}, split_O.options()); + + auto stream = at::cuda::getCurrentHIPStream().stream(); + auto lds_bytes = 0; + + dim3 blocks(B * H * M * G); + dim3 threads(kThreadsPerWavefront); + + AT_DISPATCH_SWITCH_3( + at::ScalarType::Half, + at::ScalarType::BFloat16, + at::ScalarType::Float, + O.scalar_type(), + "efficient_attention_forward_decoder_split_reduce_ck_test", + [&] { + using ck_data_t = c10_to_data_t::type; + using device_op_t = + ck::tensor_operation::device::FMHADecoderSplitReduceDeviceOp< + ck_data_t>; + auto op = device_op_t{}; + + auto split_O_acc = + split_O + .packed_accessor32(); + auto O_acc = + O.packed_accessor32(); + auto split_max_acc = + split_max.packed_accessor32(); + auto split_sumexp_acc = + split_sumexp + .packed_accessor32(); + auto arg = device_op_t::Argument( + reinterpret_cast(split_O_acc.data()), + split_max_acc.data(), + split_sumexp_acc.data(), + reinterpret_cast(O_acc.data()), + O_acc.size(1), + O_acc.size(2), + O_acc.size(3), + O_acc.size(4), + split_O_acc.stride(0), + O_acc.stride(0), + O_acc.stride(1), + O_acc.stride(2), + O_acc.stride(3), + split_k, + blocks, + threads, + lds_bytes); + + auto invoker = device_op_t::Invoker{}; + (void)invoker.Run(arg, {stream}); + }); + return O; +} + +std::tuple generate_inputs( + const int32_t padding, + const int32_t B, + const int32_t Hq, + const int32_t Hkv, + const decltype(torch::kFloat32) dtype = torch::kFloat32) { + const int32_t D = 4 * kThreadsPerWavefront; + const int32_t G = Hq / Hkv; + const int32_t num_queries = 1; + + at::manual_seed(1); + + auto options = torch::TensorOptions() + .dtype(dtype) + .layout(torch::kStrided) + .device(torch::kCUDA, 1) + .requires_grad(false); + auto int_options = options.dtype(torch::kInt); + auto XQ = at::randn({B, num_queries, G, Hq, D}, options); + auto K = (G == 1) ? at::randn({B, padding, G, Hkv, D}, options) + : at::randn({B, padding, G, 1, D}, options) + .expand({B, padding, G, Hq, D}); + auto V = at::randn_like(K); + auto seqlen = at::randint(num_queries, padding + 1, {B}, int_options); + + return std::make_tuple(XQ, K, V, seqlen); +} + +static float percent_mismatch(const at::Tensor& a, const at::Tensor& b) { + auto mask = + at::isclose(a, b, /*atol*/ 1e-3, /*rtol*/ 1e-5, /*equal_nan*/ false); + auto percent_match = at::sum(mask.to(torch::kFloat32)) / mask.numel(); + return 1. - percent_match.item(); +} + +static void test_split_attention( + int32_t padding, + int32_t batch_size, + int32_t Hq, + int32_t Hkv, + int32_t split_k) { + auto [XQ, K, V, seqlen] = generate_inputs(padding, batch_size, Hq, Hkv); + + auto [O_ref, m_ref, l_ref] = split_attention_torch( + XQ, K, V, seqlen, split_k, /* block_size */ kWavefrontsPerBlock * 16); + + auto [O_hip, m_hip, l_hip] = + split_attention_hip(XQ, K, V, seqlen, split_k, kWavefrontsPerBlock); + + auto O_percent_mismatch = percent_mismatch(O_ref, O_hip); + auto m_percent_mismatch = percent_mismatch(m_ref, m_hip); + auto l_percent_mismatch = percent_mismatch(l_ref, l_hip); + + printf( + "[Test split attention] Padding=%d BS=%d Hq=%d Hkv=%d split_k=%d Mismatched split_O " + "elements percentage: %.2f Mismatched split_max elements percentage: %.2f Mismatched " + "split_sumexp elements percentage: %.2f\n", + padding, + batch_size, + Hq, + Hkv, + split_k, + O_percent_mismatch, + m_percent_mismatch, + l_percent_mismatch); +} + +static void test_split_reduce( + int32_t padding, + int32_t batch_size, + int32_t Hq, + int32_t Hkv, + int32_t split_k) { + auto [XQ, K, V, seqlen] = generate_inputs(padding, batch_size, Hq, Hkv); + + auto [O_ref, m_ref, l_ref] = + split_attention_hip(XQ, K, V, seqlen, split_k, kWavefrontsPerBlock); + + auto O_torch = split_reduce_torch( + O_ref, m_ref.unsqueeze(0), l_ref.unsqueeze(0), split_k); + auto O_hip = split_reduce_hip(O_ref, m_ref, l_ref, split_k); + + auto hip_torch_mismatch = percent_mismatch(O_hip, O_torch); + printf( + "[Test split reduce] Padding=%d BS=%d Hq=%d Hkv=%d split_k=%d Mismatched elements " + "percentage: %.2f \n", + padding, + batch_size, + Hq, + Hkv, + split_k, + hip_torch_mismatch); +} + +static void test_splitk_decoder_e2e_correctness( + int32_t padding, + int32_t batch_size, + int32_t Hq, + int32_t Hkv, + int32_t split_k) { + auto [XQ, K, V, seqlen] = generate_inputs(padding, batch_size, Hq, Hkv); + + double qk_scale = 1. / sqrt(XQ.size(-1)); + + auto result = efficient_attention_forward_decoder_splitk_ck_impl< + kThreadsPerWavefront, + kWavefrontsPerBlock>(XQ, K, V, seqlen, qk_scale, split_k); + auto gold_result = efficient_attention_forward_decoder_splitk_torch( + XQ, K, V, seqlen, qk_scale, /* split_k */ 1, /* block_size */ 1); + auto e2e_mismatch = percent_mismatch(result, gold_result); + printf( + "[Test e2e split-k decoder] Padding=%d BS=%d Hq=%d Hkv=%d split_k=%d Mismatched " + "elements percentage: %.2f\n", + padding, + batch_size, + Hq, + Hkv, + split_k, + e2e_mismatch); +} + +int main(int argc, char** argv) { + if (argc == 1) { + for (auto padding : {32, 4096}) { + for (auto batch_size : {1, 8}) { + for (auto Hq : {16}) { + for (auto Hkv : {16}) { + for (auto split_k : {1, 2, 4, 8, 16}) { + test_splitk_decoder_e2e_correctness( + padding, batch_size, Hq, Hkv, split_k); + } + } + } + } + } + + for (auto padding : {32, 4096}) { + for (auto batch_size : {1, 8}) { + for (auto Hq : {16}) { + for (auto Hkv : {16}) { + for (auto split_k : {1, 2, 4, 8, 16}) { + test_split_attention(padding, batch_size, Hq, Hkv, split_k); + } + } + } + } + } + + for (auto padding : {32, 4096}) { + for (auto batch_size : {1, 8}) { + for (auto Hq : {16}) { + for (auto Hkv : {16}) { + for (auto split_k : {1, 2}) { + test_split_reduce(padding, batch_size, Hq, Hkv, split_k); + } + } + } + } + } + } else { + const auto args = std::vector(argv + 1, argv + argc); + if (args.size() != 6) { + std::cout << "Usage: ./a.out padding batch_size nq_heads nkv_heads dtype " + "n_wavefronts_per_block" + << std::endl; + return 0; + } + const int32_t padding = std::stoi(args[0]); + const int32_t batch_size = std::stoi(args[1]); + const int32_t nq_heads = std::stoi(args[2]); + const int32_t nkv_heads = std::stoi(args[3]); + const auto dtype = (args[4] == "f32") ? torch::kFloat32 + : (args[4] == "f16") ? torch::kFloat16 + : torch::kBFloat16; + const int32_t n_wavefronts_per_block = std::stoi(args[5]); + + auto [Q, K, V, seq] = + generate_inputs(padding, batch_size, nq_heads, nkv_heads, dtype); + auto O = at::empty_like(Q); + + constexpr auto splitk_dim = 0; + constexpr auto split_k = 1; + auto O_splits = at::stack(O, splitk_dim); + + auto split_max = at::empty( + {batch_size, padding, Q.size(2), Q.size(3), split_k}, + Q.options().dtype(at::kFloat)); + auto split_sumexp = at::empty_like(split_max); + + const double qk_scale = 1. / sqrt(Q.size(-1)); + auto call_ptr = + decltype(&efficient_attention_forward_decoder_splitk_ck_out_impl< + kThreadsPerWavefront, + kWavefrontsPerBlock>){}; + +#define SWITCH_CASE_SET_CALLPTR(n) \ + case (n): \ + call_ptr = &efficient_attention_forward_decoder_splitk_ck_out_impl< \ + kThreadsPerWavefront, \ + (n)>; \ + break; + + switch (n_wavefronts_per_block) { + SWITCH_CASE_SET_CALLPTR(1); + SWITCH_CASE_SET_CALLPTR(2); + SWITCH_CASE_SET_CALLPTR(4); + SWITCH_CASE_SET_CALLPTR(8); + SWITCH_CASE_SET_CALLPTR(16); + + default: + call_ptr = nullptr; + break; + } +#undef SWITCH_CASE_SET_CALLPTR + + if (call_ptr) { + call_ptr( + Q, + K, + V, + seq, + qk_scale, + split_k, + split_max, + split_sumexp, + O_splits, + O); + } else { + std::cout << "Warning: no kernel was found for wavefronts_per_block=" + << n_wavefronts_per_block << std::endl; + } + } + return 0; +} + +#endif // MAIN + +#undef AT_DISPATCH_CASE_3 +#undef AT_DISPATCH_SWITCH_3 \ No newline at end of file diff --git a/xformers/csrc/attention/hip_fmha/ck_attention_forward_decoder.h b/xformers/csrc/attention/hip_fmha/ck_attention_forward_decoder.h new file mode 100644 index 0000000000..20b3b8979c --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/ck_attention_forward_decoder.h @@ -0,0 +1,493 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace { + +template +__device__ typename ck::vector_type::type scalar_scale_acc( + typename ck::vector_type::type acc, + typename ck::vector_type::type a, + float b) { + union { + decltype(acc) vec; + float arr[vec_size]; + } acc_u{acc}; + union { + decltype(a) vec; + data_t arr[vec_size]; + } a_u{a}; + +#pragma unroll + for (int32_t i = 0; i < vec_size; ++i) { + acc_u.arr[i] += ck::type_convert(a_u.arr[i]) * b; + } + + return acc_u.vec; +} + +template +float __device__ __forceinline__ wavefrontReduce(float val, F f) { +#pragma unroll + for (int32_t mask = n_threads_per_wavefront >> 1; mask > 0; mask >>= 1) { + val = f(__shfl_xor(val, mask, n_threads_per_wavefront), val); + } + return val; +} + +template +__forceinline__ __device__ void load_v( + const TData* __restrict__ data_ptr, + int32_t vector_offset, + TDataVec* __restrict__ load_to) { + *load_to = *(reinterpret_cast(data_ptr) + vector_offset); +} + +template +__forceinline__ __device__ void store_v( + TData* __restrict__ data_ptr, + int32_t vector_offset, + TDataVec value) { + *(reinterpret_cast(data_ptr) + vector_offset) = value; +} + +template < + typename scalar_t, + int32_t vec_size = 4, + int32_t n_loop_unroll = 16, + int32_t n_loop_unroll_tail = 2, + int32_t KV_M_MAX = 8192, + int32_t n_wavefronts_per_block = 16> +__global__ void efficient_attention_forward_decoder_ck_kernel( + const scalar_t* __restrict__ XQ, + const scalar_t* __restrict__ cache_K, + const scalar_t* __restrict__ cache_V, + scalar_t* __restrict__ O, + const int32_t* __restrict__ seq_kv_lens, + const ptrdiff_t XQ_stride_b, + const ptrdiff_t XQ_stride_m, + const ptrdiff_t XQ_stride_g, + const ptrdiff_t XQ_stride_h, + const ptrdiff_t K_stride_b, + const ptrdiff_t K_stride_m, + const ptrdiff_t K_stride_g, + const ptrdiff_t K_stride_h, + const int32_t Q_size_m, + const int32_t Q_size_g, + const int32_t Q_size_h, + const int32_t Q_size_k, + const int32_t K_size_m, + const bool multiquery, + const float qk_scale) { + static_assert(n_loop_unroll_tail < n_loop_unroll, ""); + + // Each block handles a single batch and head and query and group + const int32_t b = blockIdx.x / (Q_size_m * Q_size_g * Q_size_h); + const int32_t m = (blockIdx.x / (Q_size_g * Q_size_h)) % Q_size_m; + const int32_t g = (blockIdx.x / Q_size_h) % Q_size_g; + const int32_t h = blockIdx.x % Q_size_h; + + // Note: this is decoding case where we attend to current and all previous + // tokens. + const int32_t t_max = seq_kv_lens ? seq_kv_lens[b] : K_size_m; + + const int32_t lane_idx = threadIdx.x; + const int32_t wavefront_idx = threadIdx.y; + const int32_t threads_per_wavefront = blockDim.x; + const int32_t wavefronts_per_block = blockDim.y; + const int32_t threads_per_block = + threads_per_wavefront * wavefronts_per_block; + const int32_t thread_linear_idx = + lane_idx + wavefront_idx * threads_per_wavefront; + // const auto* q_ = &(XQ_acc[b][m][g][h][0]); + const auto XQO_base_offset = + b * XQ_stride_b + m * XQ_stride_m + g * XQ_stride_g + h * XQ_stride_h; + const auto* __restrict__ q_ = XQ + XQO_base_offset; + + const auto cache_KV_base_offset = b * K_stride_b + 0 * K_stride_m + + g * K_stride_g + (multiquery ? 0 : h * K_stride_h); + const auto* __restrict__ cache_K_base = cache_K + cache_KV_base_offset; + const auto* __restrict__ cache_V_base = cache_V + cache_KV_base_offset; + + using data_t = scalar_t; + using data_vec_t = typename ck::vector_type::type; + using compute_t = float; + using compute_vec_t = typename ck::vector_type::type; + + const bool lane_active_for_io = lane_idx * vec_size < Q_size_k; + + extern __shared__ __align__(16) compute_t smem[]; + + data_vec_t q_thread = 0; + // Load Q into registers in all wavefronts. + // Each thread handles `vec_size` D dimensions + if (lane_active_for_io) { + load_v(q_, lane_idx, &q_thread); + } + + compute_t max_qk_acc = ck::NumericLimits::Lowest(); + + // Compute S[0:t_max] = + // ``` + // for t in range(t_max): + // S[t] = dot(Q, K[t]) + // ``` + // Split the 0:t_max range across wavefronts in a block, + // unroll loads to expose more parallelism. + // Reduce the dot product with cross-lane operation; + // Q and K[t] are in the registers of threads in a single wavefront. + + data_vec_t k_loads[n_loop_unroll] = {}; + + constexpr auto dtt = n_wavefronts_per_block * n_loop_unroll; + const int32_t t_max_unroll = (t_max / dtt) * dtt; + + for (auto tt = wavefront_idx * n_loop_unroll; tt < t_max_unroll; tt += dtt) { + if (lane_active_for_io) { +#pragma unroll n_loop_unroll + for (auto ttt = 0; ttt < n_loop_unroll; ++ttt) { + const int32_t t = tt + ttt; + // load the K[b][t][g][h|0][:] row into registers + load_v( + cache_K_base + t * K_stride_m, lane_idx, &k_loads[ttt]); + } + } + compute_t qk_accs[n_loop_unroll] = {}; +#pragma unroll n_loop_unroll + for (auto ttt = 0; ttt < n_loop_unroll; ++ttt) { + ck::inner_product( + q_thread, k_loads[ttt], qk_accs[ttt]); + qk_accs[ttt] *= qk_scale; + + qk_accs[ttt] = + wavefrontReduce(qk_accs[ttt], [](auto a, auto b) { return a + b; }); + max_qk_acc = ck::math::max(qk_accs[ttt], max_qk_acc); + } + if (lane_idx == 0) { + auto* __restrict__ smem_base = smem + tt; +#pragma unroll n_loop_unroll + for (auto ttt = 0; ttt < n_loop_unroll; ++ttt) { + smem_base[ttt] = qk_accs[ttt]; + } + } + } + + // NB: the length of the tail is <= (wavefronts_per_block * n_loop_unroll) + for (auto tt = t_max_unroll + wavefront_idx * n_loop_unroll_tail; tt < t_max; + tt += wavefronts_per_block * n_loop_unroll_tail) { + if (lane_active_for_io) { +#pragma unroll n_loop_unroll_tail + for (auto ttt = 0; ttt < n_loop_unroll_tail; ++ttt) { + const int32_t t = tt + ttt; + if (t < t_max) { + // load the K[b][t][g][h|0][:] row into registers + load_v( + cache_K_base + t * K_stride_m, lane_idx, &k_loads[ttt]); + } + } + } +#pragma unroll n_loop_unroll_tail + for (auto ttt = 0; ttt < n_loop_unroll_tail; ++ttt) { + compute_t qk_acc = 0; + const int32_t t = tt + ttt; + if (t < t_max) { + ck::inner_product( + q_thread, k_loads[ttt], qk_acc); + qk_acc *= qk_scale; + + qk_acc = wavefrontReduce(qk_acc, [](auto a, auto b) { return a + b; }); + max_qk_acc = ck::math::max(qk_acc, max_qk_acc); + + // write accumulated sums to smem. + if (lane_idx == 0) { + smem[t] = qk_acc; + } + } + } + } + + // Use shared reduction to compute max and compute softmax on shared memory. + // write max acc + if (lane_idx == 0) { + smem[KV_M_MAX + wavefront_idx] = max_qk_acc; + } + __syncthreads(); + if (lane_idx < wavefronts_per_block) { + max_qk_acc = ck::math::max(max_qk_acc, smem[KV_M_MAX + lane_idx]); + } + // shared across all threads in block + max_qk_acc = + wavefrontReduce(max_qk_acc, [](auto a, auto b) { return a > b ? a : b; }); + + // each wavefront computes partial sum of exp. + compute_t softmax_denominator = 0.0f; + for (int32_t t = thread_linear_idx; t < t_max; t += threads_per_block) { + softmax_denominator += ck::math::exp(smem[t] - max_qk_acc); + } + softmax_denominator = wavefrontReduce( + softmax_denominator, [](auto a, auto b) { return a + b; }); + + if (lane_idx == 0) { + smem[KV_M_MAX + wavefront_idx] = softmax_denominator; + } + __syncthreads(); + + // now, compute sum of exp(x - max(x)) over all intermediate results. + softmax_denominator = 0.0; + if (lane_idx < wavefronts_per_block) { + softmax_denominator = smem[KV_M_MAX + lane_idx]; + } + softmax_denominator = wavefrontReduce( + softmax_denominator, [](auto a, auto b) { return a + b; }); + + const compute_t softmax_scale_factor = 1. / softmax_denominator; + // now, compute the normalization across all threads. + for (int32_t t = thread_linear_idx; t < t_max; t += threads_per_block) { + smem[t] = ck::math::exp(smem[t] - max_qk_acc) * softmax_scale_factor; + } + __syncthreads(); + + // Split T across wavefronts in a block + // each wavefront compute sum(t_subset) P[t] * V[t_subset, d] + // outputs are of size float[D] + + compute_t ps[n_loop_unroll] = {}; + compute_vec_t o_acc = 0; + if (lane_active_for_io) { + for (auto tt = wavefront_idx * n_loop_unroll; tt < t_max_unroll; + tt += dtt) { +#pragma unroll n_loop_unroll + for (auto ttt = 0; ttt < n_loop_unroll; ++ttt) { + const int32_t t = tt + ttt; + // load the V[b][t][g][h|0][:] row into registers, reusing K register + // storage + load_v( + cache_V_base + t * K_stride_m, lane_idx, &k_loads[ttt]); + ps[ttt] = smem[t]; + } + +#pragma unroll n_loop_unroll + for (auto ttt = 0; ttt < n_loop_unroll; ++ttt) { + o_acc = + scalar_scale_acc(o_acc, k_loads[ttt], ps[ttt]); + } + } + + for (auto tt = t_max_unroll + wavefront_idx * n_loop_unroll_tail; + tt < t_max; + tt += wavefronts_per_block * n_loop_unroll_tail) { +#pragma unroll n_loop_unroll_tail + for (auto ttt = 0; ttt < n_loop_unroll_tail; ++ttt) { + const int32_t t = tt + ttt; + if (t < t_max) { + // load the V[b][t][g][h|0][:] row into registers, reusing K register + // storage + load_v( + cache_V_base + t * K_stride_m, lane_idx, &k_loads[ttt]); + ps[ttt] = smem[t]; + } + } + +#pragma unroll n_loop_unroll_tail + for (auto ttt = 0; ttt < n_loop_unroll_tail; ++ttt) { + const int32_t t = tt + ttt; + if (t < t_max) { + o_acc = + scalar_scale_acc(o_acc, k_loads[ttt], ps[ttt]); + } + } + } + } + // now, each thread has partial sums. Write to smem and get accumulated + // results back. + __syncthreads(); + + // NB: needs sizeof(smem) >= `vec_size` * (sizeof(float)==4) * threadsPerBlock + if (lane_active_for_io) { + store_v(&smem[0], thread_linear_idx, o_acc); + } + + __syncthreads(); + // sum up partial D rows from other wavefronts + if (wavefront_idx == 0 && lane_active_for_io) { + union { + compute_vec_t vec = 0; + compute_t arr[vec_size]; + } r; + for (int32_t w = 0; w < wavefronts_per_block; ++w) { + compute_vec_t partial_r; + load_v( + smem, w * threads_per_wavefront + lane_idx, &partial_r); + r.vec += partial_r; + } + // elementwise convert from compute_t result to data_t out to be written + union { + data_vec_t vec; + data_t arr[vec_size]; + } bf_r; +#pragma unroll + for (int32_t i = 0; i < vec_size; ++i) { + bf_r.arr[i] = ck::type_convert(r.arr[i]); + } + // write output row O[b][m][g][h][:] + data_t* __restrict__ o_ = O + XQO_base_offset; + store_v(o_, lane_idx, bf_r.vec); + } +} + +} // namespace + +namespace ck { +namespace tensor_operation { +namespace device { +template +struct FMHADecoderSeqlen1DeviceOp : public BaseOperator { + using DeviceOp = FMHADecoderSeqlen1DeviceOp; + struct Argument : public BaseArgument { + const scalar_t* __restrict__ XQ; + const scalar_t* __restrict__ cache_K; + const scalar_t* __restrict__ cache_V; + scalar_t* __restrict__ O; + const int32_t* __restrict__ seq_kv_lens; + const ptrdiff_t XQ_stride_b; + const ptrdiff_t XQ_stride_m; + const ptrdiff_t XQ_stride_g; + const ptrdiff_t XQ_stride_h; + const ptrdiff_t K_stride_b; + const ptrdiff_t K_stride_m; + const ptrdiff_t K_stride_g; + const ptrdiff_t K_stride_h; + const int32_t Q_size_m; + const int32_t Q_size_g; + const int32_t Q_size_h; + const int32_t Q_size_k; + const int32_t K_size_m; + const bool multiquery; + const float qk_scale; + + const dim3 grid_dim; + const dim3 block_dim; + const size_t lds_bytes; + + Argument( + const scalar_t* __restrict__ XQ, + const scalar_t* __restrict__ cache_K, + const scalar_t* __restrict__ cache_V, + scalar_t* __restrict__ O, + const int32_t* __restrict__ seq_kv_lens, + const ptrdiff_t XQ_stride_b, + const ptrdiff_t XQ_stride_m, + const ptrdiff_t XQ_stride_g, + const ptrdiff_t XQ_stride_h, + const ptrdiff_t K_stride_b, + const ptrdiff_t K_stride_m, + const ptrdiff_t K_stride_g, + const ptrdiff_t K_stride_h, + const int32_t Q_size_m, + const int32_t Q_size_g, + const int32_t Q_size_h, + const int32_t Q_size_k, + const int32_t K_size_m, + const bool multiquery, + const float qk_scale, + const dim3 grid_dim, + const dim3 block_dim, + const size_t lds_bytes) + : XQ(XQ), + cache_K(cache_K), + cache_V(cache_V), + O(O), + seq_kv_lens(seq_kv_lens), + XQ_stride_b(XQ_stride_b), + XQ_stride_m(XQ_stride_m), + XQ_stride_g(XQ_stride_g), + XQ_stride_h(XQ_stride_h), + K_stride_b(K_stride_b), + K_stride_m(K_stride_m), + K_stride_g(K_stride_g), + K_stride_h(K_stride_h), + Q_size_m(Q_size_m), + Q_size_g(Q_size_g), + Q_size_h(Q_size_h), + Q_size_k(Q_size_k), + K_size_m(K_size_m), + multiquery(multiquery), + qk_scale(qk_scale), + grid_dim(grid_dim), + block_dim(block_dim), + lds_bytes(lds_bytes) {} + }; + + struct Invoker : public BaseInvoker { + using Argument = DeviceOp::Argument; + float Run( + const Argument& arg, + const StreamConfig& stream_config = StreamConfig{}) { + auto threads_per_wavefront = arg.block_dim.x; + + auto Q_size_k_alignment_necessary = 0; + + for (auto vec_size : {4, 2, 1}) { + if (arg.Q_size_k <= vec_size * threads_per_wavefront) { + Q_size_k_alignment_necessary = vec_size; + } + } + + if (!Q_size_k_alignment_necessary) { + throw std::runtime_error("Unsupported Q_size_k"); + } + + if (arg.Q_size_k % Q_size_k_alignment_necessary) { + throw std::runtime_error("Unsupported alignment for Q_size_k"); + } + + return launch_and_time_kernel( + stream_config, + Q_size_k_alignment_necessary == 4 + ? efficient_attention_forward_decoder_ck_kernel + : Q_size_k_alignment_necessary == 2 + ? efficient_attention_forward_decoder_ck_kernel + : Q_size_k_alignment_necessary == 1 + ? efficient_attention_forward_decoder_ck_kernel + : nullptr, + arg.grid_dim, + arg.block_dim, + arg.lds_bytes, + arg.XQ, + arg.cache_K, + arg.cache_V, + arg.O, + arg.seq_kv_lens, + arg.XQ_stride_b, + arg.XQ_stride_m, + arg.XQ_stride_g, + arg.XQ_stride_h, + arg.K_stride_b, + arg.K_stride_m, + arg.K_stride_g, + arg.K_stride_h, + arg.Q_size_m, + arg.Q_size_g, + arg.Q_size_h, + arg.Q_size_k, + arg.K_size_m, + arg.multiquery, + arg.qk_scale); + } + }; +}; +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/xformers/csrc/attention/hip_fmha/ck_attention_forward_decoder_splitk.h b/xformers/csrc/attention/hip_fmha/ck_attention_forward_decoder_splitk.h new file mode 100644 index 0000000000..9eed4f001b --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/ck_attention_forward_decoder_splitk.h @@ -0,0 +1,693 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace { + +template +__device__ typename ck::vector_type::type scalar_scale_acc( + typename ck::vector_type::type acc, + typename ck::vector_type::type a, + float b) { + union { + decltype(acc) vec; + float arr[vec_size]; + } acc_u{acc}; + union { + decltype(a) vec; + data_t arr[vec_size]; + } a_u{a}; + +#pragma unroll + for (int32_t i = 0; i < vec_size; ++i) { + acc_u.arr[i] += ck::type_convert(a_u.arr[i]) * b; + } + + return acc_u.vec; +} + +template +float __device__ __forceinline__ wavefrontReduce(float val, F f) { +#pragma unroll + for (int32_t mask = n_threads_per_wavefront >> 1; mask > 0; mask >>= 1) { + val = f(__shfl_xor(val, mask, n_threads_per_wavefront), val); + } + return val; +} + +template +__forceinline__ __device__ void load_v( + const TData* __restrict__ data_ptr, + int32_t vector_offset, + TDataVec* __restrict__ load_to) { + *load_to = *(reinterpret_cast(data_ptr) + vector_offset); +} + +template +__forceinline__ __device__ void store_v( + TData* __restrict__ data_ptr, + int32_t vector_offset, + TDataVec value) { + *(reinterpret_cast(data_ptr) + vector_offset) = value; +} + +template +__global__ void efficient_attention_forward_decoder_splitk_reduce_ck_kernel( + const scalar_t* __restrict__ O_splits, + const compute_t* __restrict__ split_max, + const compute_t* __restrict__ split_sumexp, + scalar_t* __restrict__ O, + const int32_t Q_size_m, + const int32_t Q_size_g, + const int32_t Q_size_h, + const int32_t Q_size_k, + const ptrdiff_t O_stride_split, + const ptrdiff_t O_stride_b, + const ptrdiff_t O_stride_m, + const ptrdiff_t O_stride_g, + const ptrdiff_t O_stride_h, + const int32_t split_k) { + // Each block handles a single batch and head and query and group + const int32_t b = blockIdx.x / (Q_size_m * Q_size_g * Q_size_h); + const int32_t m = (blockIdx.x / (Q_size_g * Q_size_h)) % Q_size_m; + const int32_t g = (blockIdx.x / Q_size_h) % Q_size_g; + const int32_t h = blockIdx.x % Q_size_h; + + using data_t = scalar_t; + using data_vec_t = typename ck::vector_type::type; + using compute_vec_t = typename ck::vector_type::type; + + union { + data_vec_t vec; + data_t arr[vec_size]; + } O_split_data; + union { + compute_vec_t vec; + compute_t arr[vec_size]; + } O_split_compute; + union { + data_vec_t vec; + data_t arr[vec_size]; + } global_O_data; + union { + compute_vec_t vec; + compute_t arr[vec_size]; + } global_O_compute; + + global_O_compute.vec = 0; + + const int32_t lane_idx = threadIdx.x; + const bool lane_active_for_io = lane_idx * vec_size < Q_size_k; + + if (!lane_active_for_io) { + return; + } + + compute_t global_sumexp = 0; + compute_t global_max = ck::NumericLimits::Lowest(); + + for (int32_t split_idx = 0; split_idx < split_k; ++split_idx) { + load_v( + O_splits + b * O_stride_b + m * O_stride_m + g * O_stride_g + + h * O_stride_h + split_idx * O_stride_split, + lane_idx, + &O_split_data.vec); +#pragma unroll + for (int32_t i = 0; i < vec_size; ++i) { + O_split_compute.arr[i] = ck::type_convert(O_split_data.arr[i]); + } + compute_t local_max = *(split_max + blockIdx.x * split_k + split_idx); + compute_t local_sumexp = *(split_sumexp + blockIdx.x * split_k + split_idx); + + compute_t log_alpha = -std::abs(local_max - global_max); + compute_t alpha = + isnan(log_alpha) ? compute_t{1.} : ck::math::exp(log_alpha); + + bool pick_new = local_max < global_max; + compute_t pick_current_coef = pick_new ? 1. : alpha; + compute_t pick_new_coef = pick_new ? alpha : 1.; + + global_sumexp = + pick_current_coef * global_sumexp + pick_new_coef * local_sumexp; + global_O_compute.vec = pick_current_coef * global_O_compute.vec + + pick_new_coef * O_split_compute.vec; + global_max = ck::math::max(local_max, global_max); + } + global_O_compute.vec /= global_sumexp; +#pragma unroll + for (int32_t i = 0; i < vec_size; ++i) { + global_O_data.arr[i] = ck::type_convert(global_O_compute.arr[i]); + } + store_v( + O + b * O_stride_b + m * O_stride_m + g * O_stride_g + h * O_stride_h, + lane_idx, + global_O_data.vec); +} + +template < + typename scalar_t, + int32_t vec_size = 4, + int32_t n_loop_unroll = 16, + int32_t n_loop_unroll_tail = 2, + int32_t KV_M_MAX = 8192, + typename compute_t = float> +__global__ void efficient_attention_forward_decoder_splitk_ck_kernel( + const scalar_t* __restrict__ XQ, + const scalar_t* __restrict__ cache_K, + const scalar_t* __restrict__ cache_V, + scalar_t* __restrict__ O_splits, + compute_t* __restrict__ split_max, + compute_t* __restrict__ split_sumexp, + const int32_t* __restrict__ seq_kv_lens, + const ptrdiff_t XQ_stride_b, + const ptrdiff_t XQ_stride_m, + const ptrdiff_t XQ_stride_g, + const ptrdiff_t XQ_stride_h, + const ptrdiff_t K_stride_b, + const ptrdiff_t K_stride_m, + const ptrdiff_t K_stride_g, + const ptrdiff_t K_stride_h, + const ptrdiff_t O_stride_split, + const int32_t Q_size_m, + const int32_t Q_size_g, + const int32_t Q_size_h, + const int32_t Q_size_k, + const int32_t K_size_m, + const bool multiquery, + const float qk_scale, + const int32_t split_k) { + static_assert( + n_loop_unroll_tail < n_loop_unroll || n_loop_unroll_tail == 1, + "tail unroll must be smaller than main loop untoll; pragma unroll 0 is illegal " + "(and tail is no-op)"); + + // Each block handles a single batch and head and query and group + const int32_t b = blockIdx.x / (Q_size_m * Q_size_g * Q_size_h); + const int32_t m = (blockIdx.x / (Q_size_g * Q_size_h)) % Q_size_m; + const int32_t g = (blockIdx.x / Q_size_h) % Q_size_g; + const int32_t h = blockIdx.x % Q_size_h; + const int32_t split_idx = blockIdx.y; + + // Note: this is decoding case where we attend to current and all previous + // tokens. + const int32_t t_max = seq_kv_lens ? seq_kv_lens[b] : K_size_m; + + const int32_t lane_idx = threadIdx.x; + const int32_t wavefront_idx = threadIdx.y; + // TODO: `threads_per_wavefront` and `wavefronts_per_block` may be compile + // time constants; investigate when optimizing + const int32_t threads_per_wavefront = blockDim.x; + const int32_t wavefronts_per_block = blockDim.y; + const int32_t threads_per_block = + threads_per_wavefront * wavefronts_per_block; + const int32_t thread_linear_idx = + lane_idx + wavefront_idx * threads_per_wavefront; + // const auto* q_ = &(XQ_acc[b][m][g][h][0]); + const auto XQO_base_offset = + b * XQ_stride_b + m * XQ_stride_m + g * XQ_stride_g + h * XQ_stride_h; + const auto* __restrict__ q_ = XQ + XQO_base_offset; + + const auto cache_KV_base_offset = b * K_stride_b + 0 * K_stride_m + + g * K_stride_g + (multiquery ? 0 : h * K_stride_h); + const auto* __restrict__ cache_K_base = cache_K + cache_KV_base_offset; + const auto* __restrict__ cache_V_base = cache_V + cache_KV_base_offset; + + using data_t = scalar_t; + using data_vec_t = typename ck::vector_type::type; + using compute_vec_t = typename ck::vector_type::type; + + const bool lane_active_for_io = lane_idx * vec_size < Q_size_k; + + extern __shared__ __align__(16) compute_t smem[]; + + data_vec_t q_thread = 0; + // Load Q into registers in all wavefronts. + // Each thread handles `vec_size` D dimensions + if (lane_active_for_io) { + load_v(q_, lane_idx, &q_thread); + } + + compute_t max_qk_acc = ck::NumericLimits::Lowest(); + + // Compute S[0:t_max] = + // ``` + // for t in range(t_max): + // S[t] = dot(Q, K[t]) + // ``` + // Split the 0:t_max range across wavefronts in a block, + // unroll loads to expose more parallelism. + // Reduce the dot product with cross-lane operation; + // Q and K[t] are in the registers of threads in a single wavefront. + + data_vec_t k_loads[n_loop_unroll] = {}; + + const auto dtt = wavefronts_per_block * n_loop_unroll; + // only last split gets the tail. + // the first (split_k - 1) splits have a number of iterations divisible by + // `dtt` + const auto n_unrolled_loops = t_max / dtt / split_k; // +1? + const int32_t tt_low = + wavefront_idx * n_loop_unroll + n_unrolled_loops * dtt * split_idx; + const int32_t tt_high = + wavefront_idx * n_loop_unroll + n_unrolled_loops * dtt * (split_idx + 1); + const int32_t dtt_tail = wavefronts_per_block * n_loop_unroll_tail; + const int32_t tt_tail_low = wavefront_idx * n_loop_unroll_tail + + n_unrolled_loops * dtt * (split_idx + 1); + const int32_t tt_tail_high = (split_idx == split_k - 1) ? t_max : tt_tail_low; + + for (auto tt = tt_low; tt < tt_high; tt += dtt) { + if (lane_active_for_io) { +#pragma unroll n_loop_unroll + for (auto ttt = 0; ttt < n_loop_unroll; ++ttt) { + const int32_t t = tt + ttt; + // load the K[b][t][g][h|0][:] row into registers + load_v( + cache_K_base + t * K_stride_m, lane_idx, &k_loads[ttt]); + } + } +#pragma unroll n_loop_unroll + for (auto ttt = 0; ttt < n_loop_unroll; ++ttt) { + compute_t qk_acc = 0; + ck::inner_product( + q_thread, k_loads[ttt], qk_acc); + qk_acc *= qk_scale; + + qk_acc = wavefrontReduce(qk_acc, [](auto a, auto b) { return a + b; }); + max_qk_acc = ck::math::max(qk_acc, max_qk_acc); + if (lane_idx == 0) { + smem[tt + ttt - n_unrolled_loops * dtt * split_idx] = qk_acc; + } + } + } + + for (auto tt = tt_tail_low; tt < tt_tail_high; tt += dtt_tail) { + if (lane_active_for_io) { +#pragma unroll n_loop_unroll_tail + for (auto ttt = 0; ttt < n_loop_unroll_tail; ++ttt) { + const int32_t t = tt + ttt; + if (t < t_max) { + // load the K[b][t][g][h|0][:] row into registers + load_v( + cache_K_base + t * K_stride_m, lane_idx, &k_loads[ttt]); + } + } + } +#pragma unroll n_loop_unroll_tail + for (auto ttt = 0; ttt < n_loop_unroll_tail; ++ttt) { + compute_t qk_acc = 0; + const int32_t t = tt + ttt; + if (t < t_max) { + ck::inner_product( + q_thread, k_loads[ttt], qk_acc); + qk_acc *= qk_scale; + + qk_acc = wavefrontReduce(qk_acc, [](auto a, auto b) { return a + b; }); + max_qk_acc = ck::math::max(qk_acc, max_qk_acc); + + // write accumulated sums to smem. + if (lane_idx == 0) { + smem[t - n_unrolled_loops * dtt * split_idx] = qk_acc; + } + } + } + } + + // Use shared reduction to compute max and compute softmax on shared memory. + // write max acc + if (lane_idx == 0) { + smem[KV_M_MAX + wavefront_idx] = max_qk_acc; + } + __syncthreads(); + if (lane_idx < wavefronts_per_block) { + max_qk_acc = ck::math::max(max_qk_acc, smem[KV_M_MAX + lane_idx]); + } + // shared across all threads in block + max_qk_acc = + wavefrontReduce(max_qk_acc, [](auto a, auto b) { return a > b ? a : b; }); + + if (wavefront_idx == 0 && lane_idx == 0) { + split_max[blockIdx.x * split_k + split_idx] = max_qk_acc; + } + + // each wavefront computes partial sum of exp. + { // softmax reduce begin + compute_t softmax_denominator = 0.0f; + const int32_t t_low = n_unrolled_loops * dtt * split_idx; + const int32_t t_high = (split_idx + 1 < split_k) + ? n_unrolled_loops * dtt * (split_idx + 1) + : t_max; + for (int32_t t = t_low + thread_linear_idx; t < t_high; + t += threads_per_block) { + const auto s = ck::math::exp(smem[t - t_low] - max_qk_acc); + softmax_denominator += s; + smem[t - t_low] = s; + } + softmax_denominator = wavefrontReduce( + softmax_denominator, [](auto a, auto b) { return a + b; }); + + if (lane_idx == 0) { + smem[KV_M_MAX + wavefront_idx] = softmax_denominator; + } + __syncthreads(); + + // now, compute sum of exp(x - max(x)) over all intermediate results. + softmax_denominator = 0.0; + if (lane_idx < wavefronts_per_block) { + softmax_denominator = smem[KV_M_MAX + lane_idx]; + } + softmax_denominator = wavefrontReduce( + softmax_denominator, [](auto a, auto b) { return a + b; }); + + if (wavefront_idx == 0 && lane_idx == 0) { + split_sumexp[blockIdx.x * split_k + split_idx] = softmax_denominator; + } + } // softmax reduce end + + // Split T across wavefronts in a block + // each wavefront compute sum(t_subset) P[t] * V[t_subset, d] + // outputs are of size float[D] + + compute_t ps[n_loop_unroll] = {}; + compute_vec_t o_acc = 0; + if (lane_active_for_io) { + for (auto tt = tt_low; tt < tt_high; tt += dtt) { +#pragma unroll n_loop_unroll + for (auto ttt = 0; ttt < n_loop_unroll; ++ttt) { + const int32_t t = tt + ttt; + // load the V[b][t][g][h|0][:] row into registers, reusing K register + // storage + load_v( + cache_V_base + t * K_stride_m, lane_idx, &k_loads[ttt]); + ps[ttt] = smem[t - n_unrolled_loops * dtt * split_idx]; + } + +#pragma unroll n_loop_unroll + for (auto ttt = 0; ttt < n_loop_unroll; ++ttt) { + o_acc = + scalar_scale_acc(o_acc, k_loads[ttt], ps[ttt]); + } + } + + for (auto tt = tt_tail_low; tt < tt_tail_high; tt += dtt_tail) { +#pragma unroll n_loop_unroll_tail + for (auto ttt = 0; ttt < n_loop_unroll_tail; ++ttt) { + const int32_t t = tt + ttt; + if (t < t_max) { + // load the V[b][t][g][h|0][:] row into registers, reusing K register + // storage + load_v( + cache_V_base + t * K_stride_m, lane_idx, &k_loads[ttt]); + ps[ttt] = smem[t - n_unrolled_loops * dtt * split_idx]; + o_acc = + scalar_scale_acc(o_acc, k_loads[ttt], ps[ttt]); + } + } + } + } + __syncthreads(); + + // NB: needs sizeof(smem) >= `vec_size` * (sizeof(float)==4) * threadsPerBlock + if (lane_active_for_io) { + store_v(&smem[0], thread_linear_idx, o_acc); + } + + __syncthreads(); + // sum up partial D rows from other wavefronts + if (wavefront_idx == 0 && lane_active_for_io) { + union { + compute_vec_t vec = 0; + compute_t arr[vec_size]; + } r; + for (int32_t w = 0; w < wavefronts_per_block; ++w) { + compute_vec_t partial_r; + load_v( + smem, w * threads_per_wavefront + lane_idx, &partial_r); + r.vec += partial_r; + } + // elementwise convert from compute_t result to data_t out to be written + union { + data_vec_t vec; + data_t arr[vec_size]; + } bf_r; +#pragma unroll + for (int32_t i = 0; i < vec_size; ++i) { + bf_r.arr[i] = ck::type_convert(r.arr[i]); + } + // write output row O[b][m][g][h][:] + data_t* __restrict__ o_ = + O_splits + XQO_base_offset + split_idx * O_stride_split; + store_v(o_, lane_idx, bf_r.vec); + } +} + +} // namespace + +namespace ck { +namespace tensor_operation { +namespace device { +template +struct FMHADecoderSplitKDeviceOp : public BaseOperator { + using DeviceOp = FMHADecoderSplitKDeviceOp; + struct Argument : public BaseArgument { + const scalar_t* __restrict__ XQ; + const scalar_t* __restrict__ cache_K; + const scalar_t* __restrict__ cache_V; + scalar_t* __restrict__ O; + scalar_t* __restrict__ split_O; + compute_t* __restrict__ split_max; + compute_t* __restrict__ split_sumexp; + const int32_t* __restrict__ seq_kv_lens; + const ptrdiff_t XQ_stride_b; + const ptrdiff_t XQ_stride_m; + const ptrdiff_t XQ_stride_g; + const ptrdiff_t XQ_stride_h; + const ptrdiff_t K_stride_b; + const ptrdiff_t K_stride_m; + const ptrdiff_t K_stride_g; + const ptrdiff_t K_stride_h; + const ptrdiff_t O_stride_split; + const int32_t Q_size_m; + const int32_t Q_size_g; + const int32_t Q_size_h; + const int32_t Q_size_k; + const int32_t K_size_m; + const bool multiquery; + const float qk_scale; + const int32_t split_k; + + const dim3 grid_dim; + const dim3 block_dim; + const size_t lds_bytes; + + Argument( + const scalar_t* __restrict__ XQ, + const scalar_t* __restrict__ cache_K, + const scalar_t* __restrict__ cache_V, + scalar_t* __restrict__ O, + scalar_t* __restrict__ split_O, + compute_t* __restrict__ split_max, + compute_t* __restrict__ split_sumexp, + const int32_t* __restrict__ seq_kv_lens, + const ptrdiff_t XQ_stride_b, + const ptrdiff_t XQ_stride_m, + const ptrdiff_t XQ_stride_g, + const ptrdiff_t XQ_stride_h, + const ptrdiff_t K_stride_b, + const ptrdiff_t K_stride_m, + const ptrdiff_t K_stride_g, + const ptrdiff_t K_stride_h, + const ptrdiff_t O_stride_split, + const int32_t Q_size_m, + const int32_t Q_size_g, + const int32_t Q_size_h, + const int32_t Q_size_k, + const int32_t K_size_m, + const bool multiquery, + const float qk_scale, + const int32_t split_k, + // launch params + const dim3 grid_dim, + const dim3 block_dim, + const size_t lds_bytes) + : XQ(XQ), + cache_K(cache_K), + cache_V(cache_V), + O(O), + split_O(split_O), + split_max(split_max), + split_sumexp(split_sumexp), + seq_kv_lens(seq_kv_lens), + XQ_stride_b(XQ_stride_b), + XQ_stride_m(XQ_stride_m), + XQ_stride_g(XQ_stride_g), + XQ_stride_h(XQ_stride_h), + K_stride_b(K_stride_b), + K_stride_m(K_stride_m), + K_stride_g(K_stride_g), + K_stride_h(K_stride_h), + O_stride_split(O_stride_split), + Q_size_m(Q_size_m), + Q_size_g(Q_size_g), + Q_size_h(Q_size_h), + Q_size_k(Q_size_k), + K_size_m(K_size_m), + multiquery(multiquery), + qk_scale(qk_scale), + split_k(split_k), + // launch params + grid_dim(grid_dim), + block_dim(block_dim), + lds_bytes(lds_bytes) {} + + std::string str() const { + std::ostringstream oss; + oss << "Argument { " << std::endl + << " XQ: " << XQ << std::endl + << " cache_K: " << cache_K << std::endl + << " cache_V: " << cache_V << std::endl + << " O: " << O << std::endl + << " split_O: " << split_O << std::endl + << " split_max: " << split_max << std::endl + << " split_sumexp: " << split_sumexp << std::endl + << " seq_kv_lens: " << seq_kv_lens << std::endl + << " XQ_stride_b: " << XQ_stride_b << std::endl + << " XQ_stride_m: " << XQ_stride_m << std::endl + << " XQ_stride_g: " << XQ_stride_g << std::endl + << " XQ_stride_h: " << XQ_stride_h << std::endl + << " K_stride_b: " << K_stride_b << std::endl + << " K_stride_m: " << K_stride_m << std::endl + << " K_stride_g: " << K_stride_g << std::endl + << " K_stride_h: " << K_stride_h << std::endl + << " O_stride_split: " << O_stride_split << std::endl + << " Q_size_m: " << Q_size_m << std::endl + << " Q_size_g: " << Q_size_g << std::endl + << " Q_size_h: " << Q_size_h << std::endl + << " Q_size_k: " << Q_size_k << std::endl + << " K_size_m: " << K_size_m << std::endl + << " multiquery: " << multiquery << std::endl + << " qk_scale: " << qk_scale << std::endl + << " split_k: " << split_k << std::endl + << std::endl + << " grid_dim: " << grid_dim.x << "." << grid_dim.y << "." + << grid_dim.z << std::endl + << " block_dim: " << block_dim.x << "." << block_dim.y << "." + << block_dim.z << std::endl + << " lds_bytes: " << lds_bytes << std::endl + << "}"; + return oss.str(); + } + }; + + struct Invoker : public BaseInvoker { + using Argument = DeviceOp::Argument; + float Run( + const Argument& arg, + const StreamConfig& stream_config = StreamConfig{}) { + auto threads_per_wavefront = arg.block_dim.x; + auto Q_size_k_alignment_necessary = 0; + + for (auto vec_size : {4, 2, 1}) { + if (arg.Q_size_k <= vec_size * threads_per_wavefront) { + Q_size_k_alignment_necessary = vec_size; + } + } + + if (!Q_size_k_alignment_necessary) { + throw std::runtime_error("Unsupported Q_size_k"); + } + + if (arg.Q_size_k % Q_size_k_alignment_necessary) { + throw std::runtime_error("Unsupported alignment for Q_size_k"); + } + + float split_attention_result = launch_and_time_kernel( + stream_config, + Q_size_k_alignment_necessary == 4 + ? efficient_attention_forward_decoder_splitk_ck_kernel< + scalar_t, + 4> + : Q_size_k_alignment_necessary == 2 + ? efficient_attention_forward_decoder_splitk_ck_kernel< + scalar_t, + 2> + : Q_size_k_alignment_necessary == 1 + ? efficient_attention_forward_decoder_splitk_ck_kernel< + scalar_t, + 1> + : nullptr, + arg.grid_dim, + arg.block_dim, + arg.lds_bytes, + arg.XQ, + arg.cache_K, + arg.cache_V, + arg.split_O, + arg.split_max, + arg.split_sumexp, + arg.seq_kv_lens, + arg.XQ_stride_b, + arg.XQ_stride_m, + arg.XQ_stride_g, + arg.XQ_stride_h, + arg.K_stride_b, + arg.K_stride_m, + arg.K_stride_g, + arg.K_stride_h, + arg.O_stride_split, + arg.Q_size_m, + arg.Q_size_g, + arg.Q_size_h, + arg.Q_size_k, + arg.K_size_m, + arg.multiquery, + arg.qk_scale, + arg.split_k); + + const dim3 reduce_gridsize = {arg.grid_dim.x}; + const dim3 reduce_blocksize = {arg.block_dim.x}; + constexpr int32_t reduce_lds_bytes = 0; + float reduce_result = launch_and_time_kernel( + stream_config, + Q_size_k_alignment_necessary == 4 + ? efficient_attention_forward_decoder_splitk_reduce_ck_kernel< + scalar_t, + 4> + : Q_size_k_alignment_necessary == 2 + ? efficient_attention_forward_decoder_splitk_reduce_ck_kernel< + scalar_t, + 2> + : Q_size_k_alignment_necessary == 1 + ? efficient_attention_forward_decoder_splitk_reduce_ck_kernel< + scalar_t, + 1> + : nullptr, + reduce_gridsize, + reduce_blocksize, + reduce_lds_bytes, + arg.split_O, + arg.split_max, + arg.split_sumexp, + arg.O, + arg.Q_size_m, + arg.Q_size_g, + arg.Q_size_h, + arg.Q_size_k, + arg.O_stride_split, + arg.XQ_stride_b, + arg.XQ_stride_m, + arg.XQ_stride_g, + arg.XQ_stride_h, + arg.split_k); + return split_attention_result + reduce_result; + } + }; +}; +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/xformers/csrc/attention/hip_fmha/ck_fmha_test.cpp b/xformers/csrc/attention/hip_fmha/ck_fmha_test.cpp new file mode 100644 index 0000000000..08825f1a88 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/ck_fmha_test.cpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include + +namespace { + +// For testing xFormers building and binding +bool is_ck_fmha_available(double val) { + std::cout << "ck fmha is really here, val=" << val << std::endl; + return (true); +}; + +} // namespace + +TORCH_LIBRARY_FRAGMENT(xformers, m) { + m.def(TORCH_SELECTIVE_SCHEMA( + "xformers::is_ck_fmha_available(float val) -> bool")); + m.impl( + TORCH_SELECTIVE_NAME("xformers::is_ck_fmha_available"), + TORCH_FN(is_ck_fmha_available)); +} diff --git a/xformers/csrc/attention/hip_fmha/ck_fmha_util.h b/xformers/csrc/attention/hip_fmha/ck_fmha_util.h new file mode 100644 index 0000000000..a6ea50d780 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/ck_fmha_util.h @@ -0,0 +1,158 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +#include +#include +#include +#include + +#include + +#include +#include +#include + +#define XFORMERS_CHECK(COND, ERR) \ + if (!(COND)) { \ + std::ostringstream ostr; \ + ostr << "'" #COND "' failed: " << ERR; \ + throw std::runtime_error(ostr.str()); \ + } + +#define DISPATCH_TYPES(InDataType, func) \ + { \ + if (InDataType == at::ScalarType::Half) { \ + using scalar_t = ck::half_t; \ + func(); \ + } else if (InDataType == at::ScalarType::BFloat16) { \ + using scalar_t = ck::bhalf_t; \ + func(); \ + } else { \ + XFORMERS_CHECK( \ + false, "Only half & bf16 input type supported at the moment"); \ + } \ + } + +template +struct CkToAtenDtype; + +template <> +struct CkToAtenDtype { + using scalar_t = ck::half_t; + + static constexpr __host__ at::ScalarType atScalarType() { + return at::ScalarType::Half; + } +}; + +template <> +struct CkToAtenDtype { + using scalar_t = ck::bhalf_t; + + static constexpr __host__ at::ScalarType atScalarType() { + return at::ScalarType::BFloat16; + } +}; + +template <> +struct CkToAtenDtype { + using scalar_t = float; + + static constexpr __host__ at::ScalarType atScalarType() { + return at::ScalarType::Float; + } +}; + +#define CHECK_NOSPARSE_CONTIGUOUS_CUDA(TENSOR) \ + XFORMERS_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \ + XFORMERS_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \ + XFORMERS_CHECK(TENSOR.is_contiguous(), #TENSOR " must be contiguous"); + +#define CHECK_NOSPARSE_CONTIGUOUS_CPU(TENSOR) \ + XFORMERS_CHECK(TENSOR.is_cpu(), #TENSOR " must be a CPU tensor"); \ + XFORMERS_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \ + XFORMERS_CHECK(TENSOR.is_contiguous(), #TENSOR " must be contiguous"); + +#define CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(TENSOR) \ + XFORMERS_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \ + XFORMERS_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \ + XFORMERS_CHECK( \ + TENSOR.stride(-1) == 1, #TENSOR ": last dimension must be contiguous"); + +#define HIP_CALL_CHECK(flag) \ + do { \ + hipError_t _tmpVal; \ + if ((_tmpVal = flag) != hipSuccess) { \ + std::ostringstream ostr; \ + ostr << "HIP Function Failed (" << __FILE__ << "," << __LINE__ << ") " \ + << hipGetErrorString(_tmpVal); \ + throw std::runtime_error(ostr.str()); \ + } \ + } while (0) + +static inline size_t get_size_in_bytes(size_t n, at::ScalarType dtype) { + if (dtype == at::ScalarType::Float) { + return n * 4; + } else if (dtype == at::ScalarType::Half) { + return n * 2; + } else if (dtype == at::ScalarType::BFloat16) { + return n * 2; + } else if (dtype == at::ScalarType::Short) { + return n * 2; + } else if (dtype == at::ScalarType::Int) { + return n * 4; + } else if (dtype == at::ScalarType::Byte) { + return n; + } + return 0; +} + +/** + * kernels expect 4D bias/bias.grad with shape + * (batch_sz, n_heads, n_queries, n_keys). common bias shapes users may pass + * are: + * - (n_queries, n_keys) + * - (batch_sz * n_heads, n_queries, n_keys) + * - (batch_sz, n_heads, n_queries, n_keys) + * + * expand the bias as needed - be careful to only create a view with different + * shape/strides, no copies allowed. + */ +inline at::Tensor get_bias_4d_view( + const at::Tensor& bias, + int batch_sz, + int n_heads, + int n_queries, + int n_keys) { + TORCH_CHECK( + bias.size(-2) == n_queries, + "bias.size(-2) != n_queries: ", + bias.size(-2), + " != ", + n_queries); + TORCH_CHECK( + bias.size(-1) == n_keys, + "bias.size(-1) != n_keys: ", + bias.size(-1), + " != ", + n_keys); + switch (bias.dim()) { + case 2: // (n_queries, n_keys) - broadcast across all batches and heads + return bias.unsqueeze(0).unsqueeze(0).expand( + {batch_sz, n_heads, n_queries, n_keys}); + case 3: // (batch_sz * n_heads, n_queries, n_keys) - just reshape + TORCH_CHECK(bias.size(0) == batch_sz * n_heads); + return bias.view({batch_sz, n_heads, n_queries, n_keys}); + case 4: // (batch_sz, n_heads, n_queries, n_keys) - do nothing + TORCH_CHECK(bias.size(0) == batch_sz); + TORCH_CHECK(bias.size(1) == n_heads) + return bias; + default: + TORCH_CHECK(false, "bias can only have ndims in {2, 3, 4}"); + } +} diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_bool_switch.h b/xformers/csrc/attention/hip_fmha/ck_tiled_bool_switch.h new file mode 100644 index 0000000000..c07559a3ca --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_bool_switch.h @@ -0,0 +1,9 @@ +/* + * Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +#include diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_forward.h b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_forward.h new file mode 100644 index 0000000000..3dc0c47177 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_forward.h @@ -0,0 +1,232 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "ck_tiled_fmha_definitions.h" +#include "ck_tiled_fmha_forward_kernel.h" +#include "ck_tiled_fmha_fwd_epilogue.h" +#include "ck_tiled_fmha_fwd_tile_partitioner.h" +#include "ck_tiled_fmha_params.h" + +#include "ck_tiled_bool_switch.h" +#include "ck_tiled_headdim_switch.h" + +template < + typename scalar_t, + bool has_causal_mask, + bool has_attn_bias, + ck::index_t MaxK> +struct batched_forward_causalmask_attnbias_dispatched { + using FmhaEpilogue = FmhaFwdEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType>>; + + template + using FmhaPipelineProblemTemp = + ck::tile_program::block::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + FmhaFwdShape, + false, // kIsGroupMode + FmhaMask, + FmhaTraits>; + + static void Run(BatchedForwardParams& param, hipStream_t stream) { + const bool has_local_attention = (param.window_size > 0) ? true : false; + + BOOL_SWITCH(has_local_attention, USE_LOCAL_ATTENTION, [&] { + constexpr bool has_masking = has_causal_mask || USE_LOCAL_ATTENTION; + + using FmhaMask = ck::tile_program::block:: + GenericAttentionMask; + + using FmhaShape = FmhaFwdShape; + using FmhaTilePartitioner = FmhaFwdTilePartitioner; + constexpr ck::index_t occupancy = + (MaxK == 64) ? 3 : ((MaxK == 256) ? 1 : 2); + + bool pad_seqlen_q = !(param.M % FmhaShape::kM0 == 0); + bool pad_seqlen_k = !(param.N % FmhaShape::kN0 == 0); + bool pad_headdim_q = !(param.K % FmhaShape::kK0BlockLength == 0); + bool pad_headdim_v = !(param.Kv % FmhaShape::kN1 == 0); + + if constexpr (MaxK == 256) { + BOOL_SWITCH_4( + pad_seqlen_q, + kPadSeqLenQ, + pad_seqlen_k, + kPadSeqLenK, + pad_headdim_q, + kPadHeadDimQ, + pad_headdim_v, + kPadHeadDimV, + [&] { + using FmhaTraits = ck::tile_program::TileFmhaTraits< + kPadSeqLenQ, + kPadSeqLenK, + kPadHeadDimQ, + kPadHeadDimV, + has_attn_bias, + true, // kStoreLSE + occupancy>; + + using FmhaPipelineProblem = + FmhaPipelineProblemTemp; + + using FmhaPipeline = + ck::tile_program::block::BlockFmhaPipelineQRKSVS< + FmhaPipelineProblem>; + using FmhaKernel = FmhaFwdKernel< + FmhaTilePartitioner, + FmhaPipeline, + FmhaEpilogue>; + + RunWithKernel(param, stream); + }); + } else { + BOOL_SWITCH_4( + pad_seqlen_q, + kPadSeqLenQ, + pad_seqlen_k, + kPadSeqLenK, + pad_headdim_q, + kPadHeadDimQ, + pad_headdim_v, + kPadHeadDimV, + [&] { + using FmhaTraits = ck::tile_program::TileFmhaTraits< + kPadSeqLenQ, + kPadSeqLenK, + kPadHeadDimQ, + kPadHeadDimV, + has_attn_bias, + true, // kStoreLSE + occupancy>; + + using FmhaPipelineProblem = + FmhaPipelineProblemTemp; + + constexpr bool no_any_padding = + !(kPadSeqLenQ || kPadSeqLenK || kPadHeadDimQ || kPadHeadDimV); + + if constexpr (no_any_padding) { + using FmhaPipeline = + ck::tile_program::block::BlockFmhaPipelineQRKSVSAsync< + FmhaPipelineProblem>; + using FmhaKernel = FmhaFwdKernel< + FmhaTilePartitioner, + FmhaPipeline, + FmhaEpilogue>; + + RunWithKernel(param, stream); + } else { + using FmhaPipeline = + ck::tile_program::block::BlockFmhaPipelineQRKSVS< + FmhaPipelineProblem>; + using FmhaKernel = FmhaFwdKernel< + FmhaTilePartitioner, + FmhaPipeline, + FmhaEpilogue>; + + RunWithKernel(param, stream); + }; + }); + }; + }); + }; + + template + static void RunWithKernel(BatchedForwardParams& param, hipStream_t stream) { + const auto kargs = [&] { + return FmhaKernel::MakeKargs( + param.q_ptr, + param.k_ptr, + param.v_ptr, + param.attn_bias_ptr, + param.logsumexp_ptr, + param.out_ptr, + param.M, // seqlen_q + param.N, // seqlen_k + param.K, // hdim_q + param.Kv, // hdim_v + param.Hq / param.Hkv, // nhead_ratio_qk + param.scale, + param.q_strides[1], // q, k, v, bias, out tensor seq-dim stride + param.k_strides[1], + param.v_strides[1], + param.attn_bias_strides[2], + param.out_strides[1], + param.q_strides[2], // q, k, v, bias, lse, out tensor head-dim stride + param.k_strides[2], + param.v_strides[2], + param.attn_bias_strides[1], + param.M, // nhead_stride_lse + param.out_strides[2], + param.q_strides[0], // q, k, v, bias, lse, out tensor batch-dim stride + param.k_strides[0], + param.v_strides[0], + param.attn_bias_strides[0], + param.Hq * param.M, // batch_stride_lse + param.out_strides[0], + static_cast(param.custom_mask_type), + param.window_size); + }(); + + dim3 kGridSize = FmhaKernel::GridSize(param.B, param.Hq, param.M, param.Kv); + constexpr dim3 kBlockSize = FmhaKernel::BlockSize(); + constexpr ck::index_t kBlockPerCu = FmhaKernel::kBlockPerCu; + + (void)launch_kernel( + StreamConfig{stream, false}, + FmhaKernel{}, + kGridSize, + kBlockSize, + 0, + kargs); + }; +}; + +template < + typename scalar_t, + bool has_causal_mask, + bool has_attn_bias, + ck::index_t MaxK> +void run_batched_forward_causalmask_attnbias_dispatched( + BatchedForwardParams& param, + hipStream_t stream) { + batched_forward_causalmask_attnbias_dispatched< + scalar_t, + has_causal_mask, + has_attn_bias, + MaxK>::Run(param, stream); +}; diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_forward_bp16.cpp b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_forward_bp16.cpp new file mode 100644 index 0000000000..8d90c7cd51 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_forward_bp16.cpp @@ -0,0 +1,71 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include +#include +#include + +#include "ck_tiled_bool_switch.h" +#include "ck_tiled_fmha_batched_forward.h" + +// clang-format off +extern template void run_batched_forward_causalmask_attnbias_dispatched( + BatchedForwardParams& param, hipStream_t stream); +extern template void run_batched_forward_causalmask_attnbias_dispatched( + BatchedForwardParams& param, hipStream_t stream); +extern template void run_batched_forward_causalmask_attnbias_dispatched( + BatchedForwardParams& param, hipStream_t stream); +extern template void run_batched_forward_causalmask_attnbias_dispatched( + BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_causalmask_attnbias_dispatched( + BatchedForwardParams& param, hipStream_t stream); +extern template void run_batched_forward_causalmask_attnbias_dispatched( + BatchedForwardParams& param, hipStream_t stream); +extern template void run_batched_forward_causalmask_attnbias_dispatched( + BatchedForwardParams& param, hipStream_t stream); +extern template void run_batched_forward_causalmask_attnbias_dispatched( + BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_causalmask_attnbias_dispatched( + BatchedForwardParams& param, hipStream_t stream); +extern template void run_batched_forward_causalmask_attnbias_dispatched( + BatchedForwardParams& param, hipStream_t stream); +extern template void run_batched_forward_causalmask_attnbias_dispatched( + BatchedForwardParams& param, hipStream_t stream); +extern template void run_batched_forward_causalmask_attnbias_dispatched( + BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_causalmask_attnbias_dispatched( + BatchedForwardParams& param, hipStream_t stream); +extern template void run_batched_forward_causalmask_attnbias_dispatched( + BatchedForwardParams& param, hipStream_t stream); +extern template void run_batched_forward_causalmask_attnbias_dispatched( + BatchedForwardParams& param, hipStream_t stream); +extern template void run_batched_forward_causalmask_attnbias_dispatched( + BatchedForwardParams& param, hipStream_t stream); +// clang-format on + +void batched_forward_bp16(BatchedForwardParams& param, hipStream_t stream) { + BOOL_SWITCH(param.has_attn_bias, HAS_ATTN_BIAS, [&] { + FMHA_FWD_HEADDIM_SWITCH(param.K, param.Kv, MaxK, [&] { + if (param.custom_mask_type == 0) + run_batched_forward_causalmask_attnbias_dispatched< + ck::bhalf_t, + false, + HAS_ATTN_BIAS, + MaxK>(param, stream); + else if (param.custom_mask_type == 1 || param.custom_mask_type == 2) + run_batched_forward_causalmask_attnbias_dispatched< + ck::bhalf_t, + true, + HAS_ATTN_BIAS, + MaxK>(param, stream); + else + throw std::runtime_error("Invalid custom_mask_type value"); + }); + }); +}; diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_forward_fp16.cpp b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_forward_fp16.cpp new file mode 100644 index 0000000000..3e65849715 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_forward_fp16.cpp @@ -0,0 +1,71 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include +#include +#include + +#include "ck_tiled_bool_switch.h" +#include "ck_tiled_fmha_batched_forward.h" + +// clang-format off +extern template void run_batched_forward_causalmask_attnbias_dispatched( + BatchedForwardParams& param, hipStream_t stream); +extern template void run_batched_forward_causalmask_attnbias_dispatched( + BatchedForwardParams& param, hipStream_t stream); +extern template void run_batched_forward_causalmask_attnbias_dispatched( + BatchedForwardParams& param, hipStream_t stream); +extern template void run_batched_forward_causalmask_attnbias_dispatched( + BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_causalmask_attnbias_dispatched( + BatchedForwardParams& param, hipStream_t stream); +extern template void run_batched_forward_causalmask_attnbias_dispatched( + BatchedForwardParams& param, hipStream_t stream); +extern template void run_batched_forward_causalmask_attnbias_dispatched( + BatchedForwardParams& param, hipStream_t stream); +extern template void run_batched_forward_causalmask_attnbias_dispatched( + BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_causalmask_attnbias_dispatched( + BatchedForwardParams& param, hipStream_t stream); +extern template void run_batched_forward_causalmask_attnbias_dispatched( + BatchedForwardParams& param, hipStream_t stream); +extern template void run_batched_forward_causalmask_attnbias_dispatched( + BatchedForwardParams& param, hipStream_t stream); +extern template void run_batched_forward_causalmask_attnbias_dispatched( + BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_forward_causalmask_attnbias_dispatched( + BatchedForwardParams& param, hipStream_t stream); +extern template void run_batched_forward_causalmask_attnbias_dispatched( + BatchedForwardParams& param, hipStream_t stream); +extern template void run_batched_forward_causalmask_attnbias_dispatched( + BatchedForwardParams& param, hipStream_t stream); +extern template void run_batched_forward_causalmask_attnbias_dispatched( + BatchedForwardParams& param, hipStream_t stream); +// clang-format on + +void batched_forward_fp16(BatchedForwardParams& param, hipStream_t stream) { + BOOL_SWITCH(param.has_attn_bias, HAS_ATTN_BIAS, [&] { + FMHA_FWD_HEADDIM_SWITCH(param.K, param.Kv, MaxK, [&] { + if (param.custom_mask_type == 0) + run_batched_forward_causalmask_attnbias_dispatched< + ck::half_t, + false, + HAS_ATTN_BIAS, + MaxK>(param, stream); + else if (param.custom_mask_type == 1 || param.custom_mask_type == 2) + run_batched_forward_causalmask_attnbias_dispatched< + ck::half_t, + true, + HAS_ATTN_BIAS, + MaxK>(param, stream); + else + throw std::runtime_error("Invalid custom_mask_type value"); + }); + }); +}; diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_infer.h b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_infer.h new file mode 100644 index 0000000000..8696e04378 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_infer.h @@ -0,0 +1,232 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "ck_tiled_fmha_definitions.h" +#include "ck_tiled_fmha_forward_kernel.h" +#include "ck_tiled_fmha_fwd_epilogue.h" +#include "ck_tiled_fmha_fwd_tile_partitioner.h" +#include "ck_tiled_fmha_params.h" + +#include "ck_tiled_bool_switch.h" +#include "ck_tiled_headdim_switch.h" + +template < + typename scalar_t, + bool has_causal_mask, + bool has_attn_bias, + ck::index_t MaxK> +struct batched_infer_causalmask_attnbias_dispatched { + using FmhaEpilogue = FmhaFwdEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType>>; + + template + using FmhaPipelineProblemTemp = + ck::tile_program::block::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + FmhaFwdShape, + false, // kIsGroupMode + FmhaMask, + FmhaTraits>; + + static void Run(BatchedForwardParams& param, hipStream_t stream) { + const bool has_local_attention = (param.window_size > 0) ? true : false; + + BOOL_SWITCH(has_local_attention, USE_LOCAL_ATTENTION, [&] { + constexpr bool has_masking = has_causal_mask || USE_LOCAL_ATTENTION; + + using FmhaMask = ck::tile_program::block:: + GenericAttentionMask; + + using FmhaShape = FmhaFwdShape; + using FmhaTilePartitioner = FmhaFwdTilePartitioner; + constexpr ck::index_t occupancy = + (MaxK == 64) ? 3 : ((MaxK == 256) ? 1 : 2); + + bool pad_seqlen_q = !(param.M % FmhaShape::kM0 == 0); + bool pad_seqlen_k = !(param.N % FmhaShape::kN0 == 0); + bool pad_headdim_v = !(param.Kv % FmhaShape::kN1 == 0); + bool pad_headdim_q = !(param.K % FmhaShape::kK0BlockLength == 0); + + if constexpr (MaxK == 256) { + BOOL_SWITCH_4( + pad_seqlen_q, + kPadSeqLenQ, + pad_seqlen_k, + kPadSeqLenK, + pad_headdim_q, + kPadHeadDimQ, + pad_headdim_v, + kPadHeadDimV, + [&] { + using FmhaTraits = ck::tile_program::TileFmhaTraits< + kPadSeqLenQ, + kPadSeqLenK, + kPadHeadDimQ, + kPadHeadDimV, + has_attn_bias, + false, // kStoreLSE + occupancy>; + + using FmhaPipelineProblem = + FmhaPipelineProblemTemp; + + using FmhaPipeline = + ck::tile_program::block::BlockFmhaPipelineQRKSVS< + FmhaPipelineProblem>; + using FmhaKernel = FmhaFwdKernel< + FmhaTilePartitioner, + FmhaPipeline, + FmhaEpilogue>; + + RunWithKernel(param, stream); + }); + } else { + BOOL_SWITCH_4( + pad_seqlen_q, + kPadSeqLenQ, + pad_seqlen_k, + kPadSeqLenK, + pad_headdim_q, + kPadHeadDimQ, + pad_headdim_v, + kPadHeadDimV, + [&] { + using FmhaTraits = ck::tile_program::TileFmhaTraits< + kPadSeqLenQ, + kPadSeqLenK, + kPadHeadDimQ, + kPadHeadDimV, + has_attn_bias, + false, // kStoreLSE + occupancy>; + + using FmhaPipelineProblem = + FmhaPipelineProblemTemp; + + constexpr bool no_any_padding = + !(kPadSeqLenQ || kPadSeqLenK || kPadHeadDimQ || kPadHeadDimV); + + if constexpr (no_any_padding) { + using FmhaPipeline = + ck::tile_program::block::BlockFmhaPipelineQRKSVSAsync< + FmhaPipelineProblem>; + using FmhaKernel = FmhaFwdKernel< + FmhaTilePartitioner, + FmhaPipeline, + FmhaEpilogue>; + + RunWithKernel(param, stream); + } else { + using FmhaPipeline = + ck::tile_program::block::BlockFmhaPipelineQRKSVS< + FmhaPipelineProblem>; + using FmhaKernel = FmhaFwdKernel< + FmhaTilePartitioner, + FmhaPipeline, + FmhaEpilogue>; + + RunWithKernel(param, stream); + }; + }); + }; + }); + }; + + template + static void RunWithKernel(BatchedForwardParams& param, hipStream_t stream) { + const auto kargs = [&] { + return FmhaKernel::MakeKargs( + param.q_ptr, + param.k_ptr, + param.v_ptr, + param.attn_bias_ptr, + nullptr, // lse_ptr + param.out_ptr, + param.M, // seqlen_q + param.N, // seqlen_k + param.K, // hdim_q + param.Kv, // hdim_v + param.Hq / param.Hkv, // nhead_ratio_qk + param.scale, + param.q_strides[1], // q, k, v, bias, out tensor seq-dim stride + param.k_strides[1], + param.v_strides[1], + param.attn_bias_strides[2], + param.out_strides[1], + param.q_strides[2], // q, k, v, bias, lse, out tensor head-dim stride + param.k_strides[2], + param.v_strides[2], + param.attn_bias_strides[1], + 0, // nhead_stride_lse + param.out_strides[2], + param.q_strides[0], // q, k, v, bias, lse, out tensor batch-dim stride + param.k_strides[0], + param.v_strides[0], + param.attn_bias_strides[0], + 0, // batch_stride_lse + param.out_strides[0], + static_cast(param.custom_mask_type), + param.window_size); + }(); + + dim3 kGridSize = FmhaKernel::GridSize(param.B, param.Hq, param.M, param.Kv); + constexpr dim3 kBlockSize = FmhaKernel::BlockSize(); + constexpr ck::index_t kBlockPerCu = FmhaKernel::kBlockPerCu; + + (void)launch_kernel( + StreamConfig{stream, false}, + FmhaKernel{}, + kGridSize, + kBlockSize, + 0, + kargs); + }; +}; + +template < + typename scalar_t, + bool has_causal_mask, + bool has_attn_bias, + ck::index_t MaxK> +void run_batched_infer_causalmask_attnbias_dispatched( + BatchedForwardParams& param, + hipStream_t stream) { + batched_infer_causalmask_attnbias_dispatched< + scalar_t, + has_causal_mask, + has_attn_bias, + MaxK>::Run(param, stream); +}; diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_infer_bp16.cpp b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_infer_bp16.cpp new file mode 100644 index 0000000000..f4a2e064e3 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_infer_bp16.cpp @@ -0,0 +1,71 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include +#include +#include + +#include "ck_tiled_bool_switch.h" +#include "ck_tiled_fmha_batched_infer.h" + +// clang-format off +extern template void run_batched_infer_causalmask_attnbias_dispatched( + BatchedForwardParams& param, hipStream_t stream); +extern template void run_batched_infer_causalmask_attnbias_dispatched( + BatchedForwardParams& param, hipStream_t stream); +extern template void run_batched_infer_causalmask_attnbias_dispatched( + BatchedForwardParams& param, hipStream_t stream); +extern template void run_batched_infer_causalmask_attnbias_dispatched( + BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_causalmask_attnbias_dispatched( + BatchedForwardParams& param, hipStream_t stream); +extern template void run_batched_infer_causalmask_attnbias_dispatched( + BatchedForwardParams& param, hipStream_t stream); +extern template void run_batched_infer_causalmask_attnbias_dispatched( + BatchedForwardParams& param, hipStream_t stream); +extern template void run_batched_infer_causalmask_attnbias_dispatched( + BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_causalmask_attnbias_dispatched( + BatchedForwardParams& param, hipStream_t stream); +extern template void run_batched_infer_causalmask_attnbias_dispatched( + BatchedForwardParams& param, hipStream_t stream); +extern template void run_batched_infer_causalmask_attnbias_dispatched( + BatchedForwardParams& param, hipStream_t stream); +extern template void run_batched_infer_causalmask_attnbias_dispatched( + BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_causalmask_attnbias_dispatched( + BatchedForwardParams& param, hipStream_t stream); +extern template void run_batched_infer_causalmask_attnbias_dispatched( + BatchedForwardParams& param, hipStream_t stream); +extern template void run_batched_infer_causalmask_attnbias_dispatched( + BatchedForwardParams& param, hipStream_t stream); +extern template void run_batched_infer_causalmask_attnbias_dispatched( + BatchedForwardParams& param, hipStream_t stream); +// clang-format on + +void batched_infer_bp16(BatchedForwardParams& param, hipStream_t stream) { + BOOL_SWITCH(param.has_attn_bias, HAS_ATTN_BIAS, [&] { + FMHA_FWD_HEADDIM_SWITCH(param.K, param.Kv, MaxK, [&] { + if (param.custom_mask_type == 0) + run_batched_infer_causalmask_attnbias_dispatched< + ck::bhalf_t, + false, + HAS_ATTN_BIAS, + MaxK>(param, stream); + else if (param.custom_mask_type == 1 || param.custom_mask_type == 2) + run_batched_infer_causalmask_attnbias_dispatched< + ck::bhalf_t, + true, + HAS_ATTN_BIAS, + MaxK>(param, stream); + else + throw std::runtime_error("Invalid custom_mask_type value"); + }); + }); +}; diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_infer_fp16.cpp b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_infer_fp16.cpp new file mode 100644 index 0000000000..653cfacbd5 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_infer_fp16.cpp @@ -0,0 +1,71 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include +#include +#include + +#include "ck_tiled_bool_switch.h" +#include "ck_tiled_fmha_batched_infer.h" + +// clang-format off +extern template void run_batched_infer_causalmask_attnbias_dispatched( + BatchedForwardParams& param, hipStream_t stream); +extern template void run_batched_infer_causalmask_attnbias_dispatched( + BatchedForwardParams& param, hipStream_t stream); +extern template void run_batched_infer_causalmask_attnbias_dispatched( + BatchedForwardParams& param, hipStream_t stream); +extern template void run_batched_infer_causalmask_attnbias_dispatched( + BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_causalmask_attnbias_dispatched( + BatchedForwardParams& param, hipStream_t stream); +extern template void run_batched_infer_causalmask_attnbias_dispatched( + BatchedForwardParams& param, hipStream_t stream); +extern template void run_batched_infer_causalmask_attnbias_dispatched( + BatchedForwardParams& param, hipStream_t stream); +extern template void run_batched_infer_causalmask_attnbias_dispatched( + BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_causalmask_attnbias_dispatched( + BatchedForwardParams& param, hipStream_t stream); +extern template void run_batched_infer_causalmask_attnbias_dispatched( + BatchedForwardParams& param, hipStream_t stream); +extern template void run_batched_infer_causalmask_attnbias_dispatched( + BatchedForwardParams& param, hipStream_t stream); +extern template void run_batched_infer_causalmask_attnbias_dispatched( + BatchedForwardParams& param, hipStream_t stream); + +extern template void run_batched_infer_causalmask_attnbias_dispatched( + BatchedForwardParams& param, hipStream_t stream); +extern template void run_batched_infer_causalmask_attnbias_dispatched( + BatchedForwardParams& param, hipStream_t stream); +extern template void run_batched_infer_causalmask_attnbias_dispatched( + BatchedForwardParams& param, hipStream_t stream); +extern template void run_batched_infer_causalmask_attnbias_dispatched( + BatchedForwardParams& param, hipStream_t stream); +// clang-format on + +void batched_infer_fp16(BatchedForwardParams& param, hipStream_t stream) { + BOOL_SWITCH(param.has_attn_bias, HAS_ATTN_BIAS, [&] { + FMHA_FWD_HEADDIM_SWITCH(param.K, param.Kv, MaxK, [&] { + if (param.custom_mask_type == 0) + run_batched_infer_causalmask_attnbias_dispatched< + ck::half_t, + false, + HAS_ATTN_BIAS, + MaxK>(param, stream); + else if (param.custom_mask_type == 1 || param.custom_mask_type == 2) + run_batched_infer_causalmask_attnbias_dispatched< + ck::half_t, + true, + HAS_ATTN_BIAS, + MaxK>(param, stream); + else + throw std::runtime_error("Invalid custom_mask_type value"); + }); + }); +}; diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_definitions.h b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_definitions.h new file mode 100644 index 0000000000..4e3767fd2a --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_definitions.h @@ -0,0 +1,115 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +#include + +enum struct CausalMaskType { + MaskDisabled, + MaskUpperTriangleFromTopLeft, + MaskUpperTriangleFromBottomRight +}; + +template +struct FmhaFwdTypeConfig; + +template <> +struct FmhaFwdTypeConfig { + using QDataType = ck::half_t; + using KDataType = ck::half_t; + using VDataType = ck::half_t; + using BiasDataType = ck::half_t; + using LSEDataType = + float; // data type for lse(logsumexp L_j = max_j + log(l_j)) + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = ck::half_t; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = ck::half_t; +}; + +template <> +struct FmhaFwdTypeConfig { + using QDataType = ck::bhalf_t; + using KDataType = ck::bhalf_t; + using VDataType = ck::bhalf_t; + using BiasDataType = ck::bhalf_t; + using LSEDataType = + float; // data type for lse(logsumexp L_j = max_j + log(l_j)) + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = ck::bhalf_t; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = ck::bhalf_t; +}; + +template +struct FmhaFwdBlockTile; + +template <> +struct FmhaFwdBlockTile<32> { + using type = ck::Sequence<128, 64, 16, 32, 32, 32>; +}; + +template <> +struct FmhaFwdBlockTile<64> { + using type = ck::Sequence<128, 64, 32, 64, 32, 64>; +}; + +template <> +struct FmhaFwdBlockTile<128> { + using type = ck::Sequence<128, 128, 32, 128, 32, 128>; +}; + +template <> +struct FmhaFwdBlockTile<256> { + using type = ck::Sequence<128, 128, 32, 256, 32, 256>; +}; + +using FmhaFwdBlockWarps = ck::Sequence<4, 1, 1>; +using FmhaFwdWarpTile = ck::Sequence<32, 32, 16>; + +static constexpr bool IsVLayoutRowMajor = true; + +template +struct FmhaFwdShape; + +template <> +struct FmhaFwdShape<32> : ck::tile_program::TileFmhaShape< + typename FmhaFwdBlockTile<32>::type, + ck::Sequence<2, 1, 1>, + FmhaFwdWarpTile, + ck::Sequence<2, 1, 1>, + FmhaFwdWarpTile, + IsVLayoutRowMajor> {}; + +template <> +struct FmhaFwdShape<64> : ck::tile_program::TileFmhaShape< + typename FmhaFwdBlockTile<64>::type, + FmhaFwdBlockWarps, + FmhaFwdWarpTile, + FmhaFwdBlockWarps, + FmhaFwdWarpTile, + IsVLayoutRowMajor> {}; + +template <> +struct FmhaFwdShape<128> : ck::tile_program::TileFmhaShape< + typename FmhaFwdBlockTile<128>::type, + FmhaFwdBlockWarps, + FmhaFwdWarpTile, + FmhaFwdBlockWarps, + FmhaFwdWarpTile, + IsVLayoutRowMajor> {}; + +template <> +struct FmhaFwdShape<256> : ck::tile_program::TileFmhaShape< + typename FmhaFwdBlockTile<256>::type, + FmhaFwdBlockWarps, + FmhaFwdWarpTile, + FmhaFwdBlockWarps, + FmhaFwdWarpTile, + IsVLayoutRowMajor> {}; diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_forward.h b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_forward.h new file mode 100644 index 0000000000..bb4d43d5f6 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_forward.h @@ -0,0 +1,199 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "ck_tiled_fmha_definitions.h" +#include "ck_tiled_fmha_forward_kernel.h" +#include "ck_tiled_fmha_fwd_epilogue.h" +#include "ck_tiled_fmha_fwd_tile_partitioner.h" +#include "ck_tiled_fmha_params.h" + +#include "ck_tiled_bool_switch.h" +#include "ck_tiled_headdim_switch.h" + +template < + typename scalar_t, + bool has_causal_mask, + bool has_attn_bias, + ck::index_t MaxK> +struct grouped_forward_causalmask_attnbias_dispatched { + using FmhaEpilogue = FmhaFwdEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType>>; + + template + using FmhaPipelineProblemTemp = + ck::tile_program::block::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + FmhaFwdShape, + true, // kIsGroupMode + FmhaMask, + FmhaTraits>; + + static void Run(GroupedForwardParams& param, hipStream_t stream) { + const bool has_local_attention = (param.window_size > 0) ? true : false; + + BOOL_SWITCH(has_local_attention, USE_LOCAL_ATTENTION, [&] { + constexpr bool has_masking = has_causal_mask || USE_LOCAL_ATTENTION; + + using FmhaMask = ck::tile_program::block:: + GenericAttentionMask; + + using FmhaShape = FmhaFwdShape; + using FmhaTilePartitioner = FmhaFwdTilePartitioner; + constexpr ck::index_t occupancy = (MaxK == 64) ? 3 + : (MaxK == 256) ? 1 + : 2; + + constexpr bool kPadSeqLenQ = true; + constexpr bool kPadSeqLenK = true; + + bool pad_headdim_q = !(param.K % FmhaShape::kK0BlockLength == 0); + bool pad_headdim_v = !(param.Kv % FmhaShape::kN1 == 0); + + if constexpr (MaxK == 256) { + BOOL_SWITCH_2( + pad_headdim_q, kPadHeadDimQ, pad_headdim_v, kPadHeadDimV, [&] { + using FmhaTraits = ck::tile_program::TileFmhaTraits< + kPadSeqLenQ, + kPadSeqLenK, + kPadHeadDimQ, + kPadHeadDimV, + has_attn_bias, + true, // kStoreLSE + occupancy>; + + using FmhaPipelineProblem = + FmhaPipelineProblemTemp; + + using FmhaPipeline = + ck::tile_program::block::BlockFmhaPipelineQRKSVS< + FmhaPipelineProblem>; + using FmhaKernel = FmhaFwdKernel< + FmhaTilePartitioner, + FmhaPipeline, + FmhaEpilogue>; + + RunWithKernel(param, stream); + }); + } else { + BOOL_SWITCH_2( + pad_headdim_q, kPadHeadDimQ, pad_headdim_v, kPadHeadDimV, [&] { + using FmhaTraits = ck::tile_program::TileFmhaTraits< + kPadSeqLenQ, + kPadSeqLenK, + kPadHeadDimQ, + kPadHeadDimV, + has_attn_bias, + true, // kStoreLSE + occupancy>; + + using FmhaPipelineProblem = + FmhaPipelineProblemTemp; + + using FmhaPipeline = + ck::tile_program::block::BlockFmhaPipelineQRKSVS< + FmhaPipelineProblem>; + using FmhaKernel = FmhaFwdKernel< + FmhaTilePartitioner, + FmhaPipeline, + FmhaEpilogue>; + + RunWithKernel(param, stream); + }); + }; + }); + }; + + template + static void RunWithKernel(GroupedForwardParams& param, hipStream_t stream) { + const auto kargs = [&] { + return FmhaKernel::MakeKargs( + param.q_ptr, + param.k_ptr, + param.v_ptr, + param.attn_bias_ptr, + param.logsumexp_ptr, + param.out_ptr, + param.seqstart_q_dev_ptr, + param.seqstart_k_dev_ptr, + param.seqlen_k_dev_ptr, + param.K, // hdim_q + param.Kv, // hdim_v + param.Hq / param.Hkv, // nhead_ratio_qk + param.scale, + param.q_strides[0], // q, k, v, bias, out tensor seq-dim stride + param.k_strides[0], + param.v_strides[0], + param.attn_bias_strides[2], + param.out_strides[0], + param.q_strides[1], // q, k, v, bias, lse, out tensor head-dim stride + param.k_strides[1], + param.v_strides[1], + param.attn_bias_strides[1], + param.max_seqlen_q, // nhead_stride_lse + param.out_strides[1], + static_cast(param.custom_mask_type), + param.window_size); + }(); + + dim3 kGridSize = FmhaKernel::GridSize( + param.num_batches, param.Hq, param.max_seqlen_q, param.Kv); + constexpr dim3 kBlockSize = FmhaKernel::BlockSize(); + constexpr ck::index_t kBlockPerCu = FmhaKernel::kBlockPerCu; + + (void)launch_kernel( + StreamConfig{stream, false}, + FmhaKernel{}, + kGridSize, + kBlockSize, + 0, + kargs); + }; +}; + +template < + typename scalar_t, + bool has_causal_mask, + bool has_attn_bias, + ck::index_t MaxK> +void run_grouped_forward_causalmask_attnbias_dispatched( + GroupedForwardParams& param, + hipStream_t stream) { + grouped_forward_causalmask_attnbias_dispatched< + scalar_t, + has_causal_mask, + has_attn_bias, + MaxK>::Run(param, stream); +}; diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_forward_bp16.cpp b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_forward_bp16.cpp new file mode 100644 index 0000000000..b417156f53 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_forward_bp16.cpp @@ -0,0 +1,71 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include +#include +#include + +#include "ck_tiled_bool_switch.h" +#include "ck_tiled_fmha_grouped_forward.h" + +// clang-format off +extern template void run_grouped_forward_causalmask_attnbias_dispatched( + GroupedForwardParams& param, hipStream_t stream); +extern template void run_grouped_forward_causalmask_attnbias_dispatched( + GroupedForwardParams& param, hipStream_t stream); +extern template void run_grouped_forward_causalmask_attnbias_dispatched( + GroupedForwardParams& param, hipStream_t stream); +extern template void run_grouped_forward_causalmask_attnbias_dispatched( + GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_causalmask_attnbias_dispatched( + GroupedForwardParams& param, hipStream_t stream); +extern template void run_grouped_forward_causalmask_attnbias_dispatched( + GroupedForwardParams& param, hipStream_t stream); +extern template void run_grouped_forward_causalmask_attnbias_dispatched( + GroupedForwardParams& param, hipStream_t stream); +extern template void run_grouped_forward_causalmask_attnbias_dispatched( + GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_causalmask_attnbias_dispatched( + GroupedForwardParams& param, hipStream_t stream); +extern template void run_grouped_forward_causalmask_attnbias_dispatched( + GroupedForwardParams& param, hipStream_t stream); +extern template void run_grouped_forward_causalmask_attnbias_dispatched( + GroupedForwardParams& param, hipStream_t stream); +extern template void run_grouped_forward_causalmask_attnbias_dispatched( + GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_causalmask_attnbias_dispatched( + GroupedForwardParams& param, hipStream_t stream); +extern template void run_grouped_forward_causalmask_attnbias_dispatched( + GroupedForwardParams& param, hipStream_t stream); +extern template void run_grouped_forward_causalmask_attnbias_dispatched( + GroupedForwardParams& param, hipStream_t stream); +extern template void run_grouped_forward_causalmask_attnbias_dispatched( + GroupedForwardParams& param, hipStream_t stream); +// clang-format on + +void grouped_forward_bp16(GroupedForwardParams& param, hipStream_t stream) { + BOOL_SWITCH(param.has_attn_bias, HAS_ATTN_BIAS, [&] { + FMHA_FWD_HEADDIM_SWITCH(param.K, param.Kv, MaxK, [&] { + if (param.custom_mask_type == 0) + run_grouped_forward_causalmask_attnbias_dispatched< + ck::bhalf_t, + false, + HAS_ATTN_BIAS, + MaxK>(param, stream); + else if (param.custom_mask_type == 1 || param.custom_mask_type == 2) + run_grouped_forward_causalmask_attnbias_dispatched< + ck::bhalf_t, + true, + HAS_ATTN_BIAS, + MaxK>(param, stream); + else + throw std::runtime_error("Invalid custom_mask_type value"); + }); + }); +}; diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_forward_fp16.cpp b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_forward_fp16.cpp new file mode 100644 index 0000000000..b7c278c53a --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_forward_fp16.cpp @@ -0,0 +1,71 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include +#include +#include + +#include "ck_tiled_bool_switch.h" +#include "ck_tiled_fmha_grouped_forward.h" + +// clang-format off +extern template void run_grouped_forward_causalmask_attnbias_dispatched( + GroupedForwardParams& param, hipStream_t stream); +extern template void run_grouped_forward_causalmask_attnbias_dispatched( + GroupedForwardParams& param, hipStream_t stream); +extern template void run_grouped_forward_causalmask_attnbias_dispatched( + GroupedForwardParams& param, hipStream_t stream); +extern template void run_grouped_forward_causalmask_attnbias_dispatched( + GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_causalmask_attnbias_dispatched( + GroupedForwardParams& param, hipStream_t stream); +extern template void run_grouped_forward_causalmask_attnbias_dispatched( + GroupedForwardParams& param, hipStream_t stream); +extern template void run_grouped_forward_causalmask_attnbias_dispatched( + GroupedForwardParams& param, hipStream_t stream); +extern template void run_grouped_forward_causalmask_attnbias_dispatched( + GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_causalmask_attnbias_dispatched( + GroupedForwardParams& param, hipStream_t stream); +extern template void run_grouped_forward_causalmask_attnbias_dispatched( + GroupedForwardParams& param, hipStream_t stream); +extern template void run_grouped_forward_causalmask_attnbias_dispatched( + GroupedForwardParams& param, hipStream_t stream); +extern template void run_grouped_forward_causalmask_attnbias_dispatched( + GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_forward_causalmask_attnbias_dispatched( + GroupedForwardParams& param, hipStream_t stream); +extern template void run_grouped_forward_causalmask_attnbias_dispatched( + GroupedForwardParams& param, hipStream_t stream); +extern template void run_grouped_forward_causalmask_attnbias_dispatched( + GroupedForwardParams& param, hipStream_t stream); +extern template void run_grouped_forward_causalmask_attnbias_dispatched( + GroupedForwardParams& param, hipStream_t stream); +// clang-format on + +void grouped_forward_fp16(GroupedForwardParams& param, hipStream_t stream) { + BOOL_SWITCH(param.has_attn_bias, HAS_ATTN_BIAS, [&] { + FMHA_FWD_HEADDIM_SWITCH(param.K, param.Kv, MaxK, [&] { + if (param.custom_mask_type == 0) + run_grouped_forward_causalmask_attnbias_dispatched< + ck::half_t, + false, + HAS_ATTN_BIAS, + MaxK>(param, stream); + else if (param.custom_mask_type == 1 || param.custom_mask_type == 2) + run_grouped_forward_causalmask_attnbias_dispatched< + ck::half_t, + true, + HAS_ATTN_BIAS, + MaxK>(param, stream); + else + throw std::runtime_error("Invalid custom_mask_type value"); + }); + }); +}; diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_infer.h b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_infer.h new file mode 100644 index 0000000000..c371b0aa14 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_infer.h @@ -0,0 +1,198 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "ck_tiled_fmha_definitions.h" +#include "ck_tiled_fmha_forward_kernel.h" +#include "ck_tiled_fmha_fwd_epilogue.h" +#include "ck_tiled_fmha_fwd_tile_partitioner.h" +#include "ck_tiled_fmha_params.h" + +#include "ck_tiled_bool_switch.h" +#include "ck_tiled_headdim_switch.h" + +template < + typename scalar_t, + bool has_causal_mask, + bool has_attn_bias, + ck::index_t MaxK> +struct grouped_infer_causalmask_attnbias_dispatched { + using FmhaEpilogue = FmhaFwdEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType>>; + + template + using FmhaPipelineProblemTemp = + ck::tile_program::block::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + FmhaFwdShape, + true, // kIsGroupMode + FmhaMask, + FmhaTraits>; + + static void Run(GroupedForwardParams& param, hipStream_t stream) { + const bool has_local_attention = (param.window_size > 0) ? true : false; + + BOOL_SWITCH(has_local_attention, USE_LOCAL_ATTENTION, [&] { + constexpr bool has_masking = has_causal_mask || USE_LOCAL_ATTENTION; + + using FmhaMask = ck::tile_program::block:: + GenericAttentionMask; + + using FmhaShape = FmhaFwdShape; + using FmhaTilePartitioner = FmhaFwdTilePartitioner; + constexpr ck::index_t occupancy = + (MaxK == 64) ? 3 : ((MaxK == 256) ? 1 : 2); + + constexpr bool kPadSeqLenQ = true; + constexpr bool kPadSeqLenK = true; + + bool pad_headdim_q = !(param.K % FmhaShape::kK0BlockLength == 0); + bool pad_headdim_v = !(param.Kv % FmhaShape::kN1 == 0); + + if constexpr (MaxK == 256) { + BOOL_SWITCH_2( + pad_headdim_q, kPadHeadDimQ, pad_headdim_v, kPadHeadDimV, [&] { + using FmhaTraits = ck::tile_program::TileFmhaTraits< + kPadSeqLenQ, + kPadSeqLenK, + kPadHeadDimQ, + kPadHeadDimV, + has_attn_bias, + false, // kStoreLSE + occupancy>; + + using FmhaPipelineProblem = + FmhaPipelineProblemTemp; + + using FmhaPipeline = + ck::tile_program::block::BlockFmhaPipelineQRKSVS< + FmhaPipelineProblem>; + using FmhaKernel = FmhaFwdKernel< + FmhaTilePartitioner, + FmhaPipeline, + FmhaEpilogue>; + + RunWithKernel(param, stream); + }); + } else { + BOOL_SWITCH_2( + pad_headdim_q, kPadHeadDimQ, pad_headdim_v, kPadHeadDimV, [&] { + using FmhaTraits = ck::tile_program::TileFmhaTraits< + kPadSeqLenQ, + kPadSeqLenK, + kPadHeadDimQ, + kPadHeadDimV, + has_attn_bias, + false, // kStoreLSE + occupancy>; + + using FmhaPipelineProblem = + FmhaPipelineProblemTemp; + + using FmhaPipeline = + ck::tile_program::block::BlockFmhaPipelineQRKSVS< + FmhaPipelineProblem>; + using FmhaKernel = FmhaFwdKernel< + FmhaTilePartitioner, + FmhaPipeline, + FmhaEpilogue>; + + RunWithKernel(param, stream); + }); + }; + }); + }; + + template + static void RunWithKernel(GroupedForwardParams& param, hipStream_t stream) { + const auto kargs = [&] { + return FmhaKernel::MakeKargs( + param.q_ptr, + param.k_ptr, + param.v_ptr, + param.attn_bias_ptr, + nullptr, // lse_ptr + param.out_ptr, + param.seqstart_q_dev_ptr, + param.seqstart_k_dev_ptr, + param.seqlen_k_dev_ptr, + param.K, // hdim_q + param.Kv, // hdim_v + param.Hq / param.Hkv, // nhead_ratio_qk + param.scale, + param.q_strides[0], // q, k, v, bias, out tensor seq-dim stride + param.k_strides[0], + param.v_strides[0], + param.attn_bias_strides[2], + param.out_strides[0], + param.q_strides[1], // q, k, v, bias, lse, out tensor head-dim stride + param.k_strides[1], + param.v_strides[1], + param.attn_bias_strides[1], + 0, // nhead_stride_lse + param.out_strides[1], + static_cast(param.custom_mask_type), + param.window_size); + }(); + + dim3 kGridSize = FmhaKernel::GridSize( + param.num_batches, param.Hq, param.max_seqlen_q, param.Kv); + constexpr dim3 kBlockSize = FmhaKernel::BlockSize(); + constexpr ck::index_t kBlockPerCu = FmhaKernel::kBlockPerCu; + + (void)launch_kernel( + StreamConfig{stream, false}, + FmhaKernel{}, + kGridSize, + kBlockSize, + 0, + kargs); + }; +}; + +template < + typename scalar_t, + bool has_causal_mask, + bool has_attn_bias, + ck::index_t MaxK> +void run_grouped_infer_causalmask_attnbias_dispatched( + GroupedForwardParams& param, + hipStream_t stream) { + grouped_infer_causalmask_attnbias_dispatched< + scalar_t, + has_causal_mask, + has_attn_bias, + MaxK>::Run(param, stream); +}; diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_infer_bp16.cpp b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_infer_bp16.cpp new file mode 100644 index 0000000000..7ee53261d7 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_infer_bp16.cpp @@ -0,0 +1,71 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include +#include +#include + +#include "ck_tiled_bool_switch.h" +#include "ck_tiled_fmha_grouped_infer.h" + +// clang-format off +extern template void run_grouped_infer_causalmask_attnbias_dispatched( + GroupedForwardParams& param, hipStream_t stream); +extern template void run_grouped_infer_causalmask_attnbias_dispatched( + GroupedForwardParams& param, hipStream_t stream); +extern template void run_grouped_infer_causalmask_attnbias_dispatched( + GroupedForwardParams& param, hipStream_t stream); +extern template void run_grouped_infer_causalmask_attnbias_dispatched( + GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_causalmask_attnbias_dispatched( + GroupedForwardParams& param, hipStream_t stream); +extern template void run_grouped_infer_causalmask_attnbias_dispatched( + GroupedForwardParams& param, hipStream_t stream); +extern template void run_grouped_infer_causalmask_attnbias_dispatched( + GroupedForwardParams& param, hipStream_t stream); +extern template void run_grouped_infer_causalmask_attnbias_dispatched( + GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_causalmask_attnbias_dispatched( + GroupedForwardParams& param, hipStream_t stream); +extern template void run_grouped_infer_causalmask_attnbias_dispatched( + GroupedForwardParams& param, hipStream_t stream); +extern template void run_grouped_infer_causalmask_attnbias_dispatched( + GroupedForwardParams& param, hipStream_t stream); +extern template void run_grouped_infer_causalmask_attnbias_dispatched( + GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_causalmask_attnbias_dispatched( + GroupedForwardParams& param, hipStream_t stream); +extern template void run_grouped_infer_causalmask_attnbias_dispatched( + GroupedForwardParams& param, hipStream_t stream); +extern template void run_grouped_infer_causalmask_attnbias_dispatched( + GroupedForwardParams& param, hipStream_t stream); +extern template void run_grouped_infer_causalmask_attnbias_dispatched( + GroupedForwardParams& param, hipStream_t stream); +// clang-format on + +void grouped_infer_bp16(GroupedForwardParams& param, hipStream_t stream) { + BOOL_SWITCH(param.has_attn_bias, HAS_ATTN_BIAS, [&] { + FMHA_FWD_HEADDIM_SWITCH(param.K, param.Kv, MaxK, [&] { + if (param.custom_mask_type == 0) + run_grouped_infer_causalmask_attnbias_dispatched< + ck::bhalf_t, + false, + HAS_ATTN_BIAS, + MaxK>(param, stream); + else if (param.custom_mask_type == 1 || param.custom_mask_type == 2) + run_grouped_infer_causalmask_attnbias_dispatched< + ck::bhalf_t, + true, + HAS_ATTN_BIAS, + MaxK>(param, stream); + else + throw std::runtime_error("Invalid custom_mask_type value"); + }); + }); +}; diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_infer_fp16.cpp b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_infer_fp16.cpp new file mode 100644 index 0000000000..2d03119db8 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_infer_fp16.cpp @@ -0,0 +1,71 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include +#include +#include + +#include "ck_tiled_bool_switch.h" +#include "ck_tiled_fmha_grouped_infer.h" + +// clang-format off +extern template void run_grouped_infer_causalmask_attnbias_dispatched( + GroupedForwardParams& param, hipStream_t stream); +extern template void run_grouped_infer_causalmask_attnbias_dispatched( + GroupedForwardParams& param, hipStream_t stream); +extern template void run_grouped_infer_causalmask_attnbias_dispatched( + GroupedForwardParams& param, hipStream_t stream); +extern template void run_grouped_infer_causalmask_attnbias_dispatched( + GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_causalmask_attnbias_dispatched( + GroupedForwardParams& param, hipStream_t stream); +extern template void run_grouped_infer_causalmask_attnbias_dispatched( + GroupedForwardParams& param, hipStream_t stream); +extern template void run_grouped_infer_causalmask_attnbias_dispatched( + GroupedForwardParams& param, hipStream_t stream); +extern template void run_grouped_infer_causalmask_attnbias_dispatched( + GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_causalmask_attnbias_dispatched( + GroupedForwardParams& param, hipStream_t stream); +extern template void run_grouped_infer_causalmask_attnbias_dispatched( + GroupedForwardParams& param, hipStream_t stream); +extern template void run_grouped_infer_causalmask_attnbias_dispatched( + GroupedForwardParams& param, hipStream_t stream); +extern template void run_grouped_infer_causalmask_attnbias_dispatched( + GroupedForwardParams& param, hipStream_t stream); + +extern template void run_grouped_infer_causalmask_attnbias_dispatched( + GroupedForwardParams& param, hipStream_t stream); +extern template void run_grouped_infer_causalmask_attnbias_dispatched( + GroupedForwardParams& param, hipStream_t stream); +extern template void run_grouped_infer_causalmask_attnbias_dispatched( + GroupedForwardParams& param, hipStream_t stream); +extern template void run_grouped_infer_causalmask_attnbias_dispatched( + GroupedForwardParams& param, hipStream_t stream); +// clang-format on + +void grouped_infer_fp16(GroupedForwardParams& param, hipStream_t stream) { + BOOL_SWITCH(param.has_attn_bias, HAS_ATTN_BIAS, [&] { + FMHA_FWD_HEADDIM_SWITCH(param.K, param.Kv, MaxK, [&] { + if (param.custom_mask_type == 0) + run_grouped_infer_causalmask_attnbias_dispatched< + ck::half_t, + false, + HAS_ATTN_BIAS, + MaxK>(param, stream); + else if (param.custom_mask_type == 1 || param.custom_mask_type == 2) + run_grouped_infer_causalmask_attnbias_dispatched< + ck::half_t, + true, + HAS_ATTN_BIAS, + MaxK>(param, stream); + else + throw std::runtime_error("Invalid custom_mask_type value"); + }); + }); +}; diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_params.h b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_params.h new file mode 100644 index 0000000000..5d2c232ba1 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_params.h @@ -0,0 +1,215 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +#include +#include + +struct BatchedInferParams { + int B; // batch size + int M; // seq_len for Query + int N; // seq_len for Key and Value + int Hq; // number of heads for Query + int Hkv; // number of heads for Key and Value + int K; // embed_dim for Query and Key + int Kv; // embed_dim for Value + + float scale; + bool has_attn_bias; + + // BMHK mode strides + std::array q_strides; + std::array k_strides; + std::array v_strides; + std::array out_strides; + std::array attn_bias_strides; // 4d tensor_view [B, H, M, N] + + const void* q_ptr; + const void* k_ptr; + const void* v_ptr; + const void* attn_bias_ptr; + + int custom_mask_type; + int window_size; // local-attention + + void* out_ptr; +}; + +struct BatchedForwardParams : public BatchedInferParams { + bool use_dropout; + bool compute_logsumexp; + + float dropout_prob; + int64_t philox_seed; + int64_t philox_offset; + + // completely contiguous + void* logsumexp_ptr; +}; + +struct GroupedInferParams { + int num_batches; + int M; // total seq_len for all queries in the batch + int N; // total seq_len for all keys/values in the batch + int Hq; // number of heads for Query + int Hkv; // number of heads for Key and Value + int K; // embed_dim for Query and Key + int Kv; // embed_dim for Value + + int max_seqlen_q; + + void* seqstart_q_dev_ptr; + void* seqstart_k_dev_ptr; + void* seqlen_k_dev_ptr; + + float scale; + bool has_attn_bias; + + // MHK mode strides, last-dim contiguous + std::array q_strides; + std::array k_strides; + std::array v_strides; + std::array out_strides; + + // 4d tensor view [B, H, M, N] + std::array attn_bias_strides; + + const void* q_ptr; + const void* k_ptr; + const void* v_ptr; + const void* attn_bias_ptr; + + int custom_mask_type; + int window_size; // local-attention + + void* out_ptr; +}; + +struct GroupedForwardParams : public GroupedInferParams { + bool use_dropout; + bool compute_logsumexp; + + float dropout_prob; + int64_t philox_seed; + int64_t philox_offset; + + // completely contiguous + void* logsumexp_ptr; + + // TODO: need remove this after dev-op fix + std::vector randvals_ptrs; +}; + +struct BatchedBackwardParams { + int B; // batch size + int M; // seq_len for Query + int N; // seq_len for Key and Value + int Hq; // number of heads for Query + int Hkv; // number of heads for Key and Value + int K; // embed_dim for Query and Key + int Kv; // embed_dim for Value + + float scale; + bool has_attn_bias; + bool bias_has_grad; + + bool use_fp32_qkv_grad; + bool is_mqa_gqa; + + // BMHK mode strides, last-dim contiguous + std::array q_strides; + std::array k_strides; + std::array v_strides; + std::array attn_bias_strides; // 4d tensor_view [B, H, M, N] + std::array out_strides; + + std::array tmp_grad_k_strides; + std::array tmp_grad_v_strides; + + const void* q_ptr; + const void* k_ptr; + const void* v_ptr; + const void* attn_bias_ptr; + const void* grad_out_ptr; + const void* out_ptr; + + uint8_t custom_mask_type; + + void* grad_q_ptr; + void* grad_k_ptr; + void* grad_v_ptr; + void* grad_bias_ptr; + + float dropout_prob; + int64_t philox_seed; + int64_t philox_offset; + + // BHM mode lengths, completely contiguous + const void* logsumexp_ptr; +}; + +struct GroupedBackwardParams { + int num_batches; + int M; // total seq_len for all queries in the batch + int N; // total seq_len for all keys/values in the batch + int Hq; // number of heads for Query + int Hkv; // number of heads for Key and Value + int K; // embed_dim for Query and Key + int Kv; // embed_dim for Value + + int max_seqlen_q; + + std::vector host_seqstart_q; + std::vector host_seqstart_k; + std::vector host_seqlen_k; + + float scale; + bool has_attn_bias; + bool bias_has_grad; + + bool use_fp32_qkv_grad; + bool is_mqa_gqa; + + // MHK mode strides, last-dim contiguous + std::array q_strides; + std::array k_strides; + std::array v_strides; + std::array out_strides; + // 4d tensor view [B, H, M, N] + std::array attn_bias_strides; + + std::array tmp_grad_k_strides; + std::array tmp_grad_v_strides; + + std::vector q_ptrs; + std::vector k_ptrs; + std::vector v_ptrs; + std::vector attn_bias_ptrs; + std::vector grad_out_ptrs; + std::vector out_ptrs; + + // used by the light_v2 kernel + // TODO use these as workspace + std::vector ydotdy_ptrs; + + uint8_t custom_mask_type; + + std::vector grad_q_ptrs; + std::vector grad_k_ptrs; + std::vector grad_v_ptrs; + std::vector grad_bias_ptrs; + + float dropout_prob; + int64_t philox_seed; + int64_t philox_offset; + + // BHM mode lengths, completely contiguous + std::vector logsumexp_ptrs; + + // TODO: need remove this after dev-op fix + std::vector randvals_ptrs; +}; diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_headdim_switch.h b/xformers/csrc/attention/hip_fmha/ck_tiled_headdim_switch.h new file mode 100644 index 0000000000..6de737c80a --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_headdim_switch.h @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +#include + +#define FMHA_FWD_HEADDIM_SWITCH(HEAD_DIM1, HEAD_DIM2, CONST_NAME, ...) \ + [&] { \ + if (HEAD_DIM1 <= 32 && HEAD_DIM2 <= 32) { \ + constexpr ck::index_t CONST_NAME = 32; \ + __VA_ARGS__(); \ + } else if (HEAD_DIM1 <= 64 && HEAD_DIM2 <= 64) { \ + constexpr ck::index_t CONST_NAME = 64; \ + __VA_ARGS__(); \ + } else if (HEAD_DIM1 <= 128 && HEAD_DIM2 <= 128) { \ + constexpr ck::index_t CONST_NAME = 128; \ + __VA_ARGS__(); \ + } else if (HEAD_DIM1 <= 256 && HEAD_DIM2 <= 256) { \ + constexpr ck::index_t CONST_NAME = 256; \ + __VA_ARGS__(); \ + } else { \ + throw std::runtime_error("Head-dim sizes not supported!"); \ + } \ + }() diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_bp16_no_causalmask_no_attnbias_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_bp16_no_causalmask_no_attnbias_maxk_128.cpp new file mode 100644 index 0000000000..1482336abf --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_bp16_no_causalmask_no_attnbias_maxk_128.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_batched_forward.h" + +template void run_batched_forward_causalmask_attnbias_dispatched< + ck::bhalf_t, + false, + false, + 128>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_bp16_no_causalmask_no_attnbias_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_bp16_no_causalmask_no_attnbias_maxk_256.cpp new file mode 100644 index 0000000000..f1ba383daf --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_bp16_no_causalmask_no_attnbias_maxk_256.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_batched_forward.h" + +template void run_batched_forward_causalmask_attnbias_dispatched< + ck::bhalf_t, + false, + false, + 256>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_bp16_no_causalmask_no_attnbias_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_bp16_no_causalmask_no_attnbias_maxk_32.cpp new file mode 100644 index 0000000000..3b9f3026b3 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_bp16_no_causalmask_no_attnbias_maxk_32.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_batched_forward.h" + +template void run_batched_forward_causalmask_attnbias_dispatched< + ck::bhalf_t, + false, + false, + 32>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_bp16_no_causalmask_no_attnbias_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_bp16_no_causalmask_no_attnbias_maxk_64.cpp new file mode 100644 index 0000000000..c38716ce22 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_bp16_no_causalmask_no_attnbias_maxk_64.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_batched_forward.h" + +template void run_batched_forward_causalmask_attnbias_dispatched< + ck::bhalf_t, + false, + false, + 64>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_bp16_no_causalmask_with_attnbias_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_bp16_no_causalmask_with_attnbias_maxk_128.cpp new file mode 100644 index 0000000000..ed91bf4bf0 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_bp16_no_causalmask_with_attnbias_maxk_128.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_batched_forward.h" + +template void run_batched_forward_causalmask_attnbias_dispatched< + ck::bhalf_t, + false, + true, + 128>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_bp16_no_causalmask_with_attnbias_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_bp16_no_causalmask_with_attnbias_maxk_256.cpp new file mode 100644 index 0000000000..eca8592290 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_bp16_no_causalmask_with_attnbias_maxk_256.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_batched_forward.h" + +template void run_batched_forward_causalmask_attnbias_dispatched< + ck::bhalf_t, + false, + true, + 256>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_bp16_no_causalmask_with_attnbias_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_bp16_no_causalmask_with_attnbias_maxk_32.cpp new file mode 100644 index 0000000000..ec258aeda0 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_bp16_no_causalmask_with_attnbias_maxk_32.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_batched_forward.h" + +template void run_batched_forward_causalmask_attnbias_dispatched< + ck::bhalf_t, + false, + true, + 32>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_bp16_no_causalmask_with_attnbias_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_bp16_no_causalmask_with_attnbias_maxk_64.cpp new file mode 100644 index 0000000000..feb78a115c --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_bp16_no_causalmask_with_attnbias_maxk_64.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_batched_forward.h" + +template void run_batched_forward_causalmask_attnbias_dispatched< + ck::bhalf_t, + false, + true, + 64>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_bp16_with_causalmask_no_attnbias_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_bp16_with_causalmask_no_attnbias_maxk_128.cpp new file mode 100644 index 0000000000..59c6550f4c --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_bp16_with_causalmask_no_attnbias_maxk_128.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_batched_forward.h" + +template void run_batched_forward_causalmask_attnbias_dispatched< + ck::bhalf_t, + true, + false, + 128>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_bp16_with_causalmask_no_attnbias_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_bp16_with_causalmask_no_attnbias_maxk_256.cpp new file mode 100644 index 0000000000..a30775e77c --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_bp16_with_causalmask_no_attnbias_maxk_256.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_batched_forward.h" + +template void run_batched_forward_causalmask_attnbias_dispatched< + ck::bhalf_t, + true, + false, + 256>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_bp16_with_causalmask_no_attnbias_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_bp16_with_causalmask_no_attnbias_maxk_32.cpp new file mode 100644 index 0000000000..594c4a68ce --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_bp16_with_causalmask_no_attnbias_maxk_32.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_batched_forward.h" + +template void run_batched_forward_causalmask_attnbias_dispatched< + ck::bhalf_t, + true, + false, + 32>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_bp16_with_causalmask_no_attnbias_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_bp16_with_causalmask_no_attnbias_maxk_64.cpp new file mode 100644 index 0000000000..39ea429139 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_bp16_with_causalmask_no_attnbias_maxk_64.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_batched_forward.h" + +template void run_batched_forward_causalmask_attnbias_dispatched< + ck::bhalf_t, + true, + false, + 64>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_bp16_with_causalmask_with_attnbias_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_bp16_with_causalmask_with_attnbias_maxk_128.cpp new file mode 100644 index 0000000000..6ea24c5ca2 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_bp16_with_causalmask_with_attnbias_maxk_128.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_batched_forward.h" + +template void run_batched_forward_causalmask_attnbias_dispatched< + ck::bhalf_t, + true, + true, + 128>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_bp16_with_causalmask_with_attnbias_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_bp16_with_causalmask_with_attnbias_maxk_256.cpp new file mode 100644 index 0000000000..a675c95be0 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_bp16_with_causalmask_with_attnbias_maxk_256.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_batched_forward.h" + +template void run_batched_forward_causalmask_attnbias_dispatched< + ck::bhalf_t, + true, + true, + 256>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_bp16_with_causalmask_with_attnbias_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_bp16_with_causalmask_with_attnbias_maxk_32.cpp new file mode 100644 index 0000000000..dc4bb0ea0f --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_bp16_with_causalmask_with_attnbias_maxk_32.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_batched_forward.h" + +template void run_batched_forward_causalmask_attnbias_dispatched< + ck::bhalf_t, + true, + true, + 32>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_bp16_with_causalmask_with_attnbias_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_bp16_with_causalmask_with_attnbias_maxk_64.cpp new file mode 100644 index 0000000000..334eb891f1 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_bp16_with_causalmask_with_attnbias_maxk_64.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_batched_forward.h" + +template void run_batched_forward_causalmask_attnbias_dispatched< + ck::bhalf_t, + true, + true, + 64>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_fp16_no_causalmask_no_attnbias_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_fp16_no_causalmask_no_attnbias_maxk_128.cpp new file mode 100644 index 0000000000..606d9db860 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_fp16_no_causalmask_no_attnbias_maxk_128.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_batched_forward.h" + +template void run_batched_forward_causalmask_attnbias_dispatched< + ck::half_t, + false, + false, + 128>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_fp16_no_causalmask_no_attnbias_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_fp16_no_causalmask_no_attnbias_maxk_256.cpp new file mode 100644 index 0000000000..7dc799605f --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_fp16_no_causalmask_no_attnbias_maxk_256.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_batched_forward.h" + +template void run_batched_forward_causalmask_attnbias_dispatched< + ck::half_t, + false, + false, + 256>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_fp16_no_causalmask_no_attnbias_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_fp16_no_causalmask_no_attnbias_maxk_32.cpp new file mode 100644 index 0000000000..566b1bf6a3 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_fp16_no_causalmask_no_attnbias_maxk_32.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_batched_forward.h" + +template void run_batched_forward_causalmask_attnbias_dispatched< + ck::half_t, + false, + false, + 32>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_fp16_no_causalmask_no_attnbias_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_fp16_no_causalmask_no_attnbias_maxk_64.cpp new file mode 100644 index 0000000000..3b72b97d12 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_fp16_no_causalmask_no_attnbias_maxk_64.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_batched_forward.h" + +template void run_batched_forward_causalmask_attnbias_dispatched< + ck::half_t, + false, + false, + 64>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_fp16_no_causalmask_with_attnbias_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_fp16_no_causalmask_with_attnbias_maxk_128.cpp new file mode 100644 index 0000000000..c2c124dbe7 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_fp16_no_causalmask_with_attnbias_maxk_128.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_batched_forward.h" + +template void run_batched_forward_causalmask_attnbias_dispatched< + ck::half_t, + false, + true, + 128>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_fp16_no_causalmask_with_attnbias_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_fp16_no_causalmask_with_attnbias_maxk_256.cpp new file mode 100644 index 0000000000..1cdd7e0781 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_fp16_no_causalmask_with_attnbias_maxk_256.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_batched_forward.h" + +template void run_batched_forward_causalmask_attnbias_dispatched< + ck::half_t, + false, + true, + 256>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_fp16_no_causalmask_with_attnbias_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_fp16_no_causalmask_with_attnbias_maxk_32.cpp new file mode 100644 index 0000000000..50ea226597 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_fp16_no_causalmask_with_attnbias_maxk_32.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_batched_forward.h" + +template void run_batched_forward_causalmask_attnbias_dispatched< + ck::half_t, + false, + true, + 32>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_fp16_no_causalmask_with_attnbias_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_fp16_no_causalmask_with_attnbias_maxk_64.cpp new file mode 100644 index 0000000000..58ac17e394 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_fp16_no_causalmask_with_attnbias_maxk_64.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_batched_forward.h" + +template void run_batched_forward_causalmask_attnbias_dispatched< + ck::half_t, + false, + true, + 64>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_fp16_with_causalmask_no_attnbias_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_fp16_with_causalmask_no_attnbias_maxk_128.cpp new file mode 100644 index 0000000000..070ed44ef2 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_fp16_with_causalmask_no_attnbias_maxk_128.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_batched_forward.h" + +template void run_batched_forward_causalmask_attnbias_dispatched< + ck::half_t, + true, + false, + 128>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_fp16_with_causalmask_no_attnbias_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_fp16_with_causalmask_no_attnbias_maxk_256.cpp new file mode 100644 index 0000000000..e535f40f3b --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_fp16_with_causalmask_no_attnbias_maxk_256.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_batched_forward.h" + +template void run_batched_forward_causalmask_attnbias_dispatched< + ck::half_t, + true, + false, + 256>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_fp16_with_causalmask_no_attnbias_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_fp16_with_causalmask_no_attnbias_maxk_32.cpp new file mode 100644 index 0000000000..a24884bff3 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_fp16_with_causalmask_no_attnbias_maxk_32.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_batched_forward.h" + +template void run_batched_forward_causalmask_attnbias_dispatched< + ck::half_t, + true, + false, + 32>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_fp16_with_causalmask_no_attnbias_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_fp16_with_causalmask_no_attnbias_maxk_64.cpp new file mode 100644 index 0000000000..524e1ab867 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_fp16_with_causalmask_no_attnbias_maxk_64.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_batched_forward.h" + +template void run_batched_forward_causalmask_attnbias_dispatched< + ck::half_t, + true, + false, + 64>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_fp16_with_causalmask_with_attnbias_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_fp16_with_causalmask_with_attnbias_maxk_128.cpp new file mode 100644 index 0000000000..58013ca642 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_fp16_with_causalmask_with_attnbias_maxk_128.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_batched_forward.h" + +template void run_batched_forward_causalmask_attnbias_dispatched< + ck::half_t, + true, + true, + 128>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_fp16_with_causalmask_with_attnbias_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_fp16_with_causalmask_with_attnbias_maxk_256.cpp new file mode 100644 index 0000000000..fcb6d8b546 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_fp16_with_causalmask_with_attnbias_maxk_256.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_batched_forward.h" + +template void run_batched_forward_causalmask_attnbias_dispatched< + ck::half_t, + true, + true, + 256>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_fp16_with_causalmask_with_attnbias_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_fp16_with_causalmask_with_attnbias_maxk_32.cpp new file mode 100644 index 0000000000..38e7fb026c --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_fp16_with_causalmask_with_attnbias_maxk_32.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_batched_forward.h" + +template void run_batched_forward_causalmask_attnbias_dispatched< + ck::half_t, + true, + true, + 32>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_fp16_with_causalmask_with_attnbias_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_fp16_with_causalmask_with_attnbias_maxk_64.cpp new file mode 100644 index 0000000000..1c0b277b71 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_forward_fp16_with_causalmask_with_attnbias_maxk_64.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_batched_forward.h" + +template void run_batched_forward_causalmask_attnbias_dispatched< + ck::half_t, + true, + true, + 64>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_bp16_no_causalmask_no_attnbias_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_bp16_no_causalmask_no_attnbias_maxk_128.cpp new file mode 100644 index 0000000000..b95c3fdb97 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_bp16_no_causalmask_no_attnbias_maxk_128.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_batched_infer.h" + +template void run_batched_infer_causalmask_attnbias_dispatched< + ck::bhalf_t, + false, + false, + 128>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_bp16_no_causalmask_no_attnbias_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_bp16_no_causalmask_no_attnbias_maxk_256.cpp new file mode 100644 index 0000000000..dce1496ea1 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_bp16_no_causalmask_no_attnbias_maxk_256.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_batched_infer.h" + +template void run_batched_infer_causalmask_attnbias_dispatched< + ck::bhalf_t, + false, + false, + 256>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_bp16_no_causalmask_no_attnbias_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_bp16_no_causalmask_no_attnbias_maxk_32.cpp new file mode 100644 index 0000000000..fa81f80c11 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_bp16_no_causalmask_no_attnbias_maxk_32.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_batched_infer.h" + +template void run_batched_infer_causalmask_attnbias_dispatched< + ck::bhalf_t, + false, + false, + 32>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_bp16_no_causalmask_no_attnbias_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_bp16_no_causalmask_no_attnbias_maxk_64.cpp new file mode 100644 index 0000000000..fd118cd222 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_bp16_no_causalmask_no_attnbias_maxk_64.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_batched_infer.h" + +template void run_batched_infer_causalmask_attnbias_dispatched< + ck::bhalf_t, + false, + false, + 64>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_bp16_no_causalmask_with_attnbias_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_bp16_no_causalmask_with_attnbias_maxk_128.cpp new file mode 100644 index 0000000000..4772d56ab2 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_bp16_no_causalmask_with_attnbias_maxk_128.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_batched_infer.h" + +template void run_batched_infer_causalmask_attnbias_dispatched< + ck::bhalf_t, + false, + true, + 128>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_bp16_no_causalmask_with_attnbias_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_bp16_no_causalmask_with_attnbias_maxk_256.cpp new file mode 100644 index 0000000000..b95f0d5ae8 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_bp16_no_causalmask_with_attnbias_maxk_256.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_batched_infer.h" + +template void run_batched_infer_causalmask_attnbias_dispatched< + ck::bhalf_t, + false, + true, + 256>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_bp16_no_causalmask_with_attnbias_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_bp16_no_causalmask_with_attnbias_maxk_32.cpp new file mode 100644 index 0000000000..7fe7a3f69d --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_bp16_no_causalmask_with_attnbias_maxk_32.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_batched_infer.h" + +template void run_batched_infer_causalmask_attnbias_dispatched< + ck::bhalf_t, + false, + true, + 32>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_bp16_no_causalmask_with_attnbias_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_bp16_no_causalmask_with_attnbias_maxk_64.cpp new file mode 100644 index 0000000000..3ae7733695 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_bp16_no_causalmask_with_attnbias_maxk_64.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_batched_infer.h" + +template void run_batched_infer_causalmask_attnbias_dispatched< + ck::bhalf_t, + false, + true, + 64>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_bp16_with_causalmask_no_attnbias_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_bp16_with_causalmask_no_attnbias_maxk_128.cpp new file mode 100644 index 0000000000..9757278dba --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_bp16_with_causalmask_no_attnbias_maxk_128.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_batched_infer.h" + +template void run_batched_infer_causalmask_attnbias_dispatched< + ck::bhalf_t, + true, + false, + 128>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_bp16_with_causalmask_no_attnbias_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_bp16_with_causalmask_no_attnbias_maxk_256.cpp new file mode 100644 index 0000000000..6caed9563c --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_bp16_with_causalmask_no_attnbias_maxk_256.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_batched_infer.h" + +template void run_batched_infer_causalmask_attnbias_dispatched< + ck::bhalf_t, + true, + false, + 256>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_bp16_with_causalmask_no_attnbias_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_bp16_with_causalmask_no_attnbias_maxk_32.cpp new file mode 100644 index 0000000000..4dfaa36785 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_bp16_with_causalmask_no_attnbias_maxk_32.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_batched_infer.h" + +template void run_batched_infer_causalmask_attnbias_dispatched< + ck::bhalf_t, + true, + false, + 32>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_bp16_with_causalmask_no_attnbias_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_bp16_with_causalmask_no_attnbias_maxk_64.cpp new file mode 100644 index 0000000000..fa0416c5c2 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_bp16_with_causalmask_no_attnbias_maxk_64.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_batched_infer.h" + +template void run_batched_infer_causalmask_attnbias_dispatched< + ck::bhalf_t, + true, + false, + 64>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_bp16_with_causalmask_with_attnbias_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_bp16_with_causalmask_with_attnbias_maxk_128.cpp new file mode 100644 index 0000000000..ecc90b3661 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_bp16_with_causalmask_with_attnbias_maxk_128.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_batched_infer.h" + +template void run_batched_infer_causalmask_attnbias_dispatched< + ck::bhalf_t, + true, + true, + 128>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_bp16_with_causalmask_with_attnbias_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_bp16_with_causalmask_with_attnbias_maxk_256.cpp new file mode 100644 index 0000000000..dff3a317a3 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_bp16_with_causalmask_with_attnbias_maxk_256.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_batched_infer.h" + +template void run_batched_infer_causalmask_attnbias_dispatched< + ck::bhalf_t, + true, + true, + 256>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_bp16_with_causalmask_with_attnbias_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_bp16_with_causalmask_with_attnbias_maxk_32.cpp new file mode 100644 index 0000000000..fa084941bd --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_bp16_with_causalmask_with_attnbias_maxk_32.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_batched_infer.h" + +template void run_batched_infer_causalmask_attnbias_dispatched< + ck::bhalf_t, + true, + true, + 32>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_bp16_with_causalmask_with_attnbias_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_bp16_with_causalmask_with_attnbias_maxk_64.cpp new file mode 100644 index 0000000000..d0ece69d02 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_bp16_with_causalmask_with_attnbias_maxk_64.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_batched_infer.h" + +template void run_batched_infer_causalmask_attnbias_dispatched< + ck::bhalf_t, + true, + true, + 64>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_fp16_no_causalmask_no_attnbias_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_fp16_no_causalmask_no_attnbias_maxk_128.cpp new file mode 100644 index 0000000000..8e9843a5e4 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_fp16_no_causalmask_no_attnbias_maxk_128.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_batched_infer.h" + +template void run_batched_infer_causalmask_attnbias_dispatched< + ck::half_t, + false, + false, + 128>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_fp16_no_causalmask_no_attnbias_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_fp16_no_causalmask_no_attnbias_maxk_256.cpp new file mode 100644 index 0000000000..20580c11e6 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_fp16_no_causalmask_no_attnbias_maxk_256.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_batched_infer.h" + +template void run_batched_infer_causalmask_attnbias_dispatched< + ck::half_t, + false, + false, + 256>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_fp16_no_causalmask_no_attnbias_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_fp16_no_causalmask_no_attnbias_maxk_32.cpp new file mode 100644 index 0000000000..4e4d90f820 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_fp16_no_causalmask_no_attnbias_maxk_32.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_batched_infer.h" + +template void run_batched_infer_causalmask_attnbias_dispatched< + ck::half_t, + false, + false, + 32>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_fp16_no_causalmask_no_attnbias_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_fp16_no_causalmask_no_attnbias_maxk_64.cpp new file mode 100644 index 0000000000..b36864534a --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_fp16_no_causalmask_no_attnbias_maxk_64.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_batched_infer.h" + +template void run_batched_infer_causalmask_attnbias_dispatched< + ck::half_t, + false, + false, + 64>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_fp16_no_causalmask_with_attnbias_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_fp16_no_causalmask_with_attnbias_maxk_128.cpp new file mode 100644 index 0000000000..2f16639ed7 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_fp16_no_causalmask_with_attnbias_maxk_128.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_batched_infer.h" + +template void run_batched_infer_causalmask_attnbias_dispatched< + ck::half_t, + false, + true, + 128>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_fp16_no_causalmask_with_attnbias_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_fp16_no_causalmask_with_attnbias_maxk_256.cpp new file mode 100644 index 0000000000..41f8249e99 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_fp16_no_causalmask_with_attnbias_maxk_256.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_batched_infer.h" + +template void run_batched_infer_causalmask_attnbias_dispatched< + ck::half_t, + false, + true, + 256>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_fp16_no_causalmask_with_attnbias_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_fp16_no_causalmask_with_attnbias_maxk_32.cpp new file mode 100644 index 0000000000..bfdf01423b --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_fp16_no_causalmask_with_attnbias_maxk_32.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_batched_infer.h" + +template void run_batched_infer_causalmask_attnbias_dispatched< + ck::half_t, + false, + true, + 32>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_fp16_no_causalmask_with_attnbias_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_fp16_no_causalmask_with_attnbias_maxk_64.cpp new file mode 100644 index 0000000000..550831036b --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_fp16_no_causalmask_with_attnbias_maxk_64.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_batched_infer.h" + +template void run_batched_infer_causalmask_attnbias_dispatched< + ck::half_t, + false, + true, + 64>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_fp16_with_causalmask_no_attnbias_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_fp16_with_causalmask_no_attnbias_maxk_128.cpp new file mode 100644 index 0000000000..8caa116d80 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_fp16_with_causalmask_no_attnbias_maxk_128.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_batched_infer.h" + +template void run_batched_infer_causalmask_attnbias_dispatched< + ck::half_t, + true, + false, + 128>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_fp16_with_causalmask_no_attnbias_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_fp16_with_causalmask_no_attnbias_maxk_256.cpp new file mode 100644 index 0000000000..0468ba8afe --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_fp16_with_causalmask_no_attnbias_maxk_256.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_batched_infer.h" + +template void run_batched_infer_causalmask_attnbias_dispatched< + ck::half_t, + true, + false, + 256>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_fp16_with_causalmask_no_attnbias_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_fp16_with_causalmask_no_attnbias_maxk_32.cpp new file mode 100644 index 0000000000..cd8077b510 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_fp16_with_causalmask_no_attnbias_maxk_32.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_batched_infer.h" + +template void run_batched_infer_causalmask_attnbias_dispatched< + ck::half_t, + true, + false, + 32>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_fp16_with_causalmask_no_attnbias_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_fp16_with_causalmask_no_attnbias_maxk_64.cpp new file mode 100644 index 0000000000..ed22d8fc5f --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_fp16_with_causalmask_no_attnbias_maxk_64.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_batched_infer.h" + +template void run_batched_infer_causalmask_attnbias_dispatched< + ck::half_t, + true, + false, + 64>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_fp16_with_causalmask_with_attnbias_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_fp16_with_causalmask_with_attnbias_maxk_128.cpp new file mode 100644 index 0000000000..1ae833e7d0 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_fp16_with_causalmask_with_attnbias_maxk_128.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_batched_infer.h" + +template void run_batched_infer_causalmask_attnbias_dispatched< + ck::half_t, + true, + true, + 128>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_fp16_with_causalmask_with_attnbias_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_fp16_with_causalmask_with_attnbias_maxk_256.cpp new file mode 100644 index 0000000000..bb9a177b54 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_fp16_with_causalmask_with_attnbias_maxk_256.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_batched_infer.h" + +template void run_batched_infer_causalmask_attnbias_dispatched< + ck::half_t, + true, + true, + 256>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_fp16_with_causalmask_with_attnbias_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_fp16_with_causalmask_with_attnbias_maxk_32.cpp new file mode 100644 index 0000000000..88945231f9 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_fp16_with_causalmask_with_attnbias_maxk_32.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_batched_infer.h" + +template void run_batched_infer_causalmask_attnbias_dispatched< + ck::half_t, + true, + true, + 32>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_fp16_with_causalmask_with_attnbias_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_fp16_with_causalmask_with_attnbias_maxk_64.cpp new file mode 100644 index 0000000000..330e0dfbcd --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_batched_infer_fp16_with_causalmask_with_attnbias_maxk_64.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_batched_infer.h" + +template void run_batched_infer_causalmask_attnbias_dispatched< + ck::half_t, + true, + true, + 64>(BatchedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_bp16_no_causalmask_no_attnbias_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_bp16_no_causalmask_no_attnbias_maxk_128.cpp new file mode 100644 index 0000000000..d278e2b0bc --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_bp16_no_causalmask_no_attnbias_maxk_128.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_grouped_forward.h" + +template void run_grouped_forward_causalmask_attnbias_dispatched< + ck::bhalf_t, + false, + false, + 128>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_bp16_no_causalmask_no_attnbias_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_bp16_no_causalmask_no_attnbias_maxk_256.cpp new file mode 100644 index 0000000000..2bd6d042a0 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_bp16_no_causalmask_no_attnbias_maxk_256.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_grouped_forward.h" + +template void run_grouped_forward_causalmask_attnbias_dispatched< + ck::bhalf_t, + false, + false, + 256>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_bp16_no_causalmask_no_attnbias_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_bp16_no_causalmask_no_attnbias_maxk_32.cpp new file mode 100644 index 0000000000..732381a8a0 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_bp16_no_causalmask_no_attnbias_maxk_32.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_grouped_forward.h" + +template void run_grouped_forward_causalmask_attnbias_dispatched< + ck::bhalf_t, + false, + false, + 32>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_bp16_no_causalmask_no_attnbias_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_bp16_no_causalmask_no_attnbias_maxk_64.cpp new file mode 100644 index 0000000000..352d94bb4a --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_bp16_no_causalmask_no_attnbias_maxk_64.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_grouped_forward.h" + +template void run_grouped_forward_causalmask_attnbias_dispatched< + ck::bhalf_t, + false, + false, + 64>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_bp16_no_causalmask_with_attnbias_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_bp16_no_causalmask_with_attnbias_maxk_128.cpp new file mode 100644 index 0000000000..ebd002ef4e --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_bp16_no_causalmask_with_attnbias_maxk_128.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_grouped_forward.h" + +template void run_grouped_forward_causalmask_attnbias_dispatched< + ck::bhalf_t, + false, + true, + 128>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_bp16_no_causalmask_with_attnbias_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_bp16_no_causalmask_with_attnbias_maxk_256.cpp new file mode 100644 index 0000000000..844444629a --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_bp16_no_causalmask_with_attnbias_maxk_256.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_grouped_forward.h" + +template void run_grouped_forward_causalmask_attnbias_dispatched< + ck::bhalf_t, + false, + true, + 256>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_bp16_no_causalmask_with_attnbias_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_bp16_no_causalmask_with_attnbias_maxk_32.cpp new file mode 100644 index 0000000000..52b5cb8953 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_bp16_no_causalmask_with_attnbias_maxk_32.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_grouped_forward.h" + +template void run_grouped_forward_causalmask_attnbias_dispatched< + ck::bhalf_t, + false, + true, + 32>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_bp16_no_causalmask_with_attnbias_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_bp16_no_causalmask_with_attnbias_maxk_64.cpp new file mode 100644 index 0000000000..35a0583687 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_bp16_no_causalmask_with_attnbias_maxk_64.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_grouped_forward.h" + +template void run_grouped_forward_causalmask_attnbias_dispatched< + ck::bhalf_t, + false, + true, + 64>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_bp16_with_causalmask_no_attnbias_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_bp16_with_causalmask_no_attnbias_maxk_128.cpp new file mode 100644 index 0000000000..697ce6345b --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_bp16_with_causalmask_no_attnbias_maxk_128.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_grouped_forward.h" + +template void run_grouped_forward_causalmask_attnbias_dispatched< + ck::bhalf_t, + true, + false, + 128>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_bp16_with_causalmask_no_attnbias_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_bp16_with_causalmask_no_attnbias_maxk_256.cpp new file mode 100644 index 0000000000..cc24c03c0f --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_bp16_with_causalmask_no_attnbias_maxk_256.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_grouped_forward.h" + +template void run_grouped_forward_causalmask_attnbias_dispatched< + ck::bhalf_t, + true, + false, + 256>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_bp16_with_causalmask_no_attnbias_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_bp16_with_causalmask_no_attnbias_maxk_32.cpp new file mode 100644 index 0000000000..e0d0f9e03f --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_bp16_with_causalmask_no_attnbias_maxk_32.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_grouped_forward.h" + +template void run_grouped_forward_causalmask_attnbias_dispatched< + ck::bhalf_t, + true, + false, + 32>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_bp16_with_causalmask_no_attnbias_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_bp16_with_causalmask_no_attnbias_maxk_64.cpp new file mode 100644 index 0000000000..c658c89f2f --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_bp16_with_causalmask_no_attnbias_maxk_64.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_grouped_forward.h" + +template void run_grouped_forward_causalmask_attnbias_dispatched< + ck::bhalf_t, + true, + false, + 64>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_bp16_with_causalmask_with_attnbias_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_bp16_with_causalmask_with_attnbias_maxk_128.cpp new file mode 100644 index 0000000000..785e62d78a --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_bp16_with_causalmask_with_attnbias_maxk_128.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_grouped_forward.h" + +template void run_grouped_forward_causalmask_attnbias_dispatched< + ck::bhalf_t, + true, + true, + 128>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_bp16_with_causalmask_with_attnbias_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_bp16_with_causalmask_with_attnbias_maxk_256.cpp new file mode 100644 index 0000000000..83001360bd --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_bp16_with_causalmask_with_attnbias_maxk_256.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_grouped_forward.h" + +template void run_grouped_forward_causalmask_attnbias_dispatched< + ck::bhalf_t, + true, + true, + 256>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_bp16_with_causalmask_with_attnbias_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_bp16_with_causalmask_with_attnbias_maxk_32.cpp new file mode 100644 index 0000000000..ed45ccf363 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_bp16_with_causalmask_with_attnbias_maxk_32.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_grouped_forward.h" + +template void run_grouped_forward_causalmask_attnbias_dispatched< + ck::bhalf_t, + true, + true, + 32>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_bp16_with_causalmask_with_attnbias_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_bp16_with_causalmask_with_attnbias_maxk_64.cpp new file mode 100644 index 0000000000..f0b639ef65 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_bp16_with_causalmask_with_attnbias_maxk_64.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_grouped_forward.h" + +template void run_grouped_forward_causalmask_attnbias_dispatched< + ck::bhalf_t, + true, + true, + 64>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_fp16_no_causalmask_no_attnbias_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_fp16_no_causalmask_no_attnbias_maxk_128.cpp new file mode 100644 index 0000000000..08bf47cd57 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_fp16_no_causalmask_no_attnbias_maxk_128.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_grouped_forward.h" + +template void run_grouped_forward_causalmask_attnbias_dispatched< + ck::half_t, + false, + false, + 128>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_fp16_no_causalmask_no_attnbias_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_fp16_no_causalmask_no_attnbias_maxk_256.cpp new file mode 100644 index 0000000000..8c4c0c440e --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_fp16_no_causalmask_no_attnbias_maxk_256.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_grouped_forward.h" + +template void run_grouped_forward_causalmask_attnbias_dispatched< + ck::half_t, + false, + false, + 256>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_fp16_no_causalmask_no_attnbias_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_fp16_no_causalmask_no_attnbias_maxk_32.cpp new file mode 100644 index 0000000000..2ff6c73e75 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_fp16_no_causalmask_no_attnbias_maxk_32.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_grouped_forward.h" + +template void run_grouped_forward_causalmask_attnbias_dispatched< + ck::half_t, + false, + false, + 32>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_fp16_no_causalmask_no_attnbias_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_fp16_no_causalmask_no_attnbias_maxk_64.cpp new file mode 100644 index 0000000000..b5ec1a7817 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_fp16_no_causalmask_no_attnbias_maxk_64.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_grouped_forward.h" + +template void run_grouped_forward_causalmask_attnbias_dispatched< + ck::half_t, + false, + false, + 64>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_fp16_no_causalmask_with_attnbias_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_fp16_no_causalmask_with_attnbias_maxk_128.cpp new file mode 100644 index 0000000000..c7ba7f09e7 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_fp16_no_causalmask_with_attnbias_maxk_128.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_grouped_forward.h" + +template void run_grouped_forward_causalmask_attnbias_dispatched< + ck::half_t, + false, + true, + 128>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_fp16_no_causalmask_with_attnbias_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_fp16_no_causalmask_with_attnbias_maxk_256.cpp new file mode 100644 index 0000000000..577f1a1aee --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_fp16_no_causalmask_with_attnbias_maxk_256.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_grouped_forward.h" + +template void run_grouped_forward_causalmask_attnbias_dispatched< + ck::half_t, + false, + true, + 256>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_fp16_no_causalmask_with_attnbias_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_fp16_no_causalmask_with_attnbias_maxk_32.cpp new file mode 100644 index 0000000000..cd1bda5d13 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_fp16_no_causalmask_with_attnbias_maxk_32.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_grouped_forward.h" + +template void run_grouped_forward_causalmask_attnbias_dispatched< + ck::half_t, + false, + true, + 32>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_fp16_no_causalmask_with_attnbias_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_fp16_no_causalmask_with_attnbias_maxk_64.cpp new file mode 100644 index 0000000000..caa6f0d164 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_fp16_no_causalmask_with_attnbias_maxk_64.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_grouped_forward.h" + +template void run_grouped_forward_causalmask_attnbias_dispatched< + ck::half_t, + false, + true, + 64>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_fp16_with_causalmask_no_attnbias_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_fp16_with_causalmask_no_attnbias_maxk_128.cpp new file mode 100644 index 0000000000..e0349f471d --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_fp16_with_causalmask_no_attnbias_maxk_128.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_grouped_forward.h" + +template void run_grouped_forward_causalmask_attnbias_dispatched< + ck::half_t, + true, + false, + 128>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_fp16_with_causalmask_no_attnbias_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_fp16_with_causalmask_no_attnbias_maxk_256.cpp new file mode 100644 index 0000000000..58d7cec792 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_fp16_with_causalmask_no_attnbias_maxk_256.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_grouped_forward.h" + +template void run_grouped_forward_causalmask_attnbias_dispatched< + ck::half_t, + true, + false, + 256>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_fp16_with_causalmask_no_attnbias_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_fp16_with_causalmask_no_attnbias_maxk_32.cpp new file mode 100644 index 0000000000..a9a2a191e2 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_fp16_with_causalmask_no_attnbias_maxk_32.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_grouped_forward.h" + +template void run_grouped_forward_causalmask_attnbias_dispatched< + ck::half_t, + true, + false, + 32>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_fp16_with_causalmask_no_attnbias_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_fp16_with_causalmask_no_attnbias_maxk_64.cpp new file mode 100644 index 0000000000..8eb2447a8f --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_fp16_with_causalmask_no_attnbias_maxk_64.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_grouped_forward.h" + +template void run_grouped_forward_causalmask_attnbias_dispatched< + ck::half_t, + true, + false, + 64>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_fp16_with_causalmask_with_attnbias_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_fp16_with_causalmask_with_attnbias_maxk_128.cpp new file mode 100644 index 0000000000..c83769098b --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_fp16_with_causalmask_with_attnbias_maxk_128.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_grouped_forward.h" + +template void run_grouped_forward_causalmask_attnbias_dispatched< + ck::half_t, + true, + true, + 128>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_fp16_with_causalmask_with_attnbias_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_fp16_with_causalmask_with_attnbias_maxk_256.cpp new file mode 100644 index 0000000000..fe21d52feb --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_fp16_with_causalmask_with_attnbias_maxk_256.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_grouped_forward.h" + +template void run_grouped_forward_causalmask_attnbias_dispatched< + ck::half_t, + true, + true, + 256>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_fp16_with_causalmask_with_attnbias_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_fp16_with_causalmask_with_attnbias_maxk_32.cpp new file mode 100644 index 0000000000..6bedae2d29 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_fp16_with_causalmask_with_attnbias_maxk_32.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_grouped_forward.h" + +template void run_grouped_forward_causalmask_attnbias_dispatched< + ck::half_t, + true, + true, + 32>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_fp16_with_causalmask_with_attnbias_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_fp16_with_causalmask_with_attnbias_maxk_64.cpp new file mode 100644 index 0000000000..a45a99b804 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_forward_fp16_with_causalmask_with_attnbias_maxk_64.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_grouped_forward.h" + +template void run_grouped_forward_causalmask_attnbias_dispatched< + ck::half_t, + true, + true, + 64>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_bp16_no_causalmask_no_attnbias_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_bp16_no_causalmask_no_attnbias_maxk_128.cpp new file mode 100644 index 0000000000..54cbec7ec3 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_bp16_no_causalmask_no_attnbias_maxk_128.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_grouped_infer.h" + +template void run_grouped_infer_causalmask_attnbias_dispatched< + ck::bhalf_t, + false, + false, + 128>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_bp16_no_causalmask_no_attnbias_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_bp16_no_causalmask_no_attnbias_maxk_256.cpp new file mode 100644 index 0000000000..12b67ea453 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_bp16_no_causalmask_no_attnbias_maxk_256.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_grouped_infer.h" + +template void run_grouped_infer_causalmask_attnbias_dispatched< + ck::bhalf_t, + false, + false, + 256>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_bp16_no_causalmask_no_attnbias_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_bp16_no_causalmask_no_attnbias_maxk_32.cpp new file mode 100644 index 0000000000..d6c6c1a5d1 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_bp16_no_causalmask_no_attnbias_maxk_32.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_grouped_infer.h" + +template void run_grouped_infer_causalmask_attnbias_dispatched< + ck::bhalf_t, + false, + false, + 32>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_bp16_no_causalmask_no_attnbias_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_bp16_no_causalmask_no_attnbias_maxk_64.cpp new file mode 100644 index 0000000000..c74dbe2000 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_bp16_no_causalmask_no_attnbias_maxk_64.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_grouped_infer.h" + +template void run_grouped_infer_causalmask_attnbias_dispatched< + ck::bhalf_t, + false, + false, + 64>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_bp16_no_causalmask_with_attnbias_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_bp16_no_causalmask_with_attnbias_maxk_128.cpp new file mode 100644 index 0000000000..35b522a6ab --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_bp16_no_causalmask_with_attnbias_maxk_128.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_grouped_infer.h" + +template void run_grouped_infer_causalmask_attnbias_dispatched< + ck::bhalf_t, + false, + true, + 128>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_bp16_no_causalmask_with_attnbias_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_bp16_no_causalmask_with_attnbias_maxk_256.cpp new file mode 100644 index 0000000000..4fb8bdd598 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_bp16_no_causalmask_with_attnbias_maxk_256.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_grouped_infer.h" + +template void run_grouped_infer_causalmask_attnbias_dispatched< + ck::bhalf_t, + false, + true, + 256>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_bp16_no_causalmask_with_attnbias_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_bp16_no_causalmask_with_attnbias_maxk_32.cpp new file mode 100644 index 0000000000..1d2cd2656f --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_bp16_no_causalmask_with_attnbias_maxk_32.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_grouped_infer.h" + +template void run_grouped_infer_causalmask_attnbias_dispatched< + ck::bhalf_t, + false, + true, + 32>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_bp16_no_causalmask_with_attnbias_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_bp16_no_causalmask_with_attnbias_maxk_64.cpp new file mode 100644 index 0000000000..2ccb25769a --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_bp16_no_causalmask_with_attnbias_maxk_64.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_grouped_infer.h" + +template void run_grouped_infer_causalmask_attnbias_dispatched< + ck::bhalf_t, + false, + true, + 64>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_bp16_with_causalmask_no_attnbias_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_bp16_with_causalmask_no_attnbias_maxk_128.cpp new file mode 100644 index 0000000000..2f8ea04e7f --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_bp16_with_causalmask_no_attnbias_maxk_128.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_grouped_infer.h" + +template void run_grouped_infer_causalmask_attnbias_dispatched< + ck::bhalf_t, + true, + false, + 128>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_bp16_with_causalmask_no_attnbias_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_bp16_with_causalmask_no_attnbias_maxk_256.cpp new file mode 100644 index 0000000000..f10999c7cd --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_bp16_with_causalmask_no_attnbias_maxk_256.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_grouped_infer.h" + +template void run_grouped_infer_causalmask_attnbias_dispatched< + ck::bhalf_t, + true, + false, + 256>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_bp16_with_causalmask_no_attnbias_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_bp16_with_causalmask_no_attnbias_maxk_32.cpp new file mode 100644 index 0000000000..f877720240 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_bp16_with_causalmask_no_attnbias_maxk_32.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_grouped_infer.h" + +template void run_grouped_infer_causalmask_attnbias_dispatched< + ck::bhalf_t, + true, + false, + 32>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_bp16_with_causalmask_no_attnbias_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_bp16_with_causalmask_no_attnbias_maxk_64.cpp new file mode 100644 index 0000000000..d2b85141cf --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_bp16_with_causalmask_no_attnbias_maxk_64.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_grouped_infer.h" + +template void run_grouped_infer_causalmask_attnbias_dispatched< + ck::bhalf_t, + true, + false, + 64>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_bp16_with_causalmask_with_attnbias_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_bp16_with_causalmask_with_attnbias_maxk_128.cpp new file mode 100644 index 0000000000..fe5b8db516 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_bp16_with_causalmask_with_attnbias_maxk_128.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_grouped_infer.h" + +template void run_grouped_infer_causalmask_attnbias_dispatched< + ck::bhalf_t, + true, + true, + 128>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_bp16_with_causalmask_with_attnbias_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_bp16_with_causalmask_with_attnbias_maxk_256.cpp new file mode 100644 index 0000000000..593d4fda19 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_bp16_with_causalmask_with_attnbias_maxk_256.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_grouped_infer.h" + +template void run_grouped_infer_causalmask_attnbias_dispatched< + ck::bhalf_t, + true, + true, + 256>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_bp16_with_causalmask_with_attnbias_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_bp16_with_causalmask_with_attnbias_maxk_32.cpp new file mode 100644 index 0000000000..941dcd50ed --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_bp16_with_causalmask_with_attnbias_maxk_32.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_grouped_infer.h" + +template void run_grouped_infer_causalmask_attnbias_dispatched< + ck::bhalf_t, + true, + true, + 32>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_bp16_with_causalmask_with_attnbias_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_bp16_with_causalmask_with_attnbias_maxk_64.cpp new file mode 100644 index 0000000000..82183313ac --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_bp16_with_causalmask_with_attnbias_maxk_64.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_grouped_infer.h" + +template void run_grouped_infer_causalmask_attnbias_dispatched< + ck::bhalf_t, + true, + true, + 64>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_fp16_no_causalmask_no_attnbias_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_fp16_no_causalmask_no_attnbias_maxk_128.cpp new file mode 100644 index 0000000000..c3f52f074b --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_fp16_no_causalmask_no_attnbias_maxk_128.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_grouped_infer.h" + +template void run_grouped_infer_causalmask_attnbias_dispatched< + ck::half_t, + false, + false, + 128>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_fp16_no_causalmask_no_attnbias_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_fp16_no_causalmask_no_attnbias_maxk_256.cpp new file mode 100644 index 0000000000..5d4882d2b1 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_fp16_no_causalmask_no_attnbias_maxk_256.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_grouped_infer.h" + +template void run_grouped_infer_causalmask_attnbias_dispatched< + ck::half_t, + false, + false, + 256>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_fp16_no_causalmask_no_attnbias_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_fp16_no_causalmask_no_attnbias_maxk_32.cpp new file mode 100644 index 0000000000..6e0b2914d8 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_fp16_no_causalmask_no_attnbias_maxk_32.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_grouped_infer.h" + +template void run_grouped_infer_causalmask_attnbias_dispatched< + ck::half_t, + false, + false, + 32>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_fp16_no_causalmask_no_attnbias_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_fp16_no_causalmask_no_attnbias_maxk_64.cpp new file mode 100644 index 0000000000..b49d099089 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_fp16_no_causalmask_no_attnbias_maxk_64.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_grouped_infer.h" + +template void run_grouped_infer_causalmask_attnbias_dispatched< + ck::half_t, + false, + false, + 64>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_fp16_no_causalmask_with_attnbias_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_fp16_no_causalmask_with_attnbias_maxk_128.cpp new file mode 100644 index 0000000000..1741265b25 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_fp16_no_causalmask_with_attnbias_maxk_128.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_grouped_infer.h" + +template void run_grouped_infer_causalmask_attnbias_dispatched< + ck::half_t, + false, + true, + 128>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_fp16_no_causalmask_with_attnbias_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_fp16_no_causalmask_with_attnbias_maxk_256.cpp new file mode 100644 index 0000000000..4197ba831d --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_fp16_no_causalmask_with_attnbias_maxk_256.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_grouped_infer.h" + +template void run_grouped_infer_causalmask_attnbias_dispatched< + ck::half_t, + false, + true, + 256>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_fp16_no_causalmask_with_attnbias_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_fp16_no_causalmask_with_attnbias_maxk_32.cpp new file mode 100644 index 0000000000..88ac7b42c5 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_fp16_no_causalmask_with_attnbias_maxk_32.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_grouped_infer.h" + +template void run_grouped_infer_causalmask_attnbias_dispatched< + ck::half_t, + false, + true, + 32>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_fp16_no_causalmask_with_attnbias_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_fp16_no_causalmask_with_attnbias_maxk_64.cpp new file mode 100644 index 0000000000..c717aed649 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_fp16_no_causalmask_with_attnbias_maxk_64.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_grouped_infer.h" + +template void run_grouped_infer_causalmask_attnbias_dispatched< + ck::half_t, + false, + true, + 64>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_fp16_with_causalmask_no_attnbias_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_fp16_with_causalmask_no_attnbias_maxk_128.cpp new file mode 100644 index 0000000000..5449dfd322 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_fp16_with_causalmask_no_attnbias_maxk_128.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_grouped_infer.h" + +template void run_grouped_infer_causalmask_attnbias_dispatched< + ck::half_t, + true, + false, + 128>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_fp16_with_causalmask_no_attnbias_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_fp16_with_causalmask_no_attnbias_maxk_256.cpp new file mode 100644 index 0000000000..73bf0e6d69 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_fp16_with_causalmask_no_attnbias_maxk_256.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_grouped_infer.h" + +template void run_grouped_infer_causalmask_attnbias_dispatched< + ck::half_t, + true, + false, + 256>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_fp16_with_causalmask_no_attnbias_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_fp16_with_causalmask_no_attnbias_maxk_32.cpp new file mode 100644 index 0000000000..55c80b4c9e --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_fp16_with_causalmask_no_attnbias_maxk_32.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_grouped_infer.h" + +template void run_grouped_infer_causalmask_attnbias_dispatched< + ck::half_t, + true, + false, + 32>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_fp16_with_causalmask_no_attnbias_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_fp16_with_causalmask_no_attnbias_maxk_64.cpp new file mode 100644 index 0000000000..76cafe4e03 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_fp16_with_causalmask_no_attnbias_maxk_64.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_grouped_infer.h" + +template void run_grouped_infer_causalmask_attnbias_dispatched< + ck::half_t, + true, + false, + 64>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_fp16_with_causalmask_with_attnbias_maxk_128.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_fp16_with_causalmask_with_attnbias_maxk_128.cpp new file mode 100644 index 0000000000..8fe0d31e7c --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_fp16_with_causalmask_with_attnbias_maxk_128.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_grouped_infer.h" + +template void run_grouped_infer_causalmask_attnbias_dispatched< + ck::half_t, + true, + true, + 128>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_fp16_with_causalmask_with_attnbias_maxk_256.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_fp16_with_causalmask_with_attnbias_maxk_256.cpp new file mode 100644 index 0000000000..aeff1e2c67 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_fp16_with_causalmask_with_attnbias_maxk_256.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_grouped_infer.h" + +template void run_grouped_infer_causalmask_attnbias_dispatched< + ck::half_t, + true, + true, + 256>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_fp16_with_causalmask_with_attnbias_maxk_32.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_fp16_with_causalmask_with_attnbias_maxk_32.cpp new file mode 100644 index 0000000000..f8fed71069 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_fp16_with_causalmask_with_attnbias_maxk_32.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_grouped_infer.h" + +template void run_grouped_infer_causalmask_attnbias_dispatched< + ck::half_t, + true, + true, + 32>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_fp16_with_causalmask_with_attnbias_maxk_64.cpp b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_fp16_with_causalmask_with_attnbias_maxk_64.cpp new file mode 100644 index 0000000000..ec5f029d78 --- /dev/null +++ b/xformers/csrc/attention/hip_fmha/instances/ck_tiled_fmha_grouped_infer_fp16_with_causalmask_with_attnbias_maxk_64.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "ck_tiled_fmha_grouped_infer.h" + +template void run_grouped_infer_causalmask_attnbias_dispatched< + ck::half_t, + true, + true, + 64>(GroupedForwardParams& param, hipStream_t stream); diff --git a/xformers/info.py b/xformers/info.py index 1a17586e66..af0fa5b2f4 100644 --- a/xformers/info.py +++ b/xformers/info.py @@ -49,6 +49,7 @@ def print_info(): if build_info is not None: features["build.info"] = "available" features["build.cuda_version"] = build_info.cuda_version + features["build.hip_version"] = build_info.hip_version features["build.python_version"] = build_info.python_version features["build.torch_version"] = build_info.torch_version for k, v in build_info.build_env.items(): diff --git a/xformers/ops/__init__.py b/xformers/ops/__init__.py index d638dc73f3..55db6a0888 100644 --- a/xformers/ops/__init__.py +++ b/xformers/ops/__init__.py @@ -11,10 +11,12 @@ AttentionOpBase, AttentionOpDispatch, LowerTriangularMask, + MemoryEfficientAttentionCkOp, MemoryEfficientAttentionCutlassFwdFlashBwOp, MemoryEfficientAttentionCutlassOp, MemoryEfficientAttentionFlashAttentionOp, MemoryEfficientAttentionOp, + MemoryEfficientAttentionSplitKCkOp, MemoryEfficientAttentionTritonFwdFlashBwOp, TritonFlashAttentionOp, memory_efficient_attention, @@ -87,6 +89,8 @@ def masked_matmul(a, b, mask=None): "MemoryEfficientAttentionOp", "MemoryEfficientAttentionTritonFwdFlashBwOp", "TritonFlashAttentionOp", + "MemoryEfficientAttentionCkOp", + "MemoryEfficientAttentionSplitKCkOp", "memory_efficient_attention", "memory_efficient_attention_backward", "memory_efficient_attention_forward", diff --git a/xformers/ops/common.py b/xformers/ops/common.py index f86305c083..7cb0a04f0f 100644 --- a/xformers/ops/common.py +++ b/xformers/ops/common.py @@ -38,7 +38,11 @@ class BaseOperator: @classmethod def is_available(cls) -> bool: - if cls.OPERATOR is None or cls.OPERATOR.__name__ == "no_such_operator": + # cls.OPERATOR can be either a kernel or a Triton Autotuner object, which doesn't have __name__ + if ( + cls.OPERATOR is None + or getattr(cls.OPERATOR, "__name__", "") == "no_such_operator" + ): return False return True diff --git a/xformers/ops/fmha/__init__.py b/xformers/ops/fmha/__init__.py index c20b6f7688..8a79e7d64a 100644 --- a/xformers/ops/fmha/__init__.py +++ b/xformers/ops/fmha/__init__.py @@ -7,7 +7,18 @@ import torch -from . import attn_bias, cutlass, decoder, flash, small_k, triton, triton_splitk +from . import ( + attn_bias, + ck, + ck_decoder, + ck_splitk, + cutlass, + decoder, + flash, + small_k, + triton, + triton_splitk, +) from .attn_bias import ( AttentionBias, BlockDiagonalCausalWithOffsetPaddedKeysMask, @@ -33,7 +44,10 @@ MemoryEfficientAttentionTritonFwdFlashBwOp = (triton.FwOp, flash.BwOp) MemoryEfficientAttentionFlashAttentionOp = (flash.FwOp, flash.BwOp) MemoryEfficientAttentionOp = (small_k.FwOp, small_k.BwOp) -TritonFlashAttentionOp = (triton.FwOp, triton.BwOp) +TritonFlashAttentionOp = (triton.FwOp, cutlass.BwOp if torch.version.cuda else ck.BwOp) +MemoryEfficientAttentionCkOp = (ck.FwOp, ck.BwOp) +MemoryEfficientAttentionCkDecoderOp = (ck_decoder.FwOp, ck.BwOp) +MemoryEfficientAttentionSplitKCkOp = (ck_splitk.FwOp, ck.BwOp) class _fMHA(torch.autograd.Function): @@ -560,7 +574,7 @@ def merge_attentions( ALL_FW_OPS: Sequence[Type[AttentionFwOpBase]] = [ - cutlass.FwOp, + cutlass.FwOp if torch.version.cuda else ck.FwOp, flash.FwOp, triton.FwOp, small_k.FwOp, @@ -568,9 +582,8 @@ def merge_attentions( ] ALL_BW_OPS: Sequence[Type[AttentionBwOpBase]] = [ - cutlass.BwOp, + cutlass.BwOp if torch.version.cuda else ck.BwOp, flash.BwOp, - triton.BwOp, small_k.BwOp, ] @@ -587,6 +600,8 @@ def merge_attentions( "MemoryEfficientAttentionOp", "TritonFlashAttentionOp", "memory_efficient_attention", + "MemoryEfficientAttentionCkOp", + "MemoryEfficientAttentionCkDecoderOp", "ALL_FW_OPS", "ALL_BW_OPS", "attn_bias", diff --git a/xformers/ops/fmha/attn_bias.py b/xformers/ops/fmha/attn_bias.py index bf6b12842d..0687df62bf 100644 --- a/xformers/ops/fmha/attn_bias.py +++ b/xformers/ops/fmha/attn_bias.py @@ -458,8 +458,9 @@ def from_seqlens_padded( seqlen <= padding for seqlen in seqlens ), f"Seqlens {seqlens} Padding {padding}" seqstart_py = list(range(0, len(seqlens) * padding + 1, padding)) + seqlen = torch.tensor(seqlens, dtype=torch.int32) return cls( - seqlen=torch.tensor(seqlens, dtype=torch.int32), + seqlen=seqlen, seqlen_py=seqlens, max_seqlen=max(seqlens), min_seqlen=min(seqlens), diff --git a/xformers/ops/fmha/ck.py b/xformers/ops/fmha/ck.py new file mode 100644 index 0000000000..aaca59113d --- /dev/null +++ b/xformers/ops/fmha/ck.py @@ -0,0 +1,514 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +from dataclasses import replace +from enum import Enum +from functools import partial +from typing import Any, List, Mapping, Optional, Set, Tuple, Union + +import torch + +from ..common import get_xformers_operator, register_operator +from . import attn_bias +from .attn_bias import ( + AttentionBias, + BlockDiagonalCausalLocalAttentionFromBottomRightMask, + BlockDiagonalCausalLocalAttentionMask, + BlockDiagonalCausalMask, + BlockDiagonalCausalWithOffsetPaddedKeysMask, + BlockDiagonalMask, + LowerTriangularFromBottomRightLocalAttentionMask, + LowerTriangularFromBottomRightMask, + LowerTriangularMask, + LowerTriangularMaskWithTensorBias, +) +from .common import ( + AttentionBwOpBase, + AttentionFwOpBase, + Context, + Gradients, + Inputs, + _attn_bias_apply, + check_lastdim_alignment_stride1, +) + + +def _minimum_gemm_alignment(inp: Inputs) -> int: + return 1 + + +def _get_seqlen_info( + inp: Inputs, +) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], int]: + attn_bias = inp.attn_bias + if isinstance( + attn_bias, (BlockDiagonalMask, BlockDiagonalCausalWithOffsetPaddedKeysMask) + ): + seqstart_k = attn_bias.k_seqinfo.seqstart + seqstart_q = attn_bias.q_seqinfo.seqstart + max_seqlen_q = attn_bias.q_seqinfo.max_seqlen + else: + seqstart_k = None + seqstart_q = None + max_seqlen_q = -1 + + return ( + seqstart_k, + seqstart_q, + max_seqlen_q, + ) + + +def _get_tensor_bias( + attn_bias: Optional[Union[torch.Tensor, AttentionBias]] +) -> Optional[torch.Tensor]: + if isinstance(attn_bias, torch.Tensor): + return attn_bias + elif isinstance(attn_bias, LowerTriangularMaskWithTensorBias): + return attn_bias._bias + return None + + +def _check_bias_alignment( + reasons: List[str], attn_bias: Optional[Union[torch.Tensor, AttentionBias]] +) -> None: + attn_bias_tensor = _get_tensor_bias(attn_bias) + if attn_bias_tensor is not None: + alignment = 128 // torch.finfo(attn_bias_tensor.dtype).bits + show_padding_hint = False + for d in range(attn_bias_tensor.ndim - 1): + if attn_bias_tensor.stride(d) % alignment != 0: + reasons.append( + f"attn_bias.stride(-2) % {alignment} != 0 (attn_bias.stride() = {attn_bias_tensor.stride()})" + ) + show_padding_hint = True + if show_padding_hint: + reasons.append( + """\ +HINT: To use an `attn_bias` with a sequence length that is not a multiple of 8, \ +you need to ensure memory is aligned by slicing a bigger tensor. \ +Example: use `attn_bias = torch.zeros([1, 1, 5, 8])[:,:,:,:5]` instead of `torch.zeros([1, 1, 5, 5])`""" + ) + # We can have stride=0 sometimes if dimension=1 + if attn_bias_tensor.stride(-1) > 1: + reasons.append( + f"attn_bias.stride(-1) > 1 (attn_bias.stride() = {attn_bias_tensor.stride()}) - " + "you should call `.contiguous()` on the bias" + ) + + +def _check_large_shapes(reasons: List[str], inp: Inputs) -> None: + """CK kernel throws "Memory access fault by GPU node-2" when B * T >= 2**20, might be some index overflow. + To reproduce, remove this function and run benchmark_mem_eff_attention with ParlAI model shape (256, 4096, 16, 64). + This needs further debugging, for now let's not support such shapes. + """ + b_t_limit = 1024**2 + q_too_large = inp.query.shape[0] * inp.query.shape[1] >= b_t_limit + k_too_large = inp.key.shape[0] * inp.key.shape[1] >= b_t_limit + v_too_large = inp.value.shape[0] * inp.value.shape[1] >= b_t_limit + if q_too_large or k_too_large or v_too_large: + reasons.append( + "Input is too large: product of first two dimensions of q/k/v must be < 2**20" + ) + + +class _CustomMaskType(int, Enum): + """ + (Matches CustomMaskType in C++.) + """ + + NoCustomMask = 0 + CausalFromTopLeft = 1 + CausalFromBottomRight = 2 + + +def _custom_mask_type(bias: Optional[Union[torch.Tensor, AttentionBias]]) -> int: + if isinstance( + bias, + ( + LowerTriangularMask, + BlockDiagonalCausalMask, + BlockDiagonalCausalLocalAttentionMask, + ), + ): + return int(_CustomMaskType.CausalFromTopLeft) + if isinstance( + bias, + ( + LowerTriangularFromBottomRightMask, + LowerTriangularFromBottomRightLocalAttentionMask, + attn_bias.BlockDiagonalCausalFromBottomRightMask, + BlockDiagonalCausalWithOffsetPaddedKeysMask, + BlockDiagonalCausalLocalAttentionFromBottomRightMask, + ), + ): + return int(_CustomMaskType.CausalFromBottomRight) + return int(_CustomMaskType.NoCustomMask) + + +@register_operator +class FwOp(AttentionFwOpBase): + """xFormers' MHA kernel based on Composable Kernel.""" + + OPERATOR = get_xformers_operator("efficient_attention_forward_ck") + SUPPORTED_DEVICES: Set[str] = {"cuda"} + SUPPORTED_DTYPES: Set[torch.dtype] = {torch.half, torch.bfloat16} + SUPPORTED_MAX_K = 256 + + SUPPORTED_ATTN_BIAS_TYPES: Set[Any] = { + type(None), + torch.Tensor, + LowerTriangularMask, + LowerTriangularFromBottomRightMask, + LowerTriangularFromBottomRightLocalAttentionMask, + LowerTriangularMaskWithTensorBias, + BlockDiagonalMask, + BlockDiagonalCausalMask, + BlockDiagonalCausalWithOffsetPaddedKeysMask, + attn_bias.BlockDiagonalCausalFromBottomRightMask, + attn_bias.BlockDiagonalCausalLocalAttentionMask, + BlockDiagonalCausalLocalAttentionFromBottomRightMask, + } + + SUPPORTS_DROPOUT = False + SUPPORTS_CUSTOM_SCALE = True + SUPPORTS_DIFFERENT_VALUE_EMBED = True + SUPPORTS_BMGHK = True + NAME = "ckF" + + ERROR_ATOL: Mapping[torch.dtype, float] = { + torch.float: 3e-4, + torch.half: 4e-3, + torch.bfloat16: 2.8e-2, + } + ERROR_RTOL: Mapping[torch.dtype, float] = { + torch.float: 2e-5, + torch.half: 4e-4, + torch.bfloat16: 2e-2, + } + + _TEST_K: List[int] = [ + 32, # 64x64 kernel + 128, # 64x128 kernel + 256, # 64x128 with accumulation in gmem + ] + + @classmethod + def apply( + cls, inp: Inputs, needs_gradient: bool + ) -> Tuple[torch.Tensor, Optional[Context]]: + if type(inp.attn_bias) not in FwOp.SUPPORTED_ATTN_BIAS_TYPES: + raise NotImplementedError("Unsupported attn_bias type") + if inp.query.ndim in [3, 4]: + return cls.apply_bmhk(inp, needs_gradient=needs_gradient) + assert inp.query.ndim == 5, f"query has shape {inp.query.shape}" + ctx: Optional[Context] = None + # XXX: Hackfix for BMGHK with H=1 + # In that case we don't want to run G different streams because it adds + # some overhead + if inp.query.ndim == 5 and inp.query.shape[3] == 1: + slice_op = partial(torch.squeeze, dim=3) + inp = replace( + inp, + query=slice_op(inp.query), + key=slice_op(inp.key), + value=slice_op(inp.value), + attn_bias=_attn_bias_apply( + inp.attn_bias, partial(torch.squeeze, dim=2) + ), + ) + out, ctx = cls.apply_bmhk(inp, needs_gradient=needs_gradient) + out = out.unsqueeze(3) + if ctx is not None: + ctx = replace(ctx, lse=ctx.lse.unsqueeze(1), out=out) + return out, ctx + + # Workaround until this is properly implemented in C++ + # run each head group in a different stream + n_groups = inp.key.shape[2] + main_stream = torch.cuda.current_stream() + streams = [main_stream] + [ + torch.cuda.Stream(device=inp.query.device) for _ in range(n_groups - 1) + ] + outs = [] + for group, stream in enumerate(streams): + stream.wait_stream(main_stream) + with torch.cuda.stream(stream): + query = inp.query[:, :, group] + key = inp.key[:, :, group] + value = inp.value[:, :, group] + bias = _attn_bias_apply( + inp.attn_bias, partial(torch.select, dim=1, index=group) + ) + outs.append( + cls.apply_bmhk( + replace(inp, query=query, key=key, value=value, attn_bias=bias), + needs_gradient=needs_gradient, + ) + ) + for s in streams[1:]: + main_stream.wait_stream(s) + out = torch.stack([o[0] for o in outs], dim=2) + if needs_gradient: + ctx = Context( + out=out, + lse=torch.stack([o[1].lse for o in outs], dim=1), # type: ignore + op_bw=outs[0][1].op_bw, # type: ignore + ) + return out, ctx + + @classmethod + def apply_bmhk( + cls, inp: Inputs, needs_gradient: bool + ) -> Tuple[torch.Tensor, Optional[Context]]: + if type(inp.attn_bias) not in FwOp.SUPPORTED_ATTN_BIAS_TYPES: + raise NotImplementedError("Unsupported attn_bias type") + seqstart_k, seqstart_q, max_seqlen_q = _get_seqlen_info(inp) + out, lse, rng_seed, rng_offset = cls.OPERATOR( + query=inp.query, + key=inp.key, + value=inp.value, + attn_bias=_get_tensor_bias(inp.attn_bias), + seqstart_q=seqstart_q, + seqstart_k=seqstart_k, + max_seqlen_q=max_seqlen_q, + dropout_p=inp.p, + compute_logsumexp=needs_gradient, + custom_mask_type=_custom_mask_type(inp.attn_bias), + scale=inp.scale, + seqlen_k=( + inp.attn_bias.k_seqinfo.seqlen + if isinstance( + inp.attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask + ) + else None + ), + window_size=( + inp.attn_bias._window_size + if isinstance( + inp.attn_bias, + ( + BlockDiagonalCausalLocalAttentionMask, + BlockDiagonalCausalLocalAttentionFromBottomRightMask, + LowerTriangularFromBottomRightLocalAttentionMask, + ), + ) + else None + ), + ) + + ctx: Optional[Context] = None + if needs_gradient: + ctx = Context( + out=out, + lse=lse, + # cutlass forward is only compatible with cutlass backward if + # dropout is used (because of the way RNG states are passed and the + # way random numbers are generated during backward) + op_bw=BwOp if inp.p != 0 else None, + ) + if inp.p != 0: + ctx.rng_state = torch.tensor( + [rng_seed, rng_offset], dtype=torch.int64, device="cpu" + ) + return out, ctx + + @classmethod + def not_supported_reasons(cls, d: Inputs) -> List[str]: + reasons = super(FwOp, cls).not_supported_reasons(d) + matmul_alignment_mn = _minimum_gemm_alignment(d) + check_lastdim_alignment_stride1(reasons, "query", d.query, matmul_alignment_mn) + check_lastdim_alignment_stride1(reasons, "value", d.value, matmul_alignment_mn) + _check_bias_alignment(reasons, d.attn_bias) + _check_large_shapes(reasons, d) + requires_grad = ( + d.query.requires_grad or d.key.requires_grad or d.value.requires_grad + ) + if requires_grad: + reasons.append("Gradience is currently not supported by ck-tiled!") + return reasons + + @classmethod + # type: ignore + def operator_flop( + cls, + q, + k, + v, + b, + seqstart_q, + seqstart_k, + max_seqlen_q_, + compute_lse, + custom_mask_type, + *a, + ) -> int: + return cls.attn_operator_flop( + q, + k, + v, + causal=custom_mask_type > 0, + seqstart_k=seqstart_k, + seqstart_q=seqstart_q, + ) + + +@register_operator +class BwOp(AttentionBwOpBase): + __doc__ = FwOp.__doc__ + + OPERATOR = get_xformers_operator("efficient_attention_backward_ck") + SUPPORTED_DEVICES = FwOp.SUPPORTED_DEVICES + SUPPORTED_DTYPES = FwOp.SUPPORTED_DTYPES + SUPPORTED_MAX_K = FwOp.SUPPORTED_MAX_K + SUPPORTED_ATTN_BIAS_TYPES: Set[Any] = { + type(None), + torch.Tensor, + LowerTriangularMask, + # LowerTriangularFromBottomRightMask, + # TODO: Still some infs/nans in the BW pass for + # local + causal + # LowerTriangularFromBottomRightLocalAttentionMask, + # TODO: Fix handling of gradient through the fMHA autograd function + # LowerTriangularMaskWithTensorBias, + BlockDiagonalMask, + BlockDiagonalCausalMask, + attn_bias.BlockDiagonalCausalFromBottomRightMask, + # attn_bias.BlockDiagonalCausalLocalAttentionMask, + } + SUPPORTS_ATTN_BIAS_GRAD = True + SUPPORTS_DROPOUT = FwOp.SUPPORTS_DROPOUT + SUPPORTS_CUSTOM_SCALE = FwOp.SUPPORTS_CUSTOM_SCALE + SUPPORTS_DIFFERENT_VALUE_EMBED = FwOp.SUPPORTS_DIFFERENT_VALUE_EMBED + NAME = "ckB" + + _TEST_K: List[int] = [ + 32, # 64x64 kernel + 128, # 64x128/128x128 kernel + 256, # 64x128 with accumulation in gmem + ] + + @classmethod + def not_supported_reasons(cls, d: Inputs) -> List[str]: + reasons = super(BwOp, cls).not_supported_reasons(d) + matmul_alignment_mn = _minimum_gemm_alignment(d) + + check_lastdim_alignment_stride1(reasons, "query", d.query, matmul_alignment_mn) + check_lastdim_alignment_stride1(reasons, "key", d.key, matmul_alignment_mn) + check_lastdim_alignment_stride1(reasons, "value", d.value, matmul_alignment_mn) + _check_bias_alignment(reasons, d.attn_bias) + attn_bias_tensor = _get_tensor_bias(d.attn_bias) + + # Backprop of gradient through broadcasted bias is not supported + if attn_bias_tensor is not None and attn_bias_tensor.requires_grad: + # Don't forget that inputs are either in BMK or BMHK! + if d.query.ndim == 3 and attn_bias_tensor.ndim == 3: + expected_bias_shape = (*d.query.shape[:2], d.key.shape[1]) + else: + # bias is B H Mq Mk + expected_bias_shape = ( + d.query.shape[0], + d.query.shape[2] if d.query.ndim == 4 else 1, + d.query.shape[1], + d.key.shape[1], + ) + if tuple(attn_bias_tensor.shape) != expected_bias_shape: + reasons.append( + "Broadcasting the `attn_bias` tensor is not supported " + f"(shape: {tuple(attn_bias_tensor.shape)}" + f"/ expected: {expected_bias_shape})" + ) + _check_large_shapes(reasons, d) + + reasons.append("Backward is currently not supported by ck-tiled!") + return reasons + + @classmethod + def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients: + if type(inp.attn_bias) not in BwOp.SUPPORTED_ATTN_BIAS_TYPES: + raise NotImplementedError("Unsupported attn_bias type") + + seqstart_k, seqstart_q, max_seqlen_q = _get_seqlen_info(inp) + dtype = inp.query.dtype + + rng_seed = rng_offset = 0 + if inp.p != 0.0: + if ( + ctx.rng_state is None + or ctx.rng_state.dtype != torch.int64 + or ctx.rng_state.device.type != "cpu" + or ctx.rng_state.shape != (2,) + ): + raise NotImplementedError(f"Invalid rng_state: {ctx.rng_state}") + rng_seed, rng_offset = ctx.rng_state.tolist() + + (grad_q, grad_k, grad_v, grad_bias) = cls.OPERATOR( + grad.to(dtype), + inp.query, + inp.key, + inp.value, + attn_bias=_get_tensor_bias(inp.attn_bias), + seqstart_q=seqstart_q, + seqstart_k=seqstart_k, + max_seqlen_q=max_seqlen_q, + seqlen_k=( + inp.attn_bias.k_seqinfo.seqlen + if isinstance( + inp.attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask + ) + else None + ), + logsumexp=ctx.lse, + output=ctx.out.to(dtype), + dropout_p=inp.p, + # if not using dropout, seed and offset are irrelevant but still expected + # in function signature so just pass 0 + # seed and offset could be None if a different FW op other than cutlass + # was used. + rng_seed=rng_seed, + rng_offset=rng_offset, + custom_mask_type=_custom_mask_type(inp.attn_bias), + scale=inp.scale, + ) + + # c++/CUDA implementation returns an uninitialized tensor if bias doesn't + # require grad + if not ( + isinstance(inp.attn_bias, torch.Tensor) and inp.attn_bias.requires_grad + ): + grad_bias = None + + return Gradients(dq=grad_q, dk=grad_k, dv=grad_v, db=grad_bias) + + @classmethod + # type: ignore + def operator_flop( + cls, + dO, + q, + k, + v, + b, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + logsumexp, + output, + dropout_p, + rng_seed, + rng_offset, + custom_mask_type, + scale, + ) -> int: + return cls.attn_operator_flop( + q, + k, + v, + seqstart_q=cu_seqlens_q, + seqstart_k=cu_seqlens_k, + causal=custom_mask_type > 0, + ) diff --git a/xformers/ops/fmha/ck_decoder.py b/xformers/ops/fmha/ck_decoder.py new file mode 100644 index 0000000000..dfbbd581f5 --- /dev/null +++ b/xformers/ops/fmha/ck_decoder.py @@ -0,0 +1,139 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, List, Optional, Set, Tuple + +import torch + +from ..common import get_xformers_operator, register_operator +from .attn_bias import BlockDiagonalCausalWithOffsetPaddedKeysMask +from .common import AttentionFwOpBase, Context, Inputs + + +@register_operator +class FwOp(AttentionFwOpBase): + """ + An operator optimized for K=256 (so the contiguous dim fits into registers). + Tested to work on MI250x. + """ + + OPERATOR = get_xformers_operator("efficient_attention_forward_decoder_ck") + SUPPORTED_DEVICES: Set[str] = {"cuda"} + SUPPORTED_DTYPES: Set[torch.dtype] = {torch.half, torch.bfloat16, torch.float} + SUPPORTED_MAX_K: int = 256 + SUPPORTED_ATTN_BIAS_TYPES: Set[Any] = { + type(None), + BlockDiagonalCausalWithOffsetPaddedKeysMask, + } + SUPPORTS_DROPOUT = False + SUPPORTS_CUSTOM_SCALE = True + SUPPORTS_BMGHK = True + NAME = "ck_decoderF" + + @classmethod + def not_supported_reasons(cls, d: Inputs) -> List[str]: + reasons = super(FwOp, cls).not_supported_reasons(d) + + attn_bias = d.attn_bias + if isinstance(attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask): + if d.query.shape[0] != 1: + reasons.append( + f"One formal batch element expected; got {d.query.shape[0]}" + ) + + if d.query.shape[-1] > cls.SUPPORTED_MAX_K: + reasons.append( + f"Got head_dim={d.query.shape[-1]}; only head_dim<={cls.SUPPORTED_MAX_K} is supported for now." + ) + + threads_per_warp = 64 # TODO: ideally query the platform here + required_alignment = 0 + head_dim = d.query.shape[-1] + for vec_size in (4, 2, 1): + if head_dim <= vec_size * threads_per_warp: + required_alignment = vec_size + + if not required_alignment: + reasons.append(f"Got head_dim={head_dim} which is too large") + + if head_dim % required_alignment != 0: + reasons.append( + f"Got head_dim={head_dim}; it needs to be divisible by {required_alignment}" + ) + + if d.key.stride(-1) != 1: + reasons.append("expect keys to have last dim contiguous") + + if d.value.stride(-1) != 1: + reasons.append("expect values to have last dim contiguous") + + q_starts = attn_bias.q_seqinfo.seqstart_py + padding = attn_bias.k_seqinfo.padding + bsz = d.key.shape[1] // padding + num_queries = d.query.shape[1] // bsz + + if q_starts != list(range(0, 1 + bsz, num_queries)): + reasons.append("expect to have same num_queries in each batch") + if bsz != len(q_starts) - 1: + reasons.append("empty lanes not supported yet") + + if attn_bias.k_seqinfo.padding > 8192: + reasons.append("key padding exceeds 8192") + + return reasons + + @classmethod + def apply( + cls, inp: Inputs, needs_gradient: bool + ) -> Tuple[torch.Tensor, Optional[Context]]: + if needs_gradient: + raise NotImplementedError("backward pass is not supported") + attn_bias = inp.attn_bias + q, k, v = inp.get_qkv_in_bmghk() + if attn_bias is not None: + assert isinstance(attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask) + attn_bias.k_seqinfo.to(k.device) + attn_bias.q_seqinfo.to(q.device) + padding = attn_bias.k_seqinfo.padding + seq_positions_gpu = attn_bias.k_seqinfo.seqlen + else: + padding = k.shape[1] + seq_positions_gpu = None + + if attn_bias is not None: + # key: (1, B * padding, G, 1 if multiquery else Hkv, D) + # value: like key + # query: (1, B * q_seqlen, G, Hq, D) + multiquery = k.stride(3) == 0 + if multiquery: + key = k[0, :, :, :1].unflatten(0, (-1, padding)) + value = v[0, :, :, :1].unflatten(0, (-1, padding)) + else: + key = k[0].unflatten(0, (-1, padding)) + value = v[0].unflatten(0, (-1, padding)) + query = q[0].unflatten(0, (key.shape[0], -1)) + else: + # key: (B, padding, G, 1 if multiquery else Hkv, D) + # value: like key + # query: (B, q_seqlen, G, Hq, D) + key = k + query = q + value = v + + if inp.scale is not None: + qk_scale = inp.scale + else: + qk_scale = torch.rsqrt( + torch.tensor(key.shape[-1], dtype=torch.float32) + ).item() + + out = cls.OPERATOR( + query=query, + key=key, + value=value, + seq_positions=seq_positions_gpu, + scale=qk_scale, + ) + return out, None diff --git a/xformers/ops/fmha/ck_splitk.py b/xformers/ops/fmha/ck_splitk.py new file mode 100644 index 0000000000..3d37dcdf14 --- /dev/null +++ b/xformers/ops/fmha/ck_splitk.py @@ -0,0 +1,208 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, List, Optional, Set, Tuple + +import torch + +from xformers.ops.common import get_xformers_operator, register_operator +from xformers.ops.fmha.attn_bias import BlockDiagonalCausalWithOffsetPaddedKeysMask +from xformers.ops.fmha.common import ( + AttentionFwOpBase, + Context, + Inputs, + check_lastdim_alignment_stride1, +) + + +@register_operator +class FwOp(AttentionFwOpBase): + + OPERATOR = get_xformers_operator("efficient_attention_forward_decoder_splitk_ck") + SUPPORTED_DEVICES = {"cuda"} + SUPPORTED_DTYPES = { + torch.half, + torch.bfloat16, + torch.float, + } # Those are dtypes of Q. In the quantized case K/V has dtype int32 + SUPPORTED_MAX_K = 256 + SUPPORTED_ATTN_BIAS_TYPES: Set[Any] = { + type(None), + BlockDiagonalCausalWithOffsetPaddedKeysMask, + } + SUPPORTS_DROPOUT = False + SUPPORTS_CUSTOM_SCALE = True + SUPPORTS_BMGHK = True + NAME = "ck_splitKF" + + SPLIT_K: Optional[int] = None + BLOCK_M = 16 + BLOCK_N = 64 + + NUM_GROUPS = 1 # Default quantization is row-wise + + @classmethod + def shape_not_supported_reasons( + cls, Mq: int, Mkv: int, K: int, Kv: int + ) -> List[str]: + reasons = super().shape_not_supported_reasons(Mq, Mkv, K, Kv) + # if K not in {16, 32, 64, 128}: + # reasons.append(f"Embed dim {K} not supported") + return reasons + + @classmethod + def not_supported_reasons(cls, d: Inputs) -> List[str]: + reasons = super(FwOp, cls).not_supported_reasons(d) + check_lastdim_alignment_stride1(reasons, "query", d.query, 8) + if d.key.dtype != torch.int32: + check_lastdim_alignment_stride1(reasons, "key", d.key, 8) + check_lastdim_alignment_stride1(reasons, "value", d.value, 8) + if cls.OPERATOR is None: + reasons.append("triton is not available") + if d.device.type == "cuda": + # Has only been tested on 8.0 / 9.0. + if torch.cuda.get_device_capability(d.device) < (7, 0): + reasons.append( + "requires GPU with sm80 minimum compute capacity, e.g., A100/H100/L4" + ) + + q_len = d.query.shape[1] + if isinstance(d.attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask): + seqinfo = d.attn_bias.q_seqinfo + if q_len != seqinfo.seqstart_py[-1]: + reasons.append( + f"Expected total {seqinfo.seqstart_py[-1]} queries not {q_len}" + ) + q_len = seqinfo.min_seqlen + if q_len != seqinfo.max_seqlen: + reasons.append( + "Variable query len is not supported in the presence of causal mask." + ) + + if d.key.ndim in [4, 5] and d.key.shape[-2] != 1: + if d.key.stride(-2) == 0 and d.value.stride(-2) == 0 and q_len > 1: + reasons.append("multiquery is only supported with query seqlen=1") + + if d.attn_bias is not None and q_len > 1: + reasons.append( + "query with seqlen > 1 is not supported in the presence of causal mask" + ) + return reasons + + @classmethod + def get_split_k(cls, B: int, H: int, Mk: int) -> int: + """Heuristic for the number of splits""" + bh = max(B * H, 1) # NOTE: Handle B*h=0 case + split_k = max(Mk, 1024) // bh + max_chunk_size = 64 if Mk <= 512 and bh <= 64 else 128 + while split_k > 0 and Mk / split_k < max_chunk_size: + split_k = split_k // 2 + split_k = min(split_k, 64) + split_k = max(split_k, 1) + return split_k + + @classmethod + def apply( + cls, inp: Inputs, needs_gradient: bool + ) -> Tuple[torch.Tensor, Optional[Context]]: + attn_bias = inp.attn_bias + q, k, v = inp.get_qkv_in_bmghk() + + if attn_bias is not None: + assert isinstance(attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask) + attn_bias.k_seqinfo.to(k.device) + attn_bias.q_seqinfo.to(q.device) + padding = attn_bias.k_seqinfo.padding + seq_positions_gpu = attn_bias.k_seqinfo.seqlen + else: + padding = k.shape[1] + seq_positions_gpu = None + + if attn_bias is not None: + # key: (1, B * padding, G, 1 if multiquery else Hkv, D) + # value: like key + # query: (1, B * q_seqlen, G, Hq, D) + multiquery = k.stride(3) == 0 + if multiquery: + key = k[0, :, :, :1].unflatten(0, (-1, padding)) + value = v[0, :, :, :1].unflatten(0, (-1, padding)) + else: + key = k[0].unflatten(0, (-1, padding)) + value = v[0].unflatten(0, (-1, padding)) + query = q[0].unflatten(0, (key.shape[0], -1)) + else: + # key: (B, padding, G, 1 if multiquery else Hkv, D) + # value: like key + # query: (B, q_seqlen, G, Hq, D) + key = k + query = q + value = v + + B, _, _, H, _ = query.shape + _, Mk, _, _, _ = key.shape + + if cls.SPLIT_K is not None: + split_k = cls.SPLIT_K + else: + # Use heuristics + split_k = cls.get_split_k(B, H, Mk) + + if inp.scale is not None: + qk_scale = inp.scale + else: + qk_scale = torch.rsqrt( + torch.tensor(k.shape[-1], dtype=torch.float32) + ).item() + + out = cls.OPERATOR( + query=query, + key=key, + value=value, + seq_positions=seq_positions_gpu, + scale=qk_scale, + split_k=split_k, + ) + + return out, None + + +class FwOp_S1(FwOp): + SPLIT_K = 1 + NAME = "ck_splitK1" + + +class FwOp_S2(FwOp): + SPLIT_K = 2 + NAME = "ck_splitK2" + + +class FwOp_S4(FwOp): + SPLIT_K = 4 + NAME = "ck_splitK4" + + +class FwOp_S8(FwOp): + SPLIT_K = 8 + NAME = "ck_splitK8" + + +class FwOp_S16(FwOp): + SPLIT_K = 16 + NAME = "ck_splitK16" + + +class FwOp_S32(FwOp): + SPLIT_K = 32 + NAME = "ck_splitK32" + + +class FwOp_S64(FwOp): + SPLIT_K = 64 + NAME = "ck_splitK64" + + +class FwOp_S128(FwOp): + SPLIT_K = 128 + NAME = "ck_splitK128" diff --git a/xformers/ops/fmha/common.py b/xformers/ops/fmha/common.py index 62b88051bc..be863b6008 100644 --- a/xformers/ops/fmha/common.py +++ b/xformers/ops/fmha/common.py @@ -110,7 +110,7 @@ def validate_inputs(self) -> None: x.ndim != self.query.ndim for x in qkv ): raise ValueError( - f"Query/Key/Value should all have BMGHK, BMHK, or BMK shape.\n" + f"Query/Key/Value should all have BMGHK, BMHK or BMK shape.\n" f" query.shape: {self.query.shape}\n" f" key.shape : {self.key.shape}\n" f" value.shape: {self.value.shape}" @@ -310,7 +310,11 @@ def not_supported_reasons(cls, d: Inputs) -> List[str]: dtype = d.query.dtype if device_type not in cls.SUPPORTED_DEVICES: reasons.append(f"device={device_type} (supported: {cls.SUPPORTED_DEVICES})") - if device_type == "cuda" and not _built_with_cuda: + if ( + device_type == "cuda" + and not _built_with_cuda + and (torch.version.hip is None) + ): reasons.append("xFormers wasn't build with CUDA support") if device_type == "cuda": device_capability = torch.cuda.get_device_capability(d.device) diff --git a/xformers/ops/fmha/dispatch.py b/xformers/ops/fmha/dispatch.py index 30d6ec6155..aaabe5c8cf 100644 --- a/xformers/ops/fmha/dispatch.py +++ b/xformers/ops/fmha/dispatch.py @@ -8,7 +8,20 @@ from collections import deque from typing import List, Sequence, Type, TypeVar -from . import attn_bias, cutlass, decoder, flash, small_k, triton, triton_splitk +import torch + +from . import ( + attn_bias, + ck, + ck_decoder, + ck_splitk, + cutlass, + decoder, + flash, + small_k, + triton, + triton_splitk, +) from .common import AttentionBwOpBase, AttentionFwOpBase, Inputs @@ -66,17 +79,25 @@ def _run_priority_list(name: str, priority_list: Sequence[T], inp: Inputs) -> T: def _dispatch_fw_priority_list( inp: Inputs, needs_gradient: bool ) -> Sequence[Type[AttentionFwOpBase]]: - priority_list_ops = deque( - [ - flash.FwOp, - triton.FwOp, - cutlass.FwOp, - small_k.FwOp, - ] - ) - if _is_cutlass_fwd_faster_than_flash(inp): - priority_list_ops.remove(cutlass.FwOp) - priority_list_ops.appendleft(cutlass.FwOp) + if torch.version.cuda: + priority_list_ops = deque( + [ + flash.FwOp, + triton.FwOp, + cutlass.FwOp, + small_k.FwOp, + ] + ) + if _is_cutlass_fwd_faster_than_flash(inp): + priority_list_ops.remove(cutlass.FwOp) + priority_list_ops.appendleft(cutlass.FwOp) + else: + priority_list_ops = deque( + [ + triton.FwOp, + ck.FwOp, + ] + ) if _is_triton_fwd_fastest(inp): priority_list_ops.remove(triton.FwOp) priority_list_ops.appendleft(triton.FwOp) @@ -87,7 +108,9 @@ def _dispatch_fw_priority_list( if not mqa_or_gqa: # With multiquery, cutlass is sometimes faster than decoder # but it's not currently clear when. - priority_list_ops.appendleft(decoder.FwOp) + priority_list_ops.appendleft( + decoder.FwOp if torch.version.cuda else ck_decoder.FwOp + ) # Split-KV is useful with MQA # for short Q-seqlen / long K-seqlen if mqa_or_gqa and inp.query.shape[1] <= 32 and inp.key.shape[1] >= 256: @@ -99,6 +122,7 @@ def _dispatch_fw_priority_list( elif inp.query.ndim == 5: # BMGHK parallelism_BH = inp.query.shape[0] * inp.query.shape[2] if parallelism_BH > 0 and parallelism_BH < 64: + priority_list_ops.appendleft(ck_splitk.FwOp) priority_list_ops.appendleft(triton_splitk.FwOp) # Without variable seqlen flash is fastest if not isinstance(inp.attn_bias, attn_bias.BlockDiagonalMask): diff --git a/xformers/ops/fmha/triton.py b/xformers/ops/fmha/triton.py index 2d6e2a059a..3758cf8e35 100644 --- a/xformers/ops/fmha/triton.py +++ b/xformers/ops/fmha/triton.py @@ -3,63 +3,441 @@ # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. +""" +Triton Flash Attention 2 +Based on +https://github.com/openai/triton/blob/293b7fd592a1602f2305c1bd0bc978bbd97337d6/python/tutorials/06-fused-attention.py # noqa: E501 +https://github.com/openai/triton/blob/293b7fd592a1602f2305c1bd0bc978bbd97337d6/python/triton/ops/flash_attention.py # noqa: E501 +https://github.com/Dao-AILab/flash-attention/blob/dd9a6fa45a9b90ff954d2b3f3f44241b9216190e/flash_attn/flash_attn_triton.py # noqa: E501 +https://github.com/ROCmSoftwarePlatform/triton/blob/670ae8054da008424097989a5b6e3816aa601e07/python/perf-kernels/06-fused-attention-transV.py # noqa: E501 +""" from dataclasses import replace -from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple +from typing import Any, List, Mapping, Optional, Set, Tuple import torch -from ... import _is_triton_available +from xformers import _is_triton_available + from ..common import register_operator +from .attn_bias import ( + BlockDiagonalCausalMask, + BlockDiagonalCausalWithOffsetPaddedKeysMask, + LowerTriangularMask, +) +from .common import AttentionFwOpBase, Context, Inputs, check_lastdim_alignment_stride1 + +if _is_triton_available(): + import triton + import triton.language as tl + + @triton.jit + def _fwd_kernel_triton_flash_inner( + acc, + l_i, + m_i, + q, + K_block_ptr, + V_block_ptr, + q_seq_start, + lo, + hi, + start_m, + qk_scale, + kv_len, + offs_m, + offs_n, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + IS_CAUSAL: tl.constexpr, + BOUNDS_CHECKS_N: tl.constexpr, + CAST_BEFORE_MATMUL: tl.constexpr, + ALLOW_TF32: tl.constexpr, + STAGE: tl.constexpr, + pre_load_v: tl.constexpr, + ): + BOUNDS_CHECKS_STAGE: tl.constexpr = BOUNDS_CHECKS_N and STAGE == 2 + # Doesn't seem to make a difference + if STAGE == 1: + lo = 0 + else: + lo = tl.multiple_of(lo, BLOCK_N) + K_block_ptr = tl.advance(K_block_ptr, (0, lo)) + V_block_ptr = tl.advance(V_block_ptr, (lo, 0)) + + # loop over k, v and update accumulator + for start_n in range(lo, hi, BLOCK_N): + start_n = tl.multiple_of( + start_n, BLOCK_N + ) # doesn't seem to make a difference + # -- load k, v -- + k = tl.load(K_block_ptr, boundary_check=(1,) if BOUNDS_CHECKS_STAGE else ()) + # Moving masking here seems to introduce num errors, + # e.g. in test_forward[tritonflashattF-cuda-torch.bfloat16-NoneType-1-256-15-1-32-32-False-BMHK] + # if BOUNDS_CHECKS_N or USE_SEQ_LEN: + # k = tl.where(hi - tl.arange(0, BLOCK_N) > start_n, k, float("-inf")) + if pre_load_v: + v = tl.load( + V_block_ptr, boundary_check=(0,) if BOUNDS_CHECKS_STAGE else () + ) + # -- compute qk --- + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q.to(k.dtype), k, allow_tf32=ALLOW_TF32) * qk_scale + if CAST_BEFORE_MATMUL: + k = k.to(tl.float32) + if STAGE == 2: + if IS_CAUSAL: + # For some reason this is faster than start_n <= q_seq_start + offs_m[:, None] - offs_n[None, :] + qk = tl.where( + q_seq_start + offs_m[:, None] >= (start_n + offs_n[None, :]), + qk, + float("-inf"), + ) + if BOUNDS_CHECKS_N: + qk = tl.where( + tl.arange(0, BLOCK_N) < hi - start_n, qk, float("-inf") + ) + + # -- compute scaling constant --- + m_i_new = tl.maximum(m_i, tl.max(qk, 1)) + qk = qk - m_i_new[:, None] + alpha = tl.math.exp2(m_i - m_i_new) + p = tl.math.exp2(qk) + + # -- scale and update acc -- + acc *= alpha[:, None] + if not pre_load_v: + v = tl.load( + V_block_ptr, boundary_check=(0,) if BOUNDS_CHECKS_STAGE else () + ) + if CAST_BEFORE_MATMUL: + v = v.to(tl.float32) + acc += tl.dot(p.to(v.dtype), v, allow_tf32=ALLOW_TF32) + # -- update m_i and l_i -- + l_i = l_i * alpha + tl.sum(p, 1) + m_i = m_i_new + # update pointers + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + return acc, l_i, m_i -# This implementation needs pre-MLIR triton -# The BW pass is not stable/well tested -# And also does not have the latest improvements -if TYPE_CHECKING or (False and _is_triton_available()): - try: - from flash_attn.flash_attn_triton import ( - _flash_attn_backward, - _flash_attn_forward, + @triton.jit + def _fwd_kernel_triton_flash( + Q, + K, + V, + sm_scale, + L, + Out, + Seq_len, + Seq_pos_q, + stride_qz, + stride_qh, + stride_qm, + stride_qk, + stride_kz, + stride_kh, + stride_kn, + stride_kk, + stride_vz, + stride_vh, + stride_vk, + stride_vn, + stride_oz, + stride_oh, + stride_om, + stride_on, + Z, + H, + N_CTX, + Mkv, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + IS_CAUSAL: tl.constexpr, + BOUNDS_CHECKS_N: tl.constexpr, + BOUNDS_CHECKS_M: tl.constexpr, + ALLOW_TF32: tl.constexpr, + CAST_BEFORE_MATMUL: tl.constexpr, + USE_SEQ_LEN_KV: tl.constexpr, + USE_SEQ_POS_Q: tl.constexpr, + IS_KV_PADDED: tl.constexpr, # Switch between padded and non-padded block-diagonal causal masks + pre_load_v: tl.constexpr, # TODO: understand if that matters + ): + start_m = tl.program_id(0).to(tl.int64) + off_hz = tl.program_id(1).to(tl.int64) + + tl.static_assert((IS_KV_PADDED and USE_SEQ_POS_Q) or not IS_KV_PADDED) + + off_z = off_hz // H + off_h = off_hz % H + if USE_SEQ_POS_Q: + seqpos = tl.load(Seq_pos_q + off_z) + seqpos_next = tl.load(Seq_pos_q + off_z + 1) + q_len = seqpos_next - seqpos + q_offset = seqpos * stride_qm + off_h * stride_qh + out_offset = seqpos * stride_om + off_h * stride_oh + if not IS_KV_PADDED: + # BlockDiagonalCausalMask, no padding, use same sequence positions as for Q + kv_offset = seqpos * stride_kn + off_h * stride_kh + kv_len = q_len + q_seq_start = 0 + else: + # BlockDiagonalCausalWithOffsetPaddedKeysMask + kv_offset = off_z * stride_kz + off_h * stride_kh + if USE_SEQ_LEN_KV: + kv_len = tl.load(Seq_len + off_z) + q_seq_start = kv_len - q_len + else: + # if no variable K/V seqlens are provided, assume full length + kv_len = Mkv + q_seq_start = 0 + else: + # No mask or simple causal mask + q_len = N_CTX + q_offset = off_z * stride_qz + off_h * stride_qh + out_offset = off_z * stride_oz + off_h * stride_oh + + kv_len = Mkv + q_seq_start = 0 + kv_offset = off_z * stride_kz + off_h * stride_kh + + Q_block_ptr = tl.make_block_ptr( + base=Q + q_offset, + shape=(q_len, BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + K_block_ptr = tl.make_block_ptr( + base=K + kv_offset, + shape=(BLOCK_DMODEL, kv_len), + strides=(stride_kk, stride_kn), + offsets=(0, 0), + block_shape=(BLOCK_DMODEL, BLOCK_N), + order=(0, 1), ) - except ImportError: - import importlib - import pathlib - import sys - import types - - def import_module_from_path(path: str) -> types.ModuleType: - """Import a module from the given path, w/o __init__.py""" - module_path = pathlib.Path(path).resolve() - module_name = module_path.stem # 'path/x.py' -> 'x' - spec = importlib.util.spec_from_file_location(module_name, module_path) # type: ignore - assert isinstance(spec, importlib.machinery.ModuleSpec) - module = importlib.util.module_from_spec(spec) # type: ignore - sys.modules[module_name] = module - assert isinstance(spec.loader, importlib.abc.Loader) - spec.loader.exec_module(module) - return module - - flash_attn = import_module_from_path( - "third_party/flash-attention/flash_attn/flash_attn_triton.py" + V_block_ptr = tl.make_block_ptr( + base=V + kv_offset, + shape=(kv_len, BLOCK_DMODEL), + strides=(stride_vk, stride_vn), + offsets=(0, 0), + block_shape=(BLOCK_N, BLOCK_DMODEL), + order=(0, 1), ) - _flash_attn_backward = flash_attn._flash_attn_backward - _flash_attn_forward = flash_attn._flash_attn_forward + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) # For Q + offs_n = tl.arange(0, BLOCK_N) # For K/V + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # scale sm_scale by log_2(e) and use + # 2^x instead of exp in the loop because CSE and LICM + # don't work as expected with `exp` in the loop + qk_scale = sm_scale * 1.44269504 + # load q: it will stay in SRAM throughout on NV GPUs but in VGPRs on AMD GPUs + q = tl.load( + Q_block_ptr, boundary_check=(0,) if BOUNDS_CHECKS_M or USE_SEQ_POS_Q else () + ) + + # The loop over K/V sequence blocks is divided into two stages: + # Stage 1: (many) blocks which don't need boundary conditions checks - not touching sequence end or diagonal + # Stage 2: (few) blocks which need boundary conditions checks + # Following https://github.com/openai/triton/blob/293b7fd592a1602f2305c1bd0bc978bbd97337d6/python/tutorials/06-fused-attention.py # noqa: E501 + + """ + Iteration doesn't need masking if + - 1) block doesn't cross the diagonal: max(kv_pos) <= min(q_pos) + - 2) block doesn't cross the end of the sequence: max(kv_pos) < kv_len + Find maximum start_n for which condition 1 is satisifed. + Remember that + q_pos = q_seq_start + offs_m[:, None] + kv_pos = start_n + offs_n[None, :] + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + min(q_pos) = q_seq_start + start_m * BLOCK_M + max(kv_pos) = start_n + BLOCK_N - 1 + So the condition becomes + q_seq_start + start_m * BLOCK_M >= start_n + BLOCK_N - 1 + So: + 1) start_n <= q_seq_start + start_m * BLOCK_M - BLOCK_N + 1 + 2) start_n <= kv_len - BLOCK_N + + So the last allowed start_n without masking is min(q_seq_start + start_m * BLOCK_M + 1, kv_len) - BLOCK_N + """ + # Second stage can only be skipped if no mask is used and K/V length is divisible by the tile size + TWO_STAGES: tl.constexpr = BOUNDS_CHECKS_N or ( + IS_CAUSAL or (USE_SEQ_LEN_KV or (USE_SEQ_POS_Q and not IS_KV_PADDED)) + ) + if TWO_STAGES: + # Border between two stages + hi_stage_1 = min(q_seq_start + start_m * BLOCK_M + 1, kv_len) - BLOCK_N + hi_stage_1 = ( + hi_stage_1 // BLOCK_N + ) * BLOCK_N # Don't understand why it doesn't work without this + else: + hi_stage_1 = kv_len + + # Stage 1 - no boundary conditions + acc, l_i, m_i = _fwd_kernel_triton_flash_inner( + acc, + l_i, + m_i, + q, + K_block_ptr, + V_block_ptr, + q_seq_start, + 0, + hi_stage_1, + start_m, + qk_scale, + kv_len, + offs_m, + offs_n, + BLOCK_M, + BLOCK_N, + IS_CAUSAL, + BOUNDS_CHECKS_N, + CAST_BEFORE_MATMUL, + ALLOW_TF32, + STAGE=1, + pre_load_v=pre_load_v, + ) + if TWO_STAGES: + hi = ( + tl.minimum(kv_len, q_seq_start + (start_m + 1) * BLOCK_M) + if IS_CAUSAL + else kv_len + ) + # Do we need this barrier? + # tl.debug_barrier() + # Stage 2 - with boundary conditions + acc, l_i, m_i = _fwd_kernel_triton_flash_inner( + acc, + l_i, + m_i, + q, + K_block_ptr, + V_block_ptr, + q_seq_start, + hi_stage_1, + hi, + start_m, + qk_scale, + kv_len, + offs_m, + offs_n, + BLOCK_M, + BLOCK_N, + IS_CAUSAL, + BOUNDS_CHECKS_N, + CAST_BEFORE_MATMUL, + ALLOW_TF32, + STAGE=2, + pre_load_v=pre_load_v, + ) + + # write back l and m + acc1 = acc / l_i[:, None] + l_ptrs = L + off_hz * N_CTX + offs_m + # Save LSE, converting from log2 to natural logarithm + l_mask = ( + start_m * BLOCK_M + tl.arange(0, BLOCK_M) < q_len + if BOUNDS_CHECKS_M + else None + ) + tl.store(l_ptrs, (m_i + tl.math.log2(l_i)) / 1.44269504, mask=l_mask) + # write back O + O_block_ptr = tl.make_block_ptr( + base=Out + out_offset, + shape=(q_len, BLOCK_DMODEL), + strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + tl.store( + O_block_ptr, + acc1.to(Out.dtype.element_ty), + boundary_check=(0,) if BOUNDS_CHECKS_M or USE_SEQ_POS_Q else (), + ) + + _autotuner_config_amd_full = [ + triton.Config( + {"BLOCK_M": 256, "BLOCK_N": 64, "waves_per_eu": 2, "pre_load_v": False}, + num_stages=1, + num_warps=8, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 128, "waves_per_eu": 2, "pre_load_v": False}, + num_stages=1, + num_warps=4, + ), + triton.Config( + {"BLOCK_M": 256, "BLOCK_N": 128, "waves_per_eu": 2, "pre_load_v": False}, + num_stages=1, + num_warps=8, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 64, "waves_per_eu": 3, "pre_load_v": True}, + num_stages=1, + num_warps=4, + ), # d64-False + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 64, "waves_per_eu": 3, "pre_load_v": False}, + num_stages=1, + num_warps=4, + ), # d64-True + ] + + _autotuner_config_amd_dummy = [ + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 64, "waves_per_eu": 2, "pre_load_v": False}, + num_stages=1, + num_warps=8, + ), + ] + + _autotuner_config_nvidia_dummy = [ + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 64, "pre_load_v": False}, + num_stages=1, + num_warps=8, + ), + ] + + def autotune_kernel(kernel, autotune): + + kernel = triton.heuristics( + values={ + "BOUNDS_CHECKS_N": lambda args: ((args["Mkv"] % args["BLOCK_N"]) != 0) + or (args["USE_SEQ_POS_Q"] and not args["IS_KV_PADDED"]), + "BOUNDS_CHECKS_M": lambda args: (args["N_CTX"] % args["BLOCK_M"]) != 0, + } + )(kernel) - triton_flash_backward = _flash_attn_backward - triton_flash_forward = _flash_attn_forward + if torch.version.cuda: + configs = _autotuner_config_nvidia_dummy + elif autotune: + configs = _autotuner_config_amd_full + else: + configs = _autotuner_config_amd_dummy + + kernel = triton.autotune( + configs=configs, + key=["Z", "H", "N_CTX", "IS_CAUSAL", "BLOCK_DMODEL"], + )(kernel) + return kernel + + _fwd_kernel_triton_flash_maybe_autotuned = { + True: autotune_kernel(_fwd_kernel_triton_flash, True), + False: autotune_kernel(_fwd_kernel_triton_flash, False), + } else: - triton_flash_backward = None - triton_flash_forward = None - -from .attn_bias import LowerTriangularMask -from .common import ( - AttentionBwOpBase, - AttentionFwOpBase, - Context, - Gradients, - Inputs, - check_lastdim_alignment_stride1, -) + _fwd_kernel_triton_flash = None + _fwd_kernel_triton_flash_maybe_autotuned = dict() def _prepare_inputs(inp: Inputs) -> Inputs: @@ -85,7 +463,7 @@ class FwOp(AttentionFwOpBase): `Phil Tillet's code `_ """ - OPERATOR = triton_flash_forward + OPERATOR = _fwd_kernel_triton_flash SUPPORTED_DEVICES = {"cuda"} CUDA_MINIMUM_COMPUTE_CAPABILITY = (8, 0) SUPPORTED_DTYPES = {torch.half, torch.bfloat16} @@ -93,33 +471,88 @@ class FwOp(AttentionFwOpBase): SUPPORTED_ATTN_BIAS_TYPES: Set[Any] = { type(None), LowerTriangularMask, - # TODO: backwards accuracy is failing for a few cases, perhaps we want to disable this for now. - # torch.Tensor, + BlockDiagonalCausalMask, + BlockDiagonalCausalWithOffsetPaddedKeysMask, } SUPPORTS_DROPOUT = False SUPPORTS_CUSTOM_SCALE = True NAME = "tritonflashattF" + # Off by default to avoid slowing down tests. + # Needs to be turned on explicitly in benchmarks, in prod, and in a small number of tests + AUTOTUNE = False + + ERROR_ATOL: Mapping[torch.dtype, float] = { + torch.half: 2e-2, + torch.bfloat16: 2e-2, + } + + ERROR_RTOL: Mapping[torch.dtype, float] = { + torch.half: 2e-2, + torch.bfloat16: 2e-2, + } + + @classmethod + def shape_not_supported_reasons( + cls, Mq: int, Mkv: int, K: int, Kv: int + ) -> List[str]: + reasons = super().shape_not_supported_reasons(Mq, Mkv, K, Kv) + if K not in {32, 64, 128}: + reasons.append(f"Embed dim {K} not supported") + return reasons + @classmethod def not_supported_reasons(cls, d: Inputs) -> List[str]: reasons = super(FwOp, cls).not_supported_reasons(d) check_lastdim_alignment_stride1(reasons, "query", d.query, 8) check_lastdim_alignment_stride1(reasons, "key", d.key, 8) check_lastdim_alignment_stride1(reasons, "value", d.value, 8) - if cls.OPERATOR is None: - reasons.append("triton is not available") - if d.device.type == "cuda": + + if isinstance( + d.attn_bias, + BlockDiagonalCausalWithOffsetPaddedKeysMask, + ): + # Support padded causal block-diagonal mask if the distance between each two consecutive key starts + # is equal to the padding (key lengths can vary) + batch_size = len(d.attn_bias.q_seqinfo.seqstart_py) - 1 + B_T = d.key.shape[ + 1 + ] # For these mask types the shapes of Q/K/V are (1, B_T, H, K) + if B_T % batch_size: + reasons.append( + f"K/V should be padded, but batch size {batch_size} doesn't divide B*T={B_T}" + ) + else: + kv_maxlen = d.attn_bias.k_seqinfo.padding + for i, seqstart in enumerate(d.attn_bias.k_seqinfo.seqstart_py): + if seqstart != i * kv_maxlen: + reasons.append( + "Variable K/V start positions are not supported, they should be determined " + f"by kv_maxlen/padding: {d.attn_bias.k_seqinfo.seqstart_py=} {kv_maxlen=} {batch_size=}" + ) + break + if isinstance( + d.attn_bias, + BlockDiagonalCausalMask, + ): + # Support padded causal block-diagonal mask if for each batch element number of queries is equal + # to the number of key/values, i.e. each block is square + for q_pos, kv_pos in zip( + d.attn_bias.q_seqinfo.seqstart_py, d.attn_bias.k_seqinfo.seqstart_py + ): + if q_pos != kv_pos: + reasons.append( + f"Position starts of Q and K/V should be the same, but got {q_pos} != {kv_pos}" + f"{d.attn_bias.q_seqinfo.seqstart_py=}, {d.attn_bias.k_seqinfo.seqstart_py=}" + ) + + if d.device.type == "cuda" and torch.version.cuda: # Has only been tested on 8.0 / 9.0. # Fails on 7.5 with illegal memory access if torch.cuda.get_device_capability(d.device) < (8, 0): reasons.append( "requires GPU with sm80 minimum compute capacity, e.g., A100/H100/L4" ) - if _is_triton_available(): - import triton - - if triton.__version__ > "2.0.0": - reasons.append("Only work on pre-MLIR triton for now") return reasons @classmethod @@ -127,75 +560,101 @@ def apply( cls, inp: Inputs, needs_gradient: bool ) -> Tuple[torch.Tensor, Optional[Context]]: inp = _prepare_inputs(inp) + attn_bias = inp.attn_bias + seq_len_kv = None + seqstart_q = None - out, lse, softmax_scale = triton_flash_forward( - q=inp.query, - k=inp.key, - v=inp.value, - bias=inp.attn_bias if isinstance(inp.attn_bias, torch.Tensor) else None, - softmax_scale=inp.scale_float, - causal=isinstance(inp.attn_bias, LowerTriangularMask), + q = inp.query + k = inp.key + v = inp.value + + if isinstance( + attn_bias, + (BlockDiagonalCausalWithOffsetPaddedKeysMask, BlockDiagonalCausalMask), + ): + # q ~ [1, B*T, H, K] + # TODO: do we really need to do this cast? seems fishy but + # I just copied it from the split-k kernel + assert isinstance( + attn_bias, + (BlockDiagonalCausalWithOffsetPaddedKeysMask, BlockDiagonalCausalMask), + ) + attn_bias.k_seqinfo.to(inp.query.device) + attn_bias.q_seqinfo.to(inp.query.device) + seqstart_q = attn_bias.q_seqinfo.seqstart + B = len(seqstart_q) - 1 + H, Kq = inp.query.shape[-2:] + H2, Kkv = inp.key.shape[-2:] + + Mq = attn_bias.q_seqinfo.max_seqlen + if isinstance(attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask): + seq_len_kv = attn_bias.k_seqinfo.seqlen + # assume kv has been padded + k = k.reshape(B, -1, H2, Kkv) + v = v.reshape(B, -1, H2, Kkv) + else: + B, Mq, H, _ = q.shape + + # Coded for BHMK format + q, k, v = ( + q.transpose(1, 2), + k.transpose(1, 2), + v.transpose(1, 2), ) - return out, Context(lse=lse, out=out) + out = torch.empty_like(q) -@register_operator -class BwOp(AttentionBwOpBase): - __doc__ = FwOp.__doc__ - - OPERATOR = triton_flash_backward - SUPPORTED_DEVICES = FwOp.SUPPORTED_DEVICES - CUDA_MINIMUM_COMPUTE_CAPABILITY = FwOp.CUDA_MINIMUM_COMPUTE_CAPABILITY - SUPPORTED_DTYPES = FwOp.SUPPORTED_DTYPES - SUPPORTED_MAX_K = FwOp.SUPPORTED_MAX_K - SUPPORTED_ATTN_BIAS_TYPES = FwOp.SUPPORTED_ATTN_BIAS_TYPES - SUPPORTS_DROPOUT = FwOp.SUPPORTS_DROPOUT - SUPPORTS_CUSTOM_SCALE = FwOp.SUPPORTS_CUSTOM_SCALE - SUPPORTS_DIFFERENT_VALUE_EMBED = FwOp.SUPPORTS_DIFFERENT_VALUE_EMBED - NAME = "tritonflashattB" + _, _, Mkv, K = k.shape - @classmethod - def not_supported_reasons(cls, d: Inputs) -> List[str]: - reasons = super(BwOp, cls).not_supported_reasons(d) - check_lastdim_alignment_stride1(reasons, "query", d.query, 8) - check_lastdim_alignment_stride1(reasons, "key", d.key, 8) - check_lastdim_alignment_stride1(reasons, "value", d.value, 8) - if cls.OPERATOR is None: - reasons.append("triton is not available") - if d.device.type == "cuda": - if torch.cuda.get_device_capability(d.device) != (8, 0): - reasons.append("requires A100 GPU") - if _is_triton_available(): - import triton - - if triton.__version__ > "2.0.0": - reasons.append("Only work on pre-MLIR triton for now") - return reasons + sm_scale = K**-0.5 if inp.scale is None else inp.scale + L = torch.empty((B * H, Mq), device=q.device, dtype=torch.float32) + is_causal = inp.attn_bias is not None + use_seq_len_kv = seq_len_kv is not None + use_seq_pos_q = seqstart_q is not None + is_kv_padded = isinstance( + attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask + ) - @classmethod - def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients: - inp = _prepare_inputs(inp) + grid = lambda META: (triton.cdiv(Mq, META["BLOCK_M"]), B * H, 1) # noqa: E731 + kernel = _fwd_kernel_triton_flash_maybe_autotuned[cls.AUTOTUNE] + kernel[grid]( + q, + k, + v, + sm_scale, + L, + out, + seq_len_kv, + seqstart_q, + q.stride(0), + q.stride(1), + q.stride(2), + q.stride(3), + k.stride(0), + k.stride(1), + k.stride(2), + k.stride(3), + v.stride(0), + v.stride(1), + v.stride(2), + v.stride(3), + out.stride(0), + out.stride(1), + out.stride(2), + out.stride(3), + B, + H, + Mq, + Mkv, + BLOCK_DMODEL=K, + IS_CAUSAL=is_causal, + USE_SEQ_LEN_KV=use_seq_len_kv, + USE_SEQ_POS_Q=use_seq_pos_q, + IS_KV_PADDED=is_kv_padded, + ALLOW_TF32=torch.backends.cuda.matmul.allow_tf32, + CAST_BEFORE_MATMUL=False, + ) - # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd - # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version. - with torch.inference_mode(): - grads = Gradients( - dq=torch.empty_like(inp.query), - dk=torch.empty_like(inp.key), - dv=torch.empty_like(inp.value), - ) - cls.OPERATOR( - grad, - inp.query, - inp.key, - inp.value, - ctx.out, - ctx.get_padded_lse(128), - grads.dq, - grads.dk, - grads.dv, - bias=inp.attn_bias if isinstance(inp.attn_bias, torch.Tensor) else None, - softmax_scale=inp.scale_float, - causal=isinstance(inp.attn_bias, LowerTriangularMask), - ) - return grads + out = out.transpose(1, 2) + L = L.reshape(B, H, Mq) + return out, Context(lse=L, out=out) diff --git a/xformers/ops/fmha/triton_splitk.py b/xformers/ops/fmha/triton_splitk.py index 8d95785bbd..42b834d90a 100644 --- a/xformers/ops/fmha/triton_splitk.py +++ b/xformers/ops/fmha/triton_splitk.py @@ -3,7 +3,6 @@ # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. - import functools import sys from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tuple, Type @@ -168,7 +167,9 @@ def _fwd_kernel_splitK( # Align boundaries of split-k chunk to page boundaries # In the last chunk, shift hi to the right, in the other chunks, shift it to the left is_last_chunk = splitk_idx == tl.num_programs(2) - 1 - shift = PAGE_SIZE - 1 if is_last_chunk else 0 + shift = 0 + if is_last_chunk: + shift = PAGE_SIZE - 1 lo = (chunk_lo // PAGE_SIZE) * PAGE_SIZE hi = ((chunk_hi + shift) // PAGE_SIZE) * PAGE_SIZE hi = tl.minimum(hi, kv_len) @@ -682,6 +683,17 @@ def _splitK_reduce( _splitK_reduce = None +def _is_cuda() -> bool: + return torch.version.cuda is not None + + +def _is_cuda_at_least_sm80(device: torch.device) -> bool: + return _is_cuda() and torch.cuda.get_device_capability(device) >= ( + 8, + 0, + ) + + @register_operator class FwOp(AttentionFwOpBase): """Flash-Attention with Split-K. Supports fused int-4 K/V quantization. @@ -757,6 +769,8 @@ def shape_not_supported_reasons( @classmethod def not_supported_reasons(cls, d: Inputs) -> List[str]: reasons = super(FwOp, cls).not_supported_reasons(d) + if (sys.version_info.major, sys.version_info.minor) < (3, 9): + reasons.append("triton_splitk requires python 3.9 or above!") check_lastdim_alignment_stride1(reasons, "query", d.query, 8) if d.key.dtype != torch.int32: check_lastdim_alignment_stride1(reasons, "key", d.key, 8) @@ -765,10 +779,11 @@ def not_supported_reasons(cls, d: Inputs) -> List[str]: reasons.append("triton is not available") if d.device.type == "cuda": # Has only been tested on 8.0 / 9.0. - if torch.cuda.get_device_capability(d.device) < (8, 0): + if _is_cuda() and not _is_cuda_at_least_sm80(d.device): reasons.append( - "requires GPU with sm80 minimum compute capacity, e.g., A100/H100/L4" + "requires NVidia GPU with sm80 minimum compute capacity, e.g., A100/H100/L4" ) + # TODO: AMD GPU support matrix needs to be figured out. MI300X is tested to work. q_len = d.query.shape[1] is_block_diagonal = isinstance(