diff --git a/src/Makefile b/src/Makefile index 393de8e..32c08ce 100644 --- a/src/Makefile +++ b/src/Makefile @@ -34,8 +34,13 @@ NVCC_GENCODE ?= -gencode=arch=compute_35,code=sm_35 \ -gencode=arch=compute_70,code=compute_70 endif +ifeq ($(GLOO), 1) +NVCUFLAGS := -ccbin $(CXX) $(NVCC_GENCODE) -std=c++17 +CXXFLAGS := -std=c++17 +else NVCUFLAGS := -ccbin $(CXX) $(NVCC_GENCODE) -std=c++11 CXXFLAGS := -std=c++11 +endif LDFLAGS := -L${CUDA_LIB} -lcudart -lrt NVLDFLAGS := -L${CUDA_LIB} -l${CUDARTLIB} -lrt @@ -70,6 +75,13 @@ ifeq ($(MPI_IBM),1) NVCUFLAGS += -DMPI_SUPPORT NVLDFLAGS += -lmpi_ibm endif +ifeq ($(GLOO), 1) +PYTHON_CONFIG := python3-config +PYTHON_INCLUDE := $(shell $(PYTHON_CONFIG) --includes) +TORCH_HOME ?= /usr/local/libtorch +NVCUFLAGS += -D_GLIBCXX_USE_CXX11_ABI=0 -DUSE_C10D_GLOO $(PYTHON_INCLUDE) -isystem $(TORCH_HOME)/include -isystem $(TORCH_HOME)/include/torch/csrc/api/include +NVLDFLAGS += -L$(TORCH_HOME)/lib -lc10 -ltorch_cpu +endif LIBRARIES += nccl NVLDFLAGS += $(LIBRARIES:%=-l%) diff --git a/src/common.cu b/src/common.cu index 04e8142..1552624 100644 --- a/src/common.cu +++ b/src/common.cu @@ -10,10 +10,21 @@ #include #include #include +#include +#include #include "cuda.h" #include "../verifiable/verifiable.h" +#ifdef USE_C10D_GLOO +#include +#include +#include +#include +#include +#include +#endif /* USE_C10D_GLOO */ + int test_ncclVersion = 0; // init'd with ncclGetVersion() #if NCCL_MAJOR >= 2 @@ -55,6 +66,19 @@ extern "C" __attribute__((weak)) char const* ncclGetLastError(ncclComm_t comm) { return ""; } +// If 'use_c10d_gloo' is true, use pytorch c10d GLOO distributed framework for +// multi-process multi-node NCCL testing. The following environment variables +// will be used: +// - MASTER_ADDR: Master IP address where gloo server is running. +// - MASTER_PORT: Master port where gloo server is listening. +// - RANK: Global rank of the process. +// - WORLD_SIZE: Total number of processes. +bool use_c10d_gloo = false; + +#ifdef USE_C10D_GLOO +std::shared_ptr c10d_process_group; +#endif /* USE_C10D_GLOO */ + int is_main_proc = 0; thread_local int is_main_thread = 0; @@ -151,9 +175,15 @@ void Barrier(struct threadArgs *args) { if(args->thread+1 == args->nThreads) { while(counter[epoch] != args->nThreads) pthread_cond_wait(&cond[epoch], &lock[epoch]); - #ifdef MPI_SUPPORT - MPI_Barrier(MPI_COMM_WORLD); - #endif + if (!use_c10d_gloo) { +#ifdef MPI_SUPPORT + MPI_Barrier(MPI_COMM_WORLD); +#endif + } else { +#ifdef USE_C10D_GLOO + c10d_process_group->barrier()->wait(); +#endif + } counter[epoch] = 0; pthread_cond_broadcast(&cond[epoch]); } @@ -165,6 +195,28 @@ void Barrier(struct threadArgs *args) { epoch ^= 1; } +#ifdef USE_C10D_GLOO +template +struct torch_type; + +template<> +struct torch_type { + static at::ScalarType type() { return at::kLong; } + static long long value(const at::Tensor& tensor) { return tensor.item().toLong(); } +}; + +template<> +struct torch_type { + static at::ScalarType type() { return at::kDouble; } + static double value(const at::Tensor& tensor) { return tensor.item().toDouble(); } +}; + +template +at::Tensor create_tensor_from_blob(T* data, int64_t size) { + return torch::from_blob(data, {size}, torch_type::type()); +} +#endif + // Inter-thread/process barrier+allreduce. The quality of the return value // for average=0 (which means broadcast from rank=0) is dubious. The returned // value will actually be the result of process-local broadcast from the local thread=0. @@ -196,19 +248,34 @@ void Allreduce(struct threadArgs* args, T* value, int average) { while(counter[epoch] != args->nThreads) pthread_cond_wait(&cond[epoch], &lock[epoch]); - #ifdef MPI_SUPPORT if(average != 0) { static_assert(std::is_same::value || std::is_same::value, "Allreduce only for T in {long long, double}"); - MPI_Datatype ty = std::is_same::value ? MPI_LONG_LONG : - std::is_same::value ? MPI_DOUBLE : - MPI_Datatype(); - MPI_Op op = average == 1 ? MPI_SUM : - average == 2 ? MPI_MIN : - average == 3 ? MPI_MAX : - average == 4 ? MPI_SUM : MPI_Op(); - MPI_Allreduce(MPI_IN_PLACE, (void*)&accumulator[epoch], 1, ty, op, MPI_COMM_WORLD); + if (!use_c10d_gloo) { +#ifdef MPI_SUPPORT + MPI_Datatype ty = std::is_same::value ? MPI_LONG_LONG : + std::is_same::value ? MPI_DOUBLE : + MPI_Datatype(); + MPI_Op op = average == 1 ? MPI_SUM : + average == 2 ? MPI_MIN : + average == 3 ? MPI_MAX : + average == 4 ? MPI_SUM : MPI_Op(); + MPI_Allreduce(MPI_IN_PLACE, (void*)&accumulator[epoch], 1, ty, op, MPI_COMM_WORLD); +#endif + } + else { +#ifdef USE_C10D_GLOO + c10d::AllreduceOptions opts; + opts.reduceOp = average == 2 ? c10d::ReduceOp::MIN : + average == 3 ? c10d::ReduceOp::MAX : + c10d::ReduceOp::SUM; + + auto tensor = create_tensor_from_blob(&accumulator[epoch], 1); + std::vector input_tensors{tensor}; + c10d_process_group->allreduce(input_tensors, opts)->wait(); + //accumulator[epoch] = torch_type::value(input_tensors[0]); +#endif + } } - #endif if(average == 1) accumulator[epoch] /= args->totalProcs*args->nThreads; counter[epoch] = 0; @@ -870,8 +937,51 @@ int main(int argc, char* argv[]) { (unsigned long long)maxBytes); return -1; } + +#ifdef USE_C10D_GLOO + { + // Parse c10d GLOO distributed framework environment variables. + char *str = getenv("MASTER_ADDR"); + if (str) { + std::string master_addr = str; + use_c10d_gloo = true; + + str = getenv("MASTER_PORT"); + uint16_t master_port = str ? static_cast(std::stoi(str)) : 29500; + + str = getenv("RANK"); + int rank = str? std::stoi(str) : 0; + + str = getenv("WORLD_SIZE"); + int world_size = str ? std::stoi(str) : 1; + + auto options = c10d::ProcessGroupGloo::Options::create(); + // Create Gloo device that binds to any interface. + ::gloo::transport::tcp::attr tcp_attr; + str = getenv("GLOO_INTERFACE"); + tcp_attr.iface = str ? str : "eth0"; + auto gloo_device = ::gloo::transport::tcp::CreateDevice(tcp_attr); + options->devices.push_back(gloo_device); + + c10d::TCPStoreOptions store_opts; + store_opts.port = master_port; + if (rank == 0) { + store_opts.isServer = true; + } + auto store_ptr = c10::make_intrusive( + master_addr, store_opts); + + // Create the ProcessGroupGloo + c10d_process_group = std::make_shared( + store_ptr, rank, world_size, options); + } + } +#endif /* USE_C10D_GLOO */ + #ifdef MPI_SUPPORT - MPI_Init(&argc, &argv); + if (!use_c10d_gloo) { + MPI_Init(&argc, &argv); + } #endif TESTCHECK(run()); return 0; @@ -884,24 +994,51 @@ testResult_t run() { getHostName(hostname, 1024); #ifdef MPI_SUPPORT - MPI_Comm_size(MPI_COMM_WORLD, &totalProcs); - MPI_Comm_rank(MPI_COMM_WORLD, &proc); - uint64_t hostHashs[totalProcs]; - hostHashs[proc] = getHostHash(hostname); - MPI_Allgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, hostHashs, sizeof(uint64_t), MPI_BYTE, MPI_COMM_WORLD); - for (int p=0; pgetSize(); + ncclProc = proc = c10d_process_group->getRank(); + uint64_t hostHash = getHostHash(hostname); + + auto tensor = torch::tensor({(int64_t)hostHash}, torch::kLong); + std::vector input_tensors{tensor}; + std::vector> output_tensors; + output_tensors.emplace_back(); + for (const auto ii : c10::irange(totalProcs)) { + output_tensors.front().emplace_back(at::empty_like(tensor)); + } + + c10d_process_group->allgather(output_tensors, input_tensors)->wait(); + + for (int p = 0; p < output_tensors[0].size(); p++) { + if (p == proc) break; + if ((uint64_t)output_tensors[0][p].item().toLong() == hostHash) localRank++; + } +#endif + } + is_main_thread = is_main_proc = (proc == 0) ? 1 : 0; PRINT("# nThread %d nGpus %d minBytes %ld maxBytes %ld step: %ld(%s) warmup iters: %d iters: %d agg iters: %d validation: %d graph: %d\n", @@ -929,6 +1066,7 @@ testResult_t run() { maxMem = std::min(maxMem, prop.totalGlobalMem); } + if (!use_c10d_gloo) { #if MPI_SUPPORT char *lines = (proc == 0) ? (char *)malloc(totalProcs*MAX_LINE) : NULL; // Gather all output in rank order to root (0) @@ -942,6 +1080,39 @@ testResult_t run() { #else PRINT("%s", line); #endif + } else { +#ifdef USE_C10D_GLOO + { + auto tensor = torch::from_blob((void*)line, {MAX_LINE}, torch::kUInt8); + std::vector input_tensors{tensor}; + std::vector> output_tensors; + if (proc == 0) { + output_tensors.emplace_back(); + for (const auto i : c10::irange(totalProcs)) { + output_tensors.front().emplace_back(at::empty_like(tensor)); + } + } + + c10d::GatherOptions opts; + opts.rootRank = 0; + c10d_process_group->gather(output_tensors, input_tensors, opts)->wait(); + if (proc == 0) { + for (int ii = 0; ii < totalProcs; ++ii) { + PRINT("%s", output_tensors[0][ii].data_ptr()); + } + } + } + + { + auto tensor = torch::tensor({(int64_t)maxMem}, torch::kLong); + std::vector input_tensors{tensor}; + c10d::AllreduceOptions opts; + opts.reduceOp = c10d::ReduceOp::MIN; + c10d_process_group->allreduce(input_tensors, opts)->wait(); + maxMem = (size_t)input_tensors[0].item().toLong(); + } +#endif + } // We need sendbuff, recvbuff, expected (when datacheck enabled), plus 1G for the rest. size_t memMaxBytes = (maxMem - (1<<30)) / (datacheck ? 3 : 2); @@ -954,10 +1125,24 @@ testResult_t run() { if (ncclProc == 0) { NCCLCHECK(ncclGetUniqueId(&ncclId)); } + if (!use_c10d_gloo) { #ifdef MPI_SUPPORT - MPI_Bcast(&ncclId, sizeof(ncclId), MPI_BYTE, 0, mpi_comm); - MPI_Barrier(MPI_COMM_WORLD); // Ensure Bcast is complete for HCOLL + MPI_Bcast(&ncclId, sizeof(ncclId), MPI_BYTE, 0, mpi_comm); + MPI_Barrier(MPI_COMM_WORLD); // Ensure Bcast is complete for HCOLL #endif + } else { +#ifdef USE_C10D_GLOO + auto ncclId_tensor = torch::from_blob(ncclId.internal, + {static_cast(sizeof(ncclId.internal))}, torch::kByte); + std::vector ncclId_tensor_vector = {ncclId_tensor}; + c10d::BroadcastOptions opts; + opts.rootRank = 0; + c10d_process_group->broadcast(ncclId_tensor_vector, opts)->wait(); + c10d_process_group->barrier()->wait(); + + // Other ranks will receive the 'ncclId' once they reach here. +#endif + } int gpus[nGpus*nThreads]; cudaStream_t streams[nGpus*nThreads]; void* sendbuffs[nGpus*nThreads]; @@ -1074,9 +1259,20 @@ testResult_t run() { } } + if (!use_c10d_gloo) { #ifdef MPI_SUPPORT - MPI_Allreduce(MPI_IN_PLACE, &errors[0], 1, MPI_INT, MPI_SUM, MPI_COMM_WORLD); + MPI_Allreduce(MPI_IN_PLACE, &errors[0], 1, MPI_INT, MPI_SUM, MPI_COMM_WORLD); #endif + } else { +#ifdef USE_C10D_GLOO + auto tensor = torch::tensor({errors[0]}, torch::kLong); + std::vector input_tensors{tensor}; + c10d::AllreduceOptions opts; + opts.reduceOp = c10d::ReduceOp::SUM; + c10d_process_group->allreduce(input_tensors, opts)->wait(); + errors[0] = input_tensors[0].item().toLong(); +#endif + } if (!parallel_init) { for(int i=0; i