Skip to content

Commit

Permalink
add CUDA_BFLOAT16_AVALIABLE macro definition (PaddlePaddle#64372)
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanlehome authored and co63oc committed May 18, 2024
1 parent 8321099 commit 806936f
Showing 1 changed file with 13 additions and 0 deletions.
13 changes: 13 additions & 0 deletions cmake/cuda.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,8 @@ function(select_nvcc_arch_flags out_variable out_arch_bin)
elseif(${CUDA_ARCH_NAME} STREQUAL "Turing")
set(cuda_arch_bin "75")
elseif(${CUDA_ARCH_NAME} STREQUAL "Ampere")
message(STATUS "Add Define CUDA_BFLOAT16_AVALIABLE")
add_definitions("-DCUDA_BFLOAT16_AVALIABLE")
if(WITH_NV_JETSON)
set(cuda_arch_bin "87")
else()
Expand All @@ -183,6 +185,8 @@ function(select_nvcc_arch_flags out_variable out_arch_bin)
endif()
endif()
elseif(${CUDA_ARCH_NAME} STREQUAL "Hopper")
message(STATUS "Add Define CUDA_BFLOAT16_AVALIABLE")
add_definitions("-DCUDA_BFLOAT16_AVALIABLE")
set(cuda_arch_bin "90")
elseif(${CUDA_ARCH_NAME} STREQUAL "All")
set(cuda_arch_bin ${paddle_known_gpu_archs})
Expand All @@ -196,8 +200,17 @@ function(select_nvcc_arch_flags out_variable out_arch_bin)
to get a full wheel package to resolve this warning.
While, this version will still work on local GPU architecture.")
detect_installed_gpus(cuda_arch_bin)
if(${cuda_arch_bin} MATCHES "[ ]*(8\.0|8\.6|8\.9|9\.0)[ ]*")
message(STATUS "Add Define CUDA_BFLOAT16_AVALIABLE")
add_definitions("-DCUDA_BFLOAT16_AVALIABLE")
endif()
else() # (${CUDA_ARCH_NAME} STREQUAL "Manual")
set(cuda_arch_bin ${CUDA_ARCH_BIN})

if(${CUDA_ARCH_BIN} MATCHES "[ ]*(80|86|89|90)[ ]*")
message(STATUS "Add Define CUDA_BFLOAT16_AVALIABLE")
add_definitions("-DCUDA_BFLOAT16_AVALIABLE")
endif()
endif()

if(NEW_RELEASE_JIT)
Expand Down

0 comments on commit 806936f

Please sign in to comment.