diff --git a/cmake/cuda.cmake b/cmake/cuda.cmake index e0a2a7eb347394..0f4269cdad41c5 100644 --- a/cmake/cuda.cmake +++ b/cmake/cuda.cmake @@ -173,6 +173,8 @@ function(select_nvcc_arch_flags out_variable out_arch_bin) elseif(${CUDA_ARCH_NAME} STREQUAL "Turing") set(cuda_arch_bin "75") elseif(${CUDA_ARCH_NAME} STREQUAL "Ampere") + message(STATUS "Add Define CUDA_BFLOAT16_AVALIABLE") + add_definitions("-DCUDA_BFLOAT16_AVALIABLE") if(WITH_NV_JETSON) set(cuda_arch_bin "87") else() @@ -183,6 +185,8 @@ function(select_nvcc_arch_flags out_variable out_arch_bin) endif() endif() elseif(${CUDA_ARCH_NAME} STREQUAL "Hopper") + message(STATUS "Add Define CUDA_BFLOAT16_AVALIABLE") + add_definitions("-DCUDA_BFLOAT16_AVALIABLE") set(cuda_arch_bin "90") elseif(${CUDA_ARCH_NAME} STREQUAL "All") set(cuda_arch_bin ${paddle_known_gpu_archs}) @@ -196,8 +200,17 @@ function(select_nvcc_arch_flags out_variable out_arch_bin) to get a full wheel package to resolve this warning. While, this version will still work on local GPU architecture.") detect_installed_gpus(cuda_arch_bin) + if(${cuda_arch_bin} MATCHES "[ ]*(8\.0|8\.6|8\.9|9\.0)[ ]*") + message(STATUS "Add Define CUDA_BFLOAT16_AVALIABLE") + add_definitions("-DCUDA_BFLOAT16_AVALIABLE") + endif() else() # (${CUDA_ARCH_NAME} STREQUAL "Manual") set(cuda_arch_bin ${CUDA_ARCH_BIN}) + + if(${CUDA_ARCH_BIN} MATCHES "[ ]*(80|86|89|90)[ ]*") + message(STATUS "Add Define CUDA_BFLOAT16_AVALIABLE") + add_definitions("-DCUDA_BFLOAT16_AVALIABLE") + endif() endif() if(NEW_RELEASE_JIT)