Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add C++ executor test #304

Merged
merged 25 commits into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/mscclpp/nvls.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class NvlsConnection {
std::vector<char> serialize();

// the recommended buffer size for NVLS, returned by cuMulticastGetGranularity
static const int DefaultNvlsBufferSize = (1 << 29);
static const int DefaultNvlsBufferSize;

// Everyone needs to synchronize after creating a NVLS connection before adding devices
void addDevice();
Expand Down
133 changes: 107 additions & 26 deletions python/test/executor_test.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from os import path
import argparse
from mscclpp import (
DataType,
Executor,
ExecutionPlan,
PacketType,
)
import mscclpp.comm as mscclpp_comm

import cupy as cp
from mpi4py import MPI

MSCCLPP_ROOT_PATH = "/root/mscclpp"


def bench_time(niters: int, ngraphIters: int, func):
# capture cuda graph for niters of the kernel launch
Expand All @@ -40,36 +39,118 @@ def bench_time(niters: int, ngraphIters: int, func):
return cp.cuda.get_elapsed_time(start, end) / niters * 1000.0 / ngraphIters


if __name__ == "__main__":
def parse_size(size_str):
"""Convert a human-readable buffer size string to an integer."""
size_str = size_str.strip()
if not size_str:
raise ValueError("Size string can not be empty")
units = {"K": 1024, "M": 1024**2, "G": 1024**3}
if size_str[-1].upper() in units:
return int(size_str[:-1]) * units[size_str[-1].upper()]
else:
return int(size_str)


def parse_dtype(dtype_str):
"""Convert a human-readable data type string to a numpy data type."""
dtype_str = dtype_str.strip().lower()
if dtype_str == "float16":
return cp.float16
elif dtype_str == "float32":
return cp.float32
elif dtype_str == "int32":
return cp.int32
else:
raise ValueError(f"Unknown data type: {dtype_str}")


def dtype_to_mscclpp_dtype(dtype):
if dtype == cp.float16:
return DataType.float16
elif dtype == cp.float32:
return DataType.float32
elif dtype == cp.int32:
return DataType.int32
else:
raise ValueError(f"Unknown data type: {dtype}")


def main(
execution_paln_name: str,
execution_plan_path: str,
size: int,
nthreads_per_block: int,
dtype: cp.dtype = cp.float16,
packet_type: PacketType = PacketType.LL16,
seed: int = 42,
):
mscclpp_group = mscclpp_comm.CommGroup(MPI.COMM_WORLD)
cp.cuda.Device(MPI.COMM_WORLD.rank % mscclpp_group.nranks_per_node).use()
cp.cuda.Device(mscclpp_group.my_rank % mscclpp_group.nranks_per_node).use()
executor = Executor(mscclpp_group.communicator)
execution_plan = ExecutionPlan(
"allreduce_pairs", path.join(MSCCLPP_ROOT_PATH, "test", "execution-files", "allreduce.json")
)
execution_plan = ExecutionPlan(execution_paln_name, execution_plan_path)

nelems = 1024 * 1024
cp.random.seed(42)
buffer = cp.random.random(nelems).astype(cp.float16)
cp.random.seed(seed)
nelems = size // cp.dtype(dtype).itemsize
buffer = cp.random.random(nelems * mscclpp_group.nranks).astype(dtype)
sub_arrays = cp.split(buffer, MPI.COMM_WORLD.size)
sendbuf = sub_arrays[MPI.COMM_WORLD.rank]
expected = cp.zeros_like(sendbuf)
for i in range(mscclpp_group.nranks):
expected += sub_arrays[i]
mscclpp_group.barrier()

execution_time = bench_time(
100,
10,
lambda stream: executor.execute(
MPI.COMM_WORLD.rank,
sendbuf.data.ptr,
sendbuf.data.ptr,
sendbuf.nbytes,
sendbuf.nbytes,
DataType.float16,
512,
execution_plan,
stream.ptr,
),
executor_func = lambda stream: executor.execute(
MPI.COMM_WORLD.rank,
sendbuf.data.ptr,
sendbuf.data.ptr,
sendbuf.nbytes,
sendbuf.nbytes,
dtype_to_mscclpp_dtype(dtype),
nthreads_per_block,
execution_plan,
stream.ptr,
packet_type,
)
# check correctness
stream = cp.cuda.Stream(non_blocking=True)
executor_func(stream)
stream.synchronize()
assert cp.allclose(sendbuf, expected, atol=1e-2 * mscclpp_group.nranks)

mscclpp_group.barrier()
execution_time = bench_time(100, 10, executor_func)
print(
f"Rank: {MPI.COMM_WORLD.rank} Execution time: {execution_time} us, "
f"data size: {sendbuf.nbytes} bytes data type: {dtype().dtype.name} "
f"packet type: {packet_type} nthreads_per_block: {nthreads_per_block}"
)
print(f"Rank: {MPI.COMM_WORLD.rank} Execution time: {execution_time} us, data size: {sendbuf.nbytes} bytes")
executor = None
mscclpp_group = None


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-n", "--execution_plan_name", type=str, required=True)
parser.add_argument("-path", "--execution_plan_path", type=str, required=True)
parser.add_argument("--size", type=str, required=True)
parser.add_argument("--nthreads_per_block", type=int, required=True)
parser.add_argument("--dtype", type=str, default="float16", help="Choose from float16, float32, int32")
parser.add_argument("--packet_type", type=str, default="LL16", help="Choose from LL8, LL16")
parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args()

packet_type = PacketType.LL16
if args.packet_type == "LL8":
packet_type = PacketType.LL8

buffer_size = parse_size(args.size)
dtype = parse_dtype(args.dtype)
main(
args.execution_plan_name,
args.execution_plan_path,
buffer_size,
args.nthreads_per_block,
dtype,
packet_type,
args.seed,
)
2 changes: 2 additions & 0 deletions src/executor/execution_plan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,8 @@ size_t ExecutionPlan::Impl::getOffset(int rank, size_t inputSize, uint32_t chunk
const int nGroups = this->chunkGroups.at(rank);
uint32_t nInputChunks = this->inputChunks.at(rank);
uint32_t nelems = inputSize / (alignment * sizeof(uint8_t));

assert(nelems % nGroups == 0 && "inputSize must be a multiple of nGroups");
Binyang2014 marked this conversation as resolved.
Show resolved Hide resolved
int nelemsPerGroup = nelems / nGroups;
int nChunksPerGroup = nInputChunks / nGroups;
uint32_t minNelems = nelemsPerGroup / nChunksPerGroup;
Expand Down
12 changes: 6 additions & 6 deletions src/include/execution_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,13 +234,13 @@ MSCCLPP_DEVICE_INLINE void handleReadReduceCopySend(T* output, uint32_t outputOf
}

template <typename PacketType>
MSCCLPP_DEVICE_INLINE void handlePutPacket(uint32_t inputOffsetByBytes, size_t scratchSize,
DeviceHandle<SmChannel>* smChannels, uint8_t* dstChannelIndexes,
uint32_t* dstOffsets, int nDstChannels, uint32_t size, uint32_t flag) {
MSCCLPP_DEVICE_INLINE void handlePutPacket(size_t scratchSize, DeviceHandle<SmChannel>* smChannels,
uint8_t* dstChannelIndexes, uint32_t* dstOffsets, uint32_t* srcOffsets,
int nDstChannels, uint32_t size, uint32_t flag) {
const size_t scratchBaseOffset = flag & 0x1 ? 0 : scratchSize >> 1;
for (int index = 0; index < nDstChannels; ++index) {
smChannels[dstChannelIndexes[index]].putPackets<PacketType>(
scratchBaseOffset + dstOffsets[index] * 2, inputOffsetByBytes, size, threadIdx.x, blockDim.x, flag);
smChannels[dstChannelIndexes[index]].putPackets<PacketType>(scratchBaseOffset + dstOffsets[index] * 2,
srcOffsets[index], size, threadIdx.x, blockDim.x, flag);
}
}

Expand Down Expand Up @@ -376,7 +376,7 @@ __global__ void executionKernel([[maybe_unused]] int rank /*for debug*/, T* inpu
op.inputChannelIndexes, op.outputOffsets, op.inputOffsets, op.nOutputs, op.nInputs,
op.size, false);
} else if (op.type == OperationType::PUT_PACKET) {
handlePutPacket<PacketType>(op.srcOffset, scratchSize, smChannels, op.outputChannelIndexes, op.outputOffsets,
handlePutPacket<PacketType>(scratchSize, smChannels, op.outputChannelIndexes, op.outputOffsets, op.inputOffsets,
op.nOutputs, op.size, flag);
} else if (op.type == OperationType::REDUCE_SEND_PACKET) {
T* dst = getBuffer(input, output, scratch, op.dstBufferType);
Expand Down
2 changes: 2 additions & 0 deletions src/nvls.cc
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,8 @@ class NvlsConnection::Impl {
};
#endif // !(USE_NVLS)

const int NvlsConnection::DefaultNvlsBufferSize = (1 << 29);
Binyang2014 marked this conversation as resolved.
Show resolved Hide resolved

NvlsConnection::NvlsConnection(size_t bufferSize, int numDevices)
: pimpl_(std::make_shared<Impl>(bufferSize, numDevices)) {}

Expand Down
1 change: 1 addition & 0 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ endfunction()
add_test_executable(allgather_test_cpp allgather_test_cpp.cu)
add_test_executable(allgather_test_host_offloading allgather_test_host_offloading.cu)
add_test_executable(nvls_test nvls_test.cu)
add_test_executable(executor_test executor_test.cc)

configure_file(run_mpi_test.sh.in run_mpi_test.sh)

Expand Down
126 changes: 126 additions & 0 deletions test/executor_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
#include <mpi.h>
#include <unistd.h>

#include <iostream>
#include <mscclpp/executor.hpp>
#include <mscclpp/utils.hpp>
#include <sstream>

namespace {
std::string getExecutablePath() {
char result[PATH_MAX];
ssize_t count = readlink("/proc/self/exe", result, PATH_MAX);
if (count == -1) {
throw std::runtime_error("Failed to get executable path");
}
return std::string(result, count);
}
} // namespace

double parseSize(const char* value) {
std::string valueStr(value);
std::istringstream iss(valueStr);
long long int units;
double size;
char size_lit = 0;

if (iss >> size) {
iss >> std::ws; // eat whitespace
iss >> size_lit;
} else {
return -1.0;
}

if (size_lit != 0 && !std::isspace(size_lit)) {
switch (size_lit) {
case 'G':
case 'g':
units = 1024 * 1024 * 1024;
break;
case 'M':
case 'm':
units = 1024 * 1024;
break;
case 'K':
case 'k':
units = 1024;
break;
default:
return -1.0;
};
} else {
units = 1;
}
return size * units;
}

double benchTime(int rank, std::shared_ptr<mscclpp::Bootstrap> bootstrap, std::shared_ptr<mscclpp::Executor> executor,
const mscclpp::ExecutionPlan& plan, std::shared_ptr<char> sendbuff, size_t bufferSize,
int nthreadsPerBlock, int niters, int ngrapthIters) {
mscclpp::CudaStreamWithFlags stream(cudaStreamNonBlocking);
cudaGraph_t graph;
cudaGraphExec_t graphExec;
mscclpp::Timer timer;
MSCCLPP_CUDATHROW(cudaStreamBeginCapture(stream, cudaStreamCaptureModeGlobal));
for (int i = 0; i < niters; i++) {
executor->execute(rank, sendbuff.get(), sendbuff.get(), bufferSize, bufferSize, mscclpp::DataType::FLOAT16,
nthreadsPerBlock, plan, stream, mscclpp::PacketType::LL16);
}
MSCCLPP_CUDATHROW(cudaStreamEndCapture(stream, &graph));
MSCCLPP_CUDATHROW(cudaGraphInstantiate(&graphExec, graph, NULL, NULL, 0));
MSCCLPP_CUDATHROW(cudaGraphLaunch(graphExec, stream));
MSCCLPP_CUDATHROW(cudaStreamSynchronize(stream));
bootstrap->barrier();

timer.reset();
for (int i = 0; i < ngrapthIters; i++) {
MSCCLPP_CUDATHROW(cudaGraphLaunch(graphExec, stream));
}
MSCCLPP_CUDATHROW(cudaStreamSynchronize(stream));
double deltaSec = timer.elapsed() * 1.e-6;
deltaSec = deltaSec / (niters) / (ngrapthIters);
MSCCLPP_CUDATHROW(cudaGraphExecDestroy(graphExec));
MSCCLPP_CUDATHROW(cudaGraphDestroy(graph));
return deltaSec;
}

int main(int argc, char* argv[]) {
if (argc != 5) {
std::cerr << "Usage: " << argv[0] << " <buffer size>"
<< " <execution plan name>"
<< " <execution plan path>"
<< " <nthreads per block>" << std::endl;
return 1;
}

int rank;
int worldSize;
MPI_Init(NULL, NULL);
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
MPI_Comm_size(MPI_COMM_WORLD, &worldSize);
MSCCLPP_CUDATHROW(cudaSetDevice(rank));

const size_t bufferSize = parseSize(argv[1]);
const std::string executionPlanName = argv[2];
const std::string executionPlanPath = argv[3];
const int nthreadsPerBlock = std::stoi(argv[4]);

std::shared_ptr<mscclpp::TcpBootstrap> bootstrap;
mscclpp::UniqueId id;
bootstrap = std::make_shared<mscclpp::TcpBootstrap>(rank, worldSize);
if (rank == 0) id = bootstrap->createUniqueId();
MPI_Bcast(&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD);
bootstrap->initialize(id);
std::shared_ptr<mscclpp::Communicator> communicator = std::make_shared<mscclpp::Communicator>(bootstrap);
std::shared_ptr<mscclpp::Executor> executor = std::make_shared<mscclpp::Executor>(communicator);

std::string executablePath = getExecutablePath();
Binyang2014 marked this conversation as resolved.
Show resolved Hide resolved
mscclpp::ExecutionPlan plan(executionPlanName, executionPlanPath);
std::shared_ptr<char> sendbuff = mscclpp::allocExtSharedCuda<char>(bufferSize);
std::vector<int> dataHost(bufferSize / sizeof(int), rank);
MSCCLPP_CUDATHROW(cudaMemcpy(sendbuff.get(), dataHost.data(), bufferSize, cudaMemcpyHostToDevice));
double deltaSec = benchTime(rank, bootstrap, executor, plan, sendbuff, bufferSize, nthreadsPerBlock, 200, 20);
std::cout << "Rank " << rank << ": " << bufferSize << " bytes " << deltaSec * 1.e6 << " us" << std::endl;
MPI_Finalize();
return 0;
}
Loading