From 41f667934017dfbef24e632bd3c33c9746668901 Mon Sep 17 00:00:00 2001 From: Stanley Tsang Date: Tue, 9 Jul 2024 16:55:27 -0600 Subject: [PATCH] gfx12 support (#9) (#384) Add gfx12 support --- CMakeLists.txt | 3 +- .../backend/rocprim/thread/thread_load.hpp | 61 +++++++++++-------- .../backend/rocprim/thread/thread_store.hpp | 58 ++++++++++-------- test/hipcub/test_hipcub_iterators.cpp | 18 +++--- 4 files changed, 76 insertions(+), 64 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 91c62814..27e16114 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -85,10 +85,9 @@ if(NOT (CMAKE_CXX_COMPILER MATCHES ".*nvcc$" OR "${CMAKE_CXX_COMPILER_ID}" STREQ ) else() rocm_check_target_ids(DEFAULT_AMDGPU_TARGETS - TARGETS "gfx803;gfx900:xnack-;gfx906:xnack-;gfx908:xnack-;gfx90a:xnack-;gfx90a:xnack+;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102" + TARGETS "gfx803;gfx900:xnack-;gfx906:xnack-;gfx908:xnack-;gfx90a:xnack-;gfx90a:xnack+;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201" ) endif() - set(GPU_TARGETS "${DEFAULT_AMDGPU_TARGETS}" CACHE STRING "GPU architectures to compile for" FORCE) endif() endif() diff --git a/hipcub/include/hipcub/backend/rocprim/thread/thread_load.hpp b/hipcub/include/hipcub/backend/rocprim/thread/thread_load.hpp index d3fc38b4..ef0e2d64 100644 --- a/hipcub/include/hipcub/backend/rocprim/thread/thread_load.hpp +++ b/hipcub/include/hipcub/backend/rocprim/thread/thread_load.hpp @@ -60,49 +60,56 @@ HIPCUB_DEVICE __forceinline__ T AsmThreadLoad(void * ptr) interim_type, \ asm_operator, \ output_modifier, \ + wait_inst, \ wait_cmd) \ template<> \ HIPCUB_DEVICE __forceinline__ type AsmThreadLoad(void * ptr) \ { \ interim_type retval; \ - asm volatile( \ - #asm_operator " %0, %1 " llvm_cache_modifier "\n" \ - "\ts_waitcnt " wait_cmd "(0)" : "=" #output_modifier(retval) : "v"(ptr) \ - ); \ + asm volatile(#asm_operator " %0, %1 " llvm_cache_modifier "\n\t" \ + wait_inst wait_cmd "(%2)" \ + : "=" #output_modifier(retval) \ + : "v"(ptr), "I"(0x00)); \ return retval; \ } // TODO Add specialization for custom larger data types -#define HIPCUB_ASM_THREAD_LOAD_GROUP(cache_modifier, llvm_cache_modifier, wait_cmd) \ - HIPCUB_ASM_THREAD_LOAD(cache_modifier, llvm_cache_modifier, int8_t, int16_t, flat_load_sbyte, v, wait_cmd); \ - HIPCUB_ASM_THREAD_LOAD(cache_modifier, llvm_cache_modifier, int16_t, int16_t, flat_load_sshort, v, wait_cmd); \ - HIPCUB_ASM_THREAD_LOAD(cache_modifier, llvm_cache_modifier, uint8_t, uint16_t, flat_load_ubyte, v, wait_cmd); \ - HIPCUB_ASM_THREAD_LOAD(cache_modifier, llvm_cache_modifier, uint16_t, uint16_t, flat_load_ushort, v, wait_cmd); \ - HIPCUB_ASM_THREAD_LOAD(cache_modifier, llvm_cache_modifier, uint32_t, uint32_t, flat_load_dword, v, wait_cmd); \ - HIPCUB_ASM_THREAD_LOAD(cache_modifier, llvm_cache_modifier, float, uint32_t, flat_load_dword, v, wait_cmd); \ - HIPCUB_ASM_THREAD_LOAD(cache_modifier, llvm_cache_modifier, uint64_t, uint64_t, flat_load_dwordx2, v, wait_cmd); \ - HIPCUB_ASM_THREAD_LOAD(cache_modifier, llvm_cache_modifier, double, uint64_t, flat_load_dwordx2, v, wait_cmd); +#define HIPCUB_ASM_THREAD_LOAD_GROUP(cache_modifier, llvm_cache_modifier, wait_inst, wait_cmd) \ + HIPCUB_ASM_THREAD_LOAD(cache_modifier, llvm_cache_modifier, int8_t, int16_t, flat_load_sbyte, v, wait_inst, wait_cmd); \ + HIPCUB_ASM_THREAD_LOAD(cache_modifier, llvm_cache_modifier, int16_t, int16_t, flat_load_sshort, v, wait_inst, wait_cmd); \ + HIPCUB_ASM_THREAD_LOAD(cache_modifier, llvm_cache_modifier, uint8_t, uint16_t, flat_load_ubyte, v, wait_inst, wait_cmd); \ + HIPCUB_ASM_THREAD_LOAD(cache_modifier, llvm_cache_modifier, uint16_t, uint16_t, flat_load_ushort, v, wait_inst, wait_cmd); \ + HIPCUB_ASM_THREAD_LOAD(cache_modifier, llvm_cache_modifier, uint32_t, uint32_t, flat_load_dword, v, wait_inst, wait_cmd); \ + HIPCUB_ASM_THREAD_LOAD(cache_modifier, llvm_cache_modifier, float, uint32_t, flat_load_dword, v, wait_inst, wait_cmd); \ + HIPCUB_ASM_THREAD_LOAD(cache_modifier, llvm_cache_modifier, uint64_t, uint64_t, flat_load_dwordx2, v, wait_inst, wait_cmd); \ + HIPCUB_ASM_THREAD_LOAD(cache_modifier, llvm_cache_modifier, double, uint64_t, flat_load_dwordx2, v, wait_inst, wait_cmd); + #if defined(__gfx940__) || defined(__gfx941__) -HIPCUB_ASM_THREAD_LOAD_GROUP(LOAD_CA, "sc0", ""); -HIPCUB_ASM_THREAD_LOAD_GROUP(LOAD_CG, "sc1", ""); -HIPCUB_ASM_THREAD_LOAD_GROUP(LOAD_CV, "sc0 sc1", "vmcnt"); -HIPCUB_ASM_THREAD_LOAD_GROUP(LOAD_VOLATILE, "sc0 sc1", "vmcnt"); +HIPCUB_ASM_THREAD_LOAD_GROUP(LOAD_CA, "sc0", "s_waitcnt", ""); +HIPCUB_ASM_THREAD_LOAD_GROUP(LOAD_CG, "sc1", "s_waitcnt", ""); +HIPCUB_ASM_THREAD_LOAD_GROUP(LOAD_CV, "sc0 sc1", "s_waitcnt", "vmcnt"); +HIPCUB_ASM_THREAD_LOAD_GROUP(LOAD_VOLATILE, "sc0 sc1", "s_waitcnt", "vmcnt"); #elif defined(__gfx942__) -HIPCUB_ASM_THREAD_LOAD_GROUP(LOAD_CA, "sc0", ""); -HIPCUB_ASM_THREAD_LOAD_GROUP(LOAD_CG, "sc0 nt", ""); -HIPCUB_ASM_THREAD_LOAD_GROUP(LOAD_CV, "sc0", "vmcnt"); -HIPCUB_ASM_THREAD_LOAD_GROUP(LOAD_VOLATILE, "sc0", "vmcnt"); +HIPCUB_ASM_THREAD_LOAD_GROUP(LOAD_CA, "sc0", "s_waitcnt", ""); +HIPCUB_ASM_THREAD_LOAD_GROUP(LOAD_CG, "sc0 nt", "s_waitcnt", ""); +HIPCUB_ASM_THREAD_LOAD_GROUP(LOAD_CV, "sc0", "s_waitcnt", "vmcnt"); +HIPCUB_ASM_THREAD_LOAD_GROUP(LOAD_VOLATILE, "sc0", "s_waitcnt", "vmcnt"); +#elif defined(__gfx1200__) || defined(__gfx1201__) +HIPCUB_ASM_THREAD_LOAD_GROUP(LOAD_CA, "scope:SCOPE_DEV", "s_wait_loadcnt_dscnt", ""); +HIPCUB_ASM_THREAD_LOAD_GROUP(LOAD_CG, "th:TH_DEFAULT scope:SCOPE_DEV", "s_wait_loadcnt_dscnt", ""); +HIPCUB_ASM_THREAD_LOAD_GROUP(LOAD_CV, "th:TH_DEFAULT scope:SCOPE_DEV", "s_wait_loadcnt_dscnt", ""); +HIPCUB_ASM_THREAD_LOAD_GROUP(LOAD_VOLATILE, "th:TH_DEFAULT scope:SCOPE_DEV", "s_wait_loadcnt_dscnt", ""); #else -HIPCUB_ASM_THREAD_LOAD_GROUP(LOAD_CA, "glc", ""); -HIPCUB_ASM_THREAD_LOAD_GROUP(LOAD_CG, "glc slc", ""); -HIPCUB_ASM_THREAD_LOAD_GROUP(LOAD_CV, "glc", "vmcnt"); -HIPCUB_ASM_THREAD_LOAD_GROUP(LOAD_VOLATILE, "glc", "vmcnt"); +HIPCUB_ASM_THREAD_LOAD_GROUP(LOAD_CA, "glc", "s_waitcnt", ""); +HIPCUB_ASM_THREAD_LOAD_GROUP(LOAD_CG, "glc slc", "s_waitcnt", ""); +HIPCUB_ASM_THREAD_LOAD_GROUP(LOAD_CV, "glc", "s_waitcnt", "vmcnt"); +HIPCUB_ASM_THREAD_LOAD_GROUP(LOAD_VOLATILE, "glc", "s_waitcnt", "vmcnt"); #endif // TODO find correct modifiers to match these -HIPCUB_ASM_THREAD_LOAD_GROUP(LOAD_LDG, "", ""); -HIPCUB_ASM_THREAD_LOAD_GROUP(LOAD_CS, "", ""); +HIPCUB_ASM_THREAD_LOAD_GROUP(LOAD_LDG, "", "", ""); +HIPCUB_ASM_THREAD_LOAD_GROUP(LOAD_CS, "", "", ""); #endif diff --git a/hipcub/include/hipcub/backend/rocprim/thread/thread_store.hpp b/hipcub/include/hipcub/backend/rocprim/thread/thread_store.hpp index 7e7e00ef..1cbb29d4 100644 --- a/hipcub/include/hipcub/backend/rocprim/thread/thread_store.hpp +++ b/hipcub/include/hipcub/backend/rocprim/thread/thread_store.hpp @@ -62,46 +62,52 @@ HIPCUB_DEVICE __forceinline__ void AsmThreadStore(void * ptr, T val) interim_type, \ asm_operator, \ output_modifier, \ + wait_inst, \ wait_cmd) \ template<> \ HIPCUB_DEVICE __forceinline__ void AsmThreadStore(void * ptr, type val) \ { \ - interim_type temp_val = val; \ - asm volatile(#asm_operator " %0, %1 " llvm_cache_modifier : : "v"(ptr), #output_modifier(temp_val)); \ - asm volatile("s_waitcnt " wait_cmd "(%0)" : : "I"(0x00)); \ + interim_type temp_val = val; \ + asm volatile(#asm_operator " %0, %1 " llvm_cache_modifier "\n\t" \ + wait_inst wait_cmd "(%2)" \ + : : "v"(ptr), #output_modifier(temp_val), "I"(0x00)); \ } // TODO fix flat_store_ubyte and flat_store_sbyte issues // TODO Add specialization for custom larger data types -#define HIPCUB_ASM_THREAD_STORE_GROUP(cache_modifier, llvm_cache_modifier, wait_cmd) \ - HIPCUB_ASM_THREAD_STORE(cache_modifier, llvm_cache_modifier, int8_t, int16_t, flat_store_byte, v, wait_cmd); \ - HIPCUB_ASM_THREAD_STORE(cache_modifier, llvm_cache_modifier, int16_t, int16_t, flat_store_short, v, wait_cmd); \ - HIPCUB_ASM_THREAD_STORE(cache_modifier, llvm_cache_modifier, uint8_t, uint16_t, flat_store_byte, v, wait_cmd); \ - HIPCUB_ASM_THREAD_STORE(cache_modifier, llvm_cache_modifier, uint16_t, uint16_t, flat_store_short, v, wait_cmd); \ - HIPCUB_ASM_THREAD_STORE(cache_modifier, llvm_cache_modifier, uint32_t, uint32_t, flat_store_dword, v, wait_cmd); \ - HIPCUB_ASM_THREAD_STORE(cache_modifier, llvm_cache_modifier, float, uint32_t, flat_store_dword, v, wait_cmd); \ - HIPCUB_ASM_THREAD_STORE(cache_modifier, llvm_cache_modifier, uint64_t, uint64_t, flat_store_dwordx2, v, wait_cmd); \ - HIPCUB_ASM_THREAD_STORE(cache_modifier, llvm_cache_modifier, double, uint64_t, flat_store_dwordx2, v, wait_cmd); +#define HIPCUB_ASM_THREAD_STORE_GROUP(cache_modifier, llvm_cache_modifier, wait_inst, wait_cmd) \ + HIPCUB_ASM_THREAD_STORE(cache_modifier, llvm_cache_modifier, int8_t, int16_t, flat_store_byte, v, wait_inst, wait_cmd); \ + HIPCUB_ASM_THREAD_STORE(cache_modifier, llvm_cache_modifier, int16_t, int16_t, flat_store_short, v, wait_inst, wait_cmd); \ + HIPCUB_ASM_THREAD_STORE(cache_modifier, llvm_cache_modifier, uint8_t, uint16_t, flat_store_byte, v, wait_inst, wait_cmd); \ + HIPCUB_ASM_THREAD_STORE(cache_modifier, llvm_cache_modifier, uint16_t, uint16_t, flat_store_short, v, wait_inst, wait_cmd); \ + HIPCUB_ASM_THREAD_STORE(cache_modifier, llvm_cache_modifier, uint32_t, uint32_t, flat_store_dword, v, wait_inst, wait_cmd); \ + HIPCUB_ASM_THREAD_STORE(cache_modifier, llvm_cache_modifier, float, uint32_t, flat_store_dword, v, wait_inst, wait_cmd); \ + HIPCUB_ASM_THREAD_STORE(cache_modifier, llvm_cache_modifier, uint64_t, uint64_t, flat_store_dwordx2, v, wait_inst, wait_cmd); \ + HIPCUB_ASM_THREAD_STORE(cache_modifier, llvm_cache_modifier, double, uint64_t, flat_store_dwordx2, v, wait_inst, wait_cmd); #if defined(__gfx940__) || defined(__gfx941__) -HIPCUB_ASM_THREAD_STORE_GROUP(STORE_WB, "sc0 sc1", ""); -HIPCUB_ASM_THREAD_STORE_GROUP(STORE_CG, "sc0 sc1", ""); -HIPCUB_ASM_THREAD_STORE_GROUP(STORE_WT, "sc0 sc1", "vmcnt"); -HIPCUB_ASM_THREAD_STORE_GROUP(STORE_VOLATILE, "sc0 sc1", "vmcnt"); +HIPCUB_ASM_THREAD_STORE_GROUP(STORE_WB, "sc0 sc1", "s_waitcnt", ""); +HIPCUB_ASM_THREAD_STORE_GROUP(STORE_CG, "sc0 sc1", "s_waitcnt", ""); +HIPCUB_ASM_THREAD_STORE_GROUP(STORE_WT, "sc0 sc1", "s_waitcnt", "vmcnt"); +HIPCUB_ASM_THREAD_STORE_GROUP(STORE_VOLATILE, "sc0 sc1", "s_waitcnt", "vmcnt"); #elif defined(__gfx942__) -HIPCUB_ASM_THREAD_STORE_GROUP(STORE_WB, "sc0", ""); -HIPCUB_ASM_THREAD_STORE_GROUP(STORE_CG, "sc0 nt", ""); -HIPCUB_ASM_THREAD_STORE_GROUP(STORE_WT, "sc0", "vmcnt"); -HIPCUB_ASM_THREAD_STORE_GROUP(STORE_VOLATILE, "sc0", "vmcnt"); +HIPCUB_ASM_THREAD_STORE_GROUP(STORE_WB, "sc0", "s_waitcnt", ""); +HIPCUB_ASM_THREAD_STORE_GROUP(STORE_CG, "sc0 nt", "s_waitcnt", ""); +HIPCUB_ASM_THREAD_STORE_GROUP(STORE_WT, "sc0", "s_waitcnt", "vmcnt"); +HIPCUB_ASM_THREAD_STORE_GROUP(STORE_VOLATILE, "sc0", "s_waitcnt", "vmcnt"); +#elif defined(__gfx1200__) || defined(__gfx1201__) +HIPCUB_ASM_THREAD_STORE_GROUP(STORE_WB, "scope:SCOPE_DEV", "s_wait_storecnt_dscnt", ""); +HIPCUB_ASM_THREAD_STORE_GROUP(STORE_CG, "th:TH_DEFAULT scope:SCOPE_DEV", "s_wait_storecnt_dscnt", ""); +HIPCUB_ASM_THREAD_STORE_GROUP(STORE_WT, "scope:SCOPE_DEV", "s_wait_storecnt_dscnt", ""); +HIPCUB_ASM_THREAD_STORE_GROUP(STORE_VOLATILE, "scope:SCOPE_DEV", "s_wait_storecnt_dscnt", ""); #else -HIPCUB_ASM_THREAD_STORE_GROUP(STORE_WB, "glc", ""); -HIPCUB_ASM_THREAD_STORE_GROUP(STORE_CG, "glc slc", ""); -HIPCUB_ASM_THREAD_STORE_GROUP(STORE_WT, "glc", "vmcnt"); -HIPCUB_ASM_THREAD_STORE_GROUP(STORE_VOLATILE, "glc", "vmcnt"); +HIPCUB_ASM_THREAD_STORE_GROUP(STORE_WB, "glc", "s_waitcnt", ""); +HIPCUB_ASM_THREAD_STORE_GROUP(STORE_CG, "glc slc", "s_waitcnt", ""); +HIPCUB_ASM_THREAD_STORE_GROUP(STORE_WT, "glc", "s_waitcnt", "vmcnt"); +HIPCUB_ASM_THREAD_STORE_GROUP(STORE_VOLATILE, "glc", "s_waitcnt", "vmcnt"); #endif - // TODO find correct modifiers to match these -HIPCUB_ASM_THREAD_STORE_GROUP(STORE_CS, "", ""); +HIPCUB_ASM_THREAD_STORE_GROUP(STORE_CS, "", "", ""); #endif diff --git a/test/hipcub/test_hipcub_iterators.cpp b/test/hipcub/test_hipcub_iterators.cpp index 85888240..73c51ae9 100644 --- a/test/hipcub/test_hipcub_iterators.cpp +++ b/test/hipcub/test_hipcub_iterators.cpp @@ -344,9 +344,9 @@ TYPED_TEST(HipcubIteratorTests, TestTexObj) hipDeviceProp_t props; HIP_CHECK(hipGetDeviceProperties(&props, device_id)); std::string deviceName = std::string(props.gcnArchName); - if (deviceName.rfind("gfx94", 0) == 0) { - // This is a gfx94x device, so skip this test - GTEST_SKIP() << "Test not run on gfx94x as texture cache API is not supported"; + if (deviceName.rfind("gfx94", 0) == 0 || deviceName.rfind("gfx120") == 0) { + // This is a gfx94x or gfx120x device, so skip this test + GTEST_SKIP() << "Test not run on gfx94x or gfx120x as texture cache API is not supported"; } HIP_CHECK(hipSetDevice(device_id)); @@ -411,9 +411,9 @@ TYPED_TEST(HipcubIteratorTests, TestTexRef) hipDeviceProp_t props; HIP_CHECK(hipGetDeviceProperties(&props, device_id)); std::string deviceName = std::string(props.gcnArchName); - if (deviceName.rfind("gfx94", 0) == 0) { - // This is a gfx94x device, so skip this test - GTEST_SKIP() << "Test not run on gfx94x as texture cache API is not supported"; + if (deviceName.rfind("gfx94", 0) == 0 || deviceName.rfind("gfx120") == 0) { + // This is a gfx94x or gfx120x device, so skip this test + GTEST_SKIP() << "Test not run on gfx94x or gfx120x as texture cache API is not supported"; } HIP_CHECK(hipSetDevice(device_id)); @@ -482,9 +482,9 @@ TYPED_TEST(HipcubIteratorTests, TestTexTransform) hipDeviceProp_t props; HIP_CHECK(hipGetDeviceProperties(&props, device_id)); std::string deviceName = std::string(props.gcnArchName); - if (deviceName.rfind("gfx94", 0) == 0) { - // This is a gfx94x device, so skip this test - GTEST_SKIP() << "Test not run on gfx94x as texture cache API is not supported"; + if (deviceName.rfind("gfx94", 0) == 0 || deviceName.rfind("gfx120") == 0) { + // This is a gfx94x or gfx120x device, so skip this test + GTEST_SKIP() << "Test not run on gfx94x or gfx120x as texture cache API is not supported"; } HIP_CHECK(hipSetDevice(device_id));