Skip to content

Commit

Permalink
[bugfix] Fix 2-shot All Reduce correctness issue (indexing bug). (NVI…
Browse files Browse the repository at this point in the history
…DIA#672)

FasterTransformer 2-shot all reduce is implemented as a reduce-scatter + all-gather. There is an indexing bug in the all-gather step. Prior to this change, 2-shot all reduce was only producing correct results on device 0. Now, all devices have the correct results.
  • Loading branch information
rkindi authored and azahed98 committed Jul 20, 2023
1 parent c6e8f60 commit 0db54f9
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/fastertransformer/kernels/custom_ar_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ static __global__ void twoShotAllReduceKernel(AllReduceParams<T> params)
// use round-robin gathering from other ranks
int offset_rank = local_offset + (dst_rank[ii] - params.local_rank) * params.elts_per_rank;
reinterpret_cast<PackedType*>(&params.local_output_buffer_ptr[offset_rank])[0] =
reinterpret_cast<PackedType*>(&src_d[dst_rank[ii]][offset_rank])[0];
reinterpret_cast<PackedType*>(&src_d[ii][offset_rank])[0];
}
}
}
Expand Down Expand Up @@ -395,4 +395,4 @@ template void invokeOneOrTwoShotAllReduceKernel<__nv_bfloat16>(AllReduceParams<_
cudaStream_t stream);
#endif
template void invokeOneOrTwoShotAllReduceKernel<uint32_t>(AllReduceParams<uint32_t>& param, cudaStream_t stream);
} // namespace fastertransformer
} // namespace fastertransformer

0 comments on commit 0db54f9

Please sign in to comment.