Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fast Multi-ahead Attention support on AMD ROCM (facebookresearch#978)
* add option to build a standalone runner for splitk decoder; debugging numerics in reduction * fix a few bugs * fix an indexing bug * stash changes * Add benchmark_mem_eff_attn_mqa_gqa_ck_tiled.py to benchmark mqa/gqa performance on ck-tiled fmha * Synchronize with latest update in composable_kernel_tiled feature/fmha-pad-support branch * Tiny fix in benchmark_mem_eff_attn_mqa_gqa_ck_tiled.py * Synchronize with latest update in composable_kernel_tiled and make all unit_tests passed * Swith to new branch for composable_kernel_tiled submodule * Add bfp16 instances for ck-tiled inference * Update to test and benchmark scripts to include bfloat16 * Tiny update to ck_tiled kernel * Change to benchmark_mem_eff_attn_mqa_gqa_ck_tiled benchmark cases * stash changes * Use Async pipeline for no M/N0K1 padding cases * Add CF_FMHA_FWD_FAST_EXP2 to buiding * Add Triton FA2 forward op * Add Triton Flash Attention 2 to benchmarks * Synchronize with latest third_party/composable_kernel and remove the inner_product bhalf_t overloading in ck_attention_forward_decoder.h * stash split attention testing wip * Synchronize with latest third_party/composable_kernel again * Synchronize with latest third_party/composable_kernel_tiled * Change to make ck decoder buildable with both ck tiled or non-tiled fmha kernel * Change to make ck decoder buildable with both ck tiled or non-tiled fmha kernel * fix gqa for split-k=1 * Skip backward tests, fix import * fix the mask for decoding; row max and lse are computed correctly; debugging must go on * make libtorch split-1 decoder implementation pass numerical correctness * Disable CK kernel for large shapes, better catch OOMs * Actually remove submodule composable_kernel_tiled from the branch * Change the domain for the repo of composable_kernel submodule to ROCm * Update to validate_inputs() in common.py to support 4d mqa/gqa * synchronize test_mem_eff_attention_ck.py with test_mem_eff_attention.py * Tiny update in benchmark_mem_eff_attn_decoder_ck.py * Synchronize benchmark_mem_eff_attention_ck.py with benchmark_mem_eff_attention.py * Remove benchmark_mem_eff_attn_decoder_ck_tiled.py * Support for Generic Attention Mask Coordinate * Add ck.FwOp and ck.BwOp to dispatched operations * Add ck.FwOp and ck.BwOp to ALL_FW_OPS and ALL_BW_OPS * Update in tests/readme_test_on_rocm.txt * Add ckF and ck_decoder to benchmark_mem_eff_attn_decoder.py * Synchronize with the latest ck-tiled commits * Add is_ck_tiled_used() c++ extension interface for judging if ck-tiled is used * Remove composable_kernel_tiled submodule * inner_product removed from splitk kernel code * remove some commented out debug code * comment out debug code calling libtorch instead of hip implementation * remove commented out old and incorrect code fragments * add python version override to cmakelists * add conversion from Argument struct to string; fix split1 test crash -- fyi device guard needs to be declared to avoid segfaults in the kernel * add f32 support in the python op * refactor out input generation in cpp standalone * set loop unrolls to 1 in order to avoid index errors (will need to be fixed later for perf) * fix output splits allocation * fix bug in split attention: sumexp needs timestep bounds in each split * clang-format-10 * Enable support of attn-bias types with LocalAttention * Enable support of attn-bias types with LocalAttention * Synchronize submodule composable_kernel to the latest commits * Make the efficient_attention_forward_ck() C++ interface consistent with the updating of xformers/ops/fmha API * Tiny fix in ck.py to make test_backward pass * some refactorings for standalone tests * cleanup testing * Make the efficient_attention_forward_ck() C++ interface consistent with the updating of xformers/ops/fmha API * Tiny fix in ck.py to make test_backward pass * fix split1 attention csrc test * Enable support of flexible head-dim size (but <= 128) for ck-tiled fmha forward * Use Async pipeline when no any padding used * implement general split-k split-attention in libtorch, use for testing * fix split-max and split-sumexp shapes for split attention in libtorch * implement generic reduce split attention with libtorch * implement testing split reduce hip vs libtorch; tbd debug split-k=2 numerical mismatch in this test * refactor repetitive testing code * address code review: rearrange loops * address code review: add comment about number of iterations per split * address code review: remove comments * address code review: possibly eliminate a bug by using correct timestep range for scaling sumexp in smem * address code review: add todo * address code review: shift LDS access by tt_low to avoid smem overbooking * address code review: simplify reduction loops in split attention * Tiny update in ck-tiled forward kernel * address code review: merge for loops * address code review: simplify coefficient pick * fix runtime error message in testing code * fix split reduce test * address code review: fix smem offsets * remove redundant comment * address code review: initialize split attention workspace as empty * address code review: rename local vars * address code review: remove unused _rand_seqlens * address code review: cleanup python tests * remove redundant new_max local var * address code review: rename seq_acc * re-enable loop unroll; adjust tests to handle splits with size divisible by block size; handle empty splits correctly * test a wider range of split-k in cpp tests; fix torch implementation one more time to handle empty splits * Synchronize with ck-tiled update to support head-dim-256 and LSE storing * Add definition of FMHA_FWD_HEADDIM_SWITCH * Split the ck-tiled inference instances based on head-dim sizes to improve compiling * Setting k0n1_need_padding according to pipeline kQLoadOnce implementation * Add fmha forward c++ extension for ck-tiled * Set SUPPORTED_MAX_K=256 in ck.py * fix index in split-k attention * fix index in softmax reduce and complete fixing wavefronts per block optimization * clang-format-10 * Fix v_dram_transposed transpose transform in the kernel * Skipe trition_splitk for test_forward in test_mem_eff_attention.py * cleanup commented dead code * enable ck split-k in benchmark_attn_decoding * add rocm_ci workflow * move scipy import from file level under function similar to _vec_binom_test saves a few keystrokes when setting up environment * Add including of math_v2.hpp to ck_attention_forward_decoder_splitk.h * move forward_splitk to ck_splitk; make dispatch aware of ck_splitk and ck_decoder * Synchronize to latest ck-tiled and update accordingly * fix benchmark_attn_decoding * Remove third_party/composable_kernel_tiled * [Fix] use kK0BlockLength for HeadDim256 padding judging * Tiny type change for custom_mask_type in param class * Change to use ROCm repo for ck-tiled submodule * Remove tests/test_forward_ck_tiled.py * Update to test_mqa_forward_ck_tiled.py to use common create_attn_bias method * Add ck-tiled checking in test_mqa_forward_ck_tiled.py * rearrange smem access in softmax reduction * Add test_decoder and test_splitk_decoder for ROCM into test_mem_eff_attention.py * Add ref_attention_splitk and its test to tests/test_mem_eff_attention.py * Rename test_mem_eff_attention_ck.py as discarded * Add test_mqa_forward and ref_attention_mqa (for BMHK format mqa/gqa verification) into test_mem_eff_attention.py * Rename test_mqa_forward_ck_tiled.py as discarded * Remove CK specific script benchmark_mem_eff_attn_decoder_ck.py * Refine benchmark_mem_eff_attn_mqa_gqa_ck_tiled.py * Rename benchmark_mem_eff_attn_mqa_gqa_ck_tiled.py to benchmark_mem_eff_attention_mqa.py * Remove the runtime_error with using logsumexp in attention_forward_generic_ck_tiled.cpp * Add ck-tiled checking in ck.py * Remove CK-specific benchmark scripts * Don't require is_cpu_tensor for seqstart_q/seqstart_k/seqlen_k in attention_forward_generic_ck_tiled * Remove seqlen_cpu from _PaddedSeqLenInfo in attn_bias.py * Change the branch for composable_kernel_tiled submodule and update to latest * Remove the using of seqlen_cpu in BwOp of ck.py * Remove the using of seqlen_cpu in BwOp of ck.py * Align .clang_format with main branch and re-format c++ files * Synchronize to latest ck-tiled commit * Add checking of IS_CK_TILED into some testing scripts * Update to test_mem_eff_attention.py and ck.py * Building xformers using ck-tiled as default * ensure ck_decoder does not dispatch * Add disable_on_rocm on some test scripts * Update to test_mem_eff_attention.py * apply isort * apply black * fix flake8 suggestions * add license headers and reapply black * Tiny update to rocm_ci.yml * Add conditional compiling for cuda-depending codes in ROCM * Update to benchmark scripts * Rename the one script file * Revert "Add conditional compiling for cuda-depending codes in ROCM" This reverts commit 12fb41c. * Update to scripts * Change and add readme for tests and benchmarks * Remove the stuffs for supporting old ck * Remove old composable_kernel from submodule list * Remove folder third_party/composable_kernel * Rename the folder * Remove unused script file * apply black * pacify mypy * fix clang-format * reapply black * fix lints * make test_splitk_reference run on cpu * add ck modules to docs * try fixing nvidia build by re-including sparse24 cpp folder into extension sources * update cutlass to upstream commit * update flash-attention to upstream commit * simplify setup.py * remove duplicate run_batched_infer_causalmask_attnbias_dispatched<f16, true, true, 128> * add hip version and pytorch hip arch list to xformers build info * fix build * patch around the unhappy path in get_hip_version * skip test_grad_checkpointing for triton_splitk since it doesn't have bwop * re-enable test_mqa_forward since ck tiled is the current implementation * make skip test_wrong_alignment more generic * reapply black * simplify test_decoder * put python version check inside triton_splitk op * fix logic * cleanup python3.9 checks in tests * cleanup test_attentions * cleanup test_checkpoint as test running on cpu does not depend on gpu platform * fix lints * try fixing win build by conditional import of triton in triton op * re-enable test_triton_layernorm as it passes * re-enable test_triton_blocksparse as it passes * cleanup test_sparse_tensors * cleanup test_custom_ops * reapply black * cleanup test_core_attention * benchmark ck ops on rocm only * fix mypy * fix lint: black * fix lints: mypy * Rename HDim/headdim to MaxK/maxk * Move some headers files to ck examples for later reusing * Replace using qs_ks_vs pipeline by qr_ks_vs pipeline while HeadDim is 256 for better performance * rm test_ck_7 * fix lints * unskip test_unsupported_alignment * move out test_splitk_reference * add license header to file created in prev commit * roll back fmha/common.py ... so users are forced to provide rank-5 inputs for mqa/gqa * fix lint * remove unused ref_attention_mqa * resolve error in triton_splitk on rocm > Triton does not support if expressions (ternary operators) with dynamic conditions, use if statements instead * disable partial attention tests on rocm --------- Co-authored-by: Max Podkorytov <[email protected]> Co-authored-by: Grigory Sizov <[email protected]>
- Loading branch information