Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fast Multi-ahead Attention support on AMD ROCM #978

Merged
merged 540 commits into from
Mar 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
540 commits
Select commit Hold shift + click to select a range
bc23333
add option to build a standalone runner for splitk decoder; debugging…
tenpercent Dec 5, 2023
2c7b9bb
fix a few bugs
tenpercent Dec 6, 2023
709727f
fix an indexing bug
tenpercent Dec 6, 2023
785481c
stash changes
tenpercent Dec 6, 2023
ff0ebdb
Add benchmark_mem_eff_attn_mqa_gqa_ck_tiled.py to benchmark mqa/gqa p…
qianfengz Dec 8, 2023
9a8baf7
Synchronize with latest update in composable_kernel_tiled feature/fmh…
qianfengz Dec 8, 2023
959ae7f
Tiny fix in benchmark_mem_eff_attn_mqa_gqa_ck_tiled.py
qianfengz Dec 8, 2023
cc2f487
Synchronize with latest update in composable_kernel_tiled and make al…
qianfengz Dec 11, 2023
2162b45
Swith to new branch for composable_kernel_tiled submodule
qianfengz Dec 11, 2023
d6cf545
Add bfp16 instances for ck-tiled inference
qianfengz Dec 11, 2023
5cfda98
Update to test and benchmark scripts to include bfloat16
qianfengz Dec 11, 2023
ab60547
Tiny update to ck_tiled kernel
qianfengz Dec 11, 2023
a2af789
Change to benchmark_mem_eff_attn_mqa_gqa_ck_tiled benchmark cases
qianfengz Dec 11, 2023
d957dd9
stash changes
tenpercent Dec 11, 2023
40aa884
Use Async pipeline for no M/N0K1 padding cases
qianfengz Dec 11, 2023
73e97d8
Add CF_FMHA_FWD_FAST_EXP2 to buiding
qianfengz Dec 11, 2023
b0c7023
Add Triton FA2 forward op
sgrigory Dec 12, 2023
63c3523
Add Triton Flash Attention 2 to benchmarks
sgrigory Dec 12, 2023
fbd836a
Synchronize with latest third_party/composable_kernel and remove the …
qianfengz Dec 12, 2023
0d15f1b
stash split attention testing wip
tenpercent Dec 13, 2023
5c1bc54
Synchronize with latest third_party/composable_kernel again
qianfengz Dec 13, 2023
0172147
Merge branch 'develop' into ck-tiled-fa
qianfengz Dec 13, 2023
a018550
Synchronize with latest third_party/composable_kernel_tiled
qianfengz Dec 13, 2023
31da32e
Change to make ck decoder buildable with both ck tiled or non-tiled f…
qianfengz Dec 13, 2023
22c8d6f
Change to make ck decoder buildable with both ck tiled or non-tiled f…
qianfengz Dec 13, 2023
6428374
fix gqa for split-k=1
tenpercent Dec 13, 2023
f21e39a
Skip backward tests, fix import
sgrigory Dec 17, 2023
6c5540c
fix the mask for decoding; row max and lse are computed correctly; de…
tenpercent Dec 18, 2023
5225eef
make libtorch split-1 decoder implementation pass numerical correctness
tenpercent Dec 19, 2023
45727d6
Disable CK kernel for large shapes, better catch OOMs
sgrigory Dec 20, 2023
de5098e
Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/xfo…
qianfengz Dec 24, 2023
402ee91
Actually remove submodule composable_kernel_tiled from the branch
qianfengz Dec 24, 2023
7904096
Change the domain for the repo of composable_kernel submodule to ROCm
qianfengz Dec 24, 2023
defb8d9
Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/xfo…
qianfengz Dec 24, 2023
388a5ca
Merge pull request #5 from ROCmSoftwarePlatform/merge-upstream-merge
qianfengz Dec 26, 2023
b068558
Merge branch 'develop' into ck-tiled-fa
qianfengz Dec 26, 2023
44f6160
Update to validate_inputs() in common.py to support 4d mqa/gqa
qianfengz Dec 26, 2023
e03f67a
synchronize test_mem_eff_attention_ck.py with test_mem_eff_attention.py
qianfengz Dec 27, 2023
6aef46d
Tiny update in benchmark_mem_eff_attn_decoder_ck.py
qianfengz Dec 28, 2023
4a1cea0
Synchronize benchmark_mem_eff_attention_ck.py with benchmark_mem_eff_…
qianfengz Dec 28, 2023
ad024e4
Merge branch 'develop' into ck-tiled-fa
qianfengz Dec 28, 2023
c5ca494
Remove benchmark_mem_eff_attn_decoder_ck_tiled.py
qianfengz Dec 28, 2023
a74ee16
Merge branch 'develop' into decoder-splitk
tenpercent Jan 3, 2024
8ebfd5f
Support for Generic Attention Mask Coordinate
qianfengz Jan 3, 2024
43e7797
Merge pull request #6 from sgrigory/add-triton-fa2
qianfengz Jan 5, 2024
ba5fd52
Add ck.FwOp and ck.BwOp to dispatched operations
qianfengz Jan 5, 2024
6533aca
Add ck.FwOp and ck.BwOp to ALL_FW_OPS and ALL_BW_OPS
qianfengz Jan 5, 2024
7fc3620
Update in tests/readme_test_on_rocm.txt
qianfengz Jan 5, 2024
23e191a
Add ckF and ck_decoder to benchmark_mem_eff_attn_decoder.py
qianfengz Jan 5, 2024
b077cfc
Merge branch 'develop' into ck-tiled-fa
qianfengz Jan 5, 2024
45287b7
Synchronize with the latest ck-tiled commits
qianfengz Jan 8, 2024
1a74675
Add is_ck_tiled_used() c++ extension interface for judging if ck-tile…
qianfengz Jan 8, 2024
cbcc196
Remove composable_kernel_tiled submodule
qianfengz Jan 9, 2024
b4539f7
inner_product removed from splitk kernel code
tenpercent Jan 3, 2024
9c52e0e
remove some commented out debug code
tenpercent Jan 3, 2024
0a1aa5d
comment out debug code calling libtorch instead of hip implementation
tenpercent Jan 3, 2024
153d722
remove commented out old and incorrect code fragments
tenpercent Jan 3, 2024
eea5fef
add python version override to cmakelists
tenpercent Jan 3, 2024
d442fbe
add conversion from Argument struct to string; fix split1 test crash
tenpercent Jan 4, 2024
38c5e90
add f32 support in the python op
tenpercent Jan 5, 2024
b805813
refactor out input generation in cpp standalone
tenpercent Jan 5, 2024
03aed21
set loop unrolls to 1 in order to avoid index errors (will need to be…
tenpercent Jan 6, 2024
930dda1
fix output splits allocation
tenpercent Jan 8, 2024
bd50cf4
fix bug in split attention: sumexp needs timestep bounds in each split
tenpercent Jan 9, 2024
60c997d
clang-format-10
tenpercent Jan 9, 2024
b655ded
Merge remote-tracking branch 'origin/develop' into decoder-splitk
tenpercent Jan 9, 2024
588b3a0
Enable support of attn-bias types with LocalAttention
qianfengz Jan 10, 2024
04cf84b
Enable support of attn-bias types with LocalAttention
qianfengz Jan 10, 2024
a27403c
Synchronize submodule composable_kernel to the latest commits
qianfengz Jan 10, 2024
dfc2618
Make the efficient_attention_forward_ck() C++ interface consistent wi…
qianfengz Jan 10, 2024
5421612
Tiny fix in ck.py to make test_backward pass
qianfengz Jan 10, 2024
248efe1
Merge remote-tracking branch 'origin/develop' into decoder-splitk
tenpercent Jan 10, 2024
7948fe6
some refactorings for standalone tests
tenpercent Jan 11, 2024
e7ffe68
cleanup testing
tenpercent Jan 11, 2024
4953101
Make the efficient_attention_forward_ck() C++ interface consistent wi…
qianfengz Jan 10, 2024
e99fc1a
Tiny fix in ck.py to make test_backward pass
qianfengz Jan 10, 2024
d7721d2
fix split1 attention csrc test
tenpercent Jan 11, 2024
902910a
Enable support of flexible head-dim size (but <= 128) for ck-tiled fm…
qianfengz Jan 12, 2024
d1ef4bc
Use Async pipeline when no any padding used
qianfengz Jan 12, 2024
6cb0f60
implement general split-k split-attention in libtorch, use for testing
tenpercent Jan 12, 2024
0e04b17
fix split-max and split-sumexp shapes for split attention in libtorch
tenpercent Jan 12, 2024
e4d6b88
implement generic reduce split attention with libtorch
tenpercent Jan 13, 2024
17ec430
implement testing split reduce hip vs libtorch; tbd debug split-k=2 n…
tenpercent Jan 13, 2024
69f2f0a
refactor repetitive testing code
tenpercent Jan 15, 2024
2d54085
address code review: rearrange loops
tenpercent Jan 15, 2024
f937f06
address code review: add comment about number of iterations per split
tenpercent Jan 15, 2024
7f6b01f
address code review: remove comments
tenpercent Jan 15, 2024
187a4bc
address code review: possibly eliminate a bug by using correct timest…
tenpercent Jan 15, 2024
b157cba
address code review: add todo
tenpercent Jan 15, 2024
8581811
address code review: shift LDS access by tt_low to avoid smem overboo…
tenpercent Jan 16, 2024
b1638ad
address code review: simplify reduction loops in split attention
tenpercent Jan 16, 2024
10e76ab
Tiny update in ck-tiled forward kernel
qianfengz Jan 17, 2024
67009e0
address code review: merge for loops
tenpercent Jan 17, 2024
8673fa9
address code review: simplify coefficient pick
tenpercent Jan 17, 2024
3427dcc
fix runtime error message in testing code
tenpercent Jan 17, 2024
2e11d32
fix split reduce test
tenpercent Jan 17, 2024
dabc771
address code review: fix smem offsets
tenpercent Jan 17, 2024
6f1d5df
remove redundant comment
tenpercent Jan 17, 2024
8ee60d7
address code review: initialize split attention workspace as empty
tenpercent Jan 18, 2024
ff985d2
address code review: rename local vars
tenpercent Jan 18, 2024
d7132b9
address code review: remove unused _rand_seqlens
tenpercent Jan 18, 2024
f4d5263
address code review: cleanup python tests
tenpercent Jan 18, 2024
d81285a
remove redundant new_max local var
tenpercent Jan 18, 2024
eba46f1
address code review: rename seq_acc
tenpercent Jan 18, 2024
7f9ce55
re-enable loop unroll; adjust tests to handle splits with size divisi…
tenpercent Jan 18, 2024
f888b88
test a wider range of split-k in cpp tests; fix torch implementation …
tenpercent Jan 18, 2024
88afcea
Merge pull request #8 from ROCmSoftwarePlatform/decoder-splitk
qianfengz Jan 19, 2024
bad053f
Synchronize with ck-tiled update to support head-dim-256 and LSE storing
qianfengz Jan 19, 2024
391af2b
Add definition of FMHA_FWD_HEADDIM_SWITCH
qianfengz Jan 19, 2024
53719f9
Split the ck-tiled inference instances based on head-dim sizes to imp…
qianfengz Jan 19, 2024
92e088e
Setting k0n1_need_padding according to pipeline kQLoadOnce implementa…
qianfengz Jan 20, 2024
60a8e4a
Add fmha forward c++ extension for ck-tiled
qianfengz Jan 21, 2024
9357a24
Set SUPPORTED_MAX_K=256 in ck.py
qianfengz Jan 22, 2024
df479b5
Merge branch 'ck-tiled-fa' into develop
qianfengz Jan 22, 2024
04ddd4c
fix index in split-k attention
tenpercent Jan 24, 2024
c922d73
fix index in softmax reduce and complete fixing wavefronts per block …
tenpercent Jan 24, 2024
f666965
clang-format-10
tenpercent Jan 24, 2024
ecaf623
Fix v_dram_transposed transpose transform in the kernel
qianfengz Jan 24, 2024
8b337bd
Skipe trition_splitk for test_forward in test_mem_eff_attention.py
qianfengz Jan 24, 2024
ee577e2
cleanup commented dead code
tenpercent Jan 24, 2024
a21ac03
enable ck split-k in benchmark_attn_decoding
tenpercent Jan 24, 2024
52dde22
Merge pull request #9 from ROCmSoftwarePlatform/decoder-splitk-opt
tenpercent Jan 25, 2024
5e3213f
add rocm_ci workflow
tenpercent Jan 24, 2024
0e47337
move scipy import from file level under function similar to _vec_bino…
tenpercent Jan 25, 2024
0bf3546
Merge pull request #11 from ROCmSoftwarePlatform/tests-imports
qianfengz Jan 28, 2024
1e1dca8
Merge branch 'develop' into ck-tiled-fa
qianfengz Jan 28, 2024
360201f
Add including of math_v2.hpp to ck_attention_forward_decoder_splitk.h
qianfengz Jan 28, 2024
faf1b16
move forward_splitk to ck_splitk; make dispatch aware of ck_splitk an…
tenpercent Jan 29, 2024
323ebae
Synchronize to latest ck-tiled and update accordingly
qianfengz Jan 30, 2024
9d2be4f
fix benchmark_attn_decoding
tenpercent Jan 30, 2024
7c3c766
Remove third_party/composable_kernel_tiled
qianfengz Jan 30, 2024
708c047
[Fix] use kK0BlockLength for HeadDim256 padding judging
qianfengz Jan 30, 2024
a0f2643
Tiny type change for custom_mask_type in param class
qianfengz Jan 31, 2024
96f3027
Change to use ROCm repo for ck-tiled submodule
qianfengz Feb 1, 2024
f3f2be4
Remove tests/test_forward_ck_tiled.py
qianfengz Feb 1, 2024
34466be
Update to test_mqa_forward_ck_tiled.py to use common create_attn_bias…
qianfengz Feb 1, 2024
2f92cde
Merge branch 'ck-tiled-fa' into develop
qianfengz Feb 1, 2024
351c766
Add ck-tiled checking in test_mqa_forward_ck_tiled.py
qianfengz Feb 1, 2024
ed26f5b
Merge branch 'ck-tiled-fa' into develop
qianfengz Feb 1, 2024
b58b4ed
rearrange smem access in softmax reduction
tenpercent Feb 2, 2024
5a026c0
Merge pull request #14 from ROCm/perf-adjustment-1
qianfengz Feb 2, 2024
5bbbe8f
Merge pull request #13 from ROCm/dispatcher
qianfengz Feb 2, 2024
8a40a31
Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/xfo…
qianfengz Feb 2, 2024
21062d1
Add test_decoder and test_splitk_decoder for ROCM into test_mem_eff_a…
qianfengz Feb 2, 2024
df7d523
Add ref_attention_splitk and its test to tests/test_mem_eff_attention.py
qianfengz Feb 2, 2024
ee633c8
Rename test_mem_eff_attention_ck.py as discarded
qianfengz Feb 2, 2024
2df5ed3
Add test_mqa_forward and ref_attention_mqa (for BMHK format mqa/gqa v…
qianfengz Feb 2, 2024
7d1219b
Rename test_mqa_forward_ck_tiled.py as discarded
qianfengz Feb 2, 2024
fe6f96e
Remove CK specific script benchmark_mem_eff_attn_decoder_ck.py
qianfengz Feb 2, 2024
5af967c
Refine benchmark_mem_eff_attn_mqa_gqa_ck_tiled.py
qianfengz Feb 2, 2024
3f46c2f
Rename benchmark_mem_eff_attn_mqa_gqa_ck_tiled.py to benchmark_mem_ef…
qianfengz Feb 2, 2024
2c27aac
Remove the runtime_error with using logsumexp in attention_forward_ge…
qianfengz Feb 2, 2024
4b8ce7c
Add ck-tiled checking in ck.py
qianfengz Feb 2, 2024
0d311f5
Remove CK-specific benchmark scripts
qianfengz Feb 2, 2024
d57a5db
Don't require is_cpu_tensor for seqstart_q/seqstart_k/seqlen_k in att…
qianfengz Feb 3, 2024
b25c239
Remove seqlen_cpu from _PaddedSeqLenInfo in attn_bias.py
qianfengz Feb 3, 2024
1a3ce52
Change the branch for composable_kernel_tiled submodule and update to…
qianfengz Feb 4, 2024
f7bf9b4
Remove the using of seqlen_cpu in BwOp of ck.py
qianfengz Feb 4, 2024
15d2a72
Remove the using of seqlen_cpu in BwOp of ck.py
qianfengz Feb 4, 2024
bcd1936
Align .clang_format with main branch and re-format c++ files
qianfengz Feb 4, 2024
52ae8a3
Synchronize to latest ck-tiled commit
qianfengz Feb 4, 2024
af2aa86
Merge branch 'ck-tiled-fa' into develop
qianfengz Feb 4, 2024
7dd3aee
Add checking of IS_CK_TILED into some testing scripts
qianfengz Feb 4, 2024
5eb1235
Update to test_mem_eff_attention.py and ck.py
qianfengz Feb 5, 2024
dc0e67a
Merge branch 'ck-tiled-fa' into develop
qianfengz Feb 5, 2024
58e6101
Building xformers using ck-tiled as default
qianfengz Feb 5, 2024
1276abc
Merge branch 'ck-tiled-fa' into develop
qianfengz Feb 5, 2024
389dfb4
ensure ck_decoder does not dispatch
tenpercent Feb 5, 2024
f8d9043
Add disable_on_rocm on some test scripts
qianfengz Feb 5, 2024
78df6a9
Merge branch 'ck-tiled-fa' into develop
qianfengz Feb 5, 2024
6dae63c
Update to test_mem_eff_attention.py
qianfengz Feb 5, 2024
a7ed88c
Merge branch 'ck-tiled-fa' into develop
qianfengz Feb 5, 2024
20e178a
Merge pull request #16 from ROCm/fix_test_attn_bias_padded
qianfengz Feb 6, 2024
0624c92
apply isort
tenpercent Feb 6, 2024
b8ebf08
apply black
tenpercent Feb 6, 2024
3b33c5d
fix flake8 suggestions
tenpercent Feb 6, 2024
0a9c933
add license headers and reapply black
tenpercent Feb 6, 2024
47367a4
Merge pull request #17 from ROCm/linters
qianfengz Feb 6, 2024
fb46611
Merge pull request #10 from ROCm/enable-ci
qianfengz Feb 6, 2024
28d3672
Tiny update to rocm_ci.yml
qianfengz Feb 6, 2024
12fb41c
Add conditional compiling for cuda-depending codes in ROCM
qianfengz Feb 6, 2024
a9d83c6
Update to benchmark scripts
qianfengz Feb 7, 2024
9ab3831
Rename the one script file
qianfengz Feb 7, 2024
243dc6a
Revert "Add conditional compiling for cuda-depending codes in ROCM"
qianfengz Feb 7, 2024
3240ba1
Update to scripts
qianfengz Feb 7, 2024
0c51af1
Change and add readme for tests and benchmarks
qianfengz Feb 7, 2024
f36c93b
Remove the stuffs for supporting old ck
qianfengz Feb 7, 2024
9e4582d
Remove old composable_kernel from submodule list
qianfengz Feb 7, 2024
356cafd
Remove folder third_party/composable_kernel
qianfengz Feb 7, 2024
8415b00
Merge branch 'develop' into dev_to_upstream
qianfengz Feb 7, 2024
79c554c
Rename the folder
qianfengz Feb 8, 2024
2be6c04
Remove unused script file
qianfengz Feb 8, 2024
61d875a
apply black
tenpercent Feb 9, 2024
4616121
pacify mypy
tenpercent Feb 9, 2024
832e223
fix clang-format
tenpercent Feb 9, 2024
2b2967e
reapply black
tenpercent Feb 9, 2024
89fb7d6
Merge pull request #3 from tenpercent/lints
tenpercent Feb 12, 2024
3c9d4e5
fix lints
tenpercent Feb 13, 2024
1d474c5
make test_splitk_reference run on cpu
tenpercent Feb 13, 2024
d38a684
add ck modules to docs
tenpercent Feb 13, 2024
eccbf54
try fixing nvidia build by re-including sparse24 cpp folder into exte…
tenpercent Feb 13, 2024
1ef6c20
update cutlass to upstream commit
tenpercent Feb 13, 2024
9dfec0d
update flash-attention to upstream commit
tenpercent Feb 13, 2024
9fcda18
simplify setup.py
tenpercent Feb 13, 2024
01c2bfd
Merge branch 'main' of https://github.com/facebookresearch/xformers i…
tenpercent Feb 13, 2024
58d38d4
remove duplicate run_batched_infer_causalmask_attnbias_dispatched<f16…
tenpercent Feb 13, 2024
07183f0
add hip version and pytorch hip arch list to xformers build info
tenpercent Feb 14, 2024
993a90c
fix build
tenpercent Feb 14, 2024
d4a374b
patch around the unhappy path in get_hip_version
tenpercent Feb 14, 2024
ff59f19
skip test_grad_checkpointing for triton_splitk since it doesn't have …
tenpercent Feb 15, 2024
81bcfd5
re-enable test_mqa_forward since ck tiled is the current implementation
tenpercent Feb 15, 2024
a0f7f27
make skip test_wrong_alignment more generic
tenpercent Feb 15, 2024
a0d8dcc
reapply black
tenpercent Feb 15, 2024
bc7035c
simplify test_decoder
tenpercent Feb 15, 2024
f02d0d4
put python version check inside triton_splitk op
tenpercent Feb 15, 2024
77a6c13
fix logic
tenpercent Feb 15, 2024
a7cd678
cleanup python3.9 checks in tests
tenpercent Feb 15, 2024
dea783d
cleanup test_attentions
tenpercent Feb 15, 2024
acd6b7a
cleanup test_checkpoint as test running on cpu does not depend on gpu…
tenpercent Feb 16, 2024
f467a1d
fix lints
tenpercent Feb 16, 2024
d758eac
try fixing win build by conditional import of triton in triton op
tenpercent Feb 16, 2024
21f1904
re-enable test_triton_layernorm as it passes
tenpercent Feb 17, 2024
d880c36
re-enable test_triton_blocksparse as it passes
tenpercent Feb 17, 2024
059c84f
cleanup test_sparse_tensors
tenpercent Feb 17, 2024
8aa0bdc
cleanup test_custom_ops
tenpercent Feb 17, 2024
5bc7bbe
reapply black
tenpercent Feb 17, 2024
5b4ebe4
cleanup test_core_attention
tenpercent Feb 17, 2024
473ebc7
benchmark ck ops on rocm only
tenpercent Feb 17, 2024
2a7272e
Merge branch 'main' of https://github.com/facebookresearch/xformers i…
tenpercent Feb 19, 2024
5d3247f
fix mypy
tenpercent Feb 19, 2024
9be7f8d
Merge branch 'dev_upstream' of https://github.com/ROCm/xformers into …
tenpercent Feb 20, 2024
58b0f75
fix lint: black
tenpercent Feb 21, 2024
03b7294
fix lints: mypy
tenpercent Feb 21, 2024
a02ab9b
Rename HDim/headdim to MaxK/maxk
qianfengz Feb 22, 2024
fd36725
Move some headers files to ck examples for later reusing
qianfengz Feb 22, 2024
41f5ada
Merge branch 'develop' of https://github.com/ROCm/xformers into develop
qianfengz Feb 22, 2024
d8384c1
Replace using qs_ks_vs pipeline by qr_ks_vs pipeline while HeadDim is…
qianfengz Feb 22, 2024
e5d4a76
rm test_ck_7
tenpercent Feb 22, 2024
7d43238
fix lints
tenpercent Feb 26, 2024
6fbb383
Merge branch 'main' of https://github.com/facebookresearch/xformers i…
tenpercent Feb 26, 2024
1db3a5a
unskip test_unsupported_alignment
tenpercent Feb 28, 2024
57d7e96
move out test_splitk_reference
tenpercent Feb 28, 2024
14c831e
add license header to file created in prev commit
tenpercent Feb 28, 2024
d5a26a6
roll back fmha/common.py
tenpercent Feb 28, 2024
3560806
fix lint
tenpercent Feb 28, 2024
f654b3a
remove unused ref_attention_mqa
tenpercent Feb 28, 2024
99947ff
Merge pull request #5 from ROCm/roll-back-fmha-common
qianfengz Feb 29, 2024
c5ea221
resolve error in triton_splitk on rocm
tenpercent Mar 1, 2024
b585563
Merge branch 'main' of https://github.com/facebookresearch/xformers i…
tenpercent Mar 1, 2024
6752f07
disable partial attention tests on rocm
tenpercent Mar 1, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions .github/workflows/rocm_ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
name: ROCM_CI

on:
pull_request:
types: [labeled, synchronize, reopened]

jobs:
build:
if: contains(github.event.label.name, 'rocm')
runs-on: rocm

steps:
- uses: actions/checkout@v2
- name: Get CPU info on Ubuntu
if: contains(runner.os, 'linux')
run: |
cat /proc/cpuinfo
- name: Get env vars
run: |
echo GITHUB_WORKFLOW = $GITHUB_WORKFLOW
echo HOME = $HOME
echo PWD = $PWD
echo GITHUB_ACTION = $GITHUB_ACTION
echo GITHUB_ACTIONS = $GITHUB_ACTIONS
echo GITHUB_REPOSITORY = $GITHUB_REPOSITORY
echo GITHUB_EVENT_NAME = $GITHUB_EVENT_NAME
echo GITHUB_EVENT_PATH = $GITHUB_EVENT_PATH
echo GITHUB_WORKSPACE = $GITHUB_WORKSPACE
echo GITHUB_SHA = $GITHUB_SHA
echo GITHUB_REF = $GITHUB_REF
export GIT_BRANCH=${GITHUB_BASE_REF:-${GITHUB_REF#refs/heads/}}
echo GIT_BRANCH = $GIT_BRANCH
export ROCM_PATH=/opt/rocm
echo ROCM_PATH = $ROCM_PATH
export MAX_JOBS=64
echo MAX_JOBS = $MAX_JOBS
hipcc --version
rocm-smi
rocminfo | grep "gfx"
- name: Build XFormers
run: |
git clone --recursive -b $GIT_BRANCH $GITHUB_REPOSITORY
docker run -it --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --device=/dev/kfd --device=/dev/dri --group-add video --ipc=host --shm-size 8G -v $PWD/xformers:/xformers rocm/pytorch-nightly:latest
pip3 install --upgrade pip
pip3 uninstall -y xformers
MAX_JOBS=$MAX_JOBS pip3 install -e /xformers --verbose
pip3 install scipy==1.10
python3 -c "import torch; print(torch.__version__)"
python3 -m xformers.info
- name: Run python tests
run: |
pytest -rpfs /xformers/tests/test_mem_eff_attention.py | tee test_mem_eff_attention.log
- name: Archive logs
uses: actions/upload-artifact@v3
with:
name: test results
path: test_mem_eff_attention_ck.log

- name: Process test results
run: |
echo "Processing test results TBD"
10 changes: 10 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,13 @@ outputs
xformers/_flash_attn
xformers/version.py
xformers/cpp_lib.json

## temporary files
xformers/csrc/attention/hip_fmha/*.cu
xformers/csrc/attention/hip_fmha/*.hip
xformers/csrc/attention/hip_fmha/*_hip.h
xformers/csrc/attention/hip_fmha/instances/*.cu
xformers/csrc/attention/hip_fmha/instances/*.hip
xformers/csrc/attention/hip_fmha/instances_tiled/*.cu
xformers/csrc/attention/hip_fmha/instances_tiled/*.hip

4 changes: 4 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,7 @@
[submodule "third_party/flash-attention"]
path = third_party/flash-attention
url = https://github.com/Dao-AILab/flash-attention.git
[submodule "third_party/composable_kernel_tiled"]
path = third_party/composable_kernel_tiled
url = https://github.com/ROCm/composable_kernel.git
branch = ck_tile/dev
14 changes: 13 additions & 1 deletion docs/source/components/ops.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,25 @@ Available implementations
:member-order: bysource
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Looks like decoder and triton_splitk should have been added here months ago. 🫢)


.. automodule:: xformers.ops.fmha.triton
:members: FwOp, BwOp
:members: FwOp
:member-order: bysource

.. automodule:: xformers.ops.fmha.small_k
:members: FwOp, BwOp
:member-order: bysource

.. automodule:: xformers.ops.fmha.ck
:members: FwOp, BwOp
:member-order: bysource

.. automodule:: xformers.ops.fmha.ck_decoder
:members: FwOp
:member-order: bysource

.. automodule:: xformers.ops.fmha.ck_splitk
:members: FwOp
:member-order: bysource

Attention biases
~~~~~~~~~~~~~~~~~~~~

Expand Down
2 changes: 1 addition & 1 deletion requirements-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ hydra-core >= 1.1

# Dependency for Mixture of Experts
fairscale >= 0.4.5
scipy
scipy >= 1.7

# Dependency for fused layers, optional
cmake
82 changes: 82 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,23 @@ def get_cuda_version(cuda_dir) -> int:
return bare_metal_major * 100 + bare_metal_minor


def get_hip_version(rocm_dir) -> str:
hipcc_bin = "hipcc" if rocm_dir is None else os.path.join(rocm_dir, "bin", "hipcc")
try:
raw_output = subprocess.check_output(
[hipcc_bin, "--version"], universal_newlines=True
)
except Exception as e:
print(
f"hip installation not found: {e} ROCM_PATH={os.environ.get('ROCM_PATH')}"
)
return None
for line in raw_output.split("\n"):
if "HIP version" in line:
return line.split()[-1]
return None


def get_flash_attention_extensions(cuda_version: int, extra_compile_args):
# XXX: Not supported on windows for cuda<12
# https://github.com/Dao-AILab/flash-attention/issues/345
Expand Down Expand Up @@ -223,11 +240,27 @@ def get_flash_attention_extensions(cuda_version: int, extra_compile_args):
]


def rename_cpp_cu(cpp_files):
for entry in cpp_files:
shutil.copy(entry, os.path.splitext(entry)[0] + ".cu")


def get_extensions():
extensions_dir = os.path.join("xformers", "csrc")

sources = glob.glob(os.path.join(extensions_dir, "**", "*.cpp"), recursive=True)
source_cuda = glob.glob(os.path.join(extensions_dir, "**", "*.cu"), recursive=True)
source_hip = glob.glob(
os.path.join(extensions_dir, "attention", "hip_fmha", "**", "*.cpp"),
recursive=True,
)
source_hip_generated = glob.glob(
os.path.join(extensions_dir, "attention", "hip_fmha", "**", "*.cu"),
recursive=True,
)
# avoid the temporary .cu files generated under xformers/csrc/attention/hip_fmha
source_cuda = list(set(source_cuda) - set(source_hip_generated))
sources = list(set(sources) - set(source_hip))

sputnik_dir = os.path.join(this_dir, "third_party", "sputnik")
cutlass_dir = os.path.join(this_dir, "third_party", "cutlass", "include")
Expand All @@ -253,6 +286,7 @@ def get_extensions():
include_dirs = [extensions_dir]
ext_modules = []
cuda_version = None
hip_version = None
flash_version = "0.0.0"

if (
Expand Down Expand Up @@ -294,6 +328,7 @@ def get_extensions():
flash_extensions = get_flash_attention_extensions(
cuda_version=cuda_version, extra_compile_args=extra_compile_args
)

if flash_extensions:
flash_version = get_flash_version()
ext_modules += flash_extensions
Expand All @@ -306,6 +341,51 @@ def get_extensions():
"--ptxas-options=-O2",
"--ptxas-options=-allow-expensive-optimizations=true",
]
elif torch.cuda.is_available() and torch.version.hip:
rename_cpp_cu(source_hip)
rocm_home = os.getenv("ROCM_PATH")
hip_version = get_hip_version(rocm_home)

source_hip_cu = []
for ff in source_hip:
source_hip_cu += [ff.replace(".cpp", ".cu")]

extension = CUDAExtension
sources += source_hip_cu
include_dirs += [
Path(this_dir) / "xformers" / "csrc" / "attention" / "hip_fmha"
]

include_dirs += [
Path(this_dir)
/ "third_party"
/ "composable_kernel_tiled"
/ "example"
/ "91_tile_program"
/ "xformers_fmha"
]

include_dirs += [
Path(this_dir) / "third_party" / "composable_kernel_tiled" / "include"
]

generator_flag = []

cc_flag = ["-DBUILD_PYTHON_PACKAGE"]
extra_compile_args = {
"cxx": ["-O3", "-std=c++17"] + generator_flag,
"nvcc": [
"-O3",
"-std=c++17",
f"--offload-arch={os.getenv('HIP_ARCHITECTURES', 'native')}",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-DCK_FMHA_FWD_FAST_EXP2=1",
"-fgpu-flush-denormals-to-zero",
]
+ generator_flag
+ cc_flag,
}

ext_modules.append(
extension(
Expand All @@ -320,6 +400,7 @@ def get_extensions():
return ext_modules, {
"version": {
"cuda": cuda_version,
"hip": hip_version,
"torch": torch.__version__,
"python": platform.python_version(),
"flash": flash_version,
Expand All @@ -328,6 +409,7 @@ def get_extensions():
k: os.environ.get(k)
for k in [
"TORCH_CUDA_ARCH_LIST",
"PYTORCH_ROCM_ARCH",
"XFORMERS_BUILD_TYPE",
"XFORMERS_ENABLE_DEBUG_ASSERTIONS",
"NVCC_FLAGS",
Expand Down
13 changes: 13 additions & 0 deletions tests/readme_test_on_rocm.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@

1. #> pip install -e ./

2. verify testing for generic fmha inference on ROCM

#> pytest tests/test_mem_eff_attention.py::test_forward

3. verify testing for decoder fmha inference on ROCM

#> pytest tests/test_mem_eff_attention.py::test_decoder
#> pytest tests/test_mem_eff_attention.py::test_splitk_decoder


7 changes: 7 additions & 0 deletions tests/test_attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,13 @@ def test_order_invariance(
causal: bool,
device: torch.device,
):
if (
torch.version.hip
and device == torch.device("cuda")
and attention_name == "local"
):
# Backend calls into Sputnik library which isn't built on ROCm
device = torch.device("cpu")

torch.manual_seed(42)
torch.cuda.manual_seed_all(42)
Expand Down
15 changes: 14 additions & 1 deletion tests/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,11 @@ def test_checkpoint_with_grad(policy_fn, input_requires_grad, grad_mode):
"op",
[
xformers.ops.MemoryEfficientAttentionFlashAttentionOp,
xformers.ops.MemoryEfficientAttentionCutlassOp,
(
xformers.ops.MemoryEfficientAttentionCutlassOp
if torch.version.cuda
else xformers.ops.MemoryEfficientAttentionCkOp
),
],
)
def test_checkpoint_attention(policy_fn, input_requires_grad, device, autocast, op):
Expand All @@ -121,6 +125,15 @@ def test_checkpoint_attention(policy_fn, input_requires_grad, device, autocast,
):
pytest.skip("skipping operator not supported in this arch")

if (
op is xformers.ops.MemoryEfficientAttentionFlashAttentionOp
and torch.version.hip
):
pytest.skip("FlashAttentionOp is not supported on ROCM!")

if op is xformers.ops.MemoryEfficientAttentionCkOp:
pytest.skip("Gradience is currently not supported by ck-tiled!")

class Attn(nn.Module):
def forward(self, x):
out = xformers.ops.memory_efficient_attention(x, x, x, op=op)
Expand Down
5 changes: 4 additions & 1 deletion tests/test_core_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ def fn_and_catch_oor(*args, **kwargs):
return fn_and_catch_oor


_devices = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]
_devices = (
["cpu", "cuda"] if torch.cuda.is_available() and torch.version.cuda else ["cpu"]
)


def test_core_attention():
Expand Down Expand Up @@ -144,6 +146,7 @@ def test_amp_attention_sparsecs(device):
@pytest.mark.skipif(
not _is_blocksparse_available, reason="Blocksparse is not available"
)
@pytest.mark.skipif(not torch.version.cuda, reason="Sparse ops not supported on ROCm")
@pytest.mark.parametrize("device", ["cuda"])
@pytest.mark.parametrize("data_type", [torch.float16, torch.float32])
@catch_oor
Expand Down
9 changes: 7 additions & 2 deletions tests/test_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,13 @@
_sparse_bmm,
)

cuda_only = pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA")
_devices = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]
cuda_only = pytest.mark.skipif(
not torch.cuda.is_available() or not torch.version.cuda, reason="requires CUDA"
)

_devices = (
["cpu", "cuda"] if torch.cuda.is_available() and torch.version.cuda else ["cpu"]
)


def _baseline_matmul_with_sparse_mask(
Expand Down
Loading
Loading