Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvement in ROCM fmha-backward #1082

Merged
merged 696 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from 250 commits
Commits
Show all changes
696 commits
Select commit Hold shift + click to select a range
1a3ce52
Change the branch for composable_kernel_tiled submodule and update to…
qianfengz Feb 4, 2024
f7bf9b4
Remove the using of seqlen_cpu in BwOp of ck.py
qianfengz Feb 4, 2024
15d2a72
Remove the using of seqlen_cpu in BwOp of ck.py
qianfengz Feb 4, 2024
bcd1936
Align .clang_format with main branch and re-format c++ files
qianfengz Feb 4, 2024
52ae8a3
Synchronize to latest ck-tiled commit
qianfengz Feb 4, 2024
af2aa86
Merge branch 'ck-tiled-fa' into develop
qianfengz Feb 4, 2024
7dd3aee
Add checking of IS_CK_TILED into some testing scripts
qianfengz Feb 4, 2024
5eb1235
Update to test_mem_eff_attention.py and ck.py
qianfengz Feb 5, 2024
dc0e67a
Merge branch 'ck-tiled-fa' into develop
qianfengz Feb 5, 2024
58e6101
Building xformers using ck-tiled as default
qianfengz Feb 5, 2024
1276abc
Merge branch 'ck-tiled-fa' into develop
qianfengz Feb 5, 2024
389dfb4
ensure ck_decoder does not dispatch
tenpercent Feb 5, 2024
f8d9043
Add disable_on_rocm on some test scripts
qianfengz Feb 5, 2024
78df6a9
Merge branch 'ck-tiled-fa' into develop
qianfengz Feb 5, 2024
6dae63c
Update to test_mem_eff_attention.py
qianfengz Feb 5, 2024
a7ed88c
Merge branch 'ck-tiled-fa' into develop
qianfengz Feb 5, 2024
20e178a
Merge pull request #16 from ROCm/fix_test_attn_bias_padded
qianfengz Feb 6, 2024
0624c92
apply isort
tenpercent Feb 6, 2024
b8ebf08
apply black
tenpercent Feb 6, 2024
3b33c5d
fix flake8 suggestions
tenpercent Feb 6, 2024
0a9c933
add license headers and reapply black
tenpercent Feb 6, 2024
47367a4
Merge pull request #17 from ROCm/linters
qianfengz Feb 6, 2024
fb46611
Merge pull request #10 from ROCm/enable-ci
qianfengz Feb 6, 2024
28d3672
Tiny update to rocm_ci.yml
qianfengz Feb 6, 2024
12fb41c
Add conditional compiling for cuda-depending codes in ROCM
qianfengz Feb 6, 2024
a9d83c6
Update to benchmark scripts
qianfengz Feb 7, 2024
9ab3831
Rename the one script file
qianfengz Feb 7, 2024
243dc6a
Revert "Add conditional compiling for cuda-depending codes in ROCM"
qianfengz Feb 7, 2024
3240ba1
Update to scripts
qianfengz Feb 7, 2024
0c51af1
Change and add readme for tests and benchmarks
qianfengz Feb 7, 2024
f36c93b
Remove the stuffs for supporting old ck
qianfengz Feb 7, 2024
9e4582d
Remove old composable_kernel from submodule list
qianfengz Feb 7, 2024
356cafd
Remove folder third_party/composable_kernel
qianfengz Feb 7, 2024
8415b00
Merge branch 'develop' into dev_to_upstream
qianfengz Feb 7, 2024
79c554c
Rename the folder
qianfengz Feb 8, 2024
2be6c04
Remove unused script file
qianfengz Feb 8, 2024
61d875a
apply black
tenpercent Feb 9, 2024
4616121
pacify mypy
tenpercent Feb 9, 2024
832e223
fix clang-format
tenpercent Feb 9, 2024
2b2967e
reapply black
tenpercent Feb 9, 2024
89fb7d6
Merge pull request #3 from tenpercent/lints
tenpercent Feb 12, 2024
3c9d4e5
fix lints
tenpercent Feb 13, 2024
1d474c5
make test_splitk_reference run on cpu
tenpercent Feb 13, 2024
d38a684
add ck modules to docs
tenpercent Feb 13, 2024
eccbf54
try fixing nvidia build by re-including sparse24 cpp folder into exte…
tenpercent Feb 13, 2024
1ef6c20
update cutlass to upstream commit
tenpercent Feb 13, 2024
9dfec0d
update flash-attention to upstream commit
tenpercent Feb 13, 2024
9fcda18
simplify setup.py
tenpercent Feb 13, 2024
01c2bfd
Merge branch 'main' of https://github.com/facebookresearch/xformers i…
tenpercent Feb 13, 2024
58d38d4
remove duplicate run_batched_infer_causalmask_attnbias_dispatched<f16…
tenpercent Feb 13, 2024
07183f0
add hip version and pytorch hip arch list to xformers build info
tenpercent Feb 14, 2024
993a90c
fix build
tenpercent Feb 14, 2024
d4a374b
patch around the unhappy path in get_hip_version
tenpercent Feb 14, 2024
ff59f19
skip test_grad_checkpointing for triton_splitk since it doesn't have …
tenpercent Feb 15, 2024
81bcfd5
re-enable test_mqa_forward since ck tiled is the current implementation
tenpercent Feb 15, 2024
a0f7f27
make skip test_wrong_alignment more generic
tenpercent Feb 15, 2024
a0d8dcc
reapply black
tenpercent Feb 15, 2024
bc7035c
simplify test_decoder
tenpercent Feb 15, 2024
f02d0d4
put python version check inside triton_splitk op
tenpercent Feb 15, 2024
77a6c13
fix logic
tenpercent Feb 15, 2024
a7cd678
cleanup python3.9 checks in tests
tenpercent Feb 15, 2024
dea783d
cleanup test_attentions
tenpercent Feb 15, 2024
acd6b7a
cleanup test_checkpoint as test running on cpu does not depend on gpu…
tenpercent Feb 16, 2024
f467a1d
fix lints
tenpercent Feb 16, 2024
d758eac
try fixing win build by conditional import of triton in triton op
tenpercent Feb 16, 2024
21f1904
re-enable test_triton_layernorm as it passes
tenpercent Feb 17, 2024
d880c36
re-enable test_triton_blocksparse as it passes
tenpercent Feb 17, 2024
059c84f
cleanup test_sparse_tensors
tenpercent Feb 17, 2024
8aa0bdc
cleanup test_custom_ops
tenpercent Feb 17, 2024
5bc7bbe
reapply black
tenpercent Feb 17, 2024
5b4ebe4
cleanup test_core_attention
tenpercent Feb 17, 2024
473ebc7
benchmark ck ops on rocm only
tenpercent Feb 17, 2024
2a7272e
Merge branch 'main' of https://github.com/facebookresearch/xformers i…
tenpercent Feb 19, 2024
5d3247f
fix mypy
tenpercent Feb 19, 2024
9be7f8d
Merge branch 'dev_upstream' of https://github.com/ROCm/xformers into …
tenpercent Feb 20, 2024
58b0f75
fix lint: black
tenpercent Feb 21, 2024
03b7294
fix lints: mypy
tenpercent Feb 21, 2024
0666088
split-k decoder: move all tunable parameters to the top of cpp file
tenpercent Feb 8, 2024
04eec8d
apply clang-format
tenpercent Feb 21, 2024
a02ab9b
Rename HDim/headdim to MaxK/maxk
qianfengz Feb 22, 2024
fd36725
Move some headers files to ck examples for later reusing
qianfengz Feb 22, 2024
41f5ada
Merge branch 'develop' of https://github.com/ROCm/xformers into develop
qianfengz Feb 22, 2024
d8384c1
Replace using qs_ks_vs pipeline by qr_ks_vs pipeline while HeadDim is…
qianfengz Feb 22, 2024
10346df
rm test_ck_7
tenpercent Feb 22, 2024
bbfe112
Merge branch 'main' of https://github.com/facebookresearch/xformers i…
tenpercent Feb 26, 2024
dd3f4a9
Merge branch 'main' into develop
qianfengz Mar 5, 2024
08b4159
dump kernel resource usage to compilation logs similar to nv
tenpercent Mar 12, 2024
ce99d22
Merge branch 'facebookresearch:main' into develop
tenpercent Mar 13, 2024
7637c61
Merge pull request #4 from ROCm/move-splitk-tune-params
qianfengz Mar 19, 2024
2da2927
Add the c++ extension to the latest change of ck_tile/dev fwd kernel …
qianfengz Mar 20, 2024
9189e45
Add the c++ extension to use ck_tile/dev/ fmha bwd kernel
qianfengz Mar 27, 2024
28e713d
Update to add dropout for fmah backward
qianfengz Mar 27, 2024
4ef7eba
Update in attention.cpp to align efficient_attention_backward_ck inte…
qianfengz Mar 27, 2024
48a5f3e
Enable BwdOp in ck.py
qianfengz Mar 27, 2024
2e45012
Support grad_out to have different strides as out
qianfengz Mar 28, 2024
b382f23
Merge branch 'facebookresearch:main' into develop
tenpercent Mar 28, 2024
566d26f
Force seqstart_q/seqstart_k to be in device memory in ck.py
qianfengz Mar 29, 2024
fc6c4a6
Remove duplicated codes in ck_tiled_fmha_grouped_forward.h/infer.h
qianfengz Mar 29, 2024
ff0db07
Use optimized async pipeline where 8x headdim length is assumed
qianfengz Mar 29, 2024
0f4a171
Fix in batched_infer
qianfengz Mar 30, 2024
0d6b915
Update to track ck_tile/opt_padding_fa_train_xformers branch
qianfengz Apr 1, 2024
df43559
Update rocm_ci.yml
tenpercent Apr 1, 2024
4713576
Update to use the newer FmhaFwdEpilogue
qianfengz Apr 1, 2024
9c2f5ce
Merge branch 'facebookresearch:main' into develop
tenpercent Apr 1, 2024
a745c45
Update rocm_ci.yml
tenpercent Apr 1, 2024
95d0260
Update rocm_ci.yml
tenpercent Apr 1, 2024
4069efe
copy rocm_ci workflow from main branch
tenpercent Apr 1, 2024
724354c
Update rocm_ci.yml
tenpercent Apr 1, 2024
b1a1009
Update to use the newer FmhaFwdEpilogue for grouped infer/forward
qianfengz Apr 2, 2024
97e4e20
Temporarily disable the using of QRKSVSAsync() pipeline
qianfengz Apr 3, 2024
e98877a
Update rocm_ci.yml
tenpercent Apr 3, 2024
6fbd05d
Implement the ck_rand_uniform interface for generating random number …
qianfengz Apr 3, 2024
2ef3b3f
Add dropout to the infer path (needed by xformers test_dropout)
qianfengz Apr 7, 2024
930bb25
Update to support test_dropout and test_dropout_backward tests
qianfengz Apr 8, 2024
bdbc956
Update the padding method in batched_backward.h
qianfengz Apr 9, 2024
44fff29
Update the OGradDotO kernel padding method
qianfengz Apr 9, 2024
d5c2d88
Change the backward padding checking condition
qianfengz Apr 9, 2024
ce9c23c
Add batch_stride_lse/d parameters to adapt grouped mode forward/backw…
qianfengz Apr 10, 2024
dafea78
Fill the grad_bias in advance
qianfengz Apr 10, 2024
06ad689
Add support for kHasBiasGrad as instance template
qianfengz Apr 11, 2024
bdd6291
Remove using hdim_stride_do in fmha backward
qianfengz Apr 11, 2024
410f814
Force kPadSeqLenQ/kPadSeqLenK to be true in batched-backward to save …
qianfengz Apr 11, 2024
2712dff
Fix missing passing of {philox_seed, philox_offset} in inference path
qianfengz Apr 12, 2024
7c27a82
Use SimplifiedGenericAttentionMask to replace GenericAttentionMask
qianfengz Apr 14, 2024
46c491e
Shorten the instance file names
qianfengz Apr 14, 2024
4c6c08d
Rename the template parameters
qianfengz Apr 14, 2024
411ccd6
Simplify the names of the dispatch class and interfaces
qianfengz Apr 15, 2024
812a529
Changes to reuse the kernel files under ck_tile examples/91_tile_prog…
qianfengz Apr 16, 2024
51b4223
Update test_mem_eff_attention.py for test_dropout/test_dropout_backwa…
qianfengz Apr 16, 2024
d10ef79
Tiny change to the philox_cuda_state input setting
qianfengz Apr 16, 2024
25bd720
Allocate logsumexp to ensure aligned access by each thread-group
qianfengz Apr 16, 2024
abfdc27
Add checking for query/key headdim size attention_backward_generic
qianfengz Apr 16, 2024
ff95367
Using ck_tile/opt_padding_fa_train_pr2 and synchronize the backward c…
qianfengz Apr 22, 2024
93469ab
Enable using async pipeline in the batched inference path for perform…
qianfengz Apr 22, 2024
2c8626b
Re-organize cpp instances for calling fmha infer kernel
qianfengz Apr 23, 2024
bdd716c
Re-organize cpp instances for calling fmha forward kernel
qianfengz Apr 23, 2024
44d4592
Re-organize cpp instances for calling fmha backward kernel
qianfengz Apr 23, 2024
51ca91b
Position the composable_kernel_tiled to ck_tile/opt_padding_fa_train …
qianfengz Apr 23, 2024
1693683
Update to synchronize with the latest commits in ck_tile/opt_padding_…
qianfengz Apr 23, 2024
b7aa908
update submodule to public
carlushuang Apr 26, 2024
9a878d9
Merge pull request #7 from ROCm/origin/test_opt_padding_train_public
qianfengz Apr 26, 2024
b4fa26d
Update to the criteria for padding seqlen_k in batched infer/forward
qianfengz May 6, 2024
ee7950f
Keep latest track of ck-tile commits
qianfengz May 6, 2024
74dfdfe
Tiny fixing to the decoder including
qianfengz May 8, 2024
410757e
Position the ck-tiled to ck_tile/opt_padding branch
qianfengz May 9, 2024
fa155eb
Merge branch 'test_opt_padding_train' of https://github.com/ROCm/xfor…
qianfengz May 9, 2024
77514d5
Merge branch 'develop' into test_opt_padding_train
qianfengz May 9, 2024
92924d4
Enable some attn_bias types which were previously disabled by old-ck …
qianfengz May 11, 2024
23f64bd
Add script generate_instances.py which helps to generate instances
qianfengz May 14, 2024
d94b2c1
Simplify logic for seqstart_q/k
xw285cornell May 15, 2024
2486b56
Add Async pipeline to grouped mode inference path
qianfengz May 15, 2024
18b43c9
Use explict true for kPadSeqLenQ/kPadHeadDimQ/kPadHeadDimV templates …
qianfengz May 15, 2024
cf6cca0
Merge pull request #11 from xw285cornell/develop
qianfengz May 16, 2024
14f7abe
Synchronize to the update of composable_kernel_tiled for better perfo…
qianfengz May 21, 2024
ee4aa87
Update rocm_ci.yml - clean up dangling images after ci run
tenpercent May 23, 2024
b0b5547
Avoid unused-const-variable warning
xw285cornell May 25, 2024
dfc196d
Tiny change in the BlockTile/Shape setting overriddings
qianfengz May 29, 2024
2490166
Merge branch 'develop' of https://github.com/ROCm/xformers into develop
qianfengz May 29, 2024
f50861a
try to align fmha C++ extension to the ck_tile in ck develop branch
qianfengz Jun 12, 2024
76fb485
Synchronize composable_kernel_tiled to latest ck develop
qianfengz Jun 13, 2024
1f3add7
Use FmhaFwdTilePartitioner_HBS only with seqlen_k padded cases
qianfengz Jun 14, 2024
ed226f4
Merge branch 'main' into develop
qianfengz Jun 16, 2024
9df93e5
Tiny fix/change to make test_forward/test_backward/test_dropout/test_…
qianfengz Jun 17, 2024
d6ccfa1
Fix compiling issue with regard to Invoker definitions in forward_dec…
qianfengz Jun 17, 2024
a7c7475
Keep using -Woverloaded-virtual
qianfengz Jun 18, 2024
b157b49
Fix clang-format for headers and cpp files
qianfengz Jun 18, 2024
b2fb213
Fix format in python scripts
qianfengz Jun 18, 2024
fdf8b8e
Add noqa: C801 for generate_instances.py
qianfengz Jun 18, 2024
633a161
Align dispatch_bw with main branch
qianfengz Jun 19, 2024
00cf683
Align ops/fmha/common.py with main branch
qianfengz Jun 19, 2024
252844d
Synchronize the thirty-party/composable_kernel_tiled to latest ck_til…
qianfengz Jun 20, 2024
610909e
Relax the atol for test_forward and test_dropout due to the using of …
qianfengz Jun 20, 2024
10bf99c
Generate html report for tests run with rocm_ci.yml
tenpercent Jul 1, 2024
16bb10b
archive test results when tests have failed
tenpercent Jul 1, 2024
29c782b
Always clean up dangling docker images in rocm_ci
tenpercent Jul 1, 2024
782d5a3
Bump python to 3.11 in rocm_ci.yml
tenpercent Jul 2, 2024
bd8ca1b
Disable flash attention tests rocm_ci.yml
tenpercent Jul 2, 2024
77beb19
Try to fix rocm_ci.yml
tenpercent Jul 2, 2024
b0ae707
try to fix rocm_ci.yml flow by overriding PATH
tenpercent Jul 2, 2024
d2eeaf0
Fix setup.py path in rocm_ci.yml
tenpercent Jul 2, 2024
a62c93e
cd to xformers dir before running install in rocm_ci.yml
tenpercent Jul 2, 2024
d3ae25f
Use pip to install xformers in rocm_ci.yml
tenpercent Jul 2, 2024
d4e6abc
Possibly fix python version resolution in rocm_ci.yml
tenpercent Jul 2, 2024
490b63d
Set the correct path for pytest in rocm_ci.yml
tenpercent Jul 2, 2024
addd2f2
remove test_reference_splitk as it was moved to a different file duri…
tenpercent Jul 2, 2024
33810ff
make sure ck operators have a name to be visible in the dispatcher
tenpercent Jul 3, 2024
f3faa1a
fix sm version checks to happen only on CUDA, not ROCm
tenpercent Jul 8, 2024
04e9481
(2/n) fix sm version checks to happen only on CUDA, not ROCm
tenpercent Jul 8, 2024
9440282
Merge pull request #13 from xw285cornell/xdwang-develop
qianfengz Jul 9, 2024
bd49f48
Remove _check_large_shapes checking in fmha/ck.py (#1067)
qianfengz Jul 14, 2024
0d1d1be
make xformers install editable to fix cpp extensions detection
tenpercent Jul 18, 2024
9390d6a
Update to using the improved fmha-bwd (compiling passed)
qianfengz Jul 23, 2024
22fce7e
Update to get 80% of the test_backward and test_dropout_backward_ck c…
qianfengz Jul 23, 2024
463a475
Replace the using of ConvertGradQ by using torch tensor type converting
qianfengz Jul 25, 2024
3427a6f
Change the tile settings for MaxK=32
qianfengz Jul 25, 2024
fbc7c50
Fix padding setting bug in grouped_backward
qianfengz Jul 26, 2024
6e08666
Change -DCK_FMHA_FWD_FAST_EXP2=1 to -DCK_TILE_FMHA_FWD_FAST_EXP2=1
qianfengz Jul 26, 2024
94ab599
Point the composable_kernel_tiled submodule to ck_tile/fa_bwd_opt branch
qianfengz Jul 26, 2024
830697c
Disable flshattF and flshattB on ROCM
qianfengz Jul 27, 2024
afd7e02
Add -mllvm and -enable-post-misched=0 compiling options for ROCM on s…
qianfengz Jul 27, 2024
e67de41
Disable flshattF and flshattB on ROCM
qianfengz Jul 27, 2024
d72c2b3
Update to support separate grad_q_f32_strides do to the API change in…
qianfengz Jul 28, 2024
5ddff31
Use old method for setting BlockDropout due to the revert in fmha_fwd…
qianfengz Jul 28, 2024
cf2b622
Tiny fix in grouped_backward
qianfengz Jul 28, 2024
112aaed
Use packed tensor allocation for grad_q_f32
qianfengz Jul 28, 2024
dd83c62
Update to the ConvertGradQ kernel calling
qianfengz Jul 28, 2024
3e9b99d
Tiny update
qianfengz Jul 28, 2024
019448e
Fix the parameter location in grouped_backward
qianfengz Jul 29, 2024
c55966a
Adjust headdim128 tile shapes for better performance
qianfengz Aug 5, 2024
e22829a
Update backward kernel calling due to adding of nhead_stride_dk/nhead…
qianfengz Aug 5, 2024
cae1b77
Synchronize with CK to use separate pipeline for kPadHeadDim true of …
qianfengz Aug 5, 2024
e564f5e
Use convertDQ kernel
qianfengz Aug 6, 2024
b043765
Update to use unpadded lse layout
qianfengz Aug 7, 2024
c9e7595
Add explicit headdim256 instances for fmha backward
qianfengz Aug 7, 2024
4a7b7dc
Add leaked headdim256 instance references
qianfengz Aug 7, 2024
1ad9cbe
Change to generate.py and the re-generate the instance files using it
qianfengz Aug 7, 2024
7db2aa4
Change to generate.py to generate instances refences and uses the gen…
qianfengz Aug 7, 2024
73dbf32
Relax the RTOL of ckFwOp from 4e-4 to 3e-3 due to one big result case
qianfengz Aug 8, 2024
0e6d0c3
Change to use .h rather than .hpp as suffix for generated header files
qianfengz Aug 12, 2024
914ccc5
Fix in .gitignore
qianfengz Aug 12, 2024
8503f87
Update to bwd setting to use only IGLP pipeline
qianfengz Aug 12, 2024
bfe164d
Synchronize to latest ck_tile fix and align the headdim64 tile shape …
qianfengz Aug 12, 2024
f75c3b2
Reformat the generated instances cpp files
qianfengz Aug 12, 2024
520e6ed
Merge pull request #18 from ROCm/fa_bwd_opt_test
qianfengz Aug 12, 2024
bc3db99
Fix to the backward Trait
qianfengz Aug 13, 2024
fa6d8b3
Set occupancy to -1 to avoid the compiling warning
qianfengz Aug 13, 2024
c5c7cce
Revert "Set occupancy to -1 to avoid the compiling warning"
qianfengz Aug 13, 2024
d230433
Add environment variable and compiler definition to control the gener…
qianfengz Aug 14, 2024
82a07ae
Add --ignore-hd256 argument to generate_instance.py and some update i…
qianfengz Aug 14, 2024
38593d6
Add environment variable ENABLE_HIP_FMHA_RTN_BF16_CONVERT to enable u…
qianfengz Aug 15, 2024
15dc911
Remove commented lines in test_mem_eff_attention.py
qianfengz Aug 15, 2024
367274c
Synchronize to latest ck_tile commit
qianfengz Aug 15, 2024
f7b28c5
apply black
tenpercent Aug 16, 2024
fd82f20
apply flake8
tenpercent Aug 16, 2024
7d21800
fix mypy
tenpercent Aug 16, 2024
d6b6456
revert disable flash operator on rocm
tenpercent Aug 16, 2024
87188ea
Synchronize to ck_tile latest commit again
qianfengz Aug 16, 2024
5be80a3
Re-position the composable_kernel submodule to the develop branch
qianfengz Aug 17, 2024
cee0980
Merge pull request #20 from tenpercent/develop
qianfengz Aug 17, 2024
2a5c141
Avoid the Async pipeline when khasBias is true
qianfengz Aug 17, 2024
2874842
clang-format for two files
qianfengz Aug 17, 2024
cbb557d
Merge branch 'main' into upstream_pr
qianfengz Aug 17, 2024
1a73f34
Change allocation of grouped mode lse from [H, M] to [1, H, M] to mat…
qianfengz Aug 17, 2024
4440714
Synchronize to the upstream rocm_ci workflows
qianfengz Aug 17, 2024
db2b52e
Re-format tests/test_mem_eff_attention.py
qianfengz Aug 17, 2024
d293caf
Change in generate_instances.py so that this scripts can be called fr…
qianfengz Aug 20, 2024
8eb1bbd
Merge branch 'upstream_pr' of https://github.com/ROCm/xformers into u…
qianfengz Aug 20, 2024
ee9640a
Add GENERATE_INSTANCES.md
qianfengz Aug 20, 2024
3cf5721
clean-up commented codes
qianfengz Aug 20, 2024
01cc08e
Remove un-used test
qianfengz Aug 21, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
5 changes: 3 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ xformers/csrc/attention/hip_fmha/*.hip
xformers/csrc/attention/hip_fmha/*_hip.h
xformers/csrc/attention/hip_fmha/instances/*.cu
xformers/csrc/attention/hip_fmha/instances/*.hip
xformers/csrc/attention/hip_fmha/instances_tiled/*.cu
xformers/csrc/attention/hip_fmha/instances_tiled/*.hip
xformers/csrc/attention/hip_fmha/instances/*.cu
xformers/csrc/attention/hip_fmha/instances/*.hip
xformers/csrc/attention/hip_fmha/instances/*_hip.h

19 changes: 18 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,14 @@ def get_extensions():
"--ptxas-options=-allow-expensive-optimizations=true",
]
elif torch.cuda.is_available() and torch.version.hip:
disable_hd256_hip_fmha = os.getenv("DISABLE_HD256_HIP_FMHA", "0")
if disable_hd256_hip_fmha == "1":
source_hip_maxk_256 = []
for ff in source_hip:
if ff.endswith("maxk_256.cpp"):
source_hip_maxk_256 += [ff]
source_hip = list(set(source_hip) - set(source_hip_maxk_256))

rename_cpp_cu(source_hip)
rocm_home = os.getenv("ROCM_PATH")
hip_version = get_hip_version(rocm_home)
Expand All @@ -436,9 +444,16 @@ def get_extensions():
Path(this_dir) / "third_party" / "composable_kernel_tiled" / "include"
]

use_rtn_bf16_convert = os.getenv("ENABLE_HIP_FMHA_RTN_BF16_CONVERT", "0")

generator_flag = []
if disable_hd256_hip_fmha == "1":
generator_flag += ["-DFMHA_SUPPORT_MAX_HEADDIM_128=1"]

cc_flag = ["-DBUILD_PYTHON_PACKAGE"]
if use_rtn_bf16_convert == "1":
cc_flag += ["-DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=0"]

extra_compile_args = {
"cxx": ["-O3", "-std=c++17"] + generator_flag,
"nvcc": [
Expand All @@ -447,10 +462,12 @@ def get_extensions():
f"--offload-arch={os.getenv('HIP_ARCHITECTURES', 'native')}",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-DCK_FMHA_FWD_FAST_EXP2=1",
"-DCK_TILE_FMHA_FWD_FAST_EXP2=1",
"-fgpu-flush-denormals-to-zero",
"-Werror",
"-Woverloaded-virtual",
"-mllvm",
"-enable-post-misched=0",
]
+ generator_flag
+ cc_flag,
Expand Down
51 changes: 39 additions & 12 deletions tests/test_mem_eff_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,16 @@
if torch.cuda.is_available():
compute_capability = torch.cuda.get_device_capability("cuda")
sm70_or_better_only = pytest.mark.skipif(
compute_capability < (7, 0), reason="requires sm70+"
torch.version.cuda is not None and compute_capability < (7, 0),
reason="requires sm70+",
)
sm75_or_better_only = pytest.mark.skipif(
compute_capability < (7, 5), reason="requires sm75+"
torch.version.cuda is not None and compute_capability < (7, 5),
reason="requires sm75+",
)
sm80_or_better_only = pytest.mark.skipif(
compute_capability < (8, 0), reason="requires sm80+"
torch.version.cuda is not None and compute_capability < (8, 0),
reason="requires sm80+",
)
skip_if_rocm = pytest.mark.skipif(
torch.version.hip is not None, reason="not supported on ROCm"
Expand Down Expand Up @@ -667,16 +670,8 @@ def test_backward(

if op_bw == fmha.ck.BwOp:
op_fw = fmha.ck.FwOp
if dtype == torch.bfloat16:
pytest.skip(
"CK Fmha backward for bfloat16 currently is not very accurate for some cases!"
)
if grad_out_contiguous is False:
pytest.skip("CK Fmha does not support contiguous layout for grad_out!")
if k % 2 != 0:
pytest.skip(
"CK Fmha currently requires the headdim size of query input be an even value!"
)

qkv = None

Expand Down Expand Up @@ -1003,6 +998,38 @@ def test_dropout_backward_ck(dt, q_len, kv_len, batch_size, k, p):
)


@cuda_only
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this test got here as merge conflict resolution gone bad?

@disable_tf32
@disable_on_rocm
@pytest.mark.parametrize("k_len", [32])
@pytest.mark.parametrize("batch_size", [1])
@pytest.mark.parametrize("kv_len", [3 * 32])
@pytest.mark.parametrize("q_len", [3 * 32])
def test_memory_efficient_attention_full_block_masked(q_len, kv_len, batch_size, k_len):
device = "cuda"
op_fw = fmha.small_k.FwOp
op_bw = fmha.small_k.BwOp

scale = 3
query = torch.randn((batch_size, q_len, k_len), device=device) * scale
key = torch.randn((batch_size, kv_len, k_len), device=device) * scale
value = torch.randn((batch_size, kv_len, k_len), device=device) * scale

# in this case, most of the blocks in a row get masked
attn_bias = torch.full((3, 32), float("-inf"), device=device)
attn_bias[:2, :4] = 0
attn_bias = attn_bias.flatten()[None, None, :].expand(1, q_len, -1)

out = xformers.ops.memory_efficient_attention(
query, key, value, attn_bias, op=(op_fw, op_bw)
)
ref = ref_attention_for_test(query, key, value, attn_bias)

assert_allclose(
out, ref, atol=op_fw.ERROR_ATOL[query.dtype], rtol=op_fw.ERROR_RTOL[query.dtype]
)


@pytest.mark.parametrize("fmt", ["BMK", "BMHK"])
@parametrize_opBW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv__xs
def test_lowlevel_api_shapes(opBW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv, fmt):
Expand Down Expand Up @@ -1581,7 +1608,7 @@ def test_decoder(
# kv_heads = 1: multiquery
# kv_heads = None: neither MQA nor GQA
# kv_heads > 1: BMGHK
if dtype == "bf16" and compute_capability < (8, 0):
if dtype == "bf16" and torch.version.cuda and compute_capability < (8, 0):
raise pytest.skip("BF16 is only supported on SM80+")
import triton

qianfengz marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
2 changes: 1 addition & 1 deletion third_party/composable_kernel_tiled
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,6 @@ efficient_attention_backward_ck(
int64_t K = query.size(3);
int64_t Kv = value.size(3);

if (K % 2 != 0)
throw std::runtime_error(
"Currently CK Fmha requires the headdim of query/key be an even value!");

auto opts = query.options();

at::Tensor grad_q, grad_k, grad_v, grad_bias;
Expand All @@ -143,7 +139,6 @@ efficient_attention_backward_ck(
grad_q = chunk.select(2, 0);
grad_k = chunk.select(2, 1);
grad_v = chunk.select(2, 2);
grad_q.fill_(0);
} else if (
key.size(3) == value.size(3) &&
key.storage().is_alias_of(value.storage())) {
Expand All @@ -157,14 +152,24 @@ efficient_attention_backward_ck(
grad_v = chunk.select(2, 1);

grad_q = at::empty_strided(query.sizes(), query.strides(), query.options());
grad_q.fill_(0);
} else {
grad_q = at::empty_strided(query.sizes(), query.strides(), query.options());
grad_k = at::empty_strided(key.sizes(), key.strides(), key.options());
grad_v = at::empty_strided(value.sizes(), value.strides(), value.options());
grad_q.fill_(0);
}

at::Tensor grad_q_f32;
const bool use_grad_q_f32 =
(query.scalar_type() == at::ScalarType::BFloat16 ||
query.scalar_type() == at::ScalarType::Half);

if (use_grad_q_f32) {
grad_q_f32 = at::empty(grad_q.sizes(), opts.dtype(at::kFloat));
grad_q_f32.fill_(0);
} else {
grad_q.fill_(0);
};

// CK-FlashAttn requires q/k/v to have same shapes with dQ/dK/dV respectively
TORCH_CHECK(query.sizes() == grad_q.sizes());
TORCH_CHECK(query.strides() == grad_q.strides());
Expand Down Expand Up @@ -211,7 +216,7 @@ efficient_attention_backward_ck(

TORCH_CHECK(p.B == logsumexp.size(0));
TORCH_CHECK(p.Hq == logsumexp.size(1));
TORCH_CHECK(p.M <= logsumexp.size(2));
TORCH_CHECK(p.M == logsumexp.size(2));

if (scale.has_value()) {
p.scale = float(*scale);
Expand All @@ -229,6 +234,11 @@ efficient_attention_backward_ck(
p.grad_k_ptr = is_mqa_gqa ? tmp_grad_k.data_ptr() : grad_k.data_ptr();
p.grad_v_ptr = is_mqa_gqa ? tmp_grad_v.data_ptr() : grad_v.data_ptr();

if (use_grad_q_f32)
p.grad_q_f32_ptr = grad_q_f32.data_ptr();
else
p.grad_q_f32_ptr = nullptr;

p.q_strides = {
static_cast<int>(query.stride(0)),
static_cast<int>(query.stride(1)),
Expand Down Expand Up @@ -260,6 +270,14 @@ efficient_attention_backward_ck(
static_cast<int>(logsumexp.stride(1)),
static_cast<int>(logsumexp.stride(2))};

if (use_grad_q_f32) {
p.grad_q_f32_strides = {
static_cast<int>(grad_q_f32.stride(0)),
static_cast<int>(grad_q_f32.stride(1)),
static_cast<int>(grad_q_f32.stride(2)),
static_cast<int>(grad_q_f32.stride(3))};
}

if (is_mqa_gqa) {
p.grad_k_strides = {
static_cast<int>(tmp_grad_k.stride(0)),
Expand Down Expand Up @@ -335,9 +353,9 @@ efficient_attention_backward_ck(
p.max_seqlen_q = *max_seqlen_q_;
p.max_seqlen_k = *max_seqlen_k_;

TORCH_CHECK(p.num_batches == logsumexp.size(0));
// unpadded lse layout required
TORCH_CHECK(p.Hq == logsumexp.size(1));
TORCH_CHECK(p.max_seqlen_q <= logsumexp.size(2));
TORCH_CHECK(p.M == logsumexp.size(2));

if (scale.has_value())
p.scale = float(*scale);
Expand Down Expand Up @@ -366,10 +384,16 @@ efficient_attention_backward_ck(
static_cast<int>(grad_out.stride(3))};

p.lsed_strides = {
static_cast<int>(logsumexp.stride(0)),
static_cast<int>(logsumexp.stride(1)),
static_cast<int>(logsumexp.stride(2))};

if (use_grad_q_f32) {
p.grad_q_f32_strides = {
static_cast<int>(grad_q_f32.stride(1)),
static_cast<int>(grad_q_f32.stride(2)),
static_cast<int>(grad_q_f32.stride(3))};
}

if (is_mqa_gqa) {
p.grad_k_strides = {
static_cast<int>(tmp_grad_k.stride(1)),
Expand Down Expand Up @@ -480,6 +504,11 @@ efficient_attention_backward_ck(
p.grad_k_ptr = is_mqa_gqa ? tmp_grad_k.data_ptr() : grad_k.data_ptr();
p.grad_v_ptr = is_mqa_gqa ? tmp_grad_v.data_ptr() : grad_v.data_ptr();
p.grad_bias_ptr = bias_requires_grad ? grad_bias.data_ptr() : nullptr;

if (use_grad_q_f32)
p.grad_q_f32_ptr = grad_q_f32.data_ptr();
else
p.grad_q_f32_ptr = nullptr;
};

auto inDataType = query.scalar_type();
Expand Down Expand Up @@ -515,6 +544,14 @@ efficient_attention_backward_ck(
grad_v = tmp_grad_v_view.sum(3);
}

/*
jianyuh marked this conversation as resolved.
Show resolved Hide resolved
if (inDataType == at::ScalarType::Half)
grad_q = grad_q_f32.to(torch::kFloat16);

if (inDataType == at::ScalarType::BFloat16)
grad_q = grad_q_f32.to(torch::kBFloat16);
*/

return std::make_tuple(grad_q, grad_k, grad_v, grad_bias);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ template <
int32_t ThreadsPerWavefront,
int32_t WavefrontsPerBlock,
int32_t KV_M_MAX = 8192,
int32_t K_MAX = 256>
int32_t K_MAX = K_MAX>
at::Tensor& efficient_attention_forward_decoder_ck_out_impl(
const at::Tensor& XQ, // [B, 1, G, H, D]
const at::Tensor& cache_K, // [B, KV_M_MAX, G, H or 1, D]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -316,18 +316,14 @@ efficient_attention_forward_ck(
p.dropout_prob = 0.0f;

if (p.compute_logsumexp) {
// align the access of logsumexp by each thread-group in cache-line size
int aligned_seqlen_q = (p.max_seqlen_q + 15) / 16 * 16;
logsumexp = at::empty(
{p.num_batches, Hq, aligned_seqlen_q}, opts.dtype(at::kFloat));
logsumexp = at::empty({1, Hq, M}, opts.dtype(at::kFloat));
p.logsumexp_ptr = logsumexp.data_ptr();
p.lse_strides = {
static_cast<int>(logsumexp.stride(0)),
static_cast<int>(logsumexp.stride(1)),
static_cast<int>(logsumexp.stride(2))};
} else {
p.logsumexp_ptr = nullptr;
p.lse_strides = {0, 0, 0};
p.lse_strides = {0, 0};
}
};

Expand Down
Loading
Loading