From d262b0f39137e2e94f42978413bf5c2ee31f6958 Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Mon, 5 Dec 2022 19:52:09 +0000 Subject: [PATCH] guard all2all from empty transfer --- src/cunumeric/sort/sort.cu | 44 +++++++++++++++++++++----------------- 1 file changed, 24 insertions(+), 20 deletions(-) diff --git a/src/cunumeric/sort/sort.cu b/src/cunumeric/sort/sort.cu index af931c807..c303eb1ab 100644 --- a/src/cunumeric/sort/sort.cu +++ b/src/cunumeric/sort/sort.cu @@ -1557,32 +1557,36 @@ void sample_sort_nccl_nd(SortPiece> local_sorted, // communicate all2all (in sort dimension) CHECK_NCCL(ncclGroupStart()); for (size_t r = 0; r < num_sort_ranks; r++) { - CHECK_NCCL(ncclSend(val_send_buffers[r].ptr(0), - size_send_total[r] * sizeof(VAL), - ncclInt8, - sort_ranks[r], - *comm, - stream)); - CHECK_NCCL(ncclRecv(merge_buffers[r].values.ptr(0), - merge_buffers[r].size * sizeof(VAL), - ncclInt8, - sort_ranks[r], - *comm, - stream)); + if (size_send_total[r] > 0) + CHECK_NCCL(ncclSend(val_send_buffers[r].ptr(0), + size_send_total[r] * sizeof(VAL), + ncclInt8, + sort_ranks[r], + *comm, + stream)); + if (merge_buffers[r].size > 0) + CHECK_NCCL(ncclRecv(merge_buffers[r].values.ptr(0), + merge_buffers[r].size * sizeof(VAL), + ncclInt8, + sort_ranks[r], + *comm, + stream)); } CHECK_NCCL(ncclGroupEnd()); if (argsort) { CHECK_NCCL(ncclGroupStart()); for (size_t r = 0; r < num_sort_ranks; r++) { - CHECK_NCCL(ncclSend( - idc_send_buffers[r].ptr(0), size_send_total[r], ncclInt64, sort_ranks[r], *comm, stream)); - CHECK_NCCL(ncclRecv(merge_buffers[r].indices.ptr(0), - merge_buffers[r].size, - ncclInt64, - sort_ranks[r], - *comm, - stream)); + if (size_send_total[r] > 0) + CHECK_NCCL(ncclSend( + idc_send_buffers[r].ptr(0), size_send_total[r], ncclInt64, sort_ranks[r], *comm, stream)); + if (merge_buffers[r].size > 0) + CHECK_NCCL(ncclRecv(merge_buffers[r].indices.ptr(0), + merge_buffers[r].size, + ncclInt64, + sort_ranks[r], + *comm, + stream)); } CHECK_NCCL(ncclGroupEnd()); }