Skip to content

Commit

Permalink
Make CAR ROCm 6.1 compatible. (vllm-project#137)
Browse files Browse the repository at this point in the history
* remove scoping
* while there fix a typo
* while there remove unused variable
  • Loading branch information
iotamudelta authored Aug 14, 2024
1 parent 4132cbe commit 4d2dda6
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 14 deletions.
25 changes: 12 additions & 13 deletions csrc/custom_all_reduce.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -145,18 +145,17 @@ DINLINE O downcast(array_t<float, O::size> val) {
template <int ngpus>
#ifdef USE_ROCM
DINLINE void start_sync(const RankSignals& sg, Signal* self_sg, int rank) {
uint32_t flag = self_sg->_flag[blockIdx.x] + 1;
if (threadIdx.x < ngpus) {
__scoped_atomic_store_n(&self_sg->end[blockIdx.x][threadIdx.x], 0,
__ATOMIC_RELAXED, __MEMORY_SCOPE_DEVICE);
__atomic_store_n(&self_sg->end[blockIdx.x][threadIdx.x], 0,
__ATOMIC_RELAXED);
// simultaneously write to the corresponding flag of all ranks.
// Latency = 1 p2p write
__scoped_atomic_store_n(&sg.signals[threadIdx.x]->start[blockIdx.x][rank],
1, __ATOMIC_RELAXED, __MEMORY_SCOPE_SYSTEM);
__atomic_store_n(&sg.signals[threadIdx.x]->start[blockIdx.x][rank], 1,
__ATOMIC_RELAXED);
__atomic_thread_fence(__ATOMIC_ACQ_REL);
// wait until we got true from all ranks
while (!__scoped_atomic_load_n(&self_sg->start[blockIdx.x][threadIdx.x],
__ATOMIC_RELAXED, __MEMORY_SCOPE_DEVICE));
while (!__atomic_load_n(&self_sg->start[blockIdx.x][threadIdx.x],
__ATOMIC_RELAXED);
}
__syncthreads();
}
Expand Down Expand Up @@ -190,16 +189,16 @@ DINLINE void end_sync(const RankSignals& sg, Signal* self_sg, int rank) {
// the memory model.
if (threadIdx.x < ngpus) {
// reset flag for next time
__scoped_atomic_store_n(&self_sg->start[blockIdx.x][threadIdx.x], 0,
__ATOMIC_RELAXED, __MEMORY_SCOPE_DEVICE);
__atomic_store_n(&self_sg->start[blockIdx.x][threadIdx.x], 0,
__ATOMIC_RELAXED);
// simultaneously write to the corresponding flag of all ranks.
// Latency = 1 p2p write
__scoped_atomic_store_n(&sg.signals[threadIdx.x]->end[blockIdx.x][rank], 1,
__ATOMIC_RELAXED, __MEMORY_SCOPE_SYSTEM);
__atomic_store_n(&sg.signals[threadIdx.x]->end[blockIdx.x][rank], 1,
__ATOMIC_RELAXED);
__atomic_thread_fence(__ATOMIC_ACQ_REL);
// wait until we got true from all ranks
while (!__scoped_atomic_load_n(&self_sg->end[blockIdx.x][threadIdx.x],
__ATOMIC_RELAXED, __MEMORY_SCOPE_DEVICE));
while (!__atomic_load_n(&self_sg->end[blockIdx.x][threadIdx.x],
__ATOMIC_RELAXED));
}
if constexpr (!final_sync) __syncthreads();
}
Expand Down
2 changes: 1 addition & 1 deletion csrc/custom_all_reduce_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ int main(int argc, char** argv) {
// run<half>(myRank, nRanks, comm, threads, block_limit, 4096 * 1024);
// }
// }
#ifdef USE _ROCM
#ifdef USE_ROCM
for (int sz = 512; sz <= (8 << 22); sz *= 2) {
run<half>(myRank, nRanks, comm, 512, 18, sz + 8 * 47, performance_test);
}
Expand Down

0 comments on commit 4d2dda6

Please sign in to comment.