diff --git a/src/fastertransformer/kernels/custom_ar_kernels.cu b/src/fastertransformer/kernels/custom_ar_kernels.cu index af8aee128..056ae375c 100644 --- a/src/fastertransformer/kernels/custom_ar_kernels.cu +++ b/src/fastertransformer/kernels/custom_ar_kernels.cu @@ -292,7 +292,7 @@ static __global__ void twoShotAllReduceKernel(AllReduceParams params) // use round-robin gathering from other ranks int offset_rank = local_offset + (dst_rank[ii] - params.local_rank) * params.elts_per_rank; reinterpret_cast(¶ms.local_output_buffer_ptr[offset_rank])[0] = - reinterpret_cast(&src_d[dst_rank[ii]][offset_rank])[0]; + reinterpret_cast(&src_d[ii][offset_rank])[0]; } } } @@ -395,4 +395,4 @@ template void invokeOneOrTwoShotAllReduceKernel<__nv_bfloat16>(AllReduceParams<_ cudaStream_t stream); #endif template void invokeOneOrTwoShotAllReduceKernel(AllReduceParams& param, cudaStream_t stream); -} // namespace fastertransformer \ No newline at end of file +} // namespace fastertransformer