From 51af5572bf8ebf197bac7de8cd6bc7d847339575 Mon Sep 17 00:00:00 2001 From: John Bachan Date: Fri, 19 Aug 2022 15:15:10 -0500 Subject: [PATCH] Resync with NCCL 2.13 * Added "verifiable", a suite of kernels for generating and verifying reduction input and output arrays in a bit-precise way. * Data corruption errors now reported in number of wrong elements instead of max deviation. * Use ncclGetLastError. * Don't run hypercube on non-powers of 2 ranks. * Fix to hypercube data verification. * Use "thread local" as the defaut CUDA capture mode. * Replaced pthread_yield -> sched_yield() * Bugfix to the cpu-side barrier/allreduce implementations. --- src/Makefile | 8 +- src/all_gather.cu | 16 +- src/all_reduce.cu | 14 +- src/alltoall.cu | 17 +- src/broadcast.cu | 16 +- src/common.cu | 515 +++++--------- src/common.h | 30 +- src/gather.cu | 18 +- src/hypercube.cu | 27 +- src/reduce.cu | 14 +- src/reduce_scatter.cu | 14 +- src/scatter.cu | 16 +- src/sendrecv.cu | 16 +- verifiable/Makefile | 24 + verifiable/inexact_regress.cu | 177 +++++ verifiable/verifiable.cu | 1227 +++++++++++++++++++++++++++++++++ verifiable/verifiable.h | 59 ++ verifiable/verifiable.mk | 11 + 18 files changed, 1705 insertions(+), 514 deletions(-) create mode 100644 verifiable/Makefile create mode 100644 verifiable/inexact_regress.cu create mode 100644 verifiable/verifiable.cu create mode 100644 verifiable/verifiable.h create mode 100644 verifiable/verifiable.mk diff --git a/src/Makefile b/src/Makefile index 2a399db..137b9d7 100644 --- a/src/Makefile +++ b/src/Makefile @@ -83,12 +83,16 @@ build: ${BIN_FILES} clean: rm -rf ${DST_DIR} -${DST_DIR}/%.o: %.cu common.h +TEST_VERIFIABLE_SRCDIR := ../verifiable +TEST_VERIFIABLE_BUILDDIR := $(BUILDDIR)/verifiable +include ../verifiable/verifiable.mk + +${DST_DIR}/%.o: %.cu common.h $(TEST_VERIFIABLE_HDRS) @printf "Compiling %-35s > %s\n" $< $@ @mkdir -p ${DST_DIR} $(NVCC) -o $@ $(NVCUFLAGS) -c $< -${DST_DIR}/%_perf:${DST_DIR}/%.o ${DST_DIR}/common.o +${DST_DIR}/%_perf:${DST_DIR}/%.o ${DST_DIR}/common.o $(TEST_VERIFIABLE_OBJS) @printf "Linking %-35s > %s\n" $< $@ @mkdir -p ${DST_DIR} $(NVCC) -o $@ $(NVCUFLAGS) $^ ${NVLDFLAGS} diff --git a/src/all_gather.cu b/src/all_gather.cu index 0b9e0cc..1eaafdd 100644 --- a/src/all_gather.cu +++ b/src/all_gather.cu @@ -7,18 +7,6 @@ #include "cuda_runtime.h" #include "common.h" -void print_header() { - PRINT("# %10s %12s %8s out-of-place in-place \n", "", "", ""); - PRINT("# %10s %12s %8s %7s %6s %6s %5s %7s %6s %6s %5s\n", "size", "count", "type", - "time", "algbw", "busbw", "error", "time", "algbw", "busbw", "error"); - PRINT("# %10s %12s %8s %7s %6s %6s %5s %7s %6s %6s %5s\n", "(B)", "(elements)", "", - "(us)", "(GB/s)", "(GB/s)", "", "(us)", "(GB/s)", "(GB/s)", ""); -} - -void print_line_header (size_t size, size_t count, const char *typeName, const char *opName, int root) { - PRINT("%12li %12li %8s", size, count, typeName); -} - void AllGatherGetCollByteCount(size_t *sendcount, size_t *recvcount, size_t *paramcount, size_t *sendInplaceOffset, size_t *recvInplaceOffset, size_t count, int nranks) { *sendcount = count/nranks; *recvcount = (count/nranks)*nranks; @@ -38,9 +26,9 @@ testResult_t AllGatherInitData(struct threadArgs* args, ncclDataType_t type, ncc int rank = ((args->proc*args->nThreads + args->thread)*args->nGpus + i); CUDACHECK(cudaMemset(args->recvbuffs[i], 0, args->expectedBytes)); void* data = in_place ? ((char*)args->recvbuffs[i])+rank*args->sendBytes : args->sendbuffs[i]; - TESTCHECK(InitData(data, sendcount, type, rep, rank)); + TESTCHECK(InitData(data, sendcount, 0, type, ncclSum, 33*rep + rank, 1, 0)); for (int j=0; jexpected[i])+args->sendBytes*j, sendcount, type, rep, j)); + TESTCHECK(InitData((char*)args->expected[i] + args->sendBytes*j, sendcount, 0, type, ncclSum, 33*rep + j, 1, 0)); } CUDACHECK(cudaDeviceSynchronize()); } diff --git a/src/all_reduce.cu b/src/all_reduce.cu index 9b6b7f0..9c65f25 100644 --- a/src/all_reduce.cu +++ b/src/all_reduce.cu @@ -7,18 +7,6 @@ #include "cuda_runtime.h" #include "common.h" -void print_header() { - PRINT("# %10s %12s %8s %6s out-of-place in-place \n", "", "", "", ""); - PRINT("# %10s %12s %8s %6s %7s %6s %6s %5s %7s %6s %6s %5s\n", "size", "count", "type", "redop", - "time", "algbw", "busbw", "error", "time", "algbw", "busbw", "error"); - PRINT("# %10s %12s %8s %6s %7s %6s %6s %5s %7s %6s %6s %5s\n", "(B)", "(elements)", "", "", - "(us)", "(GB/s)", "(GB/s)", "", "(us)", "(GB/s)", "(GB/s)", ""); -} - -void print_line_header (size_t size, size_t count, const char *typeName, const char *opName, int root) { - PRINT("%12li %12li %8s %6s", size, count, typeName, opName); -} - void AllReduceGetCollByteCount(size_t *sendcount, size_t *recvcount, size_t *paramcount, size_t *sendInplaceOffset, size_t *recvInplaceOffset, size_t count, int nranks) { *sendcount = count; *recvcount = count; @@ -38,7 +26,7 @@ testResult_t AllReduceInitData(struct threadArgs* args, ncclDataType_t type, ncc int rank = ((args->proc*args->nThreads + args->thread)*args->nGpus + i); CUDACHECK(cudaMemset(args->recvbuffs[i], 0, args->expectedBytes)); void* data = in_place ? args->recvbuffs[i] : args->sendbuffs[i]; - TESTCHECK(InitData(data, sendcount, type, rep, rank)); + TESTCHECK(InitData(data, sendcount, 0, type, op, rep, nranks, rank)); TESTCHECK(InitDataReduce(args->expected[i], recvcount, 0, type, op, rep, nranks)); CUDACHECK(cudaDeviceSynchronize()); } diff --git a/src/alltoall.cu b/src/alltoall.cu index 8650997..0eae1b0 100644 --- a/src/alltoall.cu +++ b/src/alltoall.cu @@ -7,18 +7,6 @@ #include "cuda_runtime.h" #include "common.h" -void print_header() { - PRINT("# %10s %12s %8s %6s out-of-place in-place \n", "", "", "", ""); - PRINT("# %10s %12s %8s %6s %7s %6s %6s %5s %7s %6s %6s %5s\n", "size", "count", "type", "redop", - "time", "algbw", "busbw", "error", "time", "algbw", "busbw", "error"); - PRINT("# %10s %12s %8s %6s %7s %6s %6s %5s %7s %6s %6s %5s\n", "(B)", "(elements)", "", "", - "(us)", "(GB/s)", "(GB/s)", "", "(us)", "(GB/s)", "(GB/s)", ""); -} - -void print_line_header (size_t size, size_t count, const char *typeName, const char *opName, int root) { - PRINT("%12li %12li %8s %6s", size, count, typeName, opName); -} - void AlltoAllGetCollByteCount(size_t *sendcount, size_t *recvcount, size_t *paramcount, size_t *sendInplaceOffset, size_t *recvInplaceOffset, size_t count, int nranks) { *sendcount = (count/nranks)*nranks; *recvcount = (count/nranks)*nranks; @@ -39,9 +27,10 @@ testResult_t AlltoAllInitData(struct threadArgs* args, ncclDataType_t type, nccl int rank = ((args->proc*args->nThreads + args->thread)*args->nGpus + i); CUDACHECK(cudaMemset(args->recvbuffs[i], 0, args->expectedBytes)); void* data = in_place ? args->recvbuffs[i] : args->sendbuffs[i]; - TESTCHECK(InitData(data, sendcount, type, rep, rank)); + TESTCHECK(InitData(data, sendcount, 0, type, ncclSum, 33*rep + rank, 1, 0)); for (int j=0; jexpected[i])+args->sendBytes/nranks*j, sendcount/nranks, type, rep+rank*sendcount/nranks, j)); + size_t partcount = sendcount/nranks; + TESTCHECK(InitData((char*)args->expected[i] + j*partcount*wordSize(type), partcount, rank*partcount, type, ncclSum, 33*rep + j, 1, 0)); } CUDACHECK(cudaDeviceSynchronize()); } diff --git a/src/broadcast.cu b/src/broadcast.cu index e2b4421..40dcb5d 100644 --- a/src/broadcast.cu +++ b/src/broadcast.cu @@ -7,18 +7,6 @@ #include "cuda_runtime.h" #include "common.h" -void print_header() { - PRINT("# %10s %12s %8s %6s out-of-place in-place \n", "", "", "", ""); - PRINT("# %10s %12s %8s %6s %7s %6s %6s %5s %7s %6s %6s %5s\n", "size", "count", "type", "root", - "time", "algbw", "busbw", "error", "time", "algbw", "busbw", "error"); - PRINT("# %10s %12s %8s %6s %7s %6s %6s %5s %7s %6s %6s %5s\n", "(B)", "(elements)", "", "", - "(us)", "(GB/s)", "(GB/s)", "", "(us)", "(GB/s)", "(GB/s)", ""); -} - -void print_line_header (size_t size, size_t count, const char *typeName, const char *opName, int root) { - PRINT("%12li %12li %8s %6i", size, count, typeName, root); -} - void BroadcastGetCollByteCount(size_t *sendcount, size_t *recvcount, size_t *paramcount, size_t *sendInplaceOffset, size_t *recvInplaceOffset, size_t count, int nranks) { *sendcount = count; *recvcount = count; @@ -37,8 +25,8 @@ testResult_t BroadcastInitData(struct threadArgs* args, ncclDataType_t type, ncc int rank = ((args->proc*args->nThreads + args->thread)*args->nGpus + i); CUDACHECK(cudaMemset(args->recvbuffs[i], 0, args->expectedBytes)); void* data = in_place ? args->recvbuffs[i] : args->sendbuffs[i]; - if (rank == root) TESTCHECK(InitData(data, sendcount, type, rep, rank)); - TESTCHECK(InitData(args->expected[i], recvcount, type, rep, root)); + if (rank == root) TESTCHECK(InitData(data, sendcount, 0, type, ncclSum, rep, 1, 0)); + TESTCHECK(InitData(args->expected[i], recvcount, 0, type, ncclSum, rep, 1, 0)); CUDACHECK(cudaDeviceSynchronize()); } return testSuccess; diff --git a/src/common.cu b/src/common.cu index 05f814d..eaa3318 100644 --- a/src/common.cu +++ b/src/common.cu @@ -7,10 +7,13 @@ #include "common.h" #include #include +#include #include #include #include "cuda.h" +#include "../verifiable/verifiable.h" + int test_ncclVersion = 0; // init'd with ncclGetVersion() #if NCCL_MAJOR >= 2 @@ -107,362 +110,154 @@ static double parsesize(const char *value) { return size * units; } -double DeltaMaxValue(ncclDataType_t type) { - switch(type) { - case ncclHalf: return 1e-2; -#if defined(__CUDA_BF16_TYPES_EXIST__) - case ncclBfloat16: return 1e-2; -#endif - case ncclFloat: return 1e-5; - case ncclDouble: return 1e-12; - case ncclInt: -#if NCCL_MAJOR >= 2 - case ncclUint8: - //case ncclInt32: - case ncclUint32: -#endif - case ncclInt64: - case ncclUint64: return 1e-200; - } - return 1e-200; -} - -template __device__ -double absDiff(T a, T b) { - return fabs((double)(b - a)); -} - -template<> __device__ -double absDiff(half a, half b) { - float x = __half2float(a); - float y = __half2float(b); - return fabs((double)(y-x)); -} - -template __device__ -float toFloat(T a) { - return (float)a; -} -template<> __device__ -float toFloat(half a) { - return __half2float(a); -} -#if defined(__CUDA_BF16_TYPES_EXIST__) -template<> __device__ -float toFloat(__nv_bfloat16 a) { - return __bfloat162float(a); -} -#endif - -template __global__ -void deltaKern(void* A_, void* B_, size_t count, double* max) { - const T* A = (const T*)A_; - const T* B = (const T*)B_; - __shared__ double temp[BSIZE]; - int tid = blockIdx.x*blockDim.x + threadIdx.x; - double locmax = 0.0; - for(size_t i=tid; i locmax ) { - locmax = delta; -#ifdef DEBUG_PRINT - if (delta > .1) printf("Error at %ld/%ld(%p) : %f != %f\n", i, count, B+i, toFloat(A[i]), toFloat(B[i])); -#endif - } - } - - tid = threadIdx.x; - temp[tid] = locmax; - for(int stride = BSIZE/2; stride > 1; stride>>=1) { - __syncthreads(); - if( tid < stride ) - temp[tid] = temp[tid] > temp[tid+stride] ? temp[tid] : temp[tid+stride]; - } - __syncthreads(); - if( threadIdx.x == 0) - max[blockIdx.x] = temp[0] > temp[1] ? temp[0] : temp[1]; -} - -testResult_t CheckDelta(void* results, void* expected, size_t count, ncclDataType_t type, double* devmax) { - switch (type) { -#if defined(__CUDA_BF16_TYPES_EXIST__) - case ncclBfloat16: - deltaKern<__nv_bfloat16, 512><<>>(results, expected, count, devmax); break; -#endif - case ncclHalf: - deltaKern<<>>(results, expected, count, devmax); break; - case ncclFloat: - deltaKern<<>>(results, expected, count, devmax); break; - case ncclDouble: - deltaKern<<>>(results, expected, count, devmax); break; - - case ncclChar: -#if NCCL_MAJOR >= 2 - case ncclUint8: -#endif - deltaKern<<>>(results, expected, count, devmax); break; - case ncclInt: -#if NCCL_MAJOR >= 2 - case ncclUint32: -#endif - deltaKern<<>>(results, expected, count, devmax); break; - case ncclInt64: - case ncclUint64: - deltaKern<<>>(results, expected, count, devmax); break; - } +testResult_t CheckDelta(void* results, void* expected, size_t count, size_t offset, ncclDataType_t type, ncclRedOp_t op, uint64_t seed, int nranks, int64_t *wrongEltN) { + ncclVerifiableVerify(results, expected, count, (int)type, (int)op, nranks, seed, offset, wrongEltN, cudaStreamDefault); CUDACHECK(cudaDeviceSynchronize()); - for (int i=1; i -__device__ T testValue(const size_t offset, const int rep, const int rank) { - uint8_t v = (rep+rank+offset) % 256; - return (T)v; +testResult_t InitDataReduce(void* data, const size_t count, const size_t offset, ncclDataType_t type, ncclRedOp_t op, uint64_t seed, int nranks) { + ncclVerifiablePrepareExpected(data, count, (int)type, (int)op, nranks, seed, offset, cudaStreamDefault); + return testSuccess; } -// For floating point datatype, we use values between 0 and 1 otherwise the -// Product operation will produce NaNs. -template<> -__device__ double testValue(const size_t offset, const int rep, const int rank) { - return 1.0/(1.0+(double)testValue(offset, rep, rank)); -} -template<> -__device__ float testValue(const size_t offset, const int rep, const int rank) { - return 1.0/(1.0+(float)testValue(offset, rep, rank)); -} -template<> -__device__ half testValue(const size_t offset, const int rep, const int rank) { - return __float2half(testValue(offset, rep, rank)); -} -#if defined(__CUDA_BF16_TYPES_EXIST__) -template<> -__device__ __nv_bfloat16 testValue<__nv_bfloat16>(const size_t offset, const int rep, const int rank) { - return __float2bfloat16(testValue(offset, rep, rank)); +testResult_t InitData(void* data, const size_t count, size_t offset, ncclDataType_t type, ncclRedOp_t op, uint64_t seed, int nranks, int rank) { + ncclVerifiablePrepareInput(data, count, (int)type, (int)op, nranks, rank, seed, offset, cudaStreamDefault); + return testSuccess; } -#endif -// Operations -template -__device__ T ncclOpSum(T a, T b) { return a+b; } -template -__device__ T ncclOpProd(T a, T b) { return a*b; } -template -__device__ T ncclOpMax(T a, T b) { return a>b ? a : b; } -template -__device__ T ncclOpMin(T a, T b) { return a -__device__ half ncclOpSum(half a, half b) { return __float2half(__half2float(a)+__half2float(b)); } -template<> -__device__ half ncclOpProd(half a, half b) { return __float2half(__half2float(a)*__half2float(b)); } -template<> -__device__ half ncclOpMax(half a, half b) { return __half2float(a)>__half2float(b) ? a : b; } -template<> -__device__ half ncclOpMin(half a, half b) { return __half2float(a)<__half2float(b) ? a : b; } - -template -__device__ T ncclPPOpIdent(T x, int arg) { return x; } -template -__device__ T ncclPPOpMul(T x, int arg) { return x*T(arg); } -template -__device__ T ncclPPOpDiv(T x, int arg) { return x/T(arg); } -template<> -__device__ half ncclPPOpMul(half x, int arg) { - return __float2half(__half2float(x)*float(arg)); -} -template<> -__device__ half ncclPPOpDiv(half x, int n) { - return __float2half(__half2float(x)/n); -} -#if defined(__CUDA_BF16_TYPES_EXIST__) -template<> -__device__ __nv_bfloat16 ncclPPOpMul(__nv_bfloat16 x, int arg) { - return __float2bfloat16(__bfloat162float(x)*float(arg)); -} -template<> -__device__ __nv_bfloat16 ncclPPOpDiv(__nv_bfloat16 x, int n) { - return __float2bfloat16(__bfloat162float(x)/n); -} -#endif +void Barrier(struct threadArgs *args) { + thread_local int epoch = 0; + static pthread_mutex_t lock[2] = {PTHREAD_MUTEX_INITIALIZER, PTHREAD_MUTEX_INITIALIZER}; + static pthread_cond_t cond[2] = {PTHREAD_COND_INITIALIZER, PTHREAD_COND_INITIALIZER}; + static int counter[2] = {0, 0}; -__host__ __device__ int preMulScalar(int rank) { - return 1 + rank%2; -} + pthread_mutex_lock(&lock[epoch]); + if(++counter[epoch] == args->nThreads) + pthread_cond_broadcast(&cond[epoch]); -template -__global__ void InitDataReduceKernel(T* data, const size_t N, const size_t offset, const int rep, const int nranks) { - for (size_t o=blockIdx.x*blockDim.x+threadIdx.x; o(o+offset, rep, 0); - val = PreOp(val, preMulScalar(0)); - for (int i=1; i(o+offset, rep, i); - val1 = PreOp(val1, preMulScalar(i)); - val = Op(val, val1); - } - data[o] = PostOp(val, nranks); + if(args->thread+1 == args->nThreads) { + while(counter[epoch] != args->nThreads) + pthread_cond_wait(&cond[epoch], &lock[epoch]); + #ifdef MPI_SUPPORT + MPI_Barrier(MPI_COMM_WORLD); + #endif + counter[epoch] = 0; + pthread_cond_broadcast(&cond[epoch]); } + else { + while(counter[epoch] != 0) + pthread_cond_wait(&cond[epoch], &lock[epoch]); + } + pthread_mutex_unlock(&lock[epoch]); + epoch ^= 1; } -#define KERN(type, op, preop, postop) (void*)InitDataReduceKernel, preop, postop > -#if NCCL_VERSION_CODE >= NCCL_VERSION(2,11,0) - #define OPS(type) \ - KERN(type, ncclOpSum, ncclPPOpIdent, ncclPPOpIdent), \ - KERN(type, ncclOpProd, ncclPPOpIdent, ncclPPOpIdent), \ - KERN(type, ncclOpMax, ncclPPOpIdent, ncclPPOpIdent), \ - KERN(type, ncclOpMin, ncclPPOpIdent, ncclPPOpIdent), \ - KERN(type, ncclOpSum/*Avg*/, ncclPPOpIdent, ncclPPOpDiv), \ - KERN(type, ncclOpSum/*PreMulSum*/, ncclPPOpMul, ncclPPOpIdent) -#elif NCCL_VERSION_CODE >= NCCL_VERSION(2,10,0) - #define OPS(type) \ - KERN(type, ncclOpSum, ncclPPOpIdent, ncclPPOpIdent), \ - KERN(type, ncclOpProd, ncclPPOpIdent, ncclPPOpIdent), \ - KERN(type, ncclOpMax, ncclPPOpIdent, ncclPPOpIdent), \ - KERN(type, ncclOpMin, ncclPPOpIdent, ncclPPOpIdent), \ - KERN(type, ncclOpSum/*Avg*/, ncclPPOpIdent, ncclPPOpDiv) -#else - #define OPS(type) \ - KERN(type, ncclOpSum, ncclPPOpIdent, ncclPPOpIdent), \ - KERN(type, ncclOpProd, ncclPPOpIdent, ncclPPOpIdent), \ - KERN(type, ncclOpMax, ncclPPOpIdent, ncclPPOpIdent), \ - KERN(type, ncclOpMin, ncclPPOpIdent, ncclPPOpIdent) -#endif - -static void* const redInitDataKerns[test_opNumMax*ncclNumTypes] = { - OPS(int8_t), OPS(uint8_t), OPS(int32_t), OPS(uint32_t), OPS(int64_t), OPS(uint64_t), OPS(half), OPS(float), OPS(double), -#if defined(__CUDA_BF16_TYPES_EXIST__) && NCCL_VERSION_CODE >= NCCL_VERSION(2,10,0) - OPS(__nv_bfloat16) -#endif -}; - -testResult_t InitDataReduce(void* data, const size_t count, const size_t offset, ncclDataType_t type, ncclRedOp_t op, const int rep, const int nranks) { - dim3 grid = { 32, 1, 1 }; - dim3 block = { 256, 1, 1 }; - void* args[5] = { (void*)&data, (void*)&count, (void*)&offset, (void*)&rep, (void*)&nranks }; - CUDACHECK(cudaLaunchKernel(redInitDataKerns[type*test_opNumMax+op], grid, block, args, 0, cudaStreamDefault)); - return testSuccess; -} - -template -__global__ void InitDataKernel(T* data, const size_t N, const int rep, const int rank) { - for (size_t o=blockIdx.x*blockDim.x+threadIdx.x; o(o, rep, rank); -} - -static void* const initDataKerns[ncclNumTypes] = { - (void*)InitDataKernel< int8_t>, - (void*)InitDataKernel< uint8_t>, - (void*)InitDataKernel< int32_t>, - (void*)InitDataKernel, - (void*)InitDataKernel< int64_t>, - (void*)InitDataKernel, - (void*)InitDataKernel< half>, - (void*)InitDataKernel< float>, - (void*)InitDataKernel< double>, -#if defined(__CUDA_BF16_TYPES_EXIST__) && NCCL_VERSION_CODE >= NCCL_VERSION(2,10,0) - (void*)InitDataKernel<__nv_bfloat16> -#endif -}; - +// Inter-thread/process barrier+allreduce. The quality of the return value +// for average=0 (which means broadcast from rank=0) is dubious. The returned +// value will actually be the result of process-local broadcast from the local thread=0. template -testResult_t InitDataType(void* dest, const size_t N, const int rep, const int rank) { - T* ptr = (T*)dest; - InitDataKernel<<<16, 512>>>(ptr, N, rep, rank); - return testSuccess; -} - -testResult_t InitData(void* data, const size_t count, ncclDataType_t type, const int rep, const int rank) { - dim3 grid = { 32, 1, 1 }; - dim3 block = { 256, 1, 1 }; - void* args[4] = { (void*)&data, (void*)&count, (void*)&rep, (void*)&rank }; - CUDACHECK(cudaLaunchKernel(initDataKerns[type], grid, block, args, 0, cudaStreamDefault)); - return testSuccess; -} - -void Barrier(struct threadArgs* args) { - while (args->barrier[args->barrier_idx] != args->thread) pthread_yield(); - args->barrier[args->barrier_idx] = args->thread + 1; - if (args->thread+1 == args->nThreads) { -#ifdef MPI_SUPPORT - MPI_Barrier(MPI_COMM_WORLD); -#endif - args->barrier[args->barrier_idx] = 0; +void Allreduce(struct threadArgs* args, T* value, int average) { + thread_local int epoch = 0; + static pthread_mutex_t lock[2] = {PTHREAD_MUTEX_INITIALIZER, PTHREAD_MUTEX_INITIALIZER}; + static pthread_cond_t cond[2] = {PTHREAD_COND_INITIALIZER, PTHREAD_COND_INITIALIZER}; + static T accumulator[2]; + static int counter[2] = {0, 0}; + + pthread_mutex_lock(&lock[epoch]); + if(counter[epoch] == 0) { + if(average != 0 || args->thread == 0) accumulator[epoch] = *value; } else { - while (args->barrier[args->barrier_idx]) pthread_yield(); + switch(average) { + case /*r0*/ 0: if(args->thread == 0) accumulator[epoch] = *value; break; + case /*avg*/1: accumulator[epoch] += *value; break; + case /*min*/2: accumulator[epoch] = std::min(accumulator[epoch], *value); break; + case /*max*/3: accumulator[epoch] = std::max(accumulator[epoch], *value); break; + case /*sum*/4: accumulator[epoch] += *value; break; + } } - args->barrier_idx=!args->barrier_idx; -} -// Inter-thread/process barrier+allreduce -void Allreduce(struct threadArgs* args, double* value, int average) { - while (args->barrier[args->barrier_idx] != args->thread) pthread_yield(); - double val = *value; - if (args->thread > 0) { - double val2 = args->reduce[args->barrier_idx]; - if (average == 1) val += val2; - if (average == 2) val = std::min(val, val2); - if (average == 3) val = std::max(val, val2); - } - if (average || args->thread == 0) args->reduce[args->barrier_idx] = val; - args->barrier[args->barrier_idx] = args->thread + 1; - if (args->thread+1 == args->nThreads) { -#ifdef MPI_SUPPORT - if (average != 0) { - MPI_Op op = average == 1 ? MPI_SUM : average == 2 ? MPI_MIN : MPI_MAX; - MPI_Allreduce(MPI_IN_PLACE, (void*)&args->reduce[args->barrier_idx], 1, MPI_DOUBLE, op, MPI_COMM_WORLD); + if(++counter[epoch] == args->nThreads) + pthread_cond_broadcast(&cond[epoch]); + + if(args->thread+1 == args->nThreads) { + while(counter[epoch] != args->nThreads) + pthread_cond_wait(&cond[epoch], &lock[epoch]); + + #ifdef MPI_SUPPORT + if(average != 0) { + static_assert(std::is_same::value || std::is_same::value, "Allreduce only for T in {long long, double}"); + MPI_Datatype ty = std::is_same::value ? MPI_LONG_LONG : + std::is_same::value ? MPI_DOUBLE : + MPI_Datatype(); + MPI_Op op = average == 1 ? MPI_SUM : + average == 2 ? MPI_MIN : + average == 3 ? MPI_MAX : + average == 4 ? MPI_SUM : MPI_Op(); + MPI_Allreduce(MPI_IN_PLACE, (void*)&accumulator[epoch], 1, ty, op, MPI_COMM_WORLD); } -#endif - if (average == 1) args->reduce[args->barrier_idx] /= args->nProcs*args->nThreads; - args->reduce[1-args->barrier_idx] = 0; - args->barrier[args->barrier_idx] = 0; - } else { - while (args->barrier[args->barrier_idx]) pthread_yield(); + #endif + + if(average == 1) accumulator[epoch] /= args->nProcs*args->nThreads; + counter[epoch] = 0; + pthread_cond_broadcast(&cond[epoch]); } - *value = args->reduce[args->barrier_idx]; - args->barrier_idx=!args->barrier_idx; + else { + while(counter[epoch] != 0) + pthread_cond_wait(&cond[epoch], &lock[epoch]); + } + pthread_mutex_unlock(&lock[epoch]); + + *value = accumulator[epoch]; + epoch ^= 1; } -testResult_t CheckData(struct threadArgs* args, ncclDataType_t type, ncclRedOp_t op, int root, int in_place, double *delta) { +testResult_t CheckData(struct threadArgs* args, ncclDataType_t type, ncclRedOp_t op, int root, int in_place, int64_t *wrongElts) { + int nranks = args->nProcs*args->nGpus*args->nThreads; size_t count = args->expectedBytes/wordSize(type); - double maxDelta = 0.0; + + int64_t *wrongPerGpu = nullptr; + CUDACHECK(cudaHostAlloc((void**)&wrongPerGpu, args->nGpus*sizeof(int64_t), cudaHostAllocMapped)); + for (int i=0; inGpus; i++) { int device; int rank = ((args->proc*args->nThreads + args->thread)*args->nGpus + i); NCCLCHECK(ncclCommCuDevice(args->comms[i], &device)); CUDACHECK(cudaSetDevice(device)); void *data = in_place ? ((void *)((uintptr_t)args->recvbuffs[i] + args->recvInplaceOffset*rank)) : args->recvbuffs[i]; - TESTCHECK(CheckDelta(data , args->expected[i], count, type, args->deltaHost)); - maxDelta = std::max(*(args->deltaHost), maxDelta); - -#ifdef DEBUG_PRINT - if (rank == 0) { - int *expectedHost = (int *)malloc(args->expectedBytes); - int *dataHost = (int *)malloc(args->expectedBytes); - - cudaMemcpy(expectedHost, args->expected[0], args->expectedBytes, cudaMemcpyDeviceToHost); - printf("\n Expected: "); - for(int j=0; jexpectedBytes/sizeof(int); j++) { - printf("%d:%d ", j, expectedHost[j]); - } - printf("\n"); - cudaMemcpy(dataHost, data, args->expectedBytes, cudaMemcpyDeviceToHost); - printf("\n Actual: "); - for (int j=0; jexpectedBytes/sizeof(int); j++) { - printf("%d:%d ", j, dataHost[j]); - } - printf("\n"); - free(expectedHost); - free(dataHost); + TESTCHECK(CheckDelta(data, args->expected[i], count, 0, type, op, 0, nranks, wrongPerGpu+i)); + +#if 1 && DEBUG_PRINT + if (args->reportErrors && wrongPerGpu[i] != 0) { + printf("rank=%d #wrong=%d\n", rank, (int)wrongPerGpu[i]); + char *expectedHost = (char*)malloc(args->expectedBytes); + char *dataHost = (char*)malloc(args->expectedBytes); + int eltsz = wordSize(type); + cudaMemcpy(expectedHost, args->expected[i], args->expectedBytes, cudaMemcpyDeviceToHost); + cudaMemcpy(dataHost, data, args->expectedBytes, cudaMemcpyDeviceToHost); + + for(int j=0; jexpectedBytes/eltsz; j++) { + unsigned long long want, got; + want = 0; + memcpy(&want, expectedHost + j*eltsz, eltsz); + got = 0; + memcpy(&got, dataHost + j*eltsz, eltsz); + if(want != got) { + printf(" rank=%d elt[%d]: want=0x%llx got=0x%llx\n", rank, j, want, got); + } + } + free(expectedHost); + free(dataHost); } #endif } - double nranks = args->nProcs*args->nThreads*args->nGpus; - if (args->reportErrors && maxDelta > DeltaMaxValue(type)*(nranks - 1)) args->errors[0]++; - *delta = maxDelta; + + *wrongElts = 0; + for (int i=0; i < args->nGpus; i++) *wrongElts += wrongPerGpu[i]; + cudaFree(wrongPerGpu); + + if (args->reportErrors && *wrongElts) args->errors[0]++; return testSuccess; } @@ -503,7 +298,7 @@ testResult_t testStreamSynchronize(int ngpus, cudaStream_t* streams, ncclComm_t* } // We might want to let other threads (including NCCL threads) use the CPU. - if (idle) pthread_yield(); + if (idle) sched_yield(); } free(done); return testSuccess; @@ -541,19 +336,18 @@ testResult_t startColl(struct threadArgs* args, ncclDataType_t type, ncclRedOp_t __nv_bfloat16 bf16; #endif }; - int scalar = preMulScalar(rank); switch(type) { - case ncclInt8: i8 = int8_t(scalar); break; - case ncclUint8: u8 = uint8_t(scalar); break; - case ncclInt32: i32 = int32_t(scalar); break; - case ncclUint32: u32 = uint32_t(scalar); break; - case ncclInt64: i64 = int32_t(scalar); break; - case ncclUint64: u64 = uint32_t(scalar); break; - case ncclFloat16: f16 = __float2half(float(scalar)); break; - case ncclFloat32: f32 = float(scalar); break; - case ncclFloat64: f64 = double(scalar); break; + case ncclInt8: i8 = ncclVerifiablePremulScalar(rank); break; + case ncclUint8: u8 = ncclVerifiablePremulScalar(rank); break; + case ncclInt32: i32 = ncclVerifiablePremulScalar(rank); break; + case ncclUint32: u32 = ncclVerifiablePremulScalar(rank); break; + case ncclInt64: i64 = ncclVerifiablePremulScalar(rank); break; + case ncclUint64: u64 = ncclVerifiablePremulScalar(rank); break; + case ncclFloat16: f16 = ncclVerifiablePremulScalar(rank); break; + case ncclFloat32: f32 = ncclVerifiablePremulScalar(rank); break; + case ncclFloat64: f64 = ncclVerifiablePremulScalar(rank); break; #if defined(__CUDA_BF16_TYPES_EXIST__) - case ncclBfloat16: bf16 = __float2bfloat16(float(scalar)); break; + case ncclBfloat16: bf16 = ncclVerifiablePremulScalar<__nv_bfloat16>(rank); break; #endif } NCCLCHECK(ncclRedOpCreatePreMulSum(&op, &u64, type, ncclScalarHostImmediate, args->comms[i])); @@ -607,9 +401,10 @@ testResult_t BenchTime(struct threadArgs* args, ncclDataType_t type, ncclRedOp_t if (cudaGraphLaunches >= 1) { // Begin cuda graph capture for (int i=0; inGpus; i++) { - // Thread local mode is needed for: - // - Multi-thread mode - // - P2P pre-connect + // Thread local mdoe is needed for: + // - Multi-thread mode: where graph capture and instantiation can happen concurrently across threads + // - P2P pre-connect: when there is no warm-up, P2P pre-connect is done during graph capture. + // Since pre-connect calls cudaMalloc, we cannot use global capture mode CUDACHECK(cudaStreamBeginCapture(args->streams[i], cudaStreamCaptureModeThreadLocal)); } } @@ -669,7 +464,7 @@ testResult_t BenchTime(struct threadArgs* args, ncclDataType_t type, ncclRedOp_t Barrier(args); - double maxDelta = 0; + int64_t wrongElts = 0; static __thread int rep = 0; rep++; if (datacheck) { @@ -717,10 +512,12 @@ testResult_t BenchTime(struct threadArgs* args, ncclDataType_t type, ncclRedOp_t } #endif - TESTCHECK(CheckData(args, type, op, root, in_place, &maxDelta)); + TESTCHECK(CheckData(args, type, op, root, in_place, &wrongElts)); //aggregate delta from all threads and procs - Allreduce(args, &maxDelta, 3); + long long wrongElts1 = wrongElts; + Allreduce(args, &wrongElts1, /*sum*/4); + wrongElts = wrongElts1; } double timeUsec = deltaSec*1.0E6; @@ -733,9 +530,9 @@ testResult_t BenchTime(struct threadArgs* args, ncclDataType_t type, ncclRedOp_t sprintf(timeStr, "%7.2f", timeUsec); } if (datacheck) { - PRINT(" %7s %6.2f %6.2f %5.0le", timeStr, algBw, busBw, maxDelta); + PRINT(" %7s %6.2f %6.2f %5g", timeStr, algBw, busBw, (double)wrongElts); } else { - PRINT(" %7s %6.2f %6.2f %5s", timeStr, algBw, busBw, "N/A"); + PRINT(" %7s %6.2f %6.2f %5s", timeStr, algBw, busBw, "N/A"); } args->bw[0] += busBw; @@ -775,7 +572,9 @@ testResult_t TimeTest(struct threadArgs* args, ncclDataType_t type, const char* // Benchmark for (size_t size = args->minbytes; size<=args->maxbytes; size = ((args->stepfactor > 1) ? size*args->stepfactor : size+args->stepbytes)) { setupArgs(size, type, args); - print_line_header(max(args->sendBytes, args->expectedBytes), args->nbytes / wordSize(type), typeName, opName, root); + char rootName[100]; + sprintf(rootName, "%6i", root); + PRINT("%12li %12li %8s %6s %6s", max(args->sendBytes, args->expectedBytes), args->nbytes / wordSize(type), typeName, opName, rootName); TESTCHECK(BenchTime(args, type, op, root, 0)); TESTCHECK(BenchTime(args, type, op, root, 1)); PRINT("\n"); @@ -828,7 +627,7 @@ testResult_t threadLaunch(struct testThread* thread) { return testSuccess; } -testResult_t AllocateBuffs(void **sendbuff, size_t sendBytes, void **recvbuff, size_t recvBytes, void **expected, size_t nbytes, int nranks) { +testResult_t AllocateBuffs(void **sendbuff, size_t sendBytes, void **recvbuff, size_t recvBytes, void **expected, size_t nbytes) { CUDACHECK(cudaMalloc(sendbuff, nbytes)); CUDACHECK(cudaMalloc(recvbuff, nbytes)); if (datacheck) CUDACHECK(cudaMalloc(expected, recvBytes)); @@ -1027,8 +826,10 @@ testResult_t run() { #endif is_main_thread = (proc == 0) ? 1 : 0; - PRINT("# nThread %d nGpus %d minBytes %ld maxBytes %ld step: %ld(%s) warmup iters: %d iters: %d validation: %d \n", nThreads, nGpus, minBytes, maxBytes, - (stepFactor > 1)?stepFactor:stepBytes, (stepFactor > 1)?"factor":"bytes", warmup_iters, iters, datacheck); + PRINT("# nThread %d nGpus %d minBytes %ld maxBytes %ld step: %ld(%s) warmup iters: %d iters: %d agg iters: %d validation: %d graph: %d\n", + nThreads, nGpus, minBytes, maxBytes, + (stepFactor > 1)?stepFactor:stepBytes, (stepFactor > 1)?"factor":"bytes", + warmup_iters, iters, agg_iters, datacheck, cudaGraphLaunches); if (blocking_coll) PRINT("# Blocking Enabled: wait for completion and barrier after each collective \n"); if (parallel_init) PRINT("# Parallel Init Enabled: threads call into NcclInitRank concurrently \n"); PRINT("#\n"); @@ -1087,7 +888,7 @@ testResult_t run() { for (int i=0; i= NCCL_VERSION(2,12,10) +#define NCCLCHECK(cmd) do { \ + ncclResult_t res = cmd; \ + if (res != ncclSuccess) { \ + char hostname[1024]; \ + getHostName(hostname, 1024); \ + printf("%s: Test NCCL failure %s:%d " \ + "'%s / %s'\n", \ + hostname,__FILE__,__LINE__, \ + ncclGetErrorString(res), \ + ncclGetLastError(NULL)); \ + return testNcclError; \ + } \ +} while(0) +#else #define NCCLCHECK(cmd) do { \ ncclResult_t res = cmd; \ if (res != ncclSuccess) { \ @@ -39,6 +54,7 @@ return testNcclError; \ } \ } while(0) +#endif typedef enum { testSuccess = 0, @@ -111,14 +127,6 @@ struct threadArgs { void** expected; size_t expectedBytes; - volatile int* sync; - int sync_idx; - volatile int* barrier; - int barrier_idx; - volatile double* reduce; - int syncRank; - int syncNranks; - double* deltaHost; int* errors; double* bw; int* bw_count; @@ -141,8 +149,8 @@ struct testThread { // Provided by common.cu extern void Barrier(struct threadArgs* args); extern testResult_t TimeTest(struct threadArgs* args, ncclDataType_t type, const char* typeName, ncclRedOp_t op, const char* opName, int root); -extern testResult_t InitDataReduce(void* data, const size_t count, const size_t offset, ncclDataType_t type, ncclRedOp_t op, const int rep, const int nranks); -extern testResult_t InitData(void* data, const size_t count, ncclDataType_t type, const int rep, const int rank); +extern testResult_t InitDataReduce(void* data, const size_t count, const size_t offset, ncclDataType_t type, ncclRedOp_t op, const uint64_t seed, const int nranks); +extern testResult_t InitData(void* data, const size_t count, size_t offset, ncclDataType_t type, ncclRedOp_t op, const uint64_t seed, const int nranks, const int rank); extern void AllocateBuffs(void **sendbuff, void **recvbuff, void **expected, void **expectedHost, size_t nbytes, int nranks); // Provided by each coll @@ -228,7 +236,7 @@ static size_t wordSize(ncclDataType_t type) { case ncclInt64: case ncclUint64: case ncclDouble: - //case ncclFloat64: + //case ncclFloat64: return 8; default: return 0; } diff --git a/src/gather.cu b/src/gather.cu index d0cfa5d..9908852 100644 --- a/src/gather.cu +++ b/src/gather.cu @@ -7,18 +7,6 @@ #include "cuda_runtime.h" #include "common.h" -void print_header() { - PRINT("# %10s %12s %8s %6s out-of-place in-place \n", "", "", "", ""); - PRINT("# %10s %12s %8s %6s %7s %6s %6s %5s %7s %6s %6s %5s\n", "size", "count", "type", "root", - "time", "algbw", "busbw", "error", "time", "algbw", "busbw", "error"); - PRINT("# %10s %12s %8s %6s %7s %6s %6s %5s %7s %6s %6s %5s\n", "(B)", "(elements)", "", "", - "(us)", "(GB/s)", "(GB/s)", "", "(us)", "(GB/s)", "(GB/s)", ""); -} - -void print_line_header (size_t size, size_t count, const char *typeName, const char *opName, int root) { - PRINT("%12li %12li %8s %6i", size, count, typeName, root); -} - void GatherGetCollByteCount(size_t *sendcount, size_t *recvcount, size_t *paramcount, size_t *sendInplaceOffset, size_t *recvInplaceOffset, size_t count, int nranks) { *sendcount = count/nranks; *recvcount = (count/nranks)*nranks; @@ -38,12 +26,10 @@ testResult_t GatherInitData(struct threadArgs* args, ncclDataType_t type, ncclRe int rank = ((args->proc*args->nThreads + args->thread)*args->nGpus + i); CUDACHECK(cudaMemset(args->recvbuffs[i], 0, args->expectedBytes)); void* data = in_place ? ((char*)args->recvbuffs[i])+rank*args->sendBytes : args->sendbuffs[i]; - TESTCHECK(InitData(data, sendcount, type, rep, rank)); + TESTCHECK(InitData(data, sendcount, rank*sendcount, type, ncclSum, rep, 1, 0)); CUDACHECK(cudaMemcpy(args->expected[i], args->recvbuffs[i], args->expectedBytes, cudaMemcpyDefault)); if (rank == root) { - for (int j=0; jexpected[i])+args->sendBytes*j, sendcount, type, rep, j)); - } + TESTCHECK(InitData(args->expected[i], nranks*sendcount, 0, type, ncclSum, rep, 1, 0)); } CUDACHECK(cudaDeviceSynchronize()); } diff --git a/src/hypercube.cu b/src/hypercube.cu index 142f1a6..ae9fbd0 100644 --- a/src/hypercube.cu +++ b/src/hypercube.cu @@ -9,18 +9,6 @@ #define ALIGN 4 -void print_header() { - PRINT("# %10s %12s %8s out-of-place in-place \n", "", "", ""); - PRINT("# %10s %12s %8s %7s %6s %6s %5s %7s %6s %6s %5s\n", "size", "count", "type", - "time", "algbw", "busbw", "error", "time", "algbw", "busbw", "error"); - PRINT("# %10s %12s %8s %7s %6s %6s %5s %7s %6s %6s %5s\n", "(B)", "(elements)", "", - "(us)", "(GB/s)", "(GB/s)", "", "(us)", "(GB/s)", "(GB/s)", ""); -} - -void print_line_header (size_t size, size_t count, const char *typeName, const char *opName, int root) { - PRINT("%12li %12li %8s", size, count, typeName); -} - void HyperCubeGetCollByteCount(size_t *sendcount, size_t *recvcount, size_t *paramcount, size_t *sendInplaceOffset, size_t *recvInplaceOffset, size_t count, int nranks) { size_t base = (count/(ALIGN*nranks))*ALIGN; *sendcount = base; @@ -41,9 +29,9 @@ testResult_t HyperCubeInitData(struct threadArgs* args, ncclDataType_t type, ncc int rank = ((args->proc*args->nThreads + args->thread)*args->nGpus + i); CUDACHECK(cudaMemset(args->recvbuffs[i], 0, args->expectedBytes)); void* data = in_place ? ((char*)args->recvbuffs[i])+rank*args->sendBytes : args->sendbuffs[i]; - TESTCHECK(InitData(data, sendcount, type, rep, rank)); + TESTCHECK(InitData(data, sendcount, 0, type, ncclSum, 33*rep + rank, 1, 0)); for (int j=0; jexpected[i])+args->sendBytes*j, sendcount, type, rep, j)); + TESTCHECK(InitData((char*)args->expected[i] + args->sendBytes*j, sendcount, 0, type, ncclSum, 33*rep + j, 1, 0)); } CUDACHECK(cudaDeviceSynchronize()); } @@ -110,9 +98,16 @@ testResult_t HyperCubeRunTest(struct threadArgs* args, int root, ncclDataType_t run_typenames = test_typenames; } - for (int i=0; inProcs*args->nThreads*args->nGpus; + if (nRanks && !(nRanks & (nRanks - 1))) { + for (int i=0; iproc*args->nThreads + args->thread)*args->nGpus + i); CUDACHECK(cudaMemset(args->recvbuffs[i], 0, args->expectedBytes)); void* data = in_place ? args->recvbuffs[i] : args->sendbuffs[i]; - TESTCHECK(InitData(data, sendcount, type, rep, rank)); + TESTCHECK(InitData(data, sendcount, 0, type, op, rep, nranks, rank)); CUDACHECK(cudaMemcpy(args->expected[i], args->recvbuffs[i], args->expectedBytes, cudaMemcpyDefault)); if (rank == root) TESTCHECK(InitDataReduce(args->expected[i], recvcount, 0, type, op, rep, nranks)); CUDACHECK(cudaDeviceSynchronize()); diff --git a/src/reduce_scatter.cu b/src/reduce_scatter.cu index b0c4fab..e4a59dc 100644 --- a/src/reduce_scatter.cu +++ b/src/reduce_scatter.cu @@ -7,18 +7,6 @@ #include "cuda_runtime.h" #include "common.h" -void print_header() { - PRINT("# %10s %12s %8s %6s out-of-place in-place \n", "", "", "", ""); - PRINT("# %10s %12s %8s %6s %7s %6s %6s %5s %7s %6s %6s %5s\n", "size", "count", "type", "redop", - "time", "algbw", "busbw", "error", "time", "algbw", "busbw", "error"); - PRINT("# %10s %12s %8s %6s %7s %6s %6s %5s %7s %6s %6s %5s\n", "(B)", "(elements)", "", "", - "(us)", "(GB/s)", "(GB/s)", "", "(us)", "(GB/s)", "(GB/s)", ""); -} - -void print_line_header (size_t size, size_t count, const char *typeName, const char *opName, int root) { - PRINT("%12li %12li %8s %6s", size, count, typeName, opName); -} - void ReduceScatterGetCollByteCount(size_t *sendcount, size_t *recvcount, size_t *paramcount, size_t *sendInplaceOffset, size_t *recvInplaceOffset, size_t count, int nranks) { *sendcount = (count/nranks)*nranks; *recvcount = count/nranks; @@ -38,7 +26,7 @@ testResult_t ReduceScatterInitData(struct threadArgs* args, ncclDataType_t type, int rank = ((args->proc*args->nThreads + args->thread)*args->nGpus + i); CUDACHECK(cudaMemset(args->recvbuffs[i], 0, args->expectedBytes)); void* data = in_place ? args->recvbuffs[i] : args->sendbuffs[i]; - TESTCHECK(InitData(data, sendcount, type, rep, rank)); + TESTCHECK(InitData(data, sendcount, 0, type, op, rep, nranks, rank)); CUDACHECK(cudaMemcpy(args->expected[i], args->recvbuffs[i], args->expectedBytes, cudaMemcpyDefault)); TESTCHECK(InitDataReduce(args->expected[i], recvcount, rank*recvcount, type, op, rep, nranks)); CUDACHECK(cudaDeviceSynchronize()); diff --git a/src/scatter.cu b/src/scatter.cu index 93ab2e6..d244b2b 100644 --- a/src/scatter.cu +++ b/src/scatter.cu @@ -7,18 +7,6 @@ #include "cuda_runtime.h" #include "common.h" -void print_header() { - PRINT("# %10s %12s %8s %6s out-of-place in-place \n", "", "", "", ""); - PRINT("# %10s %12s %8s %6s %7s %6s %6s %5s %7s %6s %6s %5s\n", "size", "count", "type", "root", - "time", "algbw", "busbw", "error", "time", "algbw", "busbw", "error"); - PRINT("# %10s %12s %8s %6s %7s %6s %6s %5s %7s %6s %6s %5s\n", "(B)", "(elements)", "", "", - "(us)", "(GB/s)", "(GB/s)", "", "(us)", "(GB/s)", "(GB/s)", ""); -} - -void print_line_header (size_t size, size_t count, const char *typeName, const char *opName, int root) { - PRINT("%12li %12li %8s %6i", size, count, typeName, root); -} - void ScatterGetCollByteCount(size_t *sendcount, size_t *recvcount, size_t *paramcount, size_t *sendInplaceOffset, size_t *recvInplaceOffset, size_t count, int nranks) { *sendcount = (count/nranks)*nranks; *recvcount = count/nranks; @@ -37,8 +25,8 @@ testResult_t ScatterInitData(struct threadArgs* args, ncclDataType_t type, ncclR int rank = ((args->proc*args->nThreads + args->thread)*args->nGpus + i); CUDACHECK(cudaMemset(args->recvbuffs[i], 0, args->expectedBytes)); void* data = in_place ? args->recvbuffs[i] : args->sendbuffs[i]; - if (rank == root) TESTCHECK(InitData(data, sendcount, type, rep, rank)); - TESTCHECK(InitData(args->expected[i], recvcount, type, rep+rank*recvcount, root)); + if (rank == root) TESTCHECK(InitData(data, sendcount, 0, type, ncclSum, rep, 1, 0)); + TESTCHECK(InitData(args->expected[i], recvcount, rank*recvcount, type, ncclSum, rep, 1, 0)); CUDACHECK(cudaDeviceSynchronize()); } return testSuccess; diff --git a/src/sendrecv.cu b/src/sendrecv.cu index 8bebc48..e73a92b 100644 --- a/src/sendrecv.cu +++ b/src/sendrecv.cu @@ -7,18 +7,6 @@ #include "cuda_runtime.h" #include "common.h" -void print_header() { - PRINT("# %10s %12s %8s out-of-place in-place \n", "", "", ""); - PRINT("# %10s %12s %8s %7s %6s %6s %5s %7s %6s %6s %5s\n", "size", "count", "type", - "time", "algbw", "busbw", "error", "time", "algbw", "busbw", "error"); - PRINT("# %10s %12s %8s %7s %6s %6s %5s %7s %6s %6s %5s\n", "(B)", "(elements)", "", - "(us)", "(GB/s)", "(GB/s)", "", "(us)", "(GB/s)", "(GB/s)", ""); -} - -void print_line_header (size_t size, size_t count, const char *typeName, const char *opName, int root) { - PRINT("%12li %12li %8s", size, count, typeName); -} - void SendRecvGetCollByteCount(size_t *sendcount, size_t *recvcount, size_t *paramcount, size_t *sendInplaceOffset, size_t *recvInplaceOffset, size_t count, int nranks) { *sendcount = count; *recvcount = count; @@ -38,9 +26,9 @@ testResult_t SendRecvInitData(struct threadArgs* args, ncclDataType_t type, nccl int rank = ((args->proc*args->nThreads + args->thread)*args->nGpus + i); CUDACHECK(cudaMemset(args->recvbuffs[i], 0, args->expectedBytes)); void* data = in_place ? args->recvbuffs[i] : args->sendbuffs[i]; - TESTCHECK(InitData(data, sendcount, type, rep, rank)); + TESTCHECK(InitData(data, sendcount, rank*sendcount, type, ncclSum, rep, 1, 0)); int peer = (rank-1+nranks)%nranks; - TESTCHECK(InitData(args->expected[i], recvcount, type, rep, peer)); + TESTCHECK(InitData(args->expected[i], recvcount, peer*recvcount, type, ncclSum, rep, 1, 0)); CUDACHECK(cudaDeviceSynchronize()); } // We don't support in-place sendrecv diff --git a/verifiable/Makefile b/verifiable/Makefile new file mode 100644 index 0000000..b141a2a --- /dev/null +++ b/verifiable/Makefile @@ -0,0 +1,24 @@ +include ../../makefiles/common.mk + +.PHONY: all clean + +BUILDDIR := $(abspath ../../build) +NCCLDIR := $(BUILDDIR) +NVCUFLAGS += -I$(NCCLDIR)/include/ -I../include +DST_DIR := $(BUILDDIR)/test/verifiable + +all: $(DST_DIR)/self_test $(DST_DIR)/verifiable.o + +clean: + rm -rf $(DST_DIR) + +TEST_VERIFIABLE_SRCDIR := . +TEST_VERIFIABLE_BUILDDIR := $(DST_DIR) +include verifiable.mk + +self_test: $(DST_DIR)/self_test + +$(DST_DIR)/self_test: verifiable.cu verifiable.h + @printf "Linking %s\n" $@ + @mkdir -p $(DST_DIR) + $(NVCC) -o $@ $(NVCUFLAGS) -DSELF_TEST=1 verifiable.cu $(NVLDFLAGS) diff --git a/verifiable/inexact_regress.cu b/verifiable/inexact_regress.cu new file mode 100644 index 0000000..d7bd545 --- /dev/null +++ b/verifiable/inexact_regress.cu @@ -0,0 +1,177 @@ +/* Generate parameters for our error bound model of floating point average + * (sum of scaled values) by sampling sums of random sequences for each + * floating point type. + * + * The model has parameters "coef" and "power", where for two floats a & b, + * they are close enough if and only if: + * abs(intBits(a) - intBits(b)) <= 1 + coef*pow(rank_n, power); + * + * Where intBits(x) is the reinterpretation of the float bitpattern as an integer. + * + * Compile with: + * nvcc -gencode=arch=compute_80,code=sm_80 + */ + +#include +#include +#include +#include +#include +#include + +using std::uint64_t; +using std::uint32_t; +using bfloat16 = __nv_bfloat16; + +template +struct float_traits; + +template<> +struct float_traits { + static constexpr int mantissa_bits = 23; + static constexpr int exponent_bits = 8; + using uint_t = uint32_t; + __device__ static float make(double x) { return (float)x; } + __device__ static float make(uint64_t x) { return (float)x; } + __device__ static double todouble(float x) { return x; } + __device__ static float add(float a, float b) { return a+b; } + __device__ static float mul(float a, float b) { return a*b; } +}; +template<> +struct float_traits { + static constexpr int mantissa_bits = 52; + static constexpr int exponent_bits = 11; + using uint_t = uint64_t; + __device__ static double make(double x) { return x; } + __device__ static double make(uint64_t x) { return (double)x; } + __device__ static double todouble(double x) { return x; } + __device__ static double add(double a, double b) { return a+b; } + __device__ static double mul(double a, double b) { return a*b; } +}; +template<> +struct float_traits { + static constexpr int mantissa_bits = 10; + static constexpr int exponent_bits = 5; + using uint_t = uint16_t; + __device__ static half make(double x) { return __double2half(x); } + __device__ static half make(uint64_t x) { return __int2half_rn(x); } + __device__ static double todouble(half x) { return __half2float(x); } + __device__ static half add(half a, half b) { return __hadd(a, b); } + __device__ static half mul(half a, half b) { return __hmul(a, b); } +}; +template<> +struct float_traits { + static constexpr int mantissa_bits = 7; + static constexpr int exponent_bits = 8; + using uint_t = uint16_t; + __device__ static bfloat16 make(double x) { return __double2bfloat16(x); } + __device__ static bfloat16 make(uint64_t x) { return __int2bfloat16_rn(x); } + __device__ static double todouble(bfloat16 x) { return __bfloat162float(x); } + __device__ static bfloat16 add(bfloat16 a, bfloat16 b) { return __hadd(a, b); } + __device__ static bfloat16 mul(bfloat16 a, bfloat16 b) { return __hmul(a, b); } +}; + +template +__device__ int compare(F a, F b) { + union { typename float_traits::uint_t ua; F fa; }; + union { typename float_traits::uint_t ub; F fb; }; + ua=0; ub=0; + fa=a; fb=b; + //std::printf("bits(%1.10f)=%x bits(%1.10f)=%x\n", fa, ua, fb, ub); + return ua < ub ? ub-ua : ua-ub; +} + +struct xoshiro256ss { + uint64_t s[4]; + __device__ xoshiro256ss(int seed) { + constexpr uint64_t src[4] = {0xbb99e851d1f545cc, 0xbfc4022389ca40cb, 0xe84aff5cb1914af5, 0x845999858284de77}; + for(int i=0; i < 4; i++) + s[i] = src[i] + (seed + i)*0xb45de8a52fdb65d3; + } + __device__ uint64_t operator()() { + auto rol64 = [](uint64_t x, int k) { + return (x << k) | (x >> (64 - k)); + }; + uint64_t const result = rol64(s[1] * 5, 7) * 9; + uint64_t const t = s[1] << 17; + s[2] ^= s[0]; + s[3] ^= s[1]; + s[1] ^= s[2]; + s[0] ^= s[3]; + s[2] ^= t; + s[3] = rol64(s[3], 45); + return result; + } +}; + +template +__global__ void kernel() { + using traits = float_traits; + constexpr int samps = 4<<10; + __shared__ F accf[samps]; + __shared__ double accd[samps]; + + xoshiro256ss rng(threadIdx.x); + float expo_avg = 1; + for(int pass=0; pass < 2; pass++) { + F scalar = traits::make(1.0/(3.14159 + .5*threadIdx.x)); + int err_max = 0; + float coef = 0; + double expo_sum = 0; + int expo_n = 0; + int max_ranks = std::is_same::value ? 16<<10 : 1<::value ? double(rng() & m) : 1.0; + F f = traits::make(d); + accf[i] = traits::add(accf[i], traits::mul(scalar, f)); + accd[i] += traits::todouble(f); + //if(threadIdx.x==0 && std::is_same::value) std::printf(" r=%d f=%f\n", r, traits::todouble(accf[i])); + int e = compare(accf[i], traits::mul(scalar, traits::make(accd[i]))); + err = err > e ? err : e; + } + err = __reduce_max_sync(-1u, err); + err_max = err_max > err ? err_max : err; + if (r >= 2) { + // err = 1 + coef*pow(r,expo) + float c = float(err-1)/powf(float(r), expo_avg); + coef = coef > c ? coef : c; + } + if (r >= 2) { + double expo = log2f(1+err_max)/log2f(r); + expo_sum += expo; + expo_n++; + //if(threadIdx.x==0 && std::is_same::value) std::printf(" r=%d err=%d errmax=%d expo=%f sum=%f n=%d\n", r, err, err_max, expo, expo_sum, expo_n); + } + } + } + if(pass==0) + expo_avg = expo_sum/expo_n; + else if(threadIdx.x == 0) + std::printf(" coef=%1.10f expo=%1.10f\n", coef, expo_avg); + } +} + +int main() { + std::printf("type=float:\n"); + kernel<<<1,32>>>(); + cudaDeviceSynchronize(); + + std::printf("\ntype=half:\n"); + kernel<<<1,32>>>(); + cudaDeviceSynchronize(); + + std::printf("\ntype=bfloat16:\n"); + kernel<<<1,32>>>(); + cudaDeviceSynchronize(); + return 0; +} diff --git a/verifiable/verifiable.cu b/verifiable/verifiable.cu new file mode 100644 index 0000000..5f617ee --- /dev/null +++ b/verifiable/verifiable.cu @@ -0,0 +1,1227 @@ +#pragma nv_diag_suppress declared_but_not_referenced + +#include "verifiable.h" +#include + +#include +#include +#if CUDART_VERSION >= 11000 +#include +#endif + +#if NCCL_VERSION_CODE >= NCCL_VERSION(2,10,0) && defined(__CUDA_BF16_TYPES_EXIST__) + #define HAVE_ncclBfloat16 1 +#else + #define HAVE_ncclBfloat16 0 +#endif + +#if NCCL_VERSION_CODE >= NCCL_VERSION(2,10,0) + #define HAVE_ncclAvg 1 +#else + #define HAVE_ncclAvg 0 +#endif + +#if NCCL_VERSION_CODE >= NCCL_VERSION(2,11,0) + #define HAVE_ncclPreMulSum 1 +#else + #define HAVE_ncclPreMulSum 0 +#endif + +#include +#include +#include +#include +#include +#include + +using std::size_t; +using std::int8_t; +using std::int16_t; +using std::int32_t; +using std::int64_t; +using std::uint8_t; +using std::uint16_t; +using std::uint32_t; +using std::uint64_t; + +//////////////////////////////////////////////////////////////////////////////// + +namespace { +template +__device__ unsigned long long bitsOf(T x) { + union { unsigned long long ull; T val; } u; + u.ull = 0; + u.val = x; + return u.ull; +} + +__host__ __device__ uint64_t mixBits(uint64_t x) { + union { uint32_t u32[2]; uint64_t u64; }; + u64 = x; + u32[1] += 1; + u32[0] ^= u32[1]; + u64 *= 0x9e3779b97f4a7c13u; + u32[0] ^= u32[1]<<16 ^ u32[1]>>16; + return u64; +} + +__host__ __device__ uint64_t hashOf(uint64_t a, uint64_t b=0) { + a += uint64_t(1)<<32; + a += b; + a ^= a>>32; + a *= 0x9e3779b97f4a7c13u; + a += b>>16 ^ b<<48; + a ^= a>>32; + a *= 0xc4ceb9fe1a85ec53u; + return a; +} +} + +//////////////////////////////////////////////////////////////////////////////// + +namespace { +template +struct IsIntegral: std::is_integral {}; +template<> +struct IsIntegral: std::false_type {}; +#ifdef __CUDA_BF16_TYPES_EXIST__ +template<> +struct IsIntegral<__nv_bfloat16>: std::false_type {}; +#endif +} + +//////////////////////////////////////////////////////////////////////////////// + +// Hide a value from arithmetic optimizations. Hopefully compiler cannot detect +// that this is equivalent to the identity function. +template +__host__ __device__ T inhibit(T x) { + union { uint64_t u64; T val; }; + u64 = 0; + val = x; + u64 *= 0x0000000100000001u; + u64 *= 0xffffffff00000001u; + return val; +} + +//////////////////////////////////////////////////////////////////////////////// + +namespace { + template + __host__ __device__ Y castTo(X x) { + return Y(x); + } + template + __host__ __device__ Y castTo(float x) { + return Y(x); + } + template<> + __host__ __device__ half castTo(float x) { + return __float2half(x); + } + #ifdef __CUDA_BF16_TYPES_EXIST__ + template<> + __host__ __device__ __nv_bfloat16 castTo<__nv_bfloat16>(float x) { + return __float2bfloat16(x); + } + #endif +} + +//////////////////////////////////////////////////////////////////////////////// +// The reduction functions + +namespace { +struct ReduceNil { + template + __host__ __device__ T preOp(T x, int /*rank_me*/) const { return x; } + template + __host__ __device__ T operator()(T a, T /*b*/) const { return a; } + template + __host__ __device__ T postOp(T x) const { return x; } +}; +struct ReduceSum { + template + __host__ __device__ T preOp(T x, int /*rank_me*/) const { return x; } + template + __host__ __device__ T operator()(T a, T b) const { return a + b; } + __host__ __device__ half operator()(half a, half b) const { + #if __CUDA_ARCH__ >= 530 + return __hadd(a, b); + #else + return __float2half(__half2float(a) + __half2float(b)); + #endif + } + #ifdef __CUDA_BF16_TYPES_EXIST__ + __host__ __device__ __nv_bfloat16 operator()(__nv_bfloat16 a, __nv_bfloat16 b) const { + #if __CUDA_ARCH__ >= 800 + return __hadd(a, b); + #else + return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b)); + #endif + } + #endif + template + __host__ __device__ T postOp(T x) const { return x; } +}; +struct ReduceProd { + template + __host__ __device__ T preOp(T x, int /*rank_me*/) const { return x; } + template + __host__ __device__ T operator()(T a, T b) const { return a * b; } + __host__ __device__ half operator()(half a, half b) const { + #if __CUDA_ARCH__ >= 530 + return __hmul(a, b); + #else + return __float2half(__half2float(a) * __half2float(b)); + #endif + } + #ifdef __CUDA_BF16_TYPES_EXIST__ + __host__ __device__ __nv_bfloat16 operator()(__nv_bfloat16 a, __nv_bfloat16 b) const { + #if __CUDA_ARCH__ >= 800 + return __hmul(a, b); + #else + return __float2bfloat16(__bfloat162float(a) * __bfloat162float(b)); + #endif + } + #endif + template + __host__ __device__ T postOp(T x) const { return x; } +}; +struct ReduceMin { + template + __host__ __device__ T preOp(T x, int /*rank_me*/) const { return x; } + template + __host__ __device__ T operator()(T a, T b) const { return a < b ? a : b; } + __host__ __device__ half operator()(half a, half b) const { + #if __CUDA_ARCH__ >= 800 + return __hmin(a, b); + #elif __CUDA_ARCH__ >= 530 + return __hlt(a, b) ? a : b; + #else + return __half2float(a) < __half2float(b) ? a : b; + #endif + } + #ifdef __CUDA_BF16_TYPES_EXIST__ + __host__ __device__ __nv_bfloat16 operator()(__nv_bfloat16 a, __nv_bfloat16 b) const { + #if __CUDA_ARCH__ >= 800 + return __hmin(a, b); + //#elif __CUDA_ARCH__ >= 530 + // return __hlt(a, b) ? a : b; + #else + return __bfloat162float(a) < __bfloat162float(b) ? a : b; + #endif + } + #endif + template + __host__ __device__ T postOp(T x) const { return x; } +}; +struct ReduceMax { + template + __host__ __device__ T preOp(T x, int /*rank_me*/) const { return x; } + templateT())> + __host__ __device__ T operator()(T a, T b) const { return a > b ? a : b; } + __host__ __device__ half operator()(half a, half b) const { + #if __CUDA_ARCH__ >= 800 + return __hmax(a, b); + #elif __CUDA_ARCH__ >= 530 + return __hgt(a, b) ? a : b; + #else + return __half2float(a) > __half2float(b) ? a : b; + #endif + } + #ifdef __CUDA_BF16_TYPES_EXIST__ + __host__ __device__ __nv_bfloat16 operator()(__nv_bfloat16 a, __nv_bfloat16 b) const { + #if __CUDA_ARCH__ >= 800 + return __hmax(a, b); + //#elif __CUDA_ARCH__ >= 530 + // return __hgt(a, b) ? a : b; + #else + return __bfloat162float(a) > __bfloat162float(b) ? a : b; + #endif + } + #endif + template + __host__ __device__ T postOp(T x) const { return x; } +}; +struct ReducePreMulSum { + template + __host__ __device__ T preOp(T x, int rank_me) const { + return ReduceProd()(x, ncclVerifiablePremulScalar(rank_me)); + } + template + __host__ __device__ T operator()(T a, T b) const { return ReduceSum()(a, b); } + template + __host__ __device__ T postOp(T x) const { return x; } +}; + +template::value> +struct ReduceAvg_Base; + +template +struct ReduceAvg_Base { + int rank_n; + __host__ __device__ T preOp(T x, int /*rank_me*/) const { return x; } + __host__ __device__ T operator()(T a, T b) const { return ReduceSum()(a, b); } + __host__ __device__ T postOp(T x) const { return x/rank_n; } +}; + +template +struct ReduceAvg_Base { + int rank_n; + __host__ __device__ T preOp(T x, int /*rank_me*/) const { + using T1 = typename std::conditional<(sizeof(T)::type; + return ReduceProd()(inhibit(castTo(T1(1)/T1(rank_n))), inhibit(x)); + } + __host__ __device__ T operator()(T a, T b) const { return ReduceSum()(a, b); } + __host__ __device__ T postOp(T x) const { return x; } +}; + +struct ReduceAvg { + int rank_n; + template + __host__ __device__ T preOp(T x, int rank_me) const { + return ReduceAvg_Base{rank_n}.preOp(x, rank_me); + } + template + __host__ __device__ T operator()(T a, T b) const { + return ReduceAvg_Base{rank_n}(a, b); + } + template + __host__ __device__ T postOp(T x) const { + return ReduceAvg_Base{rank_n}.postOp(x); + } +}; +} + +//////////////////////////////////////////////////////////////////////////////// + +namespace { +template +struct FloatLayout; +template<> +struct FloatLayout { + static constexpr int exponent_bits = 8, mantissa_bits = 23; + static constexpr int exponent_bias = (1<<(exponent_bits-1))-1; +}; +template<> +struct FloatLayout { + static constexpr int exponent_bits = 11, mantissa_bits = 52; + static constexpr int exponent_bias = (1<<(exponent_bits-1))-1; +}; +template<> +struct FloatLayout { + static constexpr int exponent_bits = 5, mantissa_bits = 10; + static constexpr int exponent_bias = (1<<(exponent_bits-1))-1; +}; +#ifdef __CUDA_BF16_TYPES_EXIST__ +template<> +struct FloatLayout<__nv_bfloat16> { + static constexpr int exponent_bits = 8, mantissa_bits = 7; + static constexpr int exponent_bias = (1<<(exponent_bits-1))-1; +}; +#endif + +template +__host__ __device__ T makeFloat(int sign, int exp, uint64_t mant) { + union { T ans; uint64_t bits; }; + bits = sign; + bits <<= FloatLayout::exponent_bits; + bits |= exp; + bits <<= FloatLayout::mantissa_bits; + bits |= mant; + return ans; +} +} + +//////////////////////////////////////////////////////////////////////////////// + +namespace { +// High bits of multiplcation are useful for generating bounded random values +// from unbounded random values. For instance, given X a totally random 32-bit +// integer, `umul32hi(X,n)` will be totally random within [0,n). +__host__ __device__ uint64_t umul32hi(uint32_t a, uint32_t b) { +#ifdef __CUDA_ARCH__ + return __umulhi(a, b); +#else + return uint64_t(a)*b >> 32; +#endif +} +__host__ __device__ uint64_t umul64hi(uint64_t a, uint64_t b) { +#ifdef __CUDA_ARCH__ + return __umul64hi(a, b); +#else + return uint64_t(__uint128_t(a)*__uint128_t(b) >> 64); +#endif +} + +__host__ __device__ int clz32(int x) { +#ifdef __CUDA_ARCH__ + return __clz(x); +#else + return x==0 ? 32 : __builtin_clz(x); +#endif +} +__host__ __device__ int clz64(long long x) { +#ifdef __CUDA_ARCH__ + return __clzll(x); +#else + return x==0 ? 64 : __builtin_clzll(x); +#endif +} +} + +//////////////////////////////////////////////////////////////////////////////// + +namespace { +// Returns a wildly permuted rank index. Useful when we know we want exactly N +// random ranks to exhibit some behavior, we can just test if: +// `shuffleRank(rank_n, rank_me, rng) < N`. Note that rank_n > 0 must be true +// for well defined results. This mixes the bits of rng. +__host__ __device__ int shuffleRank(int rank_n, int rank_me, uint64_t &rng) { + uint32_t a = uint32_t(rng); + uint32_t b = uint32_t(rng>>32); + rng = mixBits(rng); + + uint32_t r = rank_me; + // round down rank_n to largest pow2, then subtract 1 + uint32_t n2 = (~uint32_t(0)>>1) >> clz32(rank_n); + + // These are 1:1 functions modulo 2^n: + // f(x) = x*a + b : for odd a, any b + // f(x) = (x*x + x)/2 + // So we apply both to the bottom n2+1 ranks, then rotate the top + // (rank_n-n2-1) to the bottom and apply both again. + + if(r <= n2) { + // shuffle bottom n2+1 ranks + r = (r*(a|1) + b) & n2; + r = (r*r + r)/2 & n2; + // rotate top to bottom + r += rank_n - (n2+1); + } + else + r -= n2+1; // rotate top to bottom + + if(r <= n2) { + // shuffle bottom n2+1 again + r = (r*(b|1) + a) & n2; + r = (r*r + r)/2 & n2; + } + return r; +} +} + +namespace { +// Generate wild integers x and y such that if every rank submits its x into a +// summation the result will be y with y <= y_max. Ranks should be shuffled +// before calling. +template +__host__ __device__ void genSumXY( + int rank_n, int rank_me, uint64_t &rng, Uint y_max, Uint &x, Uint &y, + bool avoid_y=false // if true then returned y will not equal given y + ) { + static_assert(std::is_unsigned::value, "Type must be unsigned integral."); + + { // Pick y as a random value in [y_max/2, y_max] + Uint d, y_min = (y_max+1)/2; + if(8*sizeof(Uint) > 32) + d = umul64hi(rng, y_max/2 + (avoid_y ? 0 : 1)); + else + d = umul32hi(uint32_t(rng), y_max/2 + (avoid_y ? 0 : 1)); + Uint y1 = (avoid_y ? y+1 : y_min) + d; + y = y1 - (avoid_y && (y1 < y_min || y_max < y1) ? y_max/2 : 0); + } + rng = mixBits(rng); + + unsigned r = unsigned(rank_me); + unsigned rn = unsigned(rank_n); + // Partition our rn ranks into pn distinct subsets each of size rn/pn. If each + // rank submits 1+p (where p is 0-based partition index) then the sum be: + // (rn/pn) * pn*(pn+1)/2 + // So set this equal to our desired sum y and solve for pn. + // (rn/pn) * pn*(pn+1)/2 = y + // rn*(pn+1)/2 = y + // pn = 2*(y/rn)-1 + Uint pn = rn == 1 ? 1 : 2*(y/rn) - 1; + // In the case where rn is huge (compared to y) use only one partition meaning + // that all rn ranks will submit 1 (since p=0). + pn = pn == 0 ? 1 : pn; + // Can't have more partitions than ranks. + pn = rn < pn ? rn : pn; + // Compute sum of contribution from pn partitions where each submits p+1. + Uint p_sum; + if(y_max <= ~uint32_t(0)>>1) // compile time known + p_sum = Uint(uint32_t(pn)*uint32_t(pn+1)/2); + else + p_sum = Uint(uint64_t(pn)*uint64_t(pn+1)/2); + // Let s be the number of ranks per partition. This is either rn/pn as we + // intended, or y/p_sum if that's smaller to prevent overshooting our target y. + uint32_t s = y/p_sum < rn/pn ? y/p_sum : rn/pn; + x = r/s < pn ? 1 + r/s : 0; // First s*pn ranks contribute partition index +1. + x += r == rn-1 ? y - s*p_sum : 0; // Last rank contributes discrepancy. +} +} + +namespace { +template +__host__ __device__ T genInOutFloatSum( + bool input_not_output, int rank_n, int rank_me, uint64_t seed, intptr_t index, + bool same_sign + ) { + constexpr int exp_lo = 1 + FloatLayout::mantissa_bits; + constexpr int exp_hi = (1<::exponent_bits)-1; + using uintmant_t = typename std::conditional<(8*sizeof(T) > 32), uint64_t, uint32_t>::type; + constexpr uintmant_t mant_mask = (uintmant_t(1) << FloatLayout::mantissa_bits)-1; + constexpr uintmant_t max_mant = 2*mant_mask + 1; // add implicit leading 1 + uint64_t rng = hashOf(seed, index); + + int y_sign = rng & 1; + int x_sign = y_sign; + int xy_exp = exp_lo + umul32hi(uint32_t(rng>>32), exp_hi-exp_lo); + rng = mixBits(rng); + rank_me = shuffleRank(rank_n, rank_me, rng); + + // If we're using mixed signs then partition into evens and odds. + int subrank_n = same_sign ? rank_n : (rank_n+1)/2; + int subrank_me = same_sign ? rank_me : rank_me/2; + uintmant_t x0_mant, y0_mant; + genSumXY(subrank_n, subrank_me, rng, max_mant, x0_mant, y0_mant); + + if (!same_sign && (rank_n+0)/2 != 0) { + uintmant_t x1_mant, y1_mant = y0_mant; + // Avoid generating y1_mant == y0_mant so we don't have to worry about + // signed zero as the result. + genSumXY((rank_n+0)/2, rank_me/2, rng, max_mant, x1_mant, y1_mant, /*avoid_y=*/true); + y_sign ^= y0_mant < y1_mant ? 1 : 0; + y0_mant = (y0_mant < y1_mant ? -1 : 1)*(y0_mant - y1_mant); + x_sign ^= rank_me%2; + x0_mant = rank_me%2 == 0 ? x0_mant : x1_mant; + } + + uintmant_t ans_mant = input_not_output ? x0_mant : y0_mant; + if(ans_mant == 0) + return T(0.0f); + else { + int shift = clz64(ans_mant) - (64-FloatLayout::mantissa_bits-1); + int ans_sign = input_not_output ? x_sign : y_sign; + int ans_exp = xy_exp - shift; + ans_mant <<= shift; + return makeFloat(ans_sign, ans_exp, ans_mant & mant_mask); + } +} +} + +namespace { +template +__host__ __device__ T genInOutFloatPreMulSum( + bool input_not_output, int rank_n, int rank_me, uint64_t seed, intptr_t index + ) { + constexpr int exp_lo = 1 + FloatLayout::mantissa_bits; + constexpr int exp_hi = (1<::exponent_bits)-1; + using uintmant_t = typename std::conditional<(8*sizeof(T) > 32), uint64_t, uint32_t>::type; + constexpr uintmant_t mant_mask = (uintmant_t(1) << FloatLayout::mantissa_bits)-1; + constexpr uintmant_t max_mant = 2*mant_mask + 1; // add implicit leading 1 + uint64_t rng = hashOf(seed, index); + + int y_sign = rng & 1; + int y_exp = exp_lo + umul32hi(uint32_t(rng>>32), exp_hi-exp_lo); + rng = mixBits(rng); + int subrank_me0 = shuffleRank((rank_n+1)/2, rank_me/2, rng); + int subrank_me1 = shuffleRank((rank_n+0)/2, rank_me/2, rng); + + // when ncclVerifiablePremulScalar() = 1.0 (rank_me%2 == 0) + uintmant_t x0_mant, y0_mant; + genSumXY((rank_n+1)/2, subrank_me0, rng, max_mant>>1, x0_mant, y0_mant); + + // when ncclVerifiablePremulScalar() = 2.0 (rank_me%2 == 1) + uintmant_t x1_mant=0, y1_mant=0; + if((rank_n+0)/2 != 0) + genSumXY((rank_n+0)/2, subrank_me1, rng, max_mant>>2, x1_mant, y1_mant); + + uintmant_t x_mant = rank_me%2 == 0 ? x0_mant : x1_mant; + uintmant_t y_mant = y0_mant + 2*y1_mant; + uintmant_t ans_mant = input_not_output ? x_mant : y_mant; + + if(ans_mant == 0) + return T(0.0f); + else { + int shift = clz64(ans_mant) - (64-FloatLayout::mantissa_bits-1); + int ans_sign = y_sign; + int ans_exp = y_exp - shift; + ans_mant <<= shift; + return makeFloat(ans_sign, ans_exp, ans_mant & mant_mask); + } +} +} + +namespace { +template +__host__ __device__ T genInOutFloatProd( + bool input_not_output, int rank_n, int rank_me, uint64_t seed, intptr_t index + ) { + // Three kinds of contributions (values for x): + // 1) x = random value: only one rank does this + // 2) x = 2^n: random positive n + // 3) x = 1 + // Since only one rank submits a random value, the result of the product + // will have the same mantissa as that value but with an exponent incorporating + // the sum of the exponents from case (2) + + uint64_t rng = hashOf(seed, index); + rank_me = shuffleRank(rank_n, rank_me, rng); + int y_sign = (rank_n/2)%2; + int x_sign = rank_me%2; + + constexpr unsigned max_exp = -1 + (1<<(FloatLayout::exponent_bits-1)); + unsigned x_exp=0, y_exp=0; + genSumXY(rank_n, rank_me, rng, max_exp, x_exp, y_exp); + x_exp += FloatLayout::exponent_bias; + y_exp += FloatLayout::exponent_bias; + + constexpr uint64_t mant_mask = (uint64_t(1)<::mantissa_bits)-1; + uint64_t y_mant = rng & mant_mask; + if (y_mant == 0) y_mant = 1; + + return makeFloat( + input_not_output ? x_sign : y_sign, + input_not_output ? x_exp : y_exp, + !input_not_output || rank_me==0 ? y_mant : 0 + ); +} +} + +//////////////////////////////////////////////////////////////////////////////// +// What follows is lots of overloads for genInput/genOutput to generate data + +namespace { +// General case for integral data for all ops but ReduceNil/premulsum +template::value + >::type> +__host__ __device__ void genInput( + T &ans, ReduceFn, int rank_n, int rank_me, uint64_t seed, intptr_t index, + std::true_type /*integral*/ + ) { + (void)rank_n; // silence unused warnings + union { uint64_t bits; T tmp; }; + bits = uint64_t(-1)>>(64 - 8*sizeof(T)); + bits &= hashOf(index ^ index<<16 ^ rank_me, seed); + // make sure we never return 0 in products + ans = std::is_same::value && bits == 0 ? T(1) : tmp; +} +} + +//////////////////////////////////////////////////////////////////////////////// +// Dumb/generic case for genOutput just reduces results of genInput + +namespace { +template +__host__ __device__ void genOutput( + T &ans, ReduceFn op, int rank_n, uint64_t seed, intptr_t index, + std::integral_constant + ) { + T acc = genInput(op, rank_n, 0, seed, index); + acc = op.preOp(acc, 0); + for(int r=1; r < rank_n; r++) + acc = op(acc, op.preOp(genInput(op, rank_n, r, seed, index), r)); + ans = op.postOp(acc); +} +} + +//////////////////////////////////////////////////////////////////////////////// +// Nil reduction (byte copy functions). Optimized to assume rank_n=1 + +namespace { +template +__host__ __device__ void genInput( + T &ans, ReduceNil, int rank_n, int rank_me, uint64_t seed, intptr_t index, + std::integral_constant + ) { + (void)rank_n, (void)rank_me; // silence unused warnings + union { uint64_t bits; T tmp; }; + bits = mixBits(seed ^ index); + bits >>= 64 - 8*sizeof(T); + bits &= uint64_t(-1)>>(64 - 8*sizeof(T)); + ans = tmp; +} + +template +__host__ __device__ void genOutput( + T &ans, ReduceNil op, int rank_n, uint64_t seed, intptr_t index, + std::integral_constant + ) { + ans = genInput(op, rank_n, 0, seed, index); +} +} + +//////////////////////////////////////////////////////////////////////////////// +// Sum of float + +namespace { +template +__host__ __device__ void genInput( + T &ans, ReduceSum, int rank_n, int rank_me, uint64_t seed, intptr_t index, + std::false_type /*integral*/ + ) { + ans = genInOutFloatSum(/*input_not_output=*/true, rank_n, rank_me, seed, index, /*same_sign=*/false); +} + +template +__host__ __device__ void genOutput( + T &ans, ReduceSum, int rank_n, uint64_t seed, intptr_t index, + std::false_type /*integral*/ + ) { + ans = genInOutFloatSum(/*input_not_output=*/false, rank_n, 0, seed, index, /*same_sign=*/false); +} +} + +//////////////////////////////////////////////////////////////////////////////// +// Product of float + +namespace { +template +__host__ __device__ void genInput( + T &ans, ReduceProd, int rank_n, int rank_me, uint64_t seed, intptr_t index, + std::false_type /*integral*/ + ) { + ans = genInOutFloatProd(/*input_not_output=*/true, rank_n, rank_me, seed, index); +} + +template +__host__ __device__ void genOutput( + T &ans, ReduceProd, int rank_n, uint64_t seed, intptr_t index, + std::false_type /*integral*/ + ) { + ans = genInOutFloatProd(/*input_not_output=*/false, rank_n, 0, seed, index); +} +} + +//////////////////////////////////////////////////////////////////////////////// +// PreMulSum of int/float + +namespace { +template +__host__ __device__ void genInput( + T &ans, ReducePreMulSum, int rank_n, int rank_me, uint64_t seed, intptr_t index, + std::true_type integral + ) { + genInput(ans, ReduceSum(), rank_n, rank_me, seed, index, integral); +} + +// No genOutput overload specific to premulsum(int), just use generic case. + +template +__host__ __device__ void genInput( + T &ans, ReducePreMulSum, int rank_n, int rank_me, uint64_t seed, intptr_t index, + std::false_type /*integral*/ + ) { + ans = genInOutFloatPreMulSum(/*input_not_output=*/true, rank_n, rank_me, seed, index); +} + +template +__host__ __device__ void genOutput( + T &ans, ReducePreMulSum, int rank_n, uint64_t seed, intptr_t index, + std::false_type /*integral*/ + ) { + ans = genInOutFloatPreMulSum(/*input_not_output=*/false, rank_n, 0, seed, index); +} +} + +///////////////////////////////////////////////////////////////////////////////// +// Average of float + +namespace { +template +__host__ __device__ void genInput( + T &ans, ReduceAvg, int rank_n, int rank_me, uint64_t seed, intptr_t index, + std::false_type /*integral*/ + ) { + ans = genInOutFloatSum(/*input_not_output=*/true, rank_n, rank_me, seed, index, /*same_sign=*/true); +} + +template +__host__ __device__ void genOutput( + T &ans, ReduceAvg, int rank_n, uint64_t seed, intptr_t index, + std::false_type /*integral*/ + ) { + ans = genInOutFloatSum(/*input_not_output=*/false, rank_n, 0, seed, index, /*same_sign=*/true); + using T1 = typename std::conditional<(sizeof(T)::type; + ans = ReduceProd()(ans, T1(1)/T1(rank_n)); +} +} + +///////////////////////////////////////////////////////////////////////////////// +// min/max of float + +namespace { +template +__host__ __device__ void genInput( + T &ans, ReduceMin, int rank_n, int rank_me, uint64_t seed, intptr_t index, + std::false_type integral + ) { + genInput(ans, ReduceMax(), rank_n, rank_me, seed, index, integral); +} +template +__host__ __device__ void genInput( + T &ans, ReduceMax, int rank_n, int rank_me, uint64_t seed, intptr_t index, + std::false_type /*integral*/ + ) { + (void)rank_n; // silence unused warnings + constexpr uint64_t mant_mask = (uint64_t(1) << FloatLayout::mantissa_bits)-1; + uint64_t rng = hashOf(index ^ index<<16 ^ rank_me, seed); + int sign = rng & 1; + rng ^= rng>>1; + int exp = rng & ((1<<(FloatLayout::exponent_bits-1))-1); + exp += 1<<(FloatLayout::exponent_bits-2); + rng ^= rng >> FloatLayout::exponent_bits; + uint64_t mant = rng & mant_mask; + ans = makeFloat(sign, exp, mant); +} + +// No genOutput overload specific to floating point min/max, just use generic case. +} + +/////////////////////////////////////////////////////////////////////////////// +// Entry API for genInput/genOutput + +namespace { +template +__host__ __device__ T genInput( + ReduceFn op, int rank_n, int rank_me, uint64_t seed, intptr_t index + ) { + T ans; + genInput(ans, op, rank_n, rank_me, seed, index, + std::integral_constant::value>()); + return ans; +} + +template +__host__ __device__ T genOutput( + ReduceFn op, int rank_n, uint64_t seed, intptr_t index + ) { + T ans; + genOutput(ans, op, rank_n, seed, index, + std::integral_constant::value>()); + return ans; +} +} + +//////////////////////////////////////////////////////////////////////////////// + +#if !SELF_TEST +namespace { +template +__global__ void prepareInput2( + T *elts, intptr_t elt_n, ReduceFn op, int rank_n, int rank_me, + uint64_t seed, intptr_t elt_ix0 + ) { + intptr_t i0 = blockIdx.x*(elt_n/gridDim.x); + i0 += blockIdx.x < elt_n%gridDim.x ? blockIdx.x : elt_n%gridDim.x; + intptr_t i1 = (blockIdx.x+1)*(elt_n/gridDim.x); + i1 += blockIdx.x+1 < elt_n%gridDim.x ? blockIdx.x+1 : elt_n%gridDim.x; + intptr_t i = i0 + threadIdx.x; + while(i < i1) { + elts[i] = genInput(op, rank_n, rank_me, seed, elt_ix0+i); + #if 0 + T output = genOutput(op, rank_n, seed, elt_ix0+i); + printf("prepareInput2 T=%d seed=0x%llx r=%d ix=%lld x=%g output=%g elts=%p\n", + std::is_same::value, (long long)seed, int(rank_me), (long long)i, (float)elts[i], (float)output, elts); + #endif + i += blockDim.x; + } +} + +template +void prepareInput1( + void *elts, intptr_t elt_n, int elt_ty, ReduceOp op, int rank_n, int rank_me, + uint64_t seed, intptr_t elt_ix0, cudaStream_t stream + ) { + int block_n = std::min(32, (elt_n + 4*512-1)/(4*512)); + #define CASE_TY(T) prepareInput2<<>>((T*)elts, elt_n, op, rank_n, rank_me, seed, elt_ix0); break; + switch(elt_ty) { + case ncclInt8: CASE_TY(int8_t) + case ncclUint8: CASE_TY(uint8_t) + case ncclInt32: CASE_TY(int32_t) + case ncclUint32: CASE_TY(uint32_t) + case ncclInt64: CASE_TY(int64_t) + case ncclUint64: CASE_TY(uint64_t) + case ncclFloat16: CASE_TY(half) + #if HAVE_ncclBfloat16 + case ncclBfloat16: CASE_TY(__nv_bfloat16) + #endif + case ncclFloat32: CASE_TY(float) + case ncclFloat64: CASE_TY(double) + default: assert(0); + } + #undef CASE_TY +} +} + +void ncclVerifiablePrepareInput( + void *elts, intptr_t elt_n, int elt_ty, int red_op, int rank_n, int rank_me, + uint64_t seed, intptr_t elt_ix0, cudaStream_t stream + ) { + #define CASE_OP(op) \ + if(rank_n == 1) \ + prepareInput1(elts, elt_n, elt_ty, ReduceNil(), rank_n, rank_me, seed, elt_ix0, stream); \ + else \ + prepareInput1(elts, elt_n, elt_ty, op, rank_n, rank_me, seed, elt_ix0, stream); \ + break; + switch(red_op) { + case ncclSum: CASE_OP(ReduceSum()) + case ncclMin: CASE_OP(ReduceMin()) + case ncclMax: CASE_OP(ReduceMax()) + case ncclProd: CASE_OP(ReduceProd()) + #if HAVE_ncclAvg + case ncclAvg: CASE_OP(ReduceAvg{rank_n}) + #endif + #if HAVE_ncclPreMulSum + default: CASE_OP(ReducePreMulSum()) + #endif + } + #undef CASE_OP +} +#endif + +//////////////////////////////////////////////////////////////////////////////// + +#if !SELF_TEST +namespace { +template +__global__ void prepareExpected2( + T *elts, intptr_t elt_n, ReduceFn op, int rank_n, + uint64_t seed, intptr_t elt_ix0 + ) { + intptr_t i0 = blockIdx.x*(elt_n/gridDim.x); + i0 += blockIdx.x < elt_n%gridDim.x ? blockIdx.x : elt_n%gridDim.x; + intptr_t i1 = (blockIdx.x+1)*(elt_n/gridDim.x); + i1 += blockIdx.x+1 < elt_n%gridDim.x ? blockIdx.x+1 : elt_n%gridDim.x; + intptr_t i = i0 + threadIdx.x; + while(i < i1) { + elts[i] = genOutput(op, rank_n, seed, elt_ix0+i); + #if 0 + printf("prepareExpected2 seed=0x%llx ix=%lld x=%g elts=%p\n", + (long long)seed, (long long)(elt_ix0+i), (float)elts[i], elts); + #endif + i += blockDim.x; + } +} + +template +void prepareExpected1( + void *elts, intptr_t elt_n, int elt_ty, ReduceOp op, int rank_n, + uint64_t seed, intptr_t elt_ix0, cudaStream_t stream + ) { + int block_n = std::min(32, (elt_n + 4*512-1)/(4*512)); + #define CASE_TY(T) prepareExpected2<<>>((T*)elts, elt_n, op, rank_n, seed, elt_ix0); break; + switch(elt_ty) { + case ncclInt8: CASE_TY(int8_t) + case ncclUint8: CASE_TY(uint8_t) + case ncclInt32: CASE_TY(int32_t) + case ncclUint32: CASE_TY(uint32_t) + case ncclInt64: CASE_TY(int64_t) + case ncclUint64: CASE_TY(uint64_t) + case ncclFloat16: CASE_TY(half) + #if HAVE_ncclBfloat16 + case ncclBfloat16: CASE_TY(__nv_bfloat16) + #endif + case ncclFloat32: CASE_TY(float) + case ncclFloat64: CASE_TY(double) + default: assert(0); + } + #undef CASE_TY +} +} + +void ncclVerifiablePrepareExpected( + void *elts, intptr_t elt_n, int elt_ty, int red_op, int rank_n, + uint64_t seed, intptr_t elt_ix0, cudaStream_t stream + ) { + #define CASE_OP(op) \ + if(rank_n == 1) \ + prepareExpected1(elts, elt_n, elt_ty, ReduceNil(), rank_n, seed, elt_ix0, stream); \ + else \ + prepareExpected1(elts, elt_n, elt_ty, op, rank_n, seed, elt_ix0, stream); \ + break; + switch(red_op) { + case ncclSum: CASE_OP(ReduceSum()) + case ncclMin: CASE_OP(ReduceMin()) + case ncclMax: CASE_OP(ReduceMax()) + case ncclProd: CASE_OP(ReduceProd()) + #if HAVE_ncclAvg + case ncclAvg: CASE_OP(ReduceAvg{rank_n}) + #endif + #if HAVE_ncclPreMulSum + default: CASE_OP(ReducePreMulSum()) + #endif + } + #undef CASE_OP +} +#endif + +//////////////////////////////////////////////////////////////////////////////// + +namespace { +/* How we compare floating point values when exactness is impossible is interesting. + * First, we take note that simply reinterpreting integer bits as floating point + * gives us a monotonic mapping which exponentially spaces out floats. Thus + * consecutive integers encode consecutive floats. In general, using integer + * subraction on the bitpatterns of two floats gives us an integer which is the + * logarithm of their relative difference. But, if the floats always have similar + * exponents, than the integer difference is actually proportional to the + * relative error (this is because we are counting hops in the mantissa bits only, + * not the exponent bits). So a cheap way to compare if two floats are relatively + * close is: abs(intBits(a), intBits(b)) < tolerance. The following formula + * calculates such a tolerance for a summation of n floats. This formula + * was derived by inspecting the maximum observed integer difference over many + * random runs of summation. The parameter values were computed by the + * companion program "inexact_regress.cu". + */ +__host__ __device__ unsigned calcSumFloatTolerance(int rank_n, int elt_ty) { + float power, coef; + switch(elt_ty) { + case ncclFloat32: + case ncclFloat64: + power = .51f; + coef = 1.25f; + break; + case ncclFloat16: + power = .91f; + coef = .75f; + break; + #if HAVE_ncclBfloat16 + case ncclBfloat16: + power = .91f; + coef = .66f; + break; + #endif + } + #if __CUDA_ARCH__ + return 1 + unsigned(coef*powf(float(rank_n), power)); + #else + return 1 + unsigned(coef*std::pow(float(rank_n), power)); + #endif +} + +template +__host__ __device__ uint64_t calcDelta(T a, T b) { + union { T t; uint8_t i1; uint16_t i2; uint32_t i4; uint64_t i8; } x, y; + x.t = a; + y.t = b; + switch(sizeof(T)) { + case 1: return x.i1 < y.i1 ? y.i1 - x.i1 : x.i1 - y.i1; + case 2: return x.i2 < y.i2 ? y.i2 - x.i2 : x.i2 - y.i2; + case 4: return x.i4 < y.i4 ? y.i4 - x.i4 : x.i4 - y.i4; + default: return x.i8 < y.i8 ? y.i8 - x.i8 : x.i8 - y.i8; + } +} +} + +//////////////////////////////////////////////////////////////////////////////// + +#if !SELF_TEST +namespace { +template +__global__ void verifyPrepared( + T const *results, T const *expected, intptr_t elt_n, unsigned tolerance, int64_t *bad_elt_n + ) { + intptr_t i0 = blockIdx.x*(elt_n/gridDim.x); + i0 += blockIdx.x < elt_n%gridDim.x ? blockIdx.x : elt_n%gridDim.x; + intptr_t i1 = (blockIdx.x+1)*(elt_n/gridDim.x); + i1 += blockIdx.x+1 < elt_n%gridDim.x ? blockIdx.x+1 : elt_n%gridDim.x; + intptr_t i = i0 + threadIdx.x; + int64_t bad = 0; + + while(i < i1) { + T a = results[i], b = expected[i]; + T delta = a < b ? b - a : a - b; + bad += tolerance < delta ? 1 : 0; + #if 0 + if(tolerance < delta) { + printf("verifyPrepared ix=%lld got=%g exp=%g\n", (long long)i, (float)results[i], (float)expected[i]); + } + #endif + i += blockDim.x; + } + asm volatile("red.global.add.u64 [%0],%1;" :: "l"(bad_elt_n), "l"(bad)); +} + +template +__global__ void verifyInline2( + T const *results, intptr_t elt_n, ReduceFn op, int rank_n, uint64_t seed, + intptr_t elt_ix0, unsigned tolerance, int64_t *bad_elt_n + ) { + intptr_t i0 = blockIdx.x*(elt_n/gridDim.x); + i0 += blockIdx.x < elt_n%gridDim.x ? blockIdx.x : elt_n%gridDim.x; + intptr_t i1 = (blockIdx.x+1)*(elt_n/gridDim.x); + i1 += blockIdx.x+1 < elt_n%gridDim.x ? blockIdx.x+1 : elt_n%gridDim.x; + intptr_t i = i0 + threadIdx.x; + int64_t bad = 0; + + while(i < i1) { + union { T t; Uint u; } a, b; + a.t = results[i]; + b.t = genOutput(op, rank_n, seed, elt_ix0+i); + Uint delta = a.u < b.u ? b.u - a.u : a.u - b.u; + bad += tolerance < delta ? 1 : 0; + #if 0 + T input = genInput(op, rank_n, 0, seed, elt_ix0+i); + if(tolerance < delta) { + printf("verifyInline2 fail T=%d ix=%lld got=%g exp=%g input=%g\n", + std::is_same::value, (long long)i, (float)a.t, (float)b.t, (float)input); + } else { + printf("verifyInline2 pass T=%d ix=%lld got=%g exp=%g input=%g\n", + std::is_same::value, (long long)i, (float)a.t, (float)b.t, (float)input); + } + #endif + i += blockDim.x; + } + asm volatile("red.global.add.u64 [%0],%1;" :: "l"(bad_elt_n), "l"(bad)); +} + +template +void verifyInline1( + T const *results, intptr_t elt_n, int red_op, int rank_n, uint64_t seed, intptr_t elt_ix0, + unsigned tolerance, int64_t *bad_elt_n, cudaStream_t stream, int block_n + ) { + #define CASE_OP(op) \ + if(rank_n == 1) \ + verifyInline2<<>> \ + ((T const*)results, elt_n, ReduceNil(), rank_n, seed, elt_ix0, tolerance, bad_elt_n); \ + else \ + verifyInline2<<>> \ + ((T const*)results, elt_n, op, rank_n, seed, elt_ix0, tolerance, bad_elt_n); \ + break; + switch(red_op) { + case ncclSum: CASE_OP(ReduceSum()) + case ncclMin: CASE_OP(ReduceMin()) + case ncclMax: CASE_OP(ReduceMax()) + case ncclProd: CASE_OP(ReduceProd()) + #if HAVE_ncclAvg + case ncclAvg: CASE_OP(ReduceAvg{rank_n}) + #endif + #if HAVE_ncclPreMulSum + default: CASE_OP(ReducePreMulSum()) + #endif + } + #undef CASE_OP +} +} + +void ncclVerifiableVerify( + void const *results, void const *expected, intptr_t elt_n, int elt_ty, + int red_op, int rank_n, uint64_t seed, intptr_t elt_ix0, + int64_t *bad_elt_n, cudaStream_t stream + ) { + bool floating = elt_ty == ncclFloat16 || elt_ty == ncclFloat32 || elt_ty == ncclFloat64; + #if HAVE_ncclBfloat16 + floating |= elt_ty == ncclBfloat16; + #endif + + unsigned tolerance = 0; + #if HAVE_ncclAvg + if (floating && red_op == ncclAvg) + tolerance = calcSumFloatTolerance(rank_n, elt_ty); + #endif + + int block_n = std::min(32, (elt_n + 4*512-1)/(4*512)); + + *bad_elt_n = 0; + #define CASE_TY(T, Uint) { \ + if(expected != nullptr) { \ + verifyPrepared<<>>((Uint const*)results, (Uint const*)expected, elt_n, tolerance, bad_elt_n); \ + } else { \ + verifyInline1((T const*)results, elt_n, red_op, rank_n, seed, elt_ix0, tolerance, bad_elt_n, stream, block_n); \ + } \ + } break; + switch(elt_ty) { + case ncclInt8: CASE_TY(int8_t, uint8_t) + case ncclUint8: CASE_TY(uint8_t, uint8_t) + case ncclInt32: CASE_TY(int32_t, uint32_t) + case ncclUint32: CASE_TY(uint32_t, uint32_t) + case ncclInt64: CASE_TY(int64_t, uint64_t) + case ncclUint64: CASE_TY(uint64_t, uint64_t) + case ncclFloat16: CASE_TY(half, uint16_t) + #if HAVE_ncclBfloat16 + case ncclBfloat16: CASE_TY(__nv_bfloat16, uint16_t) + #endif + case ncclFloat32: CASE_TY(float, uint32_t) + case ncclFloat64: CASE_TY(double, uint64_t) + default: assert(0); + } + #undef CASE_TY +} +#endif + +//////////////////////////////////////////////////////////////////////////////// + +#if SELF_TEST +#include + +template +__device__ void sweep2(int ty, char const *tyname, Op op, char const *opname, int rank_n) { + //if(!std::is_same::value) return; + //if(!std::is_same::value) return; + //if(rank_n!=3) return; + + unsigned tolerance = !IsIntegral::value && std::is_same::value ? calcSumFloatTolerance(rank_n, ty) : 0; + uint64_t seed = 0xc8e2bed69766d533; + + for(int ix=threadIdx.x; ix < 10000; ix+=blockDim.x) { + //if(ix!=387) continue; + T y = genOutput(op, rank_n, seed, ix); + T sum; + for(int r=0; r < rank_n; r++) { + T x = genInput(op, rank_n, r, seed, ix); + x = op.preOp(x, r); + sum = r==0 ? x : op(sum, inhibit(x)); + //std::printf("x = %llx, sum = %llx\n", bitsOf(x), bitsOf(sum)); + } + sum = op.postOp(sum); + if(tolerance < calcDelta(sum, y)) { + std::printf( + //"%10g != %10g : T=%-8s op=%-9s rank_n=%-1d ix=%-1d\n", + "%llx != %llx : T=%-8s op=%-9s rank_n=%-1d ix=%-1d\n", + *(long long*)&sum, *(long long*)&y, tyname, opname, rank_n, ix + ); + } + } +} + +template +__device__ void sweep1(int ty, char const *tyname) { + for(int i=0; i < 10; i++) { + int rank_n = (1<(ty, tyname, ReduceSum(), "sum", rank_n); + sweep2(ty, tyname, ReduceProd(), "prod", rank_n); + sweep2(ty, tyname, ReduceMin(), "min", rank_n); + sweep2(ty, tyname, ReduceMax(), "max", rank_n); + sweep2(ty, tyname, ReducePreMulSum(), "premulsum", rank_n); + sweep2(ty, tyname, ReduceAvg{rank_n}, "avg", rank_n); + } +} + +__global__ void sweep() { + sweep1(ncclInt8, "int8"); + sweep1(ncclUint8, "uint8"); + sweep1(ncclInt32, "int32"); + sweep1(ncclUint32, "uint32"); + sweep1(ncclInt64, "int64"); + sweep1(ncclUint64, "uint64"); + sweep1(ncclFloat16, "half"); + #if HAVE_ncclBfloat16 + sweep1<__nv_bfloat16>(ncclBfloat16, "bfloat16"); + #endif + sweep1(ncclFloat32, "float"); + sweep1(ncclFloat64, "double"); +} + +int main(int arg_n, char **args) { + std::cerr<<"You are hoping to see no output beyond this line."<>>(); + cudaDeviceSynchronize(); + return 0; +} +#endif diff --git a/verifiable/verifiable.h b/verifiable/verifiable.h new file mode 100644 index 0000000..aca0565 --- /dev/null +++ b/verifiable/verifiable.h @@ -0,0 +1,59 @@ +#ifndef _d41d8cd98f00b204e9800998ecf8427e +#define _d41d8cd98f00b204e9800998ecf8427e + +#include + +#include + +/* Routines for launching kernels that verify reduction results. A significant + * feature of these routines is they carefully craft floating point input + * to produce exactly predictable output. + * + * int elt_ty: actually just a ncclDataType_t + * + * int red_op: mostly just a ncclRedOp_t. Since PreMulSum ops are dynamically + * created, these are encoded as the value ncclNumOps and their scalar is + * assumed to be `ncclVerifiablePremulScalar(rank_me)` + * + * uint64_t seed: arbitrary 64-bits to use in seeding the random values + * + * intptr_t elt_ix0: index of first element pointed to by elts when generating + * random values. This makes it possible to generate subsequences independently + * as well as in aggregate. + * + * int rank_n: Number of contributions into the reduction. Non-reduction + * collectives like broadcast, gather, etc will always set this to one. + * + * int rank_me: Index of this contribution + */ + +// Use this as the local scalar for PreMulSum ops +template +__host__ __device__ T ncclVerifiablePremulScalar(int rank_me) { + return T(rank_me%2 == 0 ? 1.0f : 2.0f); +} + +// Enqueue kernel to generate data which is to be reduced. +void ncclVerifiablePrepareInput( + void *elts, intptr_t elt_n, int elt_ty, int red_op, int rank_n, int rank_me, + uint64_t seed, intptr_t elt_ix0, cudaStream_t stream +); + +// Enqueue kernel to generate expected results of reduction. +void ncclVerifiablePrepareExpected( + void *elts, intptr_t elt_n, int elt_ty, int red_op, int rank_n, + uint64_t seed, intptr_t elt_ix0, cudaStream_t stream +); + +// Enqueue kernel to verify reduced data matches expectation. The number of +// failed elements is written to bad_elt_n which must be in cudaHost memory. +// If `expected == nullptr` then the expected results are generated on-the-fly +// which can be costly. Thus if you plan to run the same reduction multiple +// times it is advantageous to precompute the expected values with +// ncclVerifiablePrepareExpected and pass them as `expected` here. +void ncclVerifiableVerify( + void const *results, void const *expected, intptr_t elt_n, int elt_ty, + int red_op, int rank_n, uint64_t seed, intptr_t elt_ix0, + int64_t *bad_elt_n, cudaStream_t stream +); +#endif diff --git a/verifiable/verifiable.mk b/verifiable/verifiable.mk new file mode 100644 index 0000000..225c32a --- /dev/null +++ b/verifiable/verifiable.mk @@ -0,0 +1,11 @@ +# We requires both of the following paths to be set upon including this makefile +# TEST_VERIFIABLE_SRCDIR = +# TEST_VERIFIABLE_BUILDDIR = + +TEST_VERIFIABLE_HDRS = $(TEST_VERIFIABLE_SRCDIR)/verifiable.h +TEST_VERIFIABLE_OBJS = $(TEST_VERIFIABLE_BUILDDIR)/verifiable.o + +$(TEST_VERIFIABLE_BUILDDIR)/verifiable.o: $(TEST_VERIFIABLE_SRCDIR)/verifiable.cu $(TEST_VERIFY_REDUCE_HDRS) + @printf "Compiling %s\n" $@ + @mkdir -p $(TEST_VERIFIABLE_BUILDDIR) + $(NVCC) -o $@ $(NVCUFLAGS) -c $(TEST_VERIFIABLE_SRCDIR)/verifiable.cu