From 4d2dda61c18bf93fa591cd84a5481ee9dd8ee428 Mon Sep 17 00:00:00 2001 From: iotamudelta Date: Wed, 14 Aug 2024 17:50:54 -0500 Subject: [PATCH] Make CAR ROCm 6.1 compatible. (#137) * remove scoping * while there fix a typo * while there remove unused variable --- csrc/custom_all_reduce.cuh | 25 ++++++++++++------------- csrc/custom_all_reduce_test.cu | 2 +- 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/csrc/custom_all_reduce.cuh b/csrc/custom_all_reduce.cuh index c293d6b451cd6..27e5c271fd8d2 100644 --- a/csrc/custom_all_reduce.cuh +++ b/csrc/custom_all_reduce.cuh @@ -145,18 +145,17 @@ DINLINE O downcast(array_t val) { template #ifdef USE_ROCM DINLINE void start_sync(const RankSignals& sg, Signal* self_sg, int rank) { - uint32_t flag = self_sg->_flag[blockIdx.x] + 1; if (threadIdx.x < ngpus) { - __scoped_atomic_store_n(&self_sg->end[blockIdx.x][threadIdx.x], 0, - __ATOMIC_RELAXED, __MEMORY_SCOPE_DEVICE); + __atomic_store_n(&self_sg->end[blockIdx.x][threadIdx.x], 0, + __ATOMIC_RELAXED); // simultaneously write to the corresponding flag of all ranks. // Latency = 1 p2p write - __scoped_atomic_store_n(&sg.signals[threadIdx.x]->start[blockIdx.x][rank], - 1, __ATOMIC_RELAXED, __MEMORY_SCOPE_SYSTEM); + __atomic_store_n(&sg.signals[threadIdx.x]->start[blockIdx.x][rank], 1, + __ATOMIC_RELAXED); __atomic_thread_fence(__ATOMIC_ACQ_REL); // wait until we got true from all ranks - while (!__scoped_atomic_load_n(&self_sg->start[blockIdx.x][threadIdx.x], - __ATOMIC_RELAXED, __MEMORY_SCOPE_DEVICE)); + while (!__atomic_load_n(&self_sg->start[blockIdx.x][threadIdx.x], + __ATOMIC_RELAXED); } __syncthreads(); } @@ -190,16 +189,16 @@ DINLINE void end_sync(const RankSignals& sg, Signal* self_sg, int rank) { // the memory model. if (threadIdx.x < ngpus) { // reset flag for next time - __scoped_atomic_store_n(&self_sg->start[blockIdx.x][threadIdx.x], 0, - __ATOMIC_RELAXED, __MEMORY_SCOPE_DEVICE); + __atomic_store_n(&self_sg->start[blockIdx.x][threadIdx.x], 0, + __ATOMIC_RELAXED); // simultaneously write to the corresponding flag of all ranks. // Latency = 1 p2p write - __scoped_atomic_store_n(&sg.signals[threadIdx.x]->end[blockIdx.x][rank], 1, - __ATOMIC_RELAXED, __MEMORY_SCOPE_SYSTEM); + __atomic_store_n(&sg.signals[threadIdx.x]->end[blockIdx.x][rank], 1, + __ATOMIC_RELAXED); __atomic_thread_fence(__ATOMIC_ACQ_REL); // wait until we got true from all ranks - while (!__scoped_atomic_load_n(&self_sg->end[blockIdx.x][threadIdx.x], - __ATOMIC_RELAXED, __MEMORY_SCOPE_DEVICE)); + while (!__atomic_load_n(&self_sg->end[blockIdx.x][threadIdx.x], + __ATOMIC_RELAXED)); } if constexpr (!final_sync) __syncthreads(); } diff --git a/csrc/custom_all_reduce_test.cu b/csrc/custom_all_reduce_test.cu index 116b71dce07d2..c0652e875aeff 100644 --- a/csrc/custom_all_reduce_test.cu +++ b/csrc/custom_all_reduce_test.cu @@ -330,7 +330,7 @@ int main(int argc, char** argv) { // run(myRank, nRanks, comm, threads, block_limit, 4096 * 1024); // } // } -#ifdef USE _ROCM +#ifdef USE_ROCM for (int sz = 512; sz <= (8 << 22); sz *= 2) { run(myRank, nRanks, comm, 512, 18, sz + 8 * 47, performance_test); }