Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement MMIO with put and signal #55

Merged
merged 12 commits into from
Apr 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ TESTSOBJTARGETS := $(TESTSOBJS:%=$(BUILDDIR)/$(OBJDIR)/%)
TESTSBINS := $(patsubst %.o,$(BUILDDIR)/$(BINDIR)/%,$(TESTSOBJS))

MSCLLPPTESTSOBJSDIR:= $(BUILDDIR)/$(OBJDIR)/$(TESTSDIR)
MSCLLPPTESTBINFILESLIST := allgather_test
MSCLLPPTESTBINFILESLIST := allgather_test ring_send_recv_test
MSCLLPPTESTBINS := $(MSCLLPPTESTBINFILESLIST:%=$(BUILDDIR)/$(BINDIR)/$(TESTSDIR)/%_perf)

INCLUDE := -Isrc -Isrc/include
Expand Down
42 changes: 40 additions & 2 deletions src/include/mscclpp.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,35 @@ struct mscclppDevConn
;
}

// Version that uses the SM directly to do the copy, instead of using the proxy thread like the functions above.
__forceinline__ __device__ void putDirect(uint64_t dstDataOffset, uint64_t srcDataOffset, uint64_t dataSize,
uint32_t threadId, uint32_t numThreads)
{
uint64_t* src = (uint64_t*)localBuff + srcDataOffset;
uint64_t* dst = (uint64_t*)remoteBuff + dstDataOffset;
// assume the memory is aligned to 8 bytes
size_t nElem =
dataSize % sizeof(uint64_t) ? (dataSize + sizeof(uint64_t)) / sizeof(uint64_t) : dataSize / sizeof(uint64_t);
for (size_t i = threadId; i < nElem; i += numThreads) {
dst[i] = src[i];
}
}

__forceinline__ __device__ void putDirect(uint64_t dataOffset, uint64_t dataSize, uint32_t threadId,
uint32_t numThreads)
{
putDirect(dataOffset, dataOffset, dataSize, threadId, numThreads);
}

__forceinline__ __device__ void signalDirect()
{
// This fence ensures that the writes from a preceding putDirect() are visible on the peer GPU before the
// incremented epoch id is visible.
__threadfence_system();
epochIncrement();
*(volatile uint64_t*)remoteEpochId = *sendEpochId;
chhwang marked this conversation as resolved.
Show resolved Hide resolved
}

__forceinline__ __device__ void wait()
{
(*recvEpochId) += 1;
Expand All @@ -143,6 +172,13 @@ struct mscclppDevConn
;
}

__forceinline__ __device__ void waitDirect()
{
(*recvEpochId) += 1;
while (*(volatile uint64_t*)directRecvEpochId < (*recvEpochId))
;
}

__forceinline__ __device__ void epochIncrement()
{
*(volatile uint64_t*)sendEpochId += 1;
Expand All @@ -153,11 +189,13 @@ struct mscclppDevConn
int tag;

void* localBuff;
uint64_t* sendEpochId; // this is read and written by the GPU
uint64_t* recvEpochId; // this is the copy of the remote epoch id.
uint64_t* sendEpochId; // this is read and written by the GPU
uint64_t* recvEpochId; // this is the expected recv epoch id.
uint64_t* directRecvEpochId; // this is read and written by remote GPU.
chhwang marked this conversation as resolved.
Show resolved Hide resolved

void* remoteBuff;
uint64_t* remoteFlag;
uint64_t* remoteEpochId;
uint64_t* proxyEpochId; // this is only written by the proxy thread

// this is a concurrent fifo which is multiple threads from the device
Expand Down
8 changes: 7 additions & 1 deletion src/init.cc
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ mscclppResult_t mscclppCommDestroy(mscclppComm_t comm)
if (conn) {
MSCCLPPCHECK(mscclppCudaFree(conn->devConn->sendEpochId));
MSCCLPPCHECK(mscclppCudaFree(conn->devConn->recvEpochId));
MSCCLPPCHECK(mscclppCudaFree(conn->devConn->directRecvEpochId));
}
}

Expand Down Expand Up @@ -421,6 +422,8 @@ mscclppResult_t mscclppConnect(mscclppComm_t comm, int remoteRank, int tag, void
conn->devConn->localBuff = localBuff;
MSCCLPPCHECK(mscclppCudaCalloc(&conn->devConn->sendEpochId, 1));
MSCCLPPCHECK(mscclppCudaCalloc(&conn->devConn->recvEpochId, 1));
MSCCLPPCHECK(mscclppCudaCalloc(&conn->devConn->directRecvEpochId, 1));

conn->devConn->remoteRank = remoteRank;
conn->devConn->tag = tag;
conn->devConn->fifo.connId = comm->nConns;
Expand All @@ -433,7 +436,6 @@ mscclppResult_t mscclppConnect(mscclppComm_t comm, int remoteRank, int tag, void
conn->devConn->fifo.triggerFifoTail = proxyState->fifoTailDev;

comm->nConns++;

// change the numa binding back to user's
MSCCLPPCHECK(setNumaState(curProcessState));

Expand All @@ -445,6 +447,7 @@ struct connInfo
cudaIpcMemHandle_t handleBuff;
cudaIpcMemHandle_t handleFlag;
cudaIpcMemHandle_t handleProxyFlag;
cudaIpcMemHandle_t handleRemoteEpochId;
mscclppIbQpInfo infoQp;
mscclppIbMrInfo infoBuffMr;
mscclppIbMrInfo infoLocalFlagMr;
Expand All @@ -462,6 +465,7 @@ mscclppResult_t mscclppP2pConnectionSetupStart(struct connInfo* connInfo /*outpu
CUDACHECK(cudaIpcGetMemHandle(&connInfo->handleProxyFlag, devConn->proxyEpochId));
CUDACHECK(cudaIpcGetMemHandle(&connInfo->handleBuff, devConn->localBuff));
CUDACHECK(cudaIpcGetMemHandle(&connInfo->handleFlag, devConn->sendEpochId));
CUDACHECK(cudaIpcGetMemHandle(&connInfo->handleRemoteEpochId, devConn->directRecvEpochId));
return mscclppSuccess;
}

Expand All @@ -475,6 +479,8 @@ mscclppResult_t mscclppP2pConnectionSetupEnd(struct connInfo* connInfo /*input*/
cudaIpcOpenMemHandle((void**)&conn->devConn->remoteBuff, connInfo->handleBuff, cudaIpcMemLazyEnablePeerAccess));
CUDACHECK(
cudaIpcOpenMemHandle((void**)&conn->devConn->remoteFlag, connInfo->handleFlag, cudaIpcMemLazyEnablePeerAccess));
CUDACHECK(cudaIpcOpenMemHandle((void**)&conn->devConn->remoteEpochId, connInfo->handleRemoteEpochId,
cudaIpcMemLazyEnablePeerAccess));
CUDACHECK(
cudaIpcOpenMemHandle((void**)&conn->remoteProxyFlag, connInfo->handleProxyFlag, cudaIpcMemLazyEnablePeerAccess));
return mscclppSuccess;
Expand Down
134 changes: 134 additions & 0 deletions tests/ring_send_recv_test.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
#include "comm.h"
#include "common.h"

#include <stdio.h>
#include <stdlib.h>
#include <string>
#include <unistd.h>

#define BLOCK_THREADS_NUM 128

#define ALIGN 4

__global__ void initKernel(int* dataDst, int dataCount)
{
for (size_t i = threadIdx.x; i < dataCount; i += blockDim.x) {
dataDst[i] = i % 256;
}
}

__constant__ mscclppDevConn_t sendConnConst;
__constant__ mscclppDevConn_t recvConnConst;

__global__ void kernel(bool root, size_t dataSize)
{
mscclppDevConn_t sendConn = sendConnConst;
mscclppDevConn_t recvConn = recvConnConst;

if (root) {
sendConn.putDirect(0, dataSize, threadIdx.x, blockDim.x);
// make sure all the threads have put their data
__syncthreads();
if (threadIdx.x == 0) {
sendConn.signalDirect();
recvConn.waitDirect();
}
} else {
if (threadIdx.x == 0) {
recvConn.waitDirect();
}
// make sure we get the latest data
__syncthreads();
sendConn.putDirect(0, dataSize, threadIdx.x, blockDim.x);
__syncthreads();
if (threadIdx.x == 0) {
sendConn.signalDirect();
}
}
}

testResult_t resetData(int* dataDst, size_t dataCount, bool isRoot)
{
if (isRoot) {
initKernel<<<1, BLOCK_THREADS_NUM>>>(dataDst, dataCount);
} else {
CUDACHECK(cudaMemset(dataDst, 0, dataCount * sizeof(int)));
}
return testSuccess;
}

void RingSendRecvGetCollByteCount(size_t* sendcount, size_t* recvcount, size_t* paramcount, size_t* sendInplaceOffset,
size_t* recvInplaceOffset, size_t count, int nranks)
{
size_t base = (count / ALIGN) * ALIGN;
*sendcount = base;
*recvcount = base;
*sendInplaceOffset = base;
*recvInplaceOffset = 0;
*paramcount = base;
}

testResult_t RingSendRecvInitData(struct testArgs* args, int in_place)
{
size_t recvcount = args->expectedBytes / sizeof(int);

CUDACHECK(cudaSetDevice(args->gpuNum));
int rank = args->proc;
CUDACHECK(cudaMemset(args->recvbuff, 0, args->expectedBytes));
resetData((int*)args->recvbuff, recvcount, rank == 0);

int* dataHost = new int[recvcount];
for (size_t i = 0; i < recvcount; i++) {
dataHost[i] = i % 256;
}
CUDACHECK(cudaMemcpy(args->expected, dataHost, recvcount * sizeof(int), cudaMemcpyHostToDevice));
delete dataHost;
CUDACHECK(cudaDeviceSynchronize());
MSCCLPPCHECK(mscclppBootstrapBarrier(args->comm));
return testSuccess;
}

void RingSendRecvGetBw(size_t count, int typesize, double sec, double* algBw, double* busBw, int nranks)
{
double baseBw = (double)(count * typesize * nranks) / 1.0E9 / sec;

*algBw = baseBw;
double factor = ((double)(nranks - 1)) / ((double)nranks);
*busBw = baseBw * factor;
}

testResult_t RingSendRecvRunColl(void* sendbuff, void* recvbuff, int nranksPerNode, size_t count, mscclppComm_t comm,
cudaStream_t stream, int kernel_num)
{
kernel<<<1, BLOCK_THREADS_NUM, 0, stream>>>(comm->rank == 0, count);
return testSuccess;
}

struct testColl ringSendRecvTest = {"RingSendRecvTest", RingSendRecvGetCollByteCount, RingSendRecvInitData,
RingSendRecvGetBw, RingSendRecvRunColl};

void RingSendRecvGetBuffSize(size_t* sendcount, size_t* recvcount, size_t count, int nranks)
{
size_t paramcount, sendInplaceOffset, recvInplaceOffset;
RingSendRecvGetCollByteCount(sendcount, recvcount, &paramcount, &sendInplaceOffset, &recvInplaceOffset, count,
nranks);
}

testResult_t RingSendRecvRunTest(struct testArgs* args)
{
args->collTest = &ringSendRecvTest;
int rank = args->proc, worldSize = args->totalProcs;

mscclppDevConn_t* sendDevConn;
mscclppDevConn_t* recvDevConn;
MSCCLPPCHECK(mscclppGetDeviceConnection(args->comm, (rank + 1) % worldSize, 0, &sendDevConn));
MSCCLPPCHECK(mscclppGetDeviceConnection(args->comm, (rank - 1 + worldSize) % worldSize, 0, &recvDevConn));
CUDACHECK(cudaMemcpyToSymbol(sendConnConst, sendDevConn, sizeof(mscclppDevConn_t)));
CUDACHECK(cudaMemcpyToSymbol(recvConnConst, recvDevConn, sizeof(mscclppDevConn_t)));
TESTCHECK(TimeTest(args));
return testSuccess;
}

struct testEngine ringSendRecvTestEngine = {RingSendRecvGetBuffSize, RingSendRecvRunTest};

#pragma weak mscclppTestEngine = ringSendRecvTestEngine