Skip to content

Commit

Permalink
[release/2.4] [ROCM] Properly disable Flash Attention/Efficient Atten…
Browse files Browse the repository at this point in the history
…tion with environment variables (#1570)

Now `USE_FLASH_ATTENTION=0 USE_MEM_EFF_ATTENTION=0 python setup.py` can
compile correctly.

This is cherry-picked version of
pytorch#133866
  • Loading branch information
xinyazhang authored Sep 11, 2024
1 parent 1488da9 commit f4c8ad5
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 2 deletions.
10 changes: 10 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -874,6 +874,16 @@ cmake_dependent_option(
Will be disabled if not supported by the platform" ON
"USE_CUDA OR USE_ROCM" OFF)

#
# Cannot be put into Dependencies.cmake due circular dependency:
# USE_FLASH_ATTENTION -> USE_ROCM -> Dependencies.cmake -> aotriton.cmake
#
if(USE_ROCM)
if(USE_FLASH_ATTENTION OR USE_MEM_EFF_ATTENTION)
include(cmake/External/aotriton.cmake)
endif()
endif()

if(DEBUG_CUDA)
string(APPEND CMAKE_CUDA_FLAGS_DEBUG " -lineinfo")
string(APPEND CMAKE_CUDA_FLAGS_RELWITHDEBINFO " -lineinfo")
Expand Down
14 changes: 13 additions & 1 deletion aten/src/ATen/native/transformers/cuda/sdp_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@
#include <c10/util/string_view.h>

#if USE_ROCM
#if defined(USE_FLASH_ATTENTION) || defined(USE_MEM_EFF_ATTENTION)
#include <aotriton/flash.h>
#define USE_AOTRITON 1
#endif
#endif

/**
Expand Down Expand Up @@ -185,6 +188,7 @@ bool check_flash_attention_hardware_support(sdp_params const& params, bool debug
using sm80 = SMVersion<8, 0>;
using sm90 = SMVersion<9, 0>;
#if USE_ROCM
#if USE_AOTRITON
auto stream = at::cuda::getCurrentCUDAStream().stream();
if (hipSuccess != aotriton::v2::flash::check_gpu(stream)) {
auto dprops = at::cuda::getCurrentDeviceProperties();
Expand All @@ -194,6 +198,9 @@ bool check_flash_attention_hardware_support(sdp_params const& params, bool debug
}
return false;
}
#else
return false;
#endif
#else
auto dprops = at::cuda::getCurrentDeviceProperties();
if (!check_sm_version<sm80, sm90>(dprops)) {
Expand All @@ -216,6 +223,7 @@ bool check_mem_efficient_hardware_support(sdp_params const& params, bool debug)
using sm50 = SMVersion<5, 0>;
using sm90 = SMVersion<9, 0>;
#if USE_ROCM
#if USE_AOTRITON
auto stream = at::cuda::getCurrentCUDAStream().stream();
if (hipSuccess != aotriton::v2::flash::check_gpu(stream)) {
auto dprops = at::cuda::getCurrentDeviceProperties();
Expand All @@ -225,6 +233,9 @@ bool check_mem_efficient_hardware_support(sdp_params const& params, bool debug)
}
return false;
}
#else
return false;
#endif
#else
auto dprops = at::cuda::getCurrentDeviceProperties();
if (!check_sm_version<sm50, sm90>(dprops)) {
Expand All @@ -238,8 +249,9 @@ bool check_mem_efficient_hardware_support(sdp_params const& params, bool debug)
}
return false;
}
#endif
return true;
#endif
return false;
}

bool check_requires_grad_and_head_dim_gt192_constraints_on_sm86_89(
Expand Down
1 change: 0 additions & 1 deletion cmake/Dependencies.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -1097,7 +1097,6 @@ if(USE_ROCM)
message(STATUS "Disabling Kernel Assert for ROCm")
endif()

include(${CMAKE_CURRENT_LIST_DIR}/External/aotriton.cmake)
if(USE_CUDA)
caffe2_update_option(USE_MEM_EFF_ATTENTION OFF)
endif()
Expand Down

0 comments on commit f4c8ad5

Please sign in to comment.