diff --git a/include/mscclpp/nvls.hpp b/include/mscclpp/nvls.hpp index b63be9d96..4acc040e8 100644 --- a/include/mscclpp/nvls.hpp +++ b/include/mscclpp/nvls.hpp @@ -17,7 +17,7 @@ class NvlsConnection { std::vector serialize(); // the recommended buffer size for NVLS, returned by cuMulticastGetGranularity - static const int DefaultNvlsBufferSize = (1 << 29); + static const int DefaultNvlsBufferSize; // Everyone needs to synchronize after creating a NVLS connection before adding devices void addDevice(); diff --git a/python/test/executor_test.py b/python/test/executor_test.py index b0e4342dd..d744a4c1a 100644 --- a/python/test/executor_test.py +++ b/python/test/executor_test.py @@ -1,19 +1,18 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from os import path +import argparse from mscclpp import ( DataType, Executor, ExecutionPlan, + PacketType, ) import mscclpp.comm as mscclpp_comm import cupy as cp from mpi4py import MPI -MSCCLPP_ROOT_PATH = "/root/mscclpp" - def bench_time(niters: int, ngraphIters: int, func): # capture cuda graph for niters of the kernel launch @@ -40,36 +39,118 @@ def bench_time(niters: int, ngraphIters: int, func): return cp.cuda.get_elapsed_time(start, end) / niters * 1000.0 / ngraphIters -if __name__ == "__main__": +def parse_size(size_str): + """Convert a human-readable buffer size string to an integer.""" + size_str = size_str.strip() + if not size_str: + raise ValueError("Size string can not be empty") + units = {"K": 1024, "M": 1024**2, "G": 1024**3} + if size_str[-1].upper() in units: + return int(size_str[:-1]) * units[size_str[-1].upper()] + else: + return int(size_str) + + +def parse_dtype(dtype_str): + """Convert a human-readable data type string to a numpy data type.""" + dtype_str = dtype_str.strip().lower() + if dtype_str == "float16": + return cp.float16 + elif dtype_str == "float32": + return cp.float32 + elif dtype_str == "int32": + return cp.int32 + else: + raise ValueError(f"Unknown data type: {dtype_str}") + + +def dtype_to_mscclpp_dtype(dtype): + if dtype == cp.float16: + return DataType.float16 + elif dtype == cp.float32: + return DataType.float32 + elif dtype == cp.int32: + return DataType.int32 + else: + raise ValueError(f"Unknown data type: {dtype}") + + +def main( + execution_paln_name: str, + execution_plan_path: str, + size: int, + nthreads_per_block: int, + dtype: cp.dtype = cp.float16, + packet_type: PacketType = PacketType.LL16, + seed: int = 42, +): mscclpp_group = mscclpp_comm.CommGroup(MPI.COMM_WORLD) - cp.cuda.Device(MPI.COMM_WORLD.rank % mscclpp_group.nranks_per_node).use() + cp.cuda.Device(mscclpp_group.my_rank % mscclpp_group.nranks_per_node).use() executor = Executor(mscclpp_group.communicator) - execution_plan = ExecutionPlan( - "allreduce_pairs", path.join(MSCCLPP_ROOT_PATH, "test", "execution-files", "allreduce.json") - ) + execution_plan = ExecutionPlan(execution_paln_name, execution_plan_path) - nelems = 1024 * 1024 - cp.random.seed(42) - buffer = cp.random.random(nelems).astype(cp.float16) + cp.random.seed(seed) + nelems = size // cp.dtype(dtype).itemsize + buffer = cp.random.random(nelems * mscclpp_group.nranks).astype(dtype) sub_arrays = cp.split(buffer, MPI.COMM_WORLD.size) sendbuf = sub_arrays[MPI.COMM_WORLD.rank] + expected = cp.zeros_like(sendbuf) + for i in range(mscclpp_group.nranks): + expected += sub_arrays[i] mscclpp_group.barrier() - execution_time = bench_time( - 100, - 10, - lambda stream: executor.execute( - MPI.COMM_WORLD.rank, - sendbuf.data.ptr, - sendbuf.data.ptr, - sendbuf.nbytes, - sendbuf.nbytes, - DataType.float16, - 512, - execution_plan, - stream.ptr, - ), + executor_func = lambda stream: executor.execute( + MPI.COMM_WORLD.rank, + sendbuf.data.ptr, + sendbuf.data.ptr, + sendbuf.nbytes, + sendbuf.nbytes, + dtype_to_mscclpp_dtype(dtype), + nthreads_per_block, + execution_plan, + stream.ptr, + packet_type, + ) + # check correctness + stream = cp.cuda.Stream(non_blocking=True) + executor_func(stream) + stream.synchronize() + assert cp.allclose(sendbuf, expected, atol=1e-2 * mscclpp_group.nranks) + + mscclpp_group.barrier() + execution_time = bench_time(100, 10, executor_func) + print( + f"Rank: {MPI.COMM_WORLD.rank} Execution time: {execution_time} us, " + f"data size: {sendbuf.nbytes} bytes data type: {dtype().dtype.name} " + f"packet type: {packet_type} nthreads_per_block: {nthreads_per_block}" ) - print(f"Rank: {MPI.COMM_WORLD.rank} Execution time: {execution_time} us, data size: {sendbuf.nbytes} bytes") executor = None mscclpp_group = None + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("-n", "--execution_plan_name", type=str, required=True) + parser.add_argument("-path", "--execution_plan_path", type=str, required=True) + parser.add_argument("--size", type=str, required=True) + parser.add_argument("--nthreads_per_block", type=int, required=True) + parser.add_argument("--dtype", type=str, default="float16", help="Choose from float16, float32, int32") + parser.add_argument("--packet_type", type=str, default="LL16", help="Choose from LL8, LL16") + parser.add_argument("--seed", type=int, default=42) + args = parser.parse_args() + + packet_type = PacketType.LL16 + if args.packet_type == "LL8": + packet_type = PacketType.LL8 + + buffer_size = parse_size(args.size) + dtype = parse_dtype(args.dtype) + main( + args.execution_plan_name, + args.execution_plan_path, + buffer_size, + args.nthreads_per_block, + dtype, + packet_type, + args.seed, + ) diff --git a/src/executor/execution_plan.cc b/src/executor/execution_plan.cc index 4c896f860..0d427fbe9 100644 --- a/src/executor/execution_plan.cc +++ b/src/executor/execution_plan.cc @@ -290,11 +290,17 @@ void ExecutionPlan::Impl::setupOperations(const json& gpus) { } size_t ExecutionPlan::Impl::getOffset(int rank, size_t inputSize, uint32_t chunkIndex, uint32_t alignment) const { - assert(inputSize % alignment == 0 && "inputSize must be a multiple of alignment"); + if (inputSize % alignment != 0) { + throw Error("inputSize must be a multiple of alignment", ErrorCode::ExecutorError); + } const int nGroups = this->chunkGroups.at(rank); uint32_t nInputChunks = this->inputChunks.at(rank); uint32_t nelems = inputSize / (alignment * sizeof(uint8_t)); + if (nelems % nGroups != 0) { + throw Error("Input size must be a multiple of nGroups", ErrorCode::ExecutorError); + } + int nelemsPerGroup = nelems / nGroups; int nChunksPerGroup = nInputChunks / nGroups; uint32_t minNelems = nelemsPerGroup / nChunksPerGroup; diff --git a/src/include/execution_kernel.hpp b/src/include/execution_kernel.hpp index 406e69eb8..9b2b77f4d 100644 --- a/src/include/execution_kernel.hpp +++ b/src/include/execution_kernel.hpp @@ -234,13 +234,13 @@ MSCCLPP_DEVICE_INLINE void handleReadReduceCopySend(T* output, uint32_t outputOf } template -MSCCLPP_DEVICE_INLINE void handlePutPacket(uint32_t inputOffsetByBytes, size_t scratchSize, - DeviceHandle* smChannels, uint8_t* dstChannelIndexes, - uint32_t* dstOffsets, int nDstChannels, uint32_t size, uint32_t flag) { +MSCCLPP_DEVICE_INLINE void handlePutPacket(size_t scratchSize, DeviceHandle* smChannels, + uint8_t* dstChannelIndexes, uint32_t* dstOffsets, uint32_t* srcOffsets, + int nDstChannels, uint32_t size, uint32_t flag) { const size_t scratchBaseOffset = flag & 0x1 ? 0 : scratchSize >> 1; for (int index = 0; index < nDstChannels; ++index) { - smChannels[dstChannelIndexes[index]].putPackets( - scratchBaseOffset + dstOffsets[index] * 2, inputOffsetByBytes, size, threadIdx.x, blockDim.x, flag); + smChannels[dstChannelIndexes[index]].putPackets(scratchBaseOffset + dstOffsets[index] * 2, + srcOffsets[index], size, threadIdx.x, blockDim.x, flag); } } @@ -376,7 +376,7 @@ __global__ void executionKernel([[maybe_unused]] int rank /*for debug*/, T* inpu op.inputChannelIndexes, op.outputOffsets, op.inputOffsets, op.nOutputs, op.nInputs, op.size, false); } else if (op.type == OperationType::PUT_PACKET) { - handlePutPacket(op.srcOffset, scratchSize, smChannels, op.outputChannelIndexes, op.outputOffsets, + handlePutPacket(scratchSize, smChannels, op.outputChannelIndexes, op.outputOffsets, op.inputOffsets, op.nOutputs, op.size, flag); } else if (op.type == OperationType::REDUCE_SEND_PACKET) { T* dst = getBuffer(input, output, scratch, op.dstBufferType); diff --git a/src/nvls.cc b/src/nvls.cc index c4a7c7ec8..7e3f2a41d 100644 --- a/src/nvls.cc +++ b/src/nvls.cc @@ -236,6 +236,8 @@ class NvlsConnection::Impl { }; #endif // !(USE_NVLS) +const int NvlsConnection::DefaultNvlsBufferSize = (1 << 29); + NvlsConnection::NvlsConnection(size_t bufferSize, int numDevices) : pimpl_(std::make_shared(bufferSize, numDevices)) {} diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index b5043b4bf..ec4fbff4e 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -24,6 +24,7 @@ endfunction() add_test_executable(allgather_test_cpp allgather_test_cpp.cu) add_test_executable(allgather_test_host_offloading allgather_test_host_offloading.cu) add_test_executable(nvls_test nvls_test.cu) +add_test_executable(executor_test executor_test.cc) configure_file(run_mpi_test.sh.in run_mpi_test.sh) diff --git a/test/execution-files/allreduce_packet.json b/test/execution-files/allreduce_packet.json index 5bcdf1cdd..b0df82c91 100644 --- a/test/execution-files/allreduce_packet.json +++ b/test/execution-files/allreduce_packet.json @@ -14,24 +14,6 @@ { "id": 0, "ops": [ - { - "name": "ppkt", - "o_buff": { - "src": "i", - "dst": "s" - }, - "o_cids": [ - { - "id": 0, - "off": 0 - } - ], - "src": 0, - "srcbuff": "i", - "srcoff": 2, - "ctype": "sm", - "cnt": 1 - }, { "name": "rspkt", "o_buff": { @@ -58,17 +40,6 @@ "dstoff": 0, "ctype": "sm", "cnt": 1 - }, - { - "name": "cpkt", - "src": 0, - "srcbuff": "s", - "srcoff": 6, - "dst": 0, - "dstbuff": "i", - "dstoff": 2, - "ctype": "none", - "cnt": 1 } ], "channels": [ @@ -94,14 +65,17 @@ "o_cids": [ { "id": 0, - "off": 1 + "off": 0 + } + ], + "srcs": [ + { + "buff": "i", + "off": 2 } ], - "src": 0, - "srcbuff": "i", - "srcoff": 3, "ctype": "sm", - "cnt": 1 + "cnt": 2 }, { "name": "rspkt", @@ -134,12 +108,12 @@ "name": "cpkt", "src": 0, "srcbuff": "s", - "srcoff": 7, + "srcoff": 6, "dst": 0, "dstbuff": "i", - "dstoff": 3, + "dstoff": 2, "ctype": "none", - "cnt": 1 + "cnt": 2 } ], "channels": [ @@ -188,11 +162,14 @@ "off": 2 } ], - "src": 1, - "srcbuff": "i", - "srcoff": 0, + "srcs": [ + { + "buff": "i", + "off": 0 + } + ], "ctype": "sm", - "cnt": 1 + "cnt": 2 }, { "name": "rspkt", @@ -230,7 +207,7 @@ "dstbuff": "i", "dstoff": 0, "ctype": "none", - "cnt": 1 + "cnt": 2 } ], "channels": [ @@ -247,24 +224,6 @@ { "id": 1, "ops": [ - { - "name": "ppkt", - "o_buff": { - "src": "i", - "dst": "s" - }, - "o_cids": [ - { - "id": 0, - "off": 3 - } - ], - "src": 1, - "srcbuff": "i", - "srcoff": 1, - "ctype": "sm", - "cnt": 1 - }, { "name": "rspkt", "o_buff": { @@ -291,17 +250,6 @@ "dstoff": 3, "ctype": "sm", "cnt": 1 - }, - { - "name": "cpkt", - "src": 1, - "srcbuff": "s", - "srcoff": 5, - "dst": 1, - "dstbuff": "i", - "dstoff": 1, - "ctype": "none", - "cnt": 1 } ], "channels": [ diff --git a/test/executor_test.cc b/test/executor_test.cc new file mode 100644 index 000000000..a30691dde --- /dev/null +++ b/test/executor_test.cc @@ -0,0 +1,114 @@ +#include +#include + +#include +#include +#include +#include + +double parseSize(const char* value) { + std::string valueStr(value); + std::istringstream iss(valueStr); + long long int units; + double size; + char size_lit = 0; + + if (iss >> size) { + iss >> std::ws; // eat whitespace + iss >> size_lit; + } else { + return -1.0; + } + + if (size_lit != 0 && !std::isspace(size_lit)) { + switch (size_lit) { + case 'G': + case 'g': + units = 1024 * 1024 * 1024; + break; + case 'M': + case 'm': + units = 1024 * 1024; + break; + case 'K': + case 'k': + units = 1024; + break; + default: + return -1.0; + }; + } else { + units = 1; + } + return size * units; +} + +double benchTime(int rank, std::shared_ptr bootstrap, std::shared_ptr executor, + const mscclpp::ExecutionPlan& plan, std::shared_ptr sendbuff, size_t bufferSize, + int nthreadsPerBlock, int niters, int ngrapthIters) { + mscclpp::CudaStreamWithFlags stream(cudaStreamNonBlocking); + cudaGraph_t graph; + cudaGraphExec_t graphExec; + mscclpp::Timer timer; + MSCCLPP_CUDATHROW(cudaStreamBeginCapture(stream, cudaStreamCaptureModeGlobal)); + for (int i = 0; i < niters; i++) { + executor->execute(rank, sendbuff.get(), sendbuff.get(), bufferSize, bufferSize, mscclpp::DataType::FLOAT16, + nthreadsPerBlock, plan, stream, mscclpp::PacketType::LL16); + } + MSCCLPP_CUDATHROW(cudaStreamEndCapture(stream, &graph)); + MSCCLPP_CUDATHROW(cudaGraphInstantiate(&graphExec, graph, NULL, NULL, 0)); + MSCCLPP_CUDATHROW(cudaGraphLaunch(graphExec, stream)); + MSCCLPP_CUDATHROW(cudaStreamSynchronize(stream)); + bootstrap->barrier(); + + timer.reset(); + for (int i = 0; i < ngrapthIters; i++) { + MSCCLPP_CUDATHROW(cudaGraphLaunch(graphExec, stream)); + } + MSCCLPP_CUDATHROW(cudaStreamSynchronize(stream)); + double deltaSec = timer.elapsed() * 1.e-6; + deltaSec = deltaSec / (niters) / (ngrapthIters); + MSCCLPP_CUDATHROW(cudaGraphExecDestroy(graphExec)); + MSCCLPP_CUDATHROW(cudaGraphDestroy(graph)); + return deltaSec; +} + +int main(int argc, char* argv[]) { + if (argc != 5) { + std::cerr << "Usage: " << argv[0] << " " + << " " + << " " + << " " << std::endl; + return 1; + } + + int rank; + int worldSize; + MPI_Init(NULL, NULL); + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + MPI_Comm_size(MPI_COMM_WORLD, &worldSize); + MSCCLPP_CUDATHROW(cudaSetDevice(rank)); + + const size_t bufferSize = parseSize(argv[1]); + const std::string executionPlanName = argv[2]; + const std::string executionPlanPath = argv[3]; + const int nthreadsPerBlock = std::stoi(argv[4]); + + std::shared_ptr bootstrap; + mscclpp::UniqueId id; + bootstrap = std::make_shared(rank, worldSize); + if (rank == 0) id = bootstrap->createUniqueId(); + MPI_Bcast(&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD); + bootstrap->initialize(id); + std::shared_ptr communicator = std::make_shared(bootstrap); + std::shared_ptr executor = std::make_shared(communicator); + + mscclpp::ExecutionPlan plan(executionPlanName, executionPlanPath); + std::shared_ptr sendbuff = mscclpp::allocExtSharedCuda(bufferSize); + std::vector dataHost(bufferSize / sizeof(int), rank); + MSCCLPP_CUDATHROW(cudaMemcpy(sendbuff.get(), dataHost.data(), bufferSize, cudaMemcpyHostToDevice)); + double deltaSec = benchTime(rank, bootstrap, executor, plan, sendbuff, bufferSize, nthreadsPerBlock, 200, 20); + std::cout << "Rank " << rank << ": " << bufferSize << " bytes " << deltaSec * 1.e6 << " us" << std::endl; + MPI_Finalize(); + return 0; +}