diff --git a/CMakeLists.txt b/CMakeLists.txt index d946d2d5c..0a04be9ac 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -17,8 +17,8 @@ set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler -Wall,-Wextra") list(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake) -# clang-format targets -include(${PROJECT_SOURCE_DIR}/cmake/AddClangFormatTargets.cmake) +# Format targets +include(${PROJECT_SOURCE_DIR}/cmake/AddFormatTargets.cmake) # Options option(ENABLE_TRACE "Enable tracing" OFF) diff --git a/README.md b/README.md index e5f3bf459..56a2fcf1e 100644 --- a/README.md +++ b/README.md @@ -67,7 +67,7 @@ mscclpp::Communicator comm(bootstrap); // Setup connections here using `comm` ... // Construct the default proxy -mscclpp::ProxyService proxyService(comm); +mscclpp::ProxyService proxyService(); // Start the proxy proxyService.startProxy(); // Run the user application, i.e., launch GPU kernels here @@ -80,7 +80,7 @@ While the default implementation already enables any kinds of communication, MSC ```cpp // Proxy FIFO is obtained from mscclpp::Proxy on the host and copied to the device. -__device__ mscclpp::DeviceProxyFifo fifo; +__device__ mscclpp::FifoDeviceHandle fifo; __global__ void gpuKernel() { ... // Only one thread is needed for the followings diff --git a/cmake/AddClangFormatTargets.cmake b/cmake/AddClangFormatTargets.cmake deleted file mode 100644 index 07304ce3c..000000000 --- a/cmake/AddClangFormatTargets.cmake +++ /dev/null @@ -1,18 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -# Add targets to run clang-format - -find_program(CLANG_FORMAT clang-format) -if(CLANG_FORMAT) - message(STATUS "Found clang-format: ${CLANG_FORMAT}") - set(FIND_DIRS ${PROJECT_SOURCE_DIR}/src ${PROJECT_SOURCE_DIR}/include ${PROJECT_SOURCE_DIR}/python ${PROJECT_SOURCE_DIR}/test) - add_custom_target(check-format ALL - COMMAND ${CLANG_FORMAT} -style=file --dry-run `find ${FIND_DIRS} -type f -name *.h -o -name *.hpp -o -name *.c -o -name *.cc -o -name *.cpp -o -name *.cu` - ) - add_custom_target(format - COMMAND ${CLANG_FORMAT} -style=file -i `find ${FIND_DIRS} -type f -name *.h -o -name *.hpp -o -name *.c -o -name *.cc -o -name *.cpp -o -name *.cu` - ) -else() - message(STATUS "clang-format not found.") -endif() diff --git a/cmake/AddFormatTargets.cmake b/cmake/AddFormatTargets.cmake new file mode 100644 index 000000000..71c3ef4ab --- /dev/null +++ b/cmake/AddFormatTargets.cmake @@ -0,0 +1,38 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +# Add targets to run clang-format and black + +add_custom_target(check-format) +add_custom_target(format) + +find_program(CLANG_FORMAT clang-format) +if(CLANG_FORMAT) + message(STATUS "Found clang-format: ${CLANG_FORMAT}") + set(FIND_DIRS ${PROJECT_SOURCE_DIR}/src ${PROJECT_SOURCE_DIR}/include ${PROJECT_SOURCE_DIR}/python ${PROJECT_SOURCE_DIR}/test) + add_custom_target(check-format-cpp ALL + COMMAND ${CLANG_FORMAT} -style=file --dry-run `find ${FIND_DIRS} -type f -name *.h -o -name *.hpp -o -name *.c -o -name *.cc -o -name *.cpp -o -name *.cu` + ) + add_dependencies(check-format check-format-cpp) + add_custom_target(format-cpp + COMMAND ${CLANG_FORMAT} -style=file -i `find ${FIND_DIRS} -type f -name *.h -o -name *.hpp -o -name *.c -o -name *.cc -o -name *.cpp -o -name *.cu` + ) + add_dependencies(format format-cpp) +else() + message(STATUS "clang-format not found.") +endif() + +find_program(BLACK black) +if (BLACK) + message(STATUS "Found black: ${BLACK}") + add_custom_target(check-format-py + COMMAND ${BLACK} --config ${PROJECT_SOURCE_DIR}/pyproject.toml --check ${PROJECT_SOURCE_DIR}/python ${PROJECT_SOURCE_DIR}/test + ) + add_dependencies(check-format check-format-py) + add_custom_target(format-py + COMMAND ${BLACK} --config ${PROJECT_SOURCE_DIR}/pyproject.toml ${PROJECT_SOURCE_DIR}/python ${PROJECT_SOURCE_DIR}/test + ) + add_dependencies(format format-py) +else() + message(STATUS, "black not found.") +endif() diff --git a/include/mscclpp/config.hpp b/include/mscclpp/config.hpp deleted file mode 100644 index 10fce659c..000000000 --- a/include/mscclpp/config.hpp +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -#ifndef MSCCLPP_CONFIG_H_ -#define MSCCLPP_CONFIG_H_ - -namespace mscclpp { - -class Config { - public: - int bootstrapConnectionTimeout = 30; - - static Config* getInstance(); - int getBootstrapConnectionTimeoutConfig(); - void setBootstrapConnectionTimeoutConfig(int timeout); - - private: - Config() = default; - Config(const Config&) = delete; - Config& operator=(const Config&) = delete; - - static Config instance_; -}; - -} // namespace mscclpp - -#endif // end include guard diff --git a/include/mscclpp/core.hpp b/include/mscclpp/core.hpp index 2b1b39a19..d3af3103a 100644 --- a/include/mscclpp/core.hpp +++ b/include/mscclpp/core.hpp @@ -61,11 +61,13 @@ class TcpBootstrap : public Bootstrap { /// Initialize the @ref TcpBootstrap with a given unique ID. /// @param uniqueId The unique ID to initialize the @ref TcpBootstrap with. - void initialize(UniqueId uniqueId); + /// @param timeoutSec The connection timeout in seconds. + void initialize(UniqueId uniqueId, int64_t timeoutSec = 30); /// Initialize the @ref TcpBootstrap with a string formatted as "ip:port" or "interface:ip:port". /// @param ifIpPortTrio The string formatted as "ip:port" or "interface:ip:port". - void initialize(const std::string& ifIpPortTrio); + /// @param timeoutSec The connection timeout in seconds. + void initialize(const std::string& ifIpPortTrio, int64_t timeoutSec = 30); /// Return the rank of the process. int getRank() override; @@ -384,7 +386,7 @@ class Connection { virtual void updateAndSync(RegisteredMemory dst, uint64_t dstOffset, uint64_t* src, uint64_t newValue) = 0; /// Flush any pending writes to the remote process. - virtual void flush() = 0; + virtual void flush(int64_t timeoutUsec = 3e7) = 0; /// Get the rank of the remote process. /// @@ -533,8 +535,14 @@ class Communicator { /// @param remoteRank The rank of the remote process. /// @param tag The tag of the connection for identifying it. /// @param transport The type of transport to be used. + /// @param ibMaxCqSize The maximum number of completion queue entries for IB. Unused if transport is not IB. + /// @param ibMaxCqPollNum The maximum number of completion queue entries to poll for IB. Unused if transport is not + /// IB. + /// @param ibMaxSendWr The maximum number of outstanding send work requests for IB. Unused if transport is not IB. + /// @param ibMaxWrPerSend The maximum number of work requests per send for IB. Unused if transport is not IB. /// @return std::shared_ptr A shared pointer to the connection. - std::shared_ptr connectOnSetup(int remoteRank, int tag, Transport transport); + std::shared_ptr connectOnSetup(int remoteRank, int tag, Transport transport, int ibMaxCqSize = 1024, + int ibMaxCqPollNum = 1, int ibMaxSendWr = 8192, int ibMaxWrPerSend = 64); /// Add a custom Setuppable object to a list of objects to be setup later, when @ref setup() is called. /// diff --git a/include/mscclpp/fifo.hpp b/include/mscclpp/fifo.hpp index 2bed5702b..be31b80a0 100644 --- a/include/mscclpp/fifo.hpp +++ b/include/mscclpp/fifo.hpp @@ -7,90 +7,25 @@ #include #include #include +#include #include -#define MSCCLPP_PROXY_FIFO_SIZE 128 - namespace mscclpp { -/// A struct representing a pair of 64-bit unsigned integers used as a trigger for the proxy. -/// -/// This struct is used as a work element in the concurrent FIFO where multiple device threads can push -/// ProxyTrigger elements and a single host proxy thread consumes these work elements. -/// -struct alignas(16) ProxyTrigger { - uint64_t fst, snd; -}; - -/// A concurrent FIFO where multiple device threads can push work elements and a single host proxy thread consumes them. -/// -/// The FIFO has a head pointer allocated on the device which starts at 0 and goes up to 2^64-1, which is almost -/// infinity. There are two copies of the tail, one on the device, @ref DeviceProxyFifo::tailReplica, and another on the -/// host, namely, hostTail. The host always has the "true" tail and occasionally pushes it to the copy on the device. -/// Therefore, most of the time, the device has a stale version. The invariants are: tailReplica <= hostTail <= head. -/// The @ref push() function increments head, hostTail is updated in @ref HostProxyFifo::pop(), and it occasionally -/// flushes it to tailReplica via @ref HostProxyFifo::flushTail(). -/// -/// Duplicating the tail is a good idea because the FIFO is large enough, and we do not need frequent updates for the -/// tail as there is usually enough space for device threads to push their work into. -/// -struct DeviceProxyFifo { -#ifdef __CUDACC__ - /// Push a trigger to the FIFO. - /// - /// @param trigger The trigger to push. - /// @return The new head of the FIFO. - __forceinline__ __device__ uint64_t push(ProxyTrigger trigger) { - uint64_t curFifoHead = atomicAdd((unsigned long long int*)this->head, 1); - - // Only one of two conditions need to be met to proceed. Either the tail has advanced enough or where we need to - // write to is 0. However, the first condition is faster to check since the tail is flushed periodically anyways but - // for the second condition we need to read CPU memory. - // As volatile access is slow, we first check using the bare pointer and then use the volatile pointer if the - // condition is not met. - if (curFifoHead >= MSCCLPP_PROXY_FIFO_SIZE + *(this->tailReplica)) { - OR_POLL_MAYBE_JAILBREAK(curFifoHead >= MSCCLPP_PROXY_FIFO_SIZE + *((volatile uint64_t*)this->tailReplica), - *(volatile uint64_t*)&this->triggers[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE] != 0, - 1000000); - } - - ProxyTrigger* triggerPtr = (ProxyTrigger*)&(this->triggers[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE]); - asm volatile("st.volatile.global.v2.u64 [%0], {%1,%2};" ::"l"(triggerPtr), "l"(trigger.fst), "l"(trigger.snd)); - return curFifoHead; - } - - /// Wait until there is a place in the FIFO to push a trigger. - /// - /// @param curFifoHead The current head of the FIFO. - __forceinline__ __device__ void sync(uint64_t curFifoHead) { - // Same as push but in this case checking the fist condition is probably faster since for tail to be pushed we need - // to wait for cudaMemcpy to be done. - OR_POLL_MAYBE_JAILBREAK(*(volatile uint64_t*)&(this->triggers[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE]) != 0, - *(volatile uint64_t*)(this->tailReplica) <= curFifoHead, 1000000); - } -#endif // __CUDACC__ - - /// The FIFO buffer that is allocated on the host via `cudaHostAlloc()`. - ProxyTrigger* triggers; - /// Replica of the FIFO tail that is allocated on device. - uint64_t* tailReplica; - /// The FIFO head. Allocated on the device and only accessed by the device. - uint64_t* head; -}; - /// A class representing a host proxy FIFO that can consume work elements pushed by device threads. -class HostProxyFifo { +class Fifo { public: - /// Constructs a new @ref HostProxyFifo object. - HostProxyFifo(); + /// Constructs a new @ref Fifo object. + /// @param size The number of entires in the FIFO. + Fifo(int size = 128); - /// Destroys the @ref HostProxyFifo object. - ~HostProxyFifo(); + /// Destroys the @ref Fifo object. + ~Fifo(); /// Polls the FIFO for a trigger. /// - /// @param trigger A pointer to the trigger to be filled. - void poll(ProxyTrigger* trigger); + /// Returns @ref ProxyTrigger which is the trigger at the head of fifo. + ProxyTrigger poll(); /// Pops a trigger from the FIFO. void pop(); @@ -100,10 +35,14 @@ class HostProxyFifo { /// @param sync If true, waits for the flush to complete before returning. void flushTail(bool sync = false); - /// Returns a @ref DeviceProxyFifo object representing the device FIFO. + /// Return the FIFO size. + /// @return The FIFO size. + int size() const; + + /// Returns a @ref FifoDeviceHandle object representing the device FIFO. /// - /// @return A @ref DeviceProxyFifo object representing the device FIFO. - DeviceProxyFifo deviceFifo(); + /// @return A @ref FifoDeviceHandle object representing the device FIFO. + FifoDeviceHandle deviceHandle(); private: struct Impl; diff --git a/include/mscclpp/fifo_device.hpp b/include/mscclpp/fifo_device.hpp new file mode 100644 index 000000000..f993258fa --- /dev/null +++ b/include/mscclpp/fifo_device.hpp @@ -0,0 +1,83 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#ifndef MSCCLPP_FIFO_DEVICE_HPP_ +#define MSCCLPP_FIFO_DEVICE_HPP_ + +#include "poll.hpp" + +namespace mscclpp { + +/// A struct representing a pair of 64-bit unsigned integers used as a trigger for the proxy. +/// +/// This struct is used as a work element in the concurrent FIFO where multiple device threads can push +/// ProxyTrigger elements and a single host proxy thread consumes these work elements. +/// +/// Do not use the most significant bit of @ref snd as it is reserved for memory consistency purposes +struct alignas(16) ProxyTrigger { + uint64_t fst, snd; +}; + +/// A concurrent FIFO where multiple device threads can push work elements and a single host proxy thread consumes them. +/// +/// The FIFO has a head pointer allocated on the device which starts at 0 and goes up to 2^64-1, which is almost +/// infinity. There are two copies of the tail, one on the device, @ref FifoDeviceHandle::tailReplica, and another on +/// the host, namely, hostTail. The host always has the "true" tail and occasionally pushes it to the copy on the +/// device. Therefore, most of the time, the device has a stale version. The invariants are: tailReplica <= hostTail <= +/// head. The @ref push() function increments head, hostTail is updated in @ref Fifo::pop(), and it occasionally flushes +/// it to tailReplica via @ref Fifo::flushTail(). +/// +/// Duplicating the tail is a good idea because the FIFO is large enough, and we do not need frequent updates for the +/// tail as there is usually enough space for device threads to push their work into. +/// +struct FifoDeviceHandle { +#ifdef __CUDACC__ + /// Push a trigger to the FIFO. + /// + /// @param trigger The trigger to push. + /// @return The new head of the FIFO. + __forceinline__ __device__ uint64_t push(ProxyTrigger trigger) { + uint64_t curFifoHead = atomicAdd((unsigned long long int*)this->head, 1); + // make the last bit intentionally non-zero so that we can safely poll. Don't worry, we will change it back in host + // side + trigger.snd ^= ((uint64_t)1 << (uint64_t)63); + + // Only one of two conditions need to be met to proceed. Either the tail has advanced enough or where we need to + // write to is 0. However, the first condition is faster to check since the tail is flushed periodically anyways but + // for the second condition we need to read CPU memory. + // As volatile access is slow, we first check using the bare pointer and then use the volatile pointer if the + // condition is not met. + if (curFifoHead >= size + *(this->tailReplica)) { + OR_POLL_MAYBE_JAILBREAK(curFifoHead >= size + *((volatile uint64_t*)this->tailReplica), + *(volatile uint64_t*)&this->triggers[curFifoHead % size] != 0, 1000000); + } + + ProxyTrigger* triggerPtr = (ProxyTrigger*)&(this->triggers[curFifoHead % size]); + asm volatile("st.volatile.global.v2.u64 [%0], {%1,%2};" ::"l"(triggerPtr), "l"(trigger.fst), "l"(trigger.snd)); + return curFifoHead; + } + + /// Wait until there is a place in the FIFO to push a trigger. + /// + /// @param curFifoHead The current head of the FIFO. + __forceinline__ __device__ void sync(uint64_t curFifoHead) { + // Same as push but in this case checking the fist condition is probably faster since for tail to be pushed we need + // to wait for cudaMemcpy to be done. + OR_POLL_MAYBE_JAILBREAK(*(volatile uint64_t*)&(this->triggers[curFifoHead % size]) != 0, + *(volatile uint64_t*)(this->tailReplica) <= curFifoHead, 1000000); + } +#endif // __CUDACC__ + + /// The FIFO buffer that is allocated on the host via `cudaHostAlloc()`. + ProxyTrigger* triggers; + /// Replica of the FIFO tail that is allocated on device. + uint64_t* tailReplica; + /// The FIFO head. Allocated on the device and only accessed by the device. + uint64_t* head; + /// The FIFO size. + int size; +}; + +} // namespace mscclpp + +#endif // MSCCLPP_FIFO_DEVICE_HPP_ diff --git a/src/include/numa.hpp b/include/mscclpp/numa.hpp similarity index 100% rename from src/include/numa.hpp rename to include/mscclpp/numa.hpp diff --git a/include/mscclpp/packet.hpp b/include/mscclpp/packet.hpp index f9e126bdf..1742a202c 100644 --- a/include/mscclpp/packet.hpp +++ b/include/mscclpp/packet.hpp @@ -4,6 +4,8 @@ #ifndef MSCCLPP_PACKET_HPP_ #define MSCCLPP_PACKET_HPP_ +#include "poll.hpp" + namespace mscclpp { /// LL (low latency) protocol packet. @@ -42,17 +44,24 @@ union LLPacket { "r"((uint32_t)(val >> 32)), "r"(flag)); } + /// Helper of @ref read(). + /// @param flag The flag to read. + /// @param data The 8-byte data read. + /// @return True if the flag is not equal to the given flag. + __forceinline__ __device__ bool readOnce(uint32_t flag, uint2& data) { + uint32_t flag1, flag2; + asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" + : "=r"(data.x), "=r"(flag1), "=r"(data.y), "=r"(flag2) + : "l"(v)); + return (flag1 != flag) || (flag2 != flag); + } + /// Read 8 bytes of data from the packet. /// @param flag The flag to read. /// @return The 8-byte data read. __forceinline__ __device__ uint2 read(uint32_t flag) { uint2 data; - uint32_t flag1, flag2; - do { - asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" - : "=r"(data.x), "=r"(flag1), "=r"(data.y), "=r"(flag2) - : "l"(v)); - } while ((flag1 != flag) || (flag2 != flag)); + POLL_MAYBE_JAILBREAK(readOnce(flag, data), 100000000); return data; } @@ -80,6 +89,7 @@ __forceinline__ __device__ void putPackets(void* dst, uint64_t dstOffset, void* __forceinline__ __device__ void getPackets(void* dst, uint64_t dstOffset, void* src, uint64_t srcOffset, uint64_t dstBytes, uint32_t threadId, uint32_t numThreads, uint32_t flag) { // Offsets should be aligned to 8 bytes & size should be a multiple of 8 bytes + // TODO(saemal): this is not matching sm_channel get method. LLPacket* srcBase = (LLPacket*)((char*)src + srcOffset); uint2* dstBase = (uint2*)((char*)dst + dstOffset); size_t nElem = dstBytes / sizeof(uint2); diff --git a/include/mscclpp/poll.hpp b/include/mscclpp/poll.hpp index c1ed909f2..cb32a9743 100644 --- a/include/mscclpp/poll.hpp +++ b/include/mscclpp/poll.hpp @@ -6,14 +6,8 @@ #ifdef __CUDACC__ -#ifndef NDEBUG -// TODO(chhwang): https://github.com/microsoft/mscclpp/issues/99 -#define POLL_PRINT_ON_STUCK(__cond) -// #include -// #define POLL_PRINT_ON_STUCK(__cond) do { printf("mscclpp: spin is stuck. condition: " #__cond "\n"); } while (0); -#else // NDEBUG -#define POLL_PRINT_ON_STUCK(__cond) -#endif // NDEBUG +extern __device__ void __assert_fail(const char *__assertion, const char *__file, unsigned int __line, + const char *__function) __THROW; // If a spin is stuck, escape from it and set status to 1. #define POLL_MAYBE_JAILBREAK_ESCAPE(__cond, __max_spin_cnt, __status) \ @@ -22,7 +16,6 @@ __status = 0; \ while (__cond) { \ if (__spin_cnt++ == __max_spin_cnt) { \ - POLL_PRINT_ON_STUCK(__cond); \ __status = 1; \ break; \ } \ @@ -30,31 +23,31 @@ } while (0); // If a spin is stuck, print a warning and keep spinning. -#define POLL_MAYBE_JAILBREAK(__cond, __max_spin_cnt) \ - do { \ - uint64_t __spin_cnt = 0; \ - while (__cond) { \ - if (__spin_cnt++ == __max_spin_cnt) { \ - POLL_PRINT_ON_STUCK(__cond); \ - } \ - } \ +#define POLL_MAYBE_JAILBREAK(__cond, __max_spin_cnt) \ + do { \ + uint64_t __spin_cnt = 0; \ + while (__cond) { \ + if (__spin_cnt++ == __max_spin_cnt) { \ + __assert_fail(#__cond, __FILE__, __LINE__, __PRETTY_FUNCTION__); \ + } \ + } \ } while (0); // the as POLL_MAYBE_JAILBREAK except that __cond1 is checked before __cond2 // this is specially useful when __cond1 is faster to check -#define OR_POLL_MAYBE_JAILBREAK(__cond1, __cond2, __max_spin_cnt) \ - do { \ - uint64_t __spin_cnt = 0; \ - while (true) { \ - if (!(__cond1)) { \ - break; \ - } else if (!(__cond2)) { \ - break; \ - } \ - if (__spin_cnt++ == __max_spin_cnt) { \ - POLL_PRINT_ON_STUCK(__cond); \ - } \ - } \ +#define OR_POLL_MAYBE_JAILBREAK(__cond1, __cond2, __max_spin_cnt) \ + do { \ + uint64_t __spin_cnt = 0; \ + while (true) { \ + if (!(__cond1)) { \ + break; \ + } else if (!(__cond2)) { \ + break; \ + } \ + if (__spin_cnt++ == __max_spin_cnt) { \ + __assert_fail(#__cond1 #__cond2, __FILE__, __LINE__, __PRETTY_FUNCTION__); \ + } \ + } \ } while (0); #endif // __CUDACC__ diff --git a/include/mscclpp/proxy.hpp b/include/mscclpp/proxy.hpp index a69ace8a4..359c9bca2 100644 --- a/include/mscclpp/proxy.hpp +++ b/include/mscclpp/proxy.hpp @@ -28,7 +28,7 @@ class Proxy { void start(); void stop(); - HostProxyFifo& fifo(); + Fifo& fifo(); private: struct Impl; diff --git a/include/mscclpp/proxy_channel.hpp b/include/mscclpp/proxy_channel.hpp index 99737d4ce..2c6446480 100644 --- a/include/mscclpp/proxy_channel.hpp +++ b/include/mscclpp/proxy_channel.hpp @@ -5,18 +5,12 @@ #define MSCCLPP_PROXY_CHANNEL_HPP_ #include -#include #include +#include #include namespace mscclpp { -using SemaphoreId = uint32_t; - -/// Numeric ID of @ref RegisteredMemory. @ref ProxyService has an internal array indexed by these handles mapping to the -/// actual. -using MemoryId = uint32_t; - struct ProxyChannel; /// Base class for proxy services. Proxy services are used to proxy data between devices. @@ -32,13 +26,17 @@ class BaseProxyService { class ProxyService : public BaseProxyService { public: /// Constructor. - /// @param communicator The communicator to use. - ProxyService(Communicator& communicator); + ProxyService(); - /// Add a semaphore to the proxy service. + /// Build and add a semaphore to the proxy service. /// @param connection The connection associated with the semaphore. /// @return The ID of the semaphore. - SemaphoreId addSemaphore(std::shared_ptr connection); + SemaphoreId buildAndAddSemaphore(Communicator& communicator, std::shared_ptr connection); + + /// Add a semaphore to the proxy service. + /// @param semaphore The semaphore to be added + /// @return The ID of the semaphore. + SemaphoreId addSemaphore(std::shared_ptr semaphore); /// Register a memory region with the proxy service. /// @param memory The memory region to register. @@ -53,7 +51,7 @@ class ProxyService : public BaseProxyService { /// Get a proxy channel by semaphore ID. /// @param id The ID of the semaphore. /// @return The proxy channel. - ProxyChannel deviceChannel(SemaphoreId id); + ProxyChannel proxyChannel(SemaphoreId id); /// Start the proxy service. void startProxy(); @@ -62,7 +60,6 @@ class ProxyService : public BaseProxyService { void stopProxy(); private: - Communicator& communicator_; std::vector> semaphores_; std::vector memories_; Proxy proxy_; @@ -73,170 +70,44 @@ class ProxyService : public BaseProxyService { ProxyHandlerResult handleTrigger(ProxyTrigger triggerRaw); }; -using TriggerType = uint64_t; -const TriggerType TriggerData = 0x1; // Trigger a data transfer. -const TriggerType TriggerFlag = 0x2; // Trigger a signaling. -const TriggerType TriggerSync = 0x4; // Trigger a flush. - -#define MSCCLPP_BITS_SIZE 32 -#define MSCCLPP_BITS_OFFSET 32 -#define MSCCLPP_BITS_REGMEM_HANDLE 8 -#define MSCCLPP_BITS_TYPE 3 -#define MSCCLPP_BITS_CONNID 10 - -/// Basic structure of each work element in the FIFO. -union ChannelTrigger { - ProxyTrigger value; - // The summation of number of bits must be 128 or less. - struct { - // First 64 bits: value[0] - uint64_t size : MSCCLPP_BITS_SIZE; - uint64_t srcOffset : MSCCLPP_BITS_OFFSET; - uint64_t : (64 - MSCCLPP_BITS_SIZE - MSCCLPP_BITS_OFFSET); // ensure 64-bit alignment - // Second 64 bits: value[1] - uint64_t dstOffset : MSCCLPP_BITS_OFFSET; - uint64_t srcMemoryId : MSCCLPP_BITS_REGMEM_HANDLE; - uint64_t dstMemoryId : MSCCLPP_BITS_REGMEM_HANDLE; - uint64_t type : MSCCLPP_BITS_TYPE; - uint64_t chanId : MSCCLPP_BITS_CONNID; - uint64_t : (64 - MSCCLPP_BITS_OFFSET - MSCCLPP_BITS_REGMEM_HANDLE - MSCCLPP_BITS_REGMEM_HANDLE - - MSCCLPP_BITS_TYPE); // ensure 64-bit alignment - } fields; - -#ifdef __CUDACC__ - /// Default constructor. - __device__ ChannelTrigger() {} - - /// Copy constructor. - __device__ ChannelTrigger(ProxyTrigger value) : value(value) {} - - /// Constructor. - /// @param type The type of the trigger. - /// @param dst The destination memory region. - /// @param dstOffset The offset into the destination memory region. - /// @param src The source memory region. - /// @param srcOffset The offset into the source memory region. - /// @param bytes The bytes of the transfer. - /// @param semaphoreId The ID of the semaphore. - __device__ ChannelTrigger(TriggerType type, MemoryId dst, uint64_t dstOffset, MemoryId src, uint64_t srcOffset, - uint64_t bytes, int semaphoreId) { - value.fst = ((srcOffset << MSCCLPP_BITS_SIZE) + bytes); - value.snd = ((((((((semaphoreId << MSCCLPP_BITS_TYPE) + (uint64_t)type) << MSCCLPP_BITS_REGMEM_HANDLE) + dst) - << MSCCLPP_BITS_REGMEM_HANDLE) + - src) - << MSCCLPP_BITS_OFFSET) + - dstOffset); - } -#endif // __CUDACC__ -}; - /// Proxy channel. struct ProxyChannel { - // Use DeviceHandle in device code. - typedef ProxyChannel DeviceHandle; + private: + SemaphoreId semaphoreId_; + Host2DeviceSemaphore::DeviceHandle semaphore_; + + // this is a concurrent fifo which is multiple threads from the device + // can produce for and the sole proxy thread consumes it. + FifoDeviceHandle fifo_; + + public: ProxyChannel() = default; - ProxyChannel(SemaphoreId semaphoreId, Host2DeviceSemaphore::DeviceHandle semaphore, DeviceProxyFifo fifo); + ProxyChannel(SemaphoreId semaphoreId, Host2DeviceSemaphore::DeviceHandle semaphore, FifoDeviceHandle fifo); ProxyChannel(const ProxyChannel& other) = default; ProxyChannel& operator=(ProxyChannel& other) = default; -#ifdef __CUDACC__ - /// Push a @ref TriggerData to the FIFO. - /// @param dst The destination memory region. - /// @param dstOffset The offset into the destination memory region. - /// @param src The source memory region. - /// @param srcOffset The offset into the source memory region. - /// @param size The size of the transfer. - __forceinline__ __device__ void put(MemoryId dst, uint64_t dstOffset, MemoryId src, uint64_t srcOffset, - uint64_t size) { - fifo_.push(ChannelTrigger(TriggerData, dst, dstOffset, src, srcOffset, size, semaphoreId_).value); - } - - /// Push a @ref TriggerData to the FIFO. - /// @param dst The destination memory region. - /// @param src The source memory region. - /// @param offset The common offset into the destination and source memory regions. - /// @param size The size of the transfer. - __forceinline__ __device__ void put(MemoryId dst, MemoryId src, uint64_t offset, uint64_t size) { - put(dst, offset, src, offset, size); - } - - /// Push a @ref TriggerFlag to the FIFO. - __forceinline__ __device__ void signal() { - fifo_.push(ChannelTrigger(TriggerFlag, 0, 0, 0, 0, 1, semaphoreId_).value); - } - - /// Push a @ref TriggerData and a @ref TriggerFlag at the same time to the FIFO. - /// @param dst The destination memory region. - /// @param dstOffset The offset into the destination memory region. - /// @param src The source memory region. - /// @param srcOffset The offset into the source memory region. - /// @param size The size of the transfer. - __forceinline__ __device__ void putWithSignal(MemoryId dst, uint64_t dstOffset, MemoryId src, uint64_t srcOffset, - uint64_t size) { - fifo_.push(ChannelTrigger(TriggerData | TriggerFlag, dst, dstOffset, src, srcOffset, size, semaphoreId_).value); - } - - /// Push a @ref TriggerData and a @ref TriggerFlag at the same time to the FIFO. - /// @param dst The destination memory region. - /// @param src The source memory region. - /// @param offset The common offset into the destination and source memory regions. - /// @param size The size of the transfer. - __forceinline__ __device__ void putWithSignal(MemoryId dst, MemoryId src, uint64_t offset, uint64_t size) { - putWithSignal(dst, offset, src, offset, size); - } - - /// Push a @ref TriggerData, a @ref TriggerFlag, and a @ref TriggerSync at the same time to the FIFO. - /// @param dst The destination memory region. - /// @param dstOffset The offset into the destination memory region. - /// @param src The source memory region. - /// @param srcOffset The offset into the source memory region. - /// @param size The size of the transfer. - __forceinline__ __device__ void putWithSignalAndFlush(MemoryId dst, uint64_t dstOffset, MemoryId src, - uint64_t srcOffset, uint64_t size) { - uint64_t curFifoHead = fifo_.push( - ChannelTrigger(TriggerData | TriggerFlag | TriggerSync, dst, dstOffset, src, srcOffset, size, semaphoreId_) - .value); - fifo_.sync(curFifoHead); - } - - /// Push a @ref TriggerData, a @ref TriggerFlag, and a @ref TriggerSync at the same time to the FIFO. - /// @param dst The destination memory region. - /// @param src The source memory region. - /// @param offset The common offset into the destination and source memory regions. - /// @param size The size of the transfer. - __forceinline__ __device__ void putWithSignalAndFlush(MemoryId dst, MemoryId src, uint64_t offset, uint64_t size) { - putWithSignalAndFlush(dst, offset, src, offset, size); - } - - /// Push a @ref TriggerSync to the FIFO. - __forceinline__ __device__ void flush() { - uint64_t curFifoHead = fifo_.push(ChannelTrigger(TriggerSync, 0, 0, 0, 0, 1, semaphoreId_).value); - fifo_.sync(curFifoHead); - } - - /// Wait for the proxy channel to be signaled. - __forceinline__ __device__ void wait() { semaphore_.wait(); } - -#endif // __CUDACC__ - - SemaphoreId semaphoreId_; + /// Device-side handle for @ref ProxyChannel. + using DeviceHandle = ProxyChannelDeviceHandle; - Host2DeviceSemaphore::DeviceHandle semaphore_; - - // this is a concurrent fifo which is multiple threads from the device - // can produce for and the sole proxy thread consumes it. - DeviceProxyFifo fifo_; + /// Returns the device-side handle. + /// + /// User should make sure the ProxyChannel is not released when using the returned handle. + /// + DeviceHandle deviceHandle() const; }; /// Simple proxy channel with a single destination and source memory region. struct SimpleProxyChannel { - // Use DeviceHandle in device code. - typedef SimpleProxyChannel DeviceHandle; + private: + ProxyChannel proxyChan_; + MemoryId dst_; + MemoryId src_; + public: /// Default constructor. SimpleProxyChannel() = default; @@ -256,69 +127,16 @@ struct SimpleProxyChannel { /// Assignment operator. SimpleProxyChannel& operator=(SimpleProxyChannel& other) = default; -#ifdef __CUDACC__ - /// Push a @ref TriggerData to the FIFO. - /// @param dstOffset The offset into the destination memory region. - /// @param srcOffset The offset into the source memory region. - /// @param size The size of the transfer. - __forceinline__ __device__ void put(uint64_t dstOffset, uint64_t srcOffset, uint64_t size) { - proxyChan_.put(dst_, dstOffset, src_, srcOffset, size); - } - - /// Push a @ref TriggerData to the FIFO. - /// @param offset The common offset into the destination and source memory regions. - /// @param size The size of the transfer. - __forceinline__ __device__ void put(uint64_t offset, uint64_t size) { put(offset, offset, size); } - - /// Push a @ref TriggerFlag to the FIFO. - __forceinline__ __device__ void signal() { proxyChan_.signal(); } - - /// Push a @ref TriggerData and a @ref TriggerFlag at the same time to the FIFO. - /// @param dstOffset The offset into the destination memory region. - /// @param srcOffset The offset into the source memory region. - /// @param size The size of the transfer. - __forceinline__ __device__ void putWithSignal(uint64_t dstOffset, uint64_t srcOffset, uint64_t size) { - proxyChan_.putWithSignal(dst_, dstOffset, src_, srcOffset, size); - } - - /// Push a @ref TriggerData and a @ref TriggerFlag at the same time to the FIFO. - /// @param offset The common offset into the destination and source memory regions. - /// @param size The size of the transfer. - __forceinline__ __device__ void putWithSignal(uint64_t offset, uint64_t size) { putWithSignal(offset, offset, size); } - - /// Push a @ref TriggerData, a @ref TriggerFlag, and a @ref TriggerSync at the same time to the FIFO. - /// @param dstOffset The offset into the destination memory region. - /// @param srcOffset The offset into the source memory region. - /// @param size The size of the transfer. - __forceinline__ __device__ void putWithSignalAndFlush(uint64_t dstOffset, uint64_t srcOffset, uint64_t size) { - proxyChan_.putWithSignalAndFlush(dst_, dstOffset, src_, srcOffset, size); - } - - /// Push a @ref TriggerData, a @ref TriggerFlag, and a @ref TriggerSync at the same time to the FIFO. - /// @param offset The common offset into the destination and source memory regions. - /// @param size The size of the transfer. - __forceinline__ __device__ void putWithSignalAndFlush(uint64_t offset, uint64_t size) { - putWithSignalAndFlush(offset, offset, size); - } - - /// Push a @ref TriggerSync to the FIFO. - __forceinline__ __device__ void flush() { proxyChan_.flush(); } - - /// Wait for the proxy channel to be signaled. - __forceinline__ __device__ void wait() { proxyChan_.wait(); } - -#endif // __CUDACC__ + /// Device-side handle for @ref SimpleProxyChannel. + using DeviceHandle = SimpleProxyChannelDeviceHandle; - ProxyChannel proxyChan_; - MemoryId dst_; - MemoryId src_; + /// Returns the device-side handle. + /// + /// User should make sure the SimpleProxyChannel is not released when using the returned handle. + /// + DeviceHandle deviceHandle() const; }; -template <> -DeviceHandle deviceHandle(ProxyChannel&& proxyChannel); - -template <> -DeviceHandle deviceHandle(SimpleProxyChannel&& simpleProxyChannel); } // namespace mscclpp #endif // MSCCLPP_PROXY_CHANNEL_HPP_ diff --git a/include/mscclpp/proxy_channel_device.hpp b/include/mscclpp/proxy_channel_device.hpp new file mode 100644 index 000000000..db90eac72 --- /dev/null +++ b/include/mscclpp/proxy_channel_device.hpp @@ -0,0 +1,227 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#ifndef MSCCLPP_PROXY_CHANNEL_DEVICE_HPP_ +#define MSCCLPP_PROXY_CHANNEL_DEVICE_HPP_ + +#include "fifo_device.hpp" +#include "semaphore_device.hpp" + +namespace mscclpp { + +using SemaphoreId = uint32_t; + +/// Numeric ID of @ref RegisteredMemory. @ref ProxyService has an internal array indexed by these handles mapping to the +/// actual. +using MemoryId = uint32_t; + +using TriggerType = uint64_t; +const TriggerType TriggerData = 0x1; // Trigger a data transfer. +const TriggerType TriggerFlag = 0x2; // Trigger a signaling. +const TriggerType TriggerSync = 0x4; // Trigger a flush. + +#define MSCCLPP_BITS_SIZE 32 +#define MSCCLPP_BITS_OFFSET 32 +#define MSCCLPP_BITS_REGMEM_HANDLE 8 +#define MSCCLPP_BITS_TYPE 3 +#define MSCCLPP_BITS_CONNID 10 +#define MSCCLPP_BITS_FIFO_RESERVED 1 + +/// Basic structure of each work element in the FIFO. +union ChannelTrigger { + ProxyTrigger value; + // The summation of number of bits must be 128 or less. + struct { + // First 64 bits: value[0] + uint64_t size : MSCCLPP_BITS_SIZE; + uint64_t srcOffset : MSCCLPP_BITS_OFFSET; + uint64_t : (64 - MSCCLPP_BITS_SIZE - MSCCLPP_BITS_OFFSET); // ensure 64-bit alignment + // Second 64 bits: value[1] + uint64_t dstOffset : MSCCLPP_BITS_OFFSET; + uint64_t srcMemoryId : MSCCLPP_BITS_REGMEM_HANDLE; + uint64_t dstMemoryId : MSCCLPP_BITS_REGMEM_HANDLE; + uint64_t type : MSCCLPP_BITS_TYPE; + uint64_t chanId : MSCCLPP_BITS_CONNID; + uint64_t : (64 - MSCCLPP_BITS_OFFSET - MSCCLPP_BITS_REGMEM_HANDLE - MSCCLPP_BITS_REGMEM_HANDLE - MSCCLPP_BITS_TYPE - + MSCCLPP_BITS_CONNID - MSCCLPP_BITS_FIFO_RESERVED); // ensure 64-bit alignment + uint64_t reserved : MSCCLPP_BITS_FIFO_RESERVED; + } fields; + +#ifdef __CUDACC__ + /// Default constructor. + __forceinline__ __device__ ChannelTrigger() {} + + /// Copy constructor. + __forceinline__ __device__ ChannelTrigger(ProxyTrigger value) : value(value) {} + + /// Constructor. + /// @param type The type of the trigger. + /// @param dst The destination memory region. + /// @param dstOffset The offset into the destination memory region. + /// @param src The source memory region. + /// @param srcOffset The offset into the source memory region. + /// @param bytes The bytes of the transfer. + /// @param semaphoreId The ID of the semaphore. + __forceinline__ __device__ ChannelTrigger(TriggerType type, MemoryId dst, uint64_t dstOffset, MemoryId src, + uint64_t srcOffset, uint64_t bytes, int semaphoreId) { + value.fst = ((srcOffset << MSCCLPP_BITS_SIZE) + bytes); + value.snd = ((((((((semaphoreId << MSCCLPP_BITS_TYPE) + (uint64_t)type) << MSCCLPP_BITS_REGMEM_HANDLE) + dst) + << MSCCLPP_BITS_REGMEM_HANDLE) + + src) + << MSCCLPP_BITS_OFFSET) + + dstOffset); + } +#endif // __CUDACC__ +}; + +struct ProxyChannelDeviceHandle { + SemaphoreId semaphoreId_; + + Host2DeviceSemaphoreDeviceHandle semaphore_; + + // this is a concurrent fifo which is multiple threads from the device + // can produce for and the sole proxy thread consumes it. + FifoDeviceHandle fifo_; + +#ifdef __CUDACC__ + /// Push a @ref TriggerData to the FIFO. + /// @param dst The destination memory region. + /// @param dstOffset The offset into the destination memory region. + /// @param src The source memory region. + /// @param srcOffset The offset into the source memory region. + /// @param size The size of the transfer. + __forceinline__ __device__ void put(MemoryId dst, uint64_t dstOffset, MemoryId src, uint64_t srcOffset, + uint64_t size) { + fifo_.push(ChannelTrigger(TriggerData, dst, dstOffset, src, srcOffset, size, semaphoreId_).value); + } + + /// Push a @ref TriggerData to the FIFO. + /// @param dst The destination memory region. + /// @param src The source memory region. + /// @param offset The common offset into the destination and source memory regions. + /// @param size The size of the transfer. + __forceinline__ __device__ void put(MemoryId dst, MemoryId src, uint64_t offset, uint64_t size) { + put(dst, offset, src, offset, size); + } + + /// Push a @ref TriggerFlag to the FIFO. + __forceinline__ __device__ void signal() { + fifo_.push(ChannelTrigger(TriggerFlag, 0, 0, 0, 0, 1, semaphoreId_).value); + } + + /// Push a @ref TriggerData and a @ref TriggerFlag at the same time to the FIFO. + /// @param dst The destination memory region. + /// @param dstOffset The offset into the destination memory region. + /// @param src The source memory region. + /// @param srcOffset The offset into the source memory region. + /// @param size The size of the transfer. + __forceinline__ __device__ void putWithSignal(MemoryId dst, uint64_t dstOffset, MemoryId src, uint64_t srcOffset, + uint64_t size) { + fifo_.push(ChannelTrigger(TriggerData | TriggerFlag, dst, dstOffset, src, srcOffset, size, semaphoreId_).value); + } + + /// Push a @ref TriggerData and a @ref TriggerFlag at the same time to the FIFO. + /// @param dst The destination memory region. + /// @param src The source memory region. + /// @param offset The common offset into the destination and source memory regions. + /// @param size The size of the transfer. + __forceinline__ __device__ void putWithSignal(MemoryId dst, MemoryId src, uint64_t offset, uint64_t size) { + putWithSignal(dst, offset, src, offset, size); + } + + /// Push a @ref TriggerData, a @ref TriggerFlag, and a @ref TriggerSync at the same time to the FIFO. + /// @param dst The destination memory region. + /// @param dstOffset The offset into the destination memory region. + /// @param src The source memory region. + /// @param srcOffset The offset into the source memory region. + /// @param size The size of the transfer. + __forceinline__ __device__ void putWithSignalAndFlush(MemoryId dst, uint64_t dstOffset, MemoryId src, + uint64_t srcOffset, uint64_t size) { + uint64_t curFifoHead = fifo_.push( + ChannelTrigger(TriggerData | TriggerFlag | TriggerSync, dst, dstOffset, src, srcOffset, size, semaphoreId_) + .value); + fifo_.sync(curFifoHead); + } + + /// Push a @ref TriggerData, a @ref TriggerFlag, and a @ref TriggerSync at the same time to the FIFO. + /// @param dst The destination memory region. + /// @param src The source memory region. + /// @param offset The common offset into the destination and source memory regions. + /// @param size The size of the transfer. + __forceinline__ __device__ void putWithSignalAndFlush(MemoryId dst, MemoryId src, uint64_t offset, uint64_t size) { + putWithSignalAndFlush(dst, offset, src, offset, size); + } + + /// Push a @ref TriggerSync to the FIFO. + __forceinline__ __device__ void flush() { + uint64_t curFifoHead = fifo_.push(ChannelTrigger(TriggerSync, 0, 0, 0, 0, 1, semaphoreId_).value); + fifo_.sync(curFifoHead); + } + + /// Wait for the proxy channel to be signaled. + __forceinline__ __device__ void wait() { semaphore_.wait(); } + +#endif // __CUDACC__ +}; + +struct SimpleProxyChannelDeviceHandle { + ProxyChannelDeviceHandle proxyChan_; + MemoryId dst_; + MemoryId src_; + +#ifdef __CUDACC__ + /// Push a @ref TriggerData to the FIFO. + /// @param dstOffset The offset into the destination memory region. + /// @param srcOffset The offset into the source memory region. + /// @param size The size of the transfer. + __forceinline__ __device__ void put(uint64_t dstOffset, uint64_t srcOffset, uint64_t size) { + proxyChan_.put(dst_, dstOffset, src_, srcOffset, size); + } + + /// Push a @ref TriggerData to the FIFO. + /// @param offset The common offset into the destination and source memory regions. + /// @param size The size of the transfer. + __forceinline__ __device__ void put(uint64_t offset, uint64_t size) { put(offset, offset, size); } + + /// Push a @ref TriggerFlag to the FIFO. + __forceinline__ __device__ void signal() { proxyChan_.signal(); } + + /// Push a @ref TriggerData and a @ref TriggerFlag at the same time to the FIFO. + /// @param dstOffset The offset into the destination memory region. + /// @param srcOffset The offset into the source memory region. + /// @param size The size of the transfer. + __forceinline__ __device__ void putWithSignal(uint64_t dstOffset, uint64_t srcOffset, uint64_t size) { + proxyChan_.putWithSignal(dst_, dstOffset, src_, srcOffset, size); + } + + /// Push a @ref TriggerData and a @ref TriggerFlag at the same time to the FIFO. + /// @param offset The common offset into the destination and source memory regions. + /// @param size The size of the transfer. + __forceinline__ __device__ void putWithSignal(uint64_t offset, uint64_t size) { putWithSignal(offset, offset, size); } + + /// Push a @ref TriggerData, a @ref TriggerFlag, and a @ref TriggerSync at the same time to the FIFO. + /// @param dstOffset The offset into the destination memory region. + /// @param srcOffset The offset into the source memory region. + /// @param size The size of the transfer. + __forceinline__ __device__ void putWithSignalAndFlush(uint64_t dstOffset, uint64_t srcOffset, uint64_t size) { + proxyChan_.putWithSignalAndFlush(dst_, dstOffset, src_, srcOffset, size); + } + + /// Push a @ref TriggerData, a @ref TriggerFlag, and a @ref TriggerSync at the same time to the FIFO. + /// @param offset The common offset into the destination and source memory regions. + /// @param size The size of the transfer. + __forceinline__ __device__ void putWithSignalAndFlush(uint64_t offset, uint64_t size) { + putWithSignalAndFlush(offset, offset, size); + } + + /// Push a @ref TriggerSync to the FIFO. + __forceinline__ __device__ void flush() { proxyChan_.flush(); } + + /// Wait for the proxy channel to be signaled. + __forceinline__ __device__ void wait() { proxyChan_.wait(); } +#endif // __CUDACC__ +}; + +} // namespace mscclpp + +#endif // MSCCLPP_PROXY_CHANNEL_DEVICE_HPP_ diff --git a/include/mscclpp/semaphore.hpp b/include/mscclpp/semaphore.hpp index a96619a29..9f73082ed 100644 --- a/include/mscclpp/semaphore.hpp +++ b/include/mscclpp/semaphore.hpp @@ -8,6 +8,7 @@ #include #include #include +#include namespace mscclpp { @@ -81,18 +82,7 @@ class Host2DeviceSemaphore : public BaseSemaphore SmDevice2DeviceSemaphore() = default; /// Device-side handle for @ref SmDevice2DeviceSemaphore. - struct DeviceHandle { -#ifdef __CUDACC__ - /// Wait for the remote device to signal. - __forceinline__ __device__ void wait() { - (*expectedInboundSemaphoreId) += 1; - POLL_MAYBE_JAILBREAK(*inboundSemaphoreId < (*expectedInboundSemaphoreId), 1000000); - } - - /// Signal the remote device. - /// - /// This function guarantees that all the memory operation before this function is completed before the remote - /// semaphore is signaled. - /// - __forceinline__ __device__ void signal() { - // This fence ensures that preceding writes are visible on the peer GPU before the incremented - // `outboundSemaphoreId` is visible. - __threadfence_system(); - semaphoreIncrement(); - *remoteInboundSemaphoreId = semaphoreGetLocal(); - } - - /// Signal the remote device for copied packets. - /// - /// Unlike @ref signal(), this function provides no guarantee on the completion of memory operations. This is - /// intended to be used with @ref putPackets() and @ref getPackets() that use flags inside packets to indicate the - /// completion of copies. - /// - __forceinline__ __device__ void signalPacket() { - semaphoreIncrement(); - *remoteInboundSemaphoreId = semaphoreGetLocal(); - } - - /// Increase the counter of the local semaphore. - __forceinline__ __device__ void semaphoreIncrement() { *outboundSemaphoreId += 1; } - - /// Get the value of the local semaphore. - __forceinline__ __device__ uint64_t semaphoreGetLocal() const { return *outboundSemaphoreId; } -#endif // __CUDACC__ - - volatile uint64_t* inboundSemaphoreId; - uint64_t* outboundSemaphoreId; - volatile uint64_t* remoteInboundSemaphoreId; - uint64_t* expectedInboundSemaphoreId; - }; + using DeviceHandle = SmDevice2DeviceSemaphoreDeviceHandle; /// Returns the device-side handle. DeviceHandle deviceHandle() const; diff --git a/include/mscclpp/semaphore_device.hpp b/include/mscclpp/semaphore_device.hpp new file mode 100644 index 000000000..292a8b495 --- /dev/null +++ b/include/mscclpp/semaphore_device.hpp @@ -0,0 +1,73 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#ifndef MSCCLPP_SEMAPHORE_DEVICE_HPP_ +#define MSCCLPP_SEMAPHORE_DEVICE_HPP_ + +#include "poll.hpp" + +namespace mscclpp { + +/// Device-side handle for @ref Host2DeviceSemaphore. +struct Host2DeviceSemaphoreDeviceHandle { +#ifdef __CUDACC__ + /// Wait for the host to signal. + __forceinline__ __device__ void wait() { + (*expectedInboundSemaphoreId) += 1; + POLL_MAYBE_JAILBREAK(*(volatile uint64_t*)(inboundSemaphoreId) < (*expectedInboundSemaphoreId), 100000000); + } +#endif // __CUDACC__ + + uint64_t* inboundSemaphoreId; + uint64_t* expectedInboundSemaphoreId; +}; + +/// Device-side handle for @ref SmDevice2DeviceSemaphore. +struct SmDevice2DeviceSemaphoreDeviceHandle { +#ifdef __CUDACC__ + /// Wait for the remote device to signal. + __forceinline__ __device__ void wait() { + (*expectedInboundSemaphoreId) += 1; + POLL_MAYBE_JAILBREAK(*inboundSemaphoreId < (*expectedInboundSemaphoreId), 100000000); + } + + /// Signal the remote device. + /// + /// This function guarantees that all the memory operation before this function is completed before the remote + /// semaphore is signaled. + /// + __forceinline__ __device__ void signal() { + // This fence ensures that preceding writes are visible on the peer GPU before the incremented + // `outboundSemaphoreId` is visible. + __threadfence_system(); + semaphoreIncrement(); + *remoteInboundSemaphoreId = semaphoreGetLocal(); + } + + /// Signal the remote device for copied packets. + /// + /// Unlike @ref signal(), this function provides no guarantee on the completion of memory operations. This is + /// intended to be used with @ref putPackets() and @ref getPackets() that use flags inside packets to indicate the + /// completion of copies. + /// + __forceinline__ __device__ void signalPacket() { + semaphoreIncrement(); + *remoteInboundSemaphoreId = semaphoreGetLocal(); + } + + /// Increase the counter of the local semaphore. + __forceinline__ __device__ void semaphoreIncrement() { *outboundSemaphoreId += 1; } + + /// Get the value of the local semaphore. + __forceinline__ __device__ uint64_t semaphoreGetLocal() const { return *outboundSemaphoreId; } +#endif // __CUDACC__ + + volatile uint64_t* inboundSemaphoreId; + uint64_t* outboundSemaphoreId; + volatile uint64_t* remoteInboundSemaphoreId; + uint64_t* expectedInboundSemaphoreId; +}; + +} // namespace mscclpp + +#endif // MSCCLPP_SEMAPHORE_DEVICE_HPP_ diff --git a/include/mscclpp/sm_channel.hpp b/include/mscclpp/sm_channel.hpp index 2033a872a..947eea21d 100644 --- a/include/mscclpp/sm_channel.hpp +++ b/include/mscclpp/sm_channel.hpp @@ -5,8 +5,8 @@ #define MSCCLPP_SM_CHANNEL_HPP_ #include -#include #include +#include #include namespace mscclpp { @@ -31,305 +31,8 @@ struct SmChannel { SmChannel(std::shared_ptr semaphore, RegisteredMemory dst, void* src, void* getPacketBuffer = nullptr); - struct DeviceHandle { - SmDevice2DeviceSemaphore::DeviceHandle semaphore_; - void* src_; - void* dst_; - void* getPacketBuffer_; - - private: -#ifdef __CUDACC__ - /// Helper for aligned data type access. - /// @tparam T The data type. - template - struct Element { - static constexpr bool is4B = (sizeof(T) == 4); - static constexpr bool is8B = (sizeof(T) == 8); - static constexpr bool is4Bx2 = - (std::is_same::value || std::is_same::value || std::is_same::value); - static constexpr bool is4Bx4 = - (std::is_same::value || std::is_same::value || std::is_same::value); - static constexpr bool is8Bx2 = - (std::is_same::value || std::is_same::value || std::is_same::value); - // Note: we do not support long2 and ulong2 as their size may differ on different platforms. - static constexpr bool isValid = (is4B || is8B || is4Bx2 || is4Bx4 || is8Bx2); - - /// Load an element from DRAM. - /// - /// This is a warpper of ld.volatile.global.* PTX instruction. Address alignment is not this function's - /// responsibility. - /// - /// @param v The value to be loaded. - /// @param p The address of the value to be loaded. - /// - static __forceinline__ __device__ void load(T& v, const T* p) { - if constexpr (is4B) { - asm volatile("ld.volatile.global.u32 %0, [%1];" : "=r"(v) : "l"(p) : "memory"); - } else if constexpr (is8B) { - asm volatile("ld.volatile.global.u64 %0, [%1];" : "=l"(v) : "l"(p) : "memory"); - } else if constexpr (is4Bx2) { - asm volatile("ld.volatile.global.v2.u32 {%0,%1}, [%2];" : "=r"(v.x), "=r"(v.y) : "l"(p) : "memory"); - } else if constexpr (is4Bx4) { - asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" - : "=r"(v.w), "=r"(v.x), "=r"(v.y), "=r"(v.z) - : "l"(p) - : "memory"); - } else if constexpr (is8Bx2) { - asm volatile("ld.volatile.global.v2.u64 {%0,%1}, [%2];" : "=l"(v.x), "=l"(v.y) : "l"(p) : "memory"); - } - static_assert(isValid, "Unsupported type T"); - } - - /// Write an element on DRAM. - /// - /// This is a wrapper of st.volatile.global.* PTX instruction. Address alignment is not this function's - /// responsibility. - /// - /// @param p The address of the value to be written. - /// @param v The value to be written. - /// - static __forceinline__ __device__ void store(T* p, const T& v) { - if constexpr (is4B) { - asm volatile("st.volatile.global.u32 [%0], %1;" : : "l"(p), "r"(v) : "memory"); - } else if constexpr (is8B) { - asm volatile("st.volatile.global.u64 [%0], %1;" : : "l"(p), "l"(v) : "memory"); - } else if constexpr (is4Bx2) { - asm volatile("st.volatile.global.v2.u32 [%0], {%1,%2};" : : "l"(p), "r"(v.x), "r"(v.y) : "memory"); - } else if constexpr (is4Bx4) { - asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" - : - : "l"(p), "r"(v.w), "r"(v.x), "r"(v.y), "r"(v.z) - : "memory"); - } else if constexpr (is8Bx2) { - asm volatile("st.volatile.global.v2.u64 [%0], {%1,%2};" : : "l"(p), "l"(v.x), "l"(v.y) : "memory"); - } - static_assert(isValid, "Unsupported type T"); - } - - /// Copy aligned elements from the source memory to the destination memory. - /// - /// This function is intended to be collectively called by multiple threads. Each thread copies a part of - /// elements. - /// - /// @param dst The destination address. - /// @param src The source address. - /// @param numElems The number of elements to be copied. - /// @param threadId The index of the current thread among all threads running this function. This is different - /// from the `threadIdx` in CUDA. - /// @param numThreads The total number of threads that run this function. - /// - static __forceinline__ __device__ void copy(T* dst, T* src, uint64_t numElems, uint32_t threadId, - uint32_t numThreads) { - T reg; - for (size_t i = threadId; i < numElems; i += numThreads) { - // Load to register first. - load(reg, src + i); - store(dst + i, reg); - } - } - }; -#endif // __CUDACC__ - public: -#ifdef __CUDACC__ - /// Load a value from the remote memory. - /// @tparam T The type of the value to be loaded. - /// @param index The index of the value to be loaded. The offset in bytes is calculated as index * sizeof(T). - /// @return The value loaded. - template - __forceinline__ __device__ T read(uint64_t index) { - T v; - Element::load(v, (T*)dst_ + index); - return v; - } - - /// Write a value to the remote memory. - /// @tparam T The type of the value to be written. - /// @param index The index of the value to be written. The offset in bytes is calculated as index * sizeof(T). - /// @param v The value to be written. - template - __forceinline__ __device__ void write(uint64_t index, const T& v) { - Element::store((T*)dst_ + index, v); - } - - /// Copy aligned data from the source memory to the destination memory. - /// - /// This function is a warpper of Element::copy(). Unlike Element::copy(), this function can copy remainder - /// bytes when @p CopyRemainder is true. Still, the copying bytes must be a multiple of 4. - /// - /// @tparam Alignment The alignment of the source and destination addresses. Should be 4, 8, or a multiple of 16. - /// @tparam CopyRemainder Whether to copy remainder bytes when the number of bytes is not a multiple of @p - /// Alignment. - /// @param dst The destination address. Should be aligned to @p Alignment in the same way as @p src. - /// @param src The source address. Should be aligned to @p Alignment in the same way as @p dst. - /// @param bytes Bytes of the data to be copied. Should be a multiple of @p Alignment. - /// @param threadId The index of the current thread among all threads running this function. This is different from - /// the `threadIdx` in CUDA. - /// @param numThreads The total number of threads that run this function. - /// - template - __forceinline__ __device__ void copy(void* dst, void* src, uint64_t bytes, uint32_t threadId, uint32_t numThreads) { - static_assert(Alignment == 4 || Alignment == 8 || Alignment % 16 == 0, "Unsupported alignment"); - using Type = - typename std::conditional::type>::type; - int* dstInt = reinterpret_cast(dst); - int* srcInt = reinterpret_cast(src); - const uintptr_t dstPtr = reinterpret_cast(dst); - const uintptr_t srcPtr = reinterpret_cast(src); - const uint64_t numInt = bytes / sizeof(int); - Type* dstElem = reinterpret_cast((dstPtr + sizeof(Type) - 1) / sizeof(Type) * sizeof(Type)); - Type* srcElem = reinterpret_cast((srcPtr + sizeof(Type) - 1) / sizeof(Type) * sizeof(Type)); - uint64_t nFirstInt = (reinterpret_cast(dstElem) - dstPtr) / sizeof(int); - if (CopyRemainder) { - // Copy the remainder integers at the beginning. - Element::copy(dstInt, srcInt, nFirstInt, threadId, numThreads); - } - // Copy elements. - constexpr uint64_t nIntPerElem = sizeof(Type) / sizeof(int); - uint64_t nElem = (numInt - nFirstInt) / nIntPerElem; - Element::copy(dstElem, srcElem, nElem, threadId, numThreads); - if (CopyRemainder && nIntPerElem > 1) { - // Copy the remainder integers at the end. - uint64_t nLastInt = (numInt - nFirstInt) % nIntPerElem; - Element::copy(dstInt + nFirstInt + nElem * nIntPerElem, srcInt + nFirstInt + nElem * nIntPerElem, nLastInt, - threadId, numThreads); - } - } - - /// Copy data from the local memory to the remote memory. - /// - /// This function is intended to be collectively called by multiple threads. Each thread copies a part of data. - /// - /// @tparam Alignment The alignment of the source and destination addresses. Should be 4, 8, or a multiple of 16. - /// @tparam CopyRemainder Whether to copy remainder bytes when the number of bytes is not a multiple of @p - /// Alignment. - /// @param dstOffset The offset in bytes of the remote address. Should be a multiple of @p Alignment. - /// @param srcOffset The offset in bytes of the local address. Should be a multiple of @p Alignment. - /// @param bytes Bytes of the data to be copied. Should be a multiple of @p Alignment. - /// @param threadId The index of the current thread among all threads running this function. This is different from - /// the `threadIdx` in CUDA. - /// @param numThreads The total number of threads that run this function. - /// - template - __forceinline__ __device__ void put(uint64_t dstOffset, uint64_t srcOffset, uint64_t bytes, uint32_t threadId, - uint32_t numThreads) { - copy((char*)dst_ + dstOffset, (char*)src_ + srcOffset, bytes, threadId, numThreads); - } - - /// Copy data from the remote memory to the local memory. - /// - /// This function is intended to be collectively called by multiple threads. Each thread copies a part of data. - /// - /// @tparam Alignment The alignment of the source and destination addresses. Should be 4, 8, or a multiple of 16. - /// @tparam CopyRemainder Whether to copy remainder bytes when the number of bytes is not a multiple of @p - /// Alignment. - /// @param dstOffset The offset in bytes of the remote address. Should be a multiple of @p Alignment. - /// @param srcOffset The offset in bytes of the local address. Should be a multiple of @p Alignment. - /// @param bytes Bytes of the data to be copied. Should be a multiple of @p Alignment. - /// @param threadId The index of the current thread among all threads running this function. This is different from - /// the `threadIdx` in CUDA. - /// @param numThreads The total number of threads that run this function. - /// - template - __forceinline__ __device__ void get(uint64_t dstOffset, uint64_t srcOffset, uint64_t bytes, uint32_t threadId, - uint32_t numThreads) { - // Note that `dst` and `src` are swapped for `get()`. - copy((char*)src_ + srcOffset, (char*)dst_ + dstOffset, bytes, threadId, numThreads); - } - - /// Copy data from the local memory to the remote memory. - /// - /// This function is intended to be collectively called by multiple threads. Each thread copies a part of data. - /// - /// @tparam Alignment The alignment of the source and destination addresses. Should be 4, 8, or a multiple of 16. - /// @tparam CopyRemainder Whether to copy remainder bytes when the number of bytes is not a multiple of @p - /// Alignment. - /// @param offset The offset in bytes of the local and remote addresses. Should be a multiple of @p Alignment. - /// @param bytes Bytes of the data to be copied. Should be a multiple of @p Alignment. - /// @param threadId The index of the current thread among all threads running this function. This is different from - /// the `threadIdx` in CUDA. - /// @param numThreads The total number of threads that run this function. - /// - template - __forceinline__ __device__ void put(uint64_t offset, uint64_t size, uint32_t threadId, uint32_t numThreads) { - put(offset, offset, size, threadId, numThreads); - } - - /// Copy data from the remote memory to the local memory. - /// - /// This function is intended to be collectively called by multiple threads. Each thread copies a part of data. - /// - /// @tparam Alignment The alignment of the source and destination addresses. Should be 4, 8, or a multiple of 16. - /// @tparam CopyRemainder Whether to copy remainder bytes when the number of bytes is not a multiple of @p - /// Alignment. - /// @param offset The offset in bytes of the local and remote addresses. Should be a multiple of @p Alignment. - /// @param bytes Bytes of the data to be copied. Should be a multiple of @p Alignment. - /// @param threadId The index of the current thread among all threads running this function. This is different from - /// the `threadIdx` in CUDA. - /// @param numThreads The total number of threads that run this function. - /// - template - __forceinline__ __device__ void get(uint64_t offset, uint64_t size, uint32_t threadId, uint32_t numThreads) { - get(offset, offset, size, threadId, numThreads); - } - - /// Construct @ref LLPacket from the data in the local memory and write it on the remote memory. - /// - /// This function is intended to be collectively called by multiple threads. Each thread copies a part of packets. - /// - /// @param dstOffset The offset in bytes of the remote address. - /// @param srcOffset The offset in bytes of the local address. - /// @param bytes Bytes of the data to be copied. - /// @param threadId The index of the current thread among all threads running this function. This is different from - /// the `threadIdx` in CUDA. - /// @param numThreads The total number of threads that run this function. - /// - __forceinline__ __device__ void putPackets(uint64_t dstOffset, uint64_t srcOffset, uint64_t bytes, - uint32_t threadId, uint32_t numThreads, uint32_t flag) { - mscclpp::putPackets(dst_, dstOffset, src_, srcOffset, bytes, threadId, numThreads, flag); - } - - /// Retrieve data from @ref LLPacket in the local packet buffer and write it on the local memory. - /// - /// This function is intended to be collectively called by multiple threads. Each thread copies a part of data. - /// - /// @param dstOffset The offset in bytes of the local memory. - /// @param srcOffset The offset in bytes of the local packet buffer. - /// @param bytes Bytes of the data to be copied. - /// @param threadId The index of the current thread among all threads running this function. This is different from - /// the `threadIdx` in CUDA. - /// @param numThreads The total number of threads that run this function. - /// - __forceinline__ __device__ void getPackets(uint64_t dstOffset, uint64_t srcOffset, uint64_t bytes, - uint32_t threadId, uint32_t numThreads, uint32_t flag) { - mscclpp::getPackets(src_, dstOffset, getPacketBuffer_, srcOffset, bytes, threadId, numThreads, flag); - } - - /// Signal the remote semaphore. - /// - /// This function guarantees that all the memory operation before this function is completed before the remote - /// semaphore is signaled. - /// - __forceinline__ __device__ void signal() { semaphore_.signal(); } - - /// Signal the remote semaphore for copied packets. - /// - /// Unlike @ref signal(), this function provides no guarantee on the completion of memory operations. This is - /// intended to be used with @ref putPackets() and @ref getPackets() that use flags inside packets to indicate the - /// completion of copies. - /// - __forceinline__ __device__ void signalPacket() { semaphore_.signalPacket(); } - - /// Increase the counter of the local semaphore. - __forceinline__ __device__ void semaphoreIncrement() { semaphore_.semaphoreIncrement(); } - - /// Read the counter of the local semaphore. - __forceinline__ __device__ uint64_t semaphoreGetLocal() const { return semaphore_.semaphoreGetLocal(); } - - /// Wait for the remote semaphore to send a signal. - __forceinline__ __device__ void wait() { semaphore_.wait(); } -#endif // __CUDACC__ - }; + /// Device-side handle for @ref SmChannel. + using DeviceHandle = SmChannelDeviceHandle; /// Returns the device-side handle. /// diff --git a/include/mscclpp/sm_channel_device.hpp b/include/mscclpp/sm_channel_device.hpp new file mode 100644 index 000000000..353183835 --- /dev/null +++ b/include/mscclpp/sm_channel_device.hpp @@ -0,0 +1,336 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#ifndef MSCCLPP_SM_CHANNEL_DEVICE_HPP_ +#define MSCCLPP_SM_CHANNEL_DEVICE_HPP_ + +#include "packet.hpp" +#include "poll.hpp" +#include "semaphore_device.hpp" + +namespace mscclpp { + +#ifdef __CUDACC__ + +namespace Element { + +/// Load an element from DRAM. +/// +/// This is a warpper of ld.volatile.global.* PTX instruction. Address alignment is not this function's +/// responsibility. +/// +/// @param v The value to be loaded. +/// @param p The address of the value to be loaded. +/// +template +__forceinline__ __device__ void load(T& v, const T* p) { + // We should only use the specialized functions. + __assert_fail("Unsupported type", __FILE__, __LINE__, __PRETTY_FUNCTION__); +} + +/// Write an element on DRAM. +/// +/// This is a wrapper of st.volatile.global.* PTX instruction. Address alignment is not this function's +/// responsibility. +/// +/// @param p The address of the value to be written. +/// @param v The value to be written. +/// +template +__forceinline__ __device__ void store(T* p, const T& v) { + // We should only use the specialized functions. + __assert_fail("Unsupported type", __FILE__, __LINE__, __PRETTY_FUNCTION__); +} + +/// Copy aligned elements from the source memory to the destination memory. +/// +/// This function is intended to be collectively called by multiple threads. Each thread copies a part of +/// elements. +/// +/// @param dst The destination address. +/// @param src The source address. +/// @param numElems The number of elements to be copied. +/// @param threadId The index of the current thread among all threads running this function. This is different +/// from the `threadIdx` in CUDA. +/// @param numThreads The total number of threads that run this function. +/// +template +__forceinline__ __device__ void copy(T* dst, T* src, uint64_t numElems, uint32_t threadId, uint32_t numThreads) { + T reg; + for (size_t i = threadId; i < numElems; i += numThreads) { + // Load to register first. + load(reg, src + i); + store(dst + i, reg); + } +} + +template <> +__forceinline__ __device__ void load(long long& v, const long long* p) { + asm volatile("ld.volatile.global.u64 %0, [%1];" : "=l"(v) : "l"(p) : "memory"); +} + +template <> +__forceinline__ __device__ void store(long long* p, const long long& v) { + asm volatile("st.volatile.global.u64 [%0], %1;" : : "l"(p), "l"(v) : "memory"); +} + +template <> +__forceinline__ __device__ void load(int& v, const int* p) { + asm volatile("ld.volatile.global.u32 %0, [%1];" : "=r"(v) : "l"(p) : "memory"); +} + +template <> +__forceinline__ __device__ void store(int* p, const int& v) { + asm volatile("st.volatile.global.u32 [%0], %1;" : : "l"(p), "r"(v) : "memory"); +} + +template <> +__forceinline__ __device__ void load(longlong2& v, const longlong2* p) { + asm volatile("ld.volatile.global.v2.u64 {%0,%1}, [%2];" : "=l"(v.x), "=l"(v.y) : "l"(p) : "memory"); +} + +template <> +__forceinline__ __device__ void store(longlong2* p, const longlong2& v) { + asm volatile("st.volatile.global.v2.u64 [%0], {%1,%2};" : : "l"(p), "l"(v.x), "l"(v.y) : "memory"); +} + +template <> +__forceinline__ __device__ void load(int4& v, const int4* p) { + asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" + : "=r"(v.w), "=r"(v.x), "=r"(v.y), "=r"(v.z) + : "l"(p) + : "memory"); +} + +template <> +__forceinline__ __device__ void store(int4* p, const int4& v) { + asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" + : + : "l"(p), "r"(v.w), "r"(v.x), "r"(v.y), "r"(v.z) + : "memory"); +} + +} // namespace Element + +#endif // __CUDACC__ + +/// Channel for accessing peer memory directly from SM. +struct SmChannelDeviceHandle { + SmDevice2DeviceSemaphoreDeviceHandle semaphore_; + void* src_; + void* dst_; + void* getPacketBuffer_; + +#ifdef __CUDACC__ + /// Load a value from the remote memory. + /// @tparam T The type of the value to be loaded. + /// @param index The index of the value to be loaded. The offset in bytes is calculated as index * sizeof(T). + /// @return The value loaded. + template + __forceinline__ __device__ T read(uint64_t index) { + T v; + Element::load(v, (T*)dst_ + index); + return v; + } + + /// Write a value to the remote memory. + /// @tparam T The type of the value to be written. + /// @param index The index of the value to be written. The offset in bytes is calculated as index * sizeof(T). + /// @param v The value to be written. + template + __forceinline__ __device__ void write(uint64_t index, const T& v) { + Element::store((T*)dst_ + index, v); + } + + /// this is a helper for copy function + template + __forceinline__ __device__ void copy_helper(void* dst, void* src, uint64_t bytes, uint32_t threadId, + uint32_t numThreads) { + int* dstInt = reinterpret_cast(dst); + int* srcInt = reinterpret_cast(src); + const uintptr_t dstPtr = reinterpret_cast(dst); + const uintptr_t srcPtr = reinterpret_cast(src); + const uint64_t numInt = bytes / sizeof(int); + T* dstElem = reinterpret_cast((dstPtr + sizeof(T) - 1) / sizeof(T) * sizeof(T)); + T* srcElem = reinterpret_cast((srcPtr + sizeof(T) - 1) / sizeof(T) * sizeof(T)); + uint64_t nFirstInt = (reinterpret_cast(dstElem) - dstPtr) / sizeof(int); + if (CopyRemainder) { + // Copy the remainder integers at the beginning. + Element::copy(dstInt, srcInt, nFirstInt, threadId, numThreads); + } + // Copy elements. + constexpr uint64_t nIntPerElem = sizeof(T) / sizeof(int); + uint64_t nElem = (numInt - nFirstInt) / nIntPerElem; + Element::copy(dstElem, srcElem, nElem, threadId, numThreads); + if (CopyRemainder && nIntPerElem > 1) { + // Copy the remainder integers at the end. + uint64_t nLastInt = (numInt - nFirstInt) % nIntPerElem; + Element::copy(dstInt + nFirstInt + nElem * nIntPerElem, srcInt + nFirstInt + nElem * nIntPerElem, nLastInt, + threadId, numThreads); + } + } + + /// Copy aligned data from the source memory to the destination memory. + /// + /// This function is a warpper of Element::copy(). Unlike Element::copy(), this function can copy remainder + /// bytes when @p CopyRemainder is true. Still, the 16. + /// @tparam CopyRemainder Whether to copy remainder bytes when the number of bytes is not a multiple of @p + /// Alignment. + /// @param dst The destination address. Should be aligned to @p Alignment in the same way as @p src. + /// @param src The source address. Should be aligned to @p Alignment in the same way as @p dst. + /// @param bytes Bytes of the data to be copied. Should be a multiple of @p Alignment. + /// @param threadId The index of the current thread among all threads running this function. This is different from + /// the `threadIdx` in CUDA. + /// @param numThreads The total number of threads that run this function. + /// + template + __forceinline__ __device__ void copy(void* dst, void* src, uint64_t bytes, uint32_t threadId, uint32_t numThreads) { + if (Alignment == 4) { + copy_helper(dst, src, bytes, threadId, numThreads); + } else if (Alignment == 8) { + copy_helper(dst, src, bytes, threadId, numThreads); + } else if (Alignment == 16) { + copy_helper(dst, src, bytes, threadId, numThreads); + } else { + static_assert(Alignment == 4 || Alignment == 8 || Alignment == 16, "Unsupported alignment"); + } + } + + /// Copy data from the local memory to the remote memory. + /// + /// This function is intended to be collectively called by multiple threads. Each thread copies a part of data. + /// + /// @tparam Alignment The alignment of the source and destination addresses. Should be 4, 8, or a multiple of 16. + /// @tparam CopyRemainder Whether to copy remainder bytes when the number of bytes is not a multiple of @p + /// Alignment. + /// @param dstOffset The offset in bytes of the remote address. Should be a multiple of @p Alignment. + /// @param srcOffset The offset in bytes of the local address. Should be a multiple of @p Alignment. + /// @param bytes Bytes of the data to be copied. Should be a multiple of @p Alignment. + /// @param threadId The index of the current thread among all threads running this function. This is different from + /// the `threadIdx` in CUDA. + /// @param numThreads The total number of threads that run this function. + /// + template + __forceinline__ __device__ void put(uint64_t dstOffset, uint64_t srcOffset, uint64_t bytes, uint32_t threadId, + uint32_t numThreads) { + copy((char*)dst_ + dstOffset, (char*)src_ + srcOffset, bytes, threadId, numThreads); + } + + /// Copy data from the remote memory to the local memory. + /// + /// This function is intended to be collectively called by multiple threads. Each thread copies a part of data. + /// + /// @tparam Alignment The alignment of the source and destination addresses. Should be 4, 8, or a multiple of 16. + /// @tparam CopyRemainder Whether to copy remainder bytes when the number of bytes is not a multiple of @p + /// Alignment. + /// @param dstOffset The offset in bytes of the remote address. Should be a multiple of @p Alignment. + /// @param srcOffset The offset in bytes of the local address. Should be a multiple of @p Alignment. + /// @param bytes Bytes of the data to be copied. Should be a multiple of @p Alignment. + /// @param threadId The index of the current thread among all threads running this function. This is different from + /// the `threadIdx` in CUDA. + /// @param numThreads The total number of threads that run this function. + /// + template + __forceinline__ __device__ void get(uint64_t dstOffset, uint64_t srcOffset, uint64_t bytes, uint32_t threadId, + uint32_t numThreads) { + // Note that `dst` and `src` are swapped for `get()`. + copy((char*)src_ + srcOffset, (char*)dst_ + dstOffset, bytes, threadId, numThreads); + } + + /// Copy data from the local memory to the remote memory. + /// + /// This function is intended to be collectively called by multiple threads. Each thread copies a part of data. + /// + /// @tparam Alignment The alignment of the source and destination addresses. Should be 4, 8, or a multiple of 16. + /// @tparam CopyRemainder Whether to copy remainder bytes when the number of bytes is not a multiple of @p + /// Alignment. + /// @param offset The offset in bytes of the local and remote addresses. Should be a multiple of @p Alignment. + /// @param bytes Bytes of the data to be copied. Should be a multiple of @p Alignment. + /// @param threadId The index of the current thread among all threads running this function. This is different from + /// the `threadIdx` in CUDA. + /// @param numThreads The total number of threads that run this function. + /// + template + __forceinline__ __device__ void put(uint64_t offset, uint64_t size, uint32_t threadId, uint32_t numThreads) { + put(offset, offset, size, threadId, numThreads); + } + + /// Copy data from the remote memory to the local memory. + /// + /// This function is intended to be collectively called by multiple threads. Each thread copies a part of data. + /// + /// @tparam Alignment The alignment of the source and destination addresses. Should be 4, 8, or a multiple of 16. + /// @tparam CopyRemainder Whether to copy remainder bytes when the number of bytes is not a multiple of @p + /// Alignment. + /// @param offset The offset in bytes of the local and remote addresses. Should be a multiple of @p Alignment. + /// @param bytes Bytes of the data to be copied. Should be a multiple of @p Alignment. + /// @param threadId The index of the current thread among all threads running this function. This is different from + /// the `threadIdx` in CUDA. + /// @param numThreads The total number of threads that run this function. + /// + template + __forceinline__ __device__ void get(uint64_t offset, uint64_t size, uint32_t threadId, uint32_t numThreads) { + get(offset, offset, size, threadId, numThreads); + } + + /// Construct @ref LLPacket from the data in the local memory and write it on the remote memory. + /// + /// This function is intended to be collectively called by multiple threads. Each thread copies a part of packets. + /// + /// @param dstOffset The offset in bytes of the remote address. + /// @param srcOffset The offset in bytes of the local address. + /// @param bytes Bytes of the data to be copied. + /// @param threadId The index of the current thread among all threads running this function. This is different from + /// the `threadIdx` in CUDA. + /// @param numThreads The total number of threads that run this function. + /// + __forceinline__ __device__ void putPackets(uint64_t dstOffset, uint64_t srcOffset, uint64_t bytes, uint32_t threadId, + uint32_t numThreads, uint32_t flag) { + mscclpp::putPackets(dst_, dstOffset, src_, srcOffset, bytes, threadId, numThreads, flag); + } + + /// Retrieve data from @ref LLPacket in the local packet buffer and write it on the local memory. + /// + /// This function is intended to be collectively called by multiple threads. Each thread copies a part of data. + /// + /// @param dstOffset The offset in bytes of the local memory. + /// @param srcOffset The offset in bytes of the local packet buffer. + /// @param bytes Bytes of the data to be copied. + /// @param threadId The index of the current thread among all threads running this function. This is different from + /// the `threadIdx` in CUDA. + /// @param numThreads The total number of threads that run this function. + /// + __forceinline__ __device__ void getPackets(uint64_t dstOffset, uint64_t srcOffset, uint64_t bytes, uint32_t threadId, + uint32_t numThreads, uint32_t flag) { + mscclpp::getPackets(src_, dstOffset, getPacketBuffer_, srcOffset, bytes, threadId, numThreads, flag); + } + + /// Signal the remote semaphore. + /// + /// This function guarantees that all the memory operation before this function is completed before the remote + /// semaphore is signaled. + /// + __forceinline__ __device__ void signal() { semaphore_.signal(); } + + /// Signal the remote semaphore for copied packets. + /// + /// Unlike @ref signal(), this function provides no guarantee on the completion of memory operations. This is + /// intended to be used with @ref putPackets() and @ref getPackets() that use flags inside packets to indicate the + /// completion of copies. + /// + __forceinline__ __device__ void signalPacket() { semaphore_.signalPacket(); } + + /// Increase the counter of the local semaphore. + __forceinline__ __device__ void semaphoreIncrement() { semaphore_.semaphoreIncrement(); } + + /// Read the counter of the local semaphore. + __forceinline__ __device__ uint64_t semaphoreGetLocal() const { return semaphore_.semaphoreGetLocal(); } + + /// Wait for the remote semaphore to send a signal. + __forceinline__ __device__ void wait() { semaphore_.wait(); } +#endif // __CUDACC__ +}; + +} // namespace mscclpp + +#endif // MSCCLPP_SM_CHANNEL_DEVICE_HPP_ diff --git a/include/mscclpp/utils.hpp b/include/mscclpp/utils.hpp index 081df4e58..7faec55f9 100644 --- a/include/mscclpp/utils.hpp +++ b/include/mscclpp/utils.hpp @@ -17,6 +17,7 @@ struct Timer { ~Timer(); + /// Returns the elapsed time in milliseconds. int64_t elapsed() const; void set(int timeout); diff --git a/python/config_py.cpp b/python/config_py.cpp deleted file mode 100644 index 1a3606226..000000000 --- a/python/config_py.cpp +++ /dev/null @@ -1,16 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -#include - -#include - -namespace nb = nanobind; -using namespace mscclpp; - -void register_config(nb::module_& m) { - nb::class_(m, "Config") - .def_static("get_instance", &Config::getInstance, nb::rv_policy::reference) - .def("get_bootstrap_connection_timeout_config", &Config::getBootstrapConnectionTimeoutConfig) - .def("set_bootstrap_connection_timeout_config", &Config::setBootstrapConnectionTimeoutConfig); -} diff --git a/python/core_py.cpp b/python/core_py.cpp index fb5052887..a65a443a6 100644 --- a/python/core_py.cpp +++ b/python/core_py.cpp @@ -17,8 +17,8 @@ extern void register_proxy_channel(nb::module_& m); extern void register_sm_channel(nb::module_& m); extern void register_fifo(nb::module_& m); extern void register_semaphore(nb::module_& m); -extern void register_config(nb::module_& m); extern void register_utils(nb::module_& m); +extern void register_numa(nb::module_& m); template void def_nonblocking_future(nb::handle& m, const std::string& typestr) { @@ -62,9 +62,10 @@ void register_core(nb::module_& m) { nb::arg("nRanks")) .def("create_unique_id", &TcpBootstrap::createUniqueId) .def("get_unique_id", &TcpBootstrap::getUniqueId) - .def("initialize", (void (TcpBootstrap::*)(UniqueId)) & TcpBootstrap::initialize, nb::arg("uniqueId")) - .def("initialize", (void (TcpBootstrap::*)(const std::string&)) & TcpBootstrap::initialize, - nb::arg("ifIpPortTrio")); + .def("initialize", (void (TcpBootstrap::*)(UniqueId, int64_t)) & TcpBootstrap::initialize, nb::arg("uniqueId"), + nb::arg("timeoutSec") = 30) + .def("initialize", (void (TcpBootstrap::*)(const std::string&, int64_t)) & TcpBootstrap::initialize, + nb::arg("ifIpPortTrio"), nb::arg("timeoutSec") = 30); nb::enum_(m, "Transport") .value("Unknown", Transport::Unknown) @@ -118,7 +119,7 @@ void register_core(nb::module_& m) { self->updateAndSync(dst, dstOffset, (uint64_t*)src, newValue); }, nb::arg("dst"), nb::arg("dstOffset"), nb::arg("src"), nb::arg("newValue")) - .def("flush", &Connection::flush) + .def("flush", &Connection::flush, nb::arg("timeoutUsec") = (int64_t)3e7) .def("remote_rank", &Connection::remoteRank) .def("tag", &Connection::tag) .def("transport", &Connection::transport) @@ -139,7 +140,8 @@ void register_core(nb::module_& m) { nb::arg("tag")) .def("recv_memory_on_setup", &Communicator::recvMemoryOnSetup, nb::arg("remoteRank"), nb::arg("tag")) .def("connect_on_setup", &Communicator::connectOnSetup, nb::arg("remoteRank"), nb::arg("tag"), - nb::arg("transport")) + nb::arg("transport"), nb::arg("ibMaxCqSize") = 1024, nb::arg("ibMaxCqPollNum") = 1, + nb::arg("ibMaxSendWr") = 8192, nb::arg("ibMaxWrPerSend") = 64) .def("setup", &Communicator::setup); } @@ -149,7 +151,7 @@ NB_MODULE(_mscclpp, m) { register_sm_channel(m); register_fifo(m); register_semaphore(m); - register_config(m); register_utils(m); register_core(m); + register_numa(m); } diff --git a/python/examples/bootstrap.py b/python/examples/bootstrap.py index b383222f7..ca0a521cf 100644 --- a/python/examples/bootstrap.py +++ b/python/examples/bootstrap.py @@ -1,13 +1,14 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -import mscclpp import argparse -import multiprocessing as mp import logging -import torch +import multiprocessing as mp import sys +import mscclpp +import torch + IB_TRANSPORTS = [ mscclpp.Transport.IB0, mscclpp.Transport.IB1, @@ -19,15 +20,19 @@ mscclpp.Transport.IB7, ] +# Use to hold the sm channels so they don't get garbage collected +sm_channels = [] + def setup_connections(comm, rank, world_size, element_size, proxy_service): simple_proxy_channels = [] + sm_semaphores = [] connections = [] remote_memories = [] memory = torch.zeros(element_size, dtype=torch.int32) memory = memory.to("cuda") - transport_flag = IB_TRANSPORTS[rank] or mscclpp.Transport.CudaIpc + transport_flag = mscclpp.TransportFlags(IB_TRANSPORTS[rank]) | mscclpp.Transport.CudaIpc ptr = memory.data_ptr() size = memory.numel() * memory.element_size() reg_mem = comm.register_memory(ptr, size, transport_flag) @@ -42,15 +47,26 @@ def setup_connections(comm, rank, world_size, element_size, proxy_service): remote_memories.append(remote_mem) comm.setup() + # Create simple proxy channels for i, conn in enumerate(connections): proxy_channel = mscclpp.SimpleProxyChannel( - proxy_service.device_channel(proxy_service.add_semaphore(conn)), + proxy_service.proxy_channel(proxy_service.build_and_add_semaphore(conn)), proxy_service.add_memory(remote_memories[i].get()), proxy_service.add_memory(reg_mem), ) simple_proxy_channels.append(mscclpp.device_handle(proxy_channel)) comm.setup() - return simple_proxy_channels + + # Create sm channels + for i, conn in enumerate(connections): + sm_chan = mscclpp.SmDevice2DeviceSemaphore(comm, conn) + sm_semaphores.append(sm_chan) + comm.setup() + + for i, conn in enumerate(sm_semaphores): + sm_chan = mscclpp.SmChannel(sm_semaphores[i], remote_memories[i].get(), ptr) + sm_channels.append(sm_chan) + return simple_proxy_channels, [mscclpp.device_handle(sm_chan) for sm_chan in sm_channels] def run(rank, args): @@ -60,7 +76,7 @@ def run(rank, args): boot = mscclpp.TcpBootstrap.create(rank, world_size) boot.initialize(args.if_ip_port_trio) comm = mscclpp.Communicator(boot) - proxy_service = mscclpp.ProxyService(comm) + proxy_service = mscclpp.ProxyService() logging.info("Rank: %d, setting up connections", rank) setup_connections(comm, rank, world_size, args.num_elements, proxy_service) diff --git a/python/examples/config.py b/python/examples/config.py deleted file mode 100644 index c489aa023..000000000 --- a/python/examples/config.py +++ /dev/null @@ -1,12 +0,0 @@ -import mscclpp - - -def main(): - config = mscclpp.Config.get_instance() - config.set_bootstrap_connection_timeout_config(15) - timeout = config.get_bootstrap_connection_timeout_config() - assert timeout == 15 - - -if __name__ == "__main__": - main() diff --git a/python/examples/send_recv.py b/python/examples/send_recv.py index bf7daf75d..d19a7be2d 100644 --- a/python/examples/send_recv.py +++ b/python/examples/send_recv.py @@ -1,10 +1,11 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -import mscclpp import argparse import time +import mscclpp + def main(args): if args.root: diff --git a/python/examples/utils.py b/python/examples/utils.py index ccb0410a2..7f2b4c989 100644 --- a/python/examples/utils.py +++ b/python/examples/utils.py @@ -1,7 +1,10 @@ -import mscclpp +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. import time +import mscclpp + def main(): timer = mscclpp.Timer() diff --git a/python/fifo_py.cpp b/python/fifo_py.cpp index 531dc011e..eb23118bf 100644 --- a/python/fifo_py.cpp +++ b/python/fifo_py.cpp @@ -11,15 +11,20 @@ using namespace mscclpp; void register_fifo(nb::module_& m) { nb::class_(m, "ProxyTrigger").def_rw("fst", &ProxyTrigger::fst).def_rw("snd", &ProxyTrigger::snd); - nb::class_(m, "DeviceProxyFifo") - .def_rw("triggers", &DeviceProxyFifo::triggers) - .def_rw("tail_replica", &DeviceProxyFifo::tailReplica) - .def_rw("head", &DeviceProxyFifo::head); + nb::class_(m, "FifoDeviceHandle") + .def_rw("triggers", &FifoDeviceHandle::triggers) + .def_rw("tail_replica", &FifoDeviceHandle::tailReplica) + .def_rw("head", &FifoDeviceHandle::head) + .def_rw("size", &FifoDeviceHandle::size) + .def_prop_ro("raw", [](const FifoDeviceHandle& self) -> nb::bytes { + return nb::bytes(reinterpret_cast(&self), sizeof(self)); + }); - nb::class_(m, "HostProxyFifo") - .def(nb::init<>()) - .def("poll", &HostProxyFifo::poll, nb::arg("trigger")) - .def("pop", &HostProxyFifo::pop) - .def("flush_tail", &HostProxyFifo::flushTail, nb::arg("sync") = false) - .def("device_fifo", &HostProxyFifo::deviceFifo); + nb::class_(m, "Fifo") + .def(nb::init(), nb::arg("size") = 128) + .def("poll", &Fifo::poll) + .def("pop", &Fifo::pop) + .def("flush_tail", &Fifo::flushTail, nb::arg("sync") = false) + .def("size", &Fifo::size) + .def("device_handle", &Fifo::deviceHandle); } diff --git a/python/mscclpp/__init__.py b/python/mscclpp/__init__.py index 23be4eb6a..89e889a22 100644 --- a/python/mscclpp/__init__.py +++ b/python/mscclpp/__init__.py @@ -1,10 +1,31 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from ._mscclpp import * import os as _os +from ._mscclpp import ( + Communicator, + Connection, + Fifo, + Host2DeviceSemaphore, + Host2HostSemaphore, + numa, + ProxyService, + RegisteredMemory, + SimpleProxyChannel, + SmChannel, + SmDevice2DeviceSemaphore, + TcpBootstrap, + Transport, + TransportFlags, +) + def get_include(): """Return the directory that contains the MSCCL++ headers.""" return _os.path.join(_os.path.dirname(__file__), "include") + + +def get_lib(): + """Return the directory that contains the MSCCL++ headers.""" + return _os.path.join(_os.path.dirname(__file__), "lib") diff --git a/python/numa_py.cpp b/python/numa_py.cpp new file mode 100644 index 000000000..2489a4793 --- /dev/null +++ b/python/numa_py.cpp @@ -0,0 +1,13 @@ +#include +namespace nb = nanobind; + +namespace mscclpp { +int getDeviceNumaNode(int cudaDev); +void numaBind(int node); +}; // namespace mscclpp + +void register_numa(nb::module_ &m) { + nb::module_ sub_m = m.def_submodule("numa", "numa functions"); + sub_m.def("get_device_numa_node", &mscclpp::getDeviceNumaNode); + sub_m.def("numa_bind", &mscclpp::numaBind); +} diff --git a/python/proxy_channel_py.cpp b/python/proxy_channel_py.cpp index 572811640..a483f99d2 100644 --- a/python/proxy_channel_py.cpp +++ b/python/proxy_channel_py.cpp @@ -16,22 +16,40 @@ void register_proxy_channel(nb::module_& m) { .def("stop_proxy", &BaseProxyService::stopProxy); nb::class_(m, "ProxyService") - .def(nb::init(), nb::arg("comm")) + .def(nb::init<>()) .def("start_proxy", &ProxyService::startProxy) .def("stop_proxy", &ProxyService::stopProxy) - .def("add_semaphore", &ProxyService::addSemaphore, nb::arg("connection")) + .def("build_and_add_semaphore", &ProxyService::buildAndAddSemaphore, nb::arg("comm"), nb::arg("connection")) + .def("add_semaphore", &ProxyService::addSemaphore, nb::arg("semaphore")) .def("add_memory", &ProxyService::addMemory, nb::arg("memory")) .def("semaphore", &ProxyService::semaphore, nb::arg("id")) - .def("device_channel", &ProxyService::deviceChannel, nb::arg("id")); + .def("proxy_channel", &ProxyService::proxyChannel, nb::arg("id")); nb::class_(m, "ProxyChannel") - .def(nb::init(), nb::arg("semaphoreId"), - nb::arg("semaphore"), nb::arg("fifo")); + .def(nb::init(), nb::arg("semaphoreId"), + nb::arg("semaphore"), nb::arg("fifo")) + .def("device_handle", &ProxyChannel::deviceHandle); + + nb::class_(m, "ProxyChannelDeviceHandle") + .def(nb::init<>()) + .def_rw("semaphoreId_", &ProxyChannel::DeviceHandle::semaphoreId_) + .def_rw("semaphore_", &ProxyChannel::DeviceHandle::semaphore_) + .def_rw("fifo_", &ProxyChannel::DeviceHandle::fifo_) + .def_prop_ro("raw", [](const ProxyChannel::DeviceHandle& self) -> nb::bytes { + return nb::bytes(reinterpret_cast(&self), sizeof(self)); + }); nb::class_(m, "SimpleProxyChannel") .def(nb::init(), nb::arg("proxyChan"), nb::arg("dst"), nb::arg("src")) - .def(nb::init(), nb::arg("proxyChan")); - - m.def("device_handle", &deviceHandle, nb::arg("proxyChannel")); - m.def("device_handle", &deviceHandle, nb::arg("simpleProxyChannel")); + .def(nb::init(), nb::arg("proxyChan")) + .def("device_handle", &SimpleProxyChannel::deviceHandle); + + nb::class_(m, "SimpleProxyChannelDeviceHandle") + .def(nb::init<>()) + .def_rw("proxyChan_", &SimpleProxyChannel::DeviceHandle::proxyChan_) + .def_rw("src_", &SimpleProxyChannel::DeviceHandle::src_) + .def_rw("dst_", &SimpleProxyChannel::DeviceHandle::dst_) + .def_prop_ro("raw", [](const SimpleProxyChannel::DeviceHandle& self) -> nb::bytes { + return nb::bytes(reinterpret_cast(&self), sizeof(self)); + }); }; diff --git a/python/semaphore_py.cpp b/python/semaphore_py.cpp index 379c4e4cd..015f28dee 100644 --- a/python/semaphore_py.cpp +++ b/python/semaphore_py.cpp @@ -20,7 +20,10 @@ void register_semaphore(nb::module_& m) { nb::class_(host2DeviceSemaphore, "DeviceHandle") .def(nb::init<>()) .def_rw("inbound_semaphore_id", &Host2DeviceSemaphore::DeviceHandle::inboundSemaphoreId) - .def_rw("expected_inbound_semaphore_id", &Host2DeviceSemaphore::DeviceHandle::expectedInboundSemaphoreId); + .def_rw("expected_inbound_semaphore_id", &Host2DeviceSemaphore::DeviceHandle::expectedInboundSemaphoreId) + .def_prop_ro("raw", [](const Host2DeviceSemaphore::DeviceHandle& self) -> nb::bytes { + return nb::bytes(reinterpret_cast(&self), sizeof(self)); + }); nb::class_(m, "Host2HostSemaphore") .def(nb::init>(), nb::arg("communicator"), nb::arg("connection")) @@ -38,5 +41,8 @@ void register_semaphore(nb::module_& m) { .def_rw("inboundSemaphoreId", &SmDevice2DeviceSemaphore::DeviceHandle::inboundSemaphoreId) .def_rw("outboundSemaphoreId", &SmDevice2DeviceSemaphore::DeviceHandle::outboundSemaphoreId) .def_rw("remoteInboundSemaphoreId", &SmDevice2DeviceSemaphore::DeviceHandle::remoteInboundSemaphoreId) - .def_rw("expectedInboundSemaphoreId", &SmDevice2DeviceSemaphore::DeviceHandle::expectedInboundSemaphoreId); + .def_rw("expectedInboundSemaphoreId", &SmDevice2DeviceSemaphore::DeviceHandle::expectedInboundSemaphoreId) + .def_prop_ro("raw", [](const SmDevice2DeviceSemaphore::DeviceHandle& self) -> nb::bytes { + return nb::bytes(reinterpret_cast(&self), sizeof(self)); + }); } diff --git a/python/sm_channel_py.cpp b/python/sm_channel_py.cpp index d02ac30e1..04a51eb8b 100644 --- a/python/sm_channel_py.cpp +++ b/python/sm_channel_py.cpp @@ -13,11 +13,23 @@ using namespace mscclpp; void register_sm_channel(nb::module_& m) { nb::class_ smChannel(m, "SmChannel"); smChannel - .def(nb::init, RegisteredMemory, void*, void*>(), nb::arg("semaphore"), - nb::arg("dst"), nb::arg("src"), nb::arg("getPacketBuffer")) + .def("__init__", + [](SmChannel* smChannel, std::shared_ptr semaphore, RegisteredMemory dst, + uintptr_t src) { new (smChannel) SmChannel(semaphore, dst, (void*)src); }) + .def("__init__", + [](SmChannel* smChannel, std::shared_ptr semaphore, RegisteredMemory dst, + uintptr_t src, uintptr_t get_packet_buffer) { + new (smChannel) SmChannel(semaphore, dst, (void*)src, (void*)get_packet_buffer); + }) .def("device_handle", &SmChannel::deviceHandle); - nb::class_(smChannel, "DeviceHandle"); - - m.def("device_handle", &deviceHandle, nb::arg("smChannel")); + nb::class_(m, "SmChannelDeviceHandle") + .def(nb::init<>()) + .def_rw("semaphore_", &SmChannel::DeviceHandle::semaphore_) + .def_rw("src_", &SmChannel::DeviceHandle::src_) + .def_rw("dst_", &SmChannel::DeviceHandle::dst_) + .def_rw("getPacketBuffer_", &SmChannel::DeviceHandle::getPacketBuffer_) + .def_prop_ro("raw", [](const SmChannel::DeviceHandle& self) -> nb::bytes { + return nb::bytes(reinterpret_cast(&self), sizeof(self)); + }); }; diff --git a/src/bootstrap/bootstrap.cc b/src/bootstrap/bootstrap.cc index 3cea2120a..649a1f62e 100644 --- a/src/bootstrap/bootstrap.cc +++ b/src/bootstrap/bootstrap.cc @@ -4,7 +4,6 @@ #include #include -#include #include #include #include @@ -59,9 +58,9 @@ class TcpBootstrap::Impl { public: Impl(int rank, int nRanks); ~Impl(); - void initialize(const UniqueId& uniqueId); - void initialize(const std::string& ifIpPortTrio); - void establishConnections(); + void initialize(const UniqueId& uniqueId, int64_t timeoutSec); + void initialize(const std::string& ifIpPortTrio, int64_t timeoutSec); + void establishConnections(int64_t timeoutSec); UniqueId createUniqueId(); UniqueId getUniqueId() const; int getRank(); @@ -133,15 +132,15 @@ int TcpBootstrap::Impl::getRank() { return rank_; } int TcpBootstrap::Impl::getNranks() { return nRanks_; } -void TcpBootstrap::Impl::initialize(const UniqueId& uniqueId) { +void TcpBootstrap::Impl::initialize(const UniqueId& uniqueId, int64_t timeoutSec) { netInit("", ""); std::memcpy(&uniqueId_, &uniqueId, sizeof(uniqueId_)); - establishConnections(); + establishConnections(timeoutSec); } -void TcpBootstrap::Impl::initialize(const std::string& ifIpPortTrio) { +void TcpBootstrap::Impl::initialize(const std::string& ifIpPortTrio, int64_t timeoutSec) { // first check if it is a trio int nColons = 0; for (auto c : ifIpPortTrio) { @@ -167,7 +166,7 @@ void TcpBootstrap::Impl::initialize(const std::string& ifIpPortTrio) { bootstrapCreateRoot(); } - establishConnections(); + establishConnections(timeoutSec); } TcpBootstrap::Impl::~Impl() { @@ -308,8 +307,8 @@ void TcpBootstrap::Impl::netInit(std::string ipPortPair, std::string interface) } \ } while (0); -void TcpBootstrap::Impl::establishConnections() { - const int64_t connectionTimeoutUs = (int64_t)Config::getInstance()->getBootstrapConnectionTimeoutConfig() * 1000000; +void TcpBootstrap::Impl::establishConnections(int64_t timeoutSec) { + const int64_t connectionTimeoutUs = timeoutSec * 1000000; Timer timer; SocketAddress nextAddr; ExtInfo info; @@ -317,6 +316,10 @@ void TcpBootstrap::Impl::establishConnections() { TRACE(MSCCLPP_INIT, "rank %d nranks %d", rank_, nRanks_); auto getLeftTime = [&]() { + if (connectionTimeoutUs < 0) { + // no timeout: always return a large number + return int64_t(1e9); + } int64_t timeout = connectionTimeoutUs - timer.elapsed(); if (timeout <= 0) throw Error("TcpBootstrap connection timeout", ErrorCode::Timeout); return timeout; @@ -489,9 +492,13 @@ MSCCLPP_API_CPP void TcpBootstrap::recv(void* data, int size, int peer, int tag) MSCCLPP_API_CPP void TcpBootstrap::allGather(void* allData, int size) { pimpl_->allGather(allData, size); } -MSCCLPP_API_CPP void TcpBootstrap::initialize(UniqueId uniqueId) { pimpl_->initialize(uniqueId); } +MSCCLPP_API_CPP void TcpBootstrap::initialize(UniqueId uniqueId, int64_t timeoutSec) { + pimpl_->initialize(uniqueId, timeoutSec); +} -MSCCLPP_API_CPP void TcpBootstrap::initialize(const std::string& ipPortPair) { pimpl_->initialize(ipPortPair); } +MSCCLPP_API_CPP void TcpBootstrap::initialize(const std::string& ipPortPair, int64_t timeoutSec) { + pimpl_->initialize(ipPortPair, timeoutSec); +} MSCCLPP_API_CPP void TcpBootstrap::barrier() { pimpl_->barrier(); } diff --git a/src/bootstrap/socket.cc b/src/bootstrap/socket.cc index c7db6030d..7a0ba4f1b 100644 --- a/src/bootstrap/socket.cc +++ b/src/bootstrap/socket.cc @@ -13,7 +13,6 @@ #include #include -#include #include #include #include diff --git a/src/communicator.cc b/src/communicator.cc index 0480f0231..cc0323556 100644 --- a/src/communicator.cc +++ b/src/communicator.cc @@ -94,7 +94,11 @@ MSCCLPP_API_CPP NonblockingFuture Communicator::recvMemoryOnSe return NonblockingFuture(memoryReceiver->memoryPromise_.get_future()); } -MSCCLPP_API_CPP std::shared_ptr Communicator::connectOnSetup(int remoteRank, int tag, Transport transport) { +MSCCLPP_API_CPP std::shared_ptr Communicator::connectOnSetup(int remoteRank, int tag, Transport transport, + int ibMaxCqSize /*=1024*/, + int ibMaxCqPollNum /*=1*/, + int ibMaxSendWr /*=8192*/, + int ibMaxWrPerSend /*=64*/) { std::shared_ptr conn; if (transport == Transport::CudaIpc) { // sanity check: make sure the IPC connection is being made within a node @@ -111,7 +115,8 @@ MSCCLPP_API_CPP std::shared_ptr Communicator::connectOnSetup(int rem pimpl->bootstrap_->getRank(), pimpl->rankToHash_[pimpl->bootstrap_->getRank()], remoteRank, pimpl->rankToHash_[remoteRank]); } else if (AllIBTransports.has(transport)) { - auto ibConn = std::make_shared(remoteRank, tag, transport, *pimpl); + auto ibConn = std::make_shared(remoteRank, tag, transport, ibMaxCqSize, ibMaxCqPollNum, ibMaxSendWr, + ibMaxWrPerSend, *pimpl); conn = ibConn; INFO(MSCCLPP_NET, "IB connection between rank %d(%lx) via %s and remoteRank %d(%lx) created", pimpl->bootstrap_->getRank(), pimpl->rankToHash_[pimpl->bootstrap_->getRank()], diff --git a/src/config.cc b/src/config.cc deleted file mode 100644 index 220700467..000000000 --- a/src/config.cc +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -#include - -namespace mscclpp { -Config Config::instance_; - -Config* Config::getInstance() { return &instance_; } - -int Config::getBootstrapConnectionTimeoutConfig() { return bootstrapConnectionTimeout; } - -void Config::setBootstrapConnectionTimeoutConfig(int timeout) { bootstrapConnectionTimeout = timeout; } -} // namespace mscclpp diff --git a/src/connection.cc b/src/connection.cc index 931ae7e5e..112e11783 100644 --- a/src/connection.cc +++ b/src/connection.cc @@ -70,21 +70,26 @@ void CudaIpcConnection::updateAndSync(RegisteredMemory dst, uint64_t dstOffset, // npkitCollectEntryEvent(conn, NPKIT_EVENT_DMA_SEND_DATA_ENTRY, (uint32_t)size); } -void CudaIpcConnection::flush() { +void CudaIpcConnection::flush(int64_t timeoutUsec) { + if (timeoutUsec >= 0) { + INFO(MSCCLPP_P2P, "CudaIpcConnection flush: timeout is not supported, ignored"); + } AvoidCudaGraphCaptureGuard guard; MSCCLPP_CUDATHROW(cudaStreamSynchronize(stream_)); // npkitCollectExitEvents(conn, NPKIT_EVENT_DMA_SEND_EXIT); + INFO(MSCCLPP_P2P, "CudaIpcConnection flushing connection to remote rank %d", remoteRank()); } // IBConnection -IBConnection::IBConnection(int remoteRank, int tag, Transport transport, Communicator::Impl& commImpl) +IBConnection::IBConnection(int remoteRank, int tag, Transport transport, int maxCqSize, int maxCqPollNum, int maxSendWr, + int maxWrPerSend, Communicator::Impl& commImpl) : ConnectionBase(remoteRank, tag), transport_(transport), remoteTransport_(Transport::Unknown), numSignaledSends(0), dummyAtomicSource_(std::make_unique(0)) { - qp = commImpl.getIbContext(transport)->createQp(); + qp = commImpl.getIbContext(transport)->createQp(maxCqSize, maxCqPollNum, maxSendWr, 0, maxWrPerSend); dummyAtomicSourceMem_ = RegisteredMemory(std::make_shared( dummyAtomicSource_.get(), sizeof(uint64_t), commImpl.bootstrap_->getRank(), transport, commImpl)); validateTransport(dummyAtomicSourceMem_, transport); @@ -144,7 +149,7 @@ void IBConnection::updateAndSync(RegisteredMemory dst, uint64_t dstOffset, uint6 oldValue, newValue); } -void IBConnection::flush() { +void IBConnection::flush(int64_t timeoutUsec) { Timer timer; while (numSignaledSends) { int wcNum = qp->pollCq(); @@ -153,8 +158,8 @@ void IBConnection::flush() { } auto elapsed = timer.elapsed(); - if (elapsed > MSCCLPP_POLLING_WAIT) { - throw Error("pollCq is stuck: waited for " + std::to_string(elapsed / 1e6) + " seconds. Expected " + + if ((timeoutUsec >= 0) && (elapsed * 1e3 > timeoutUsec)) { + throw Error("pollCq is stuck: waited for " + std::to_string(elapsed / 1e3) + " seconds. Expected " + std::to_string(numSignaledSends) + " signals", ErrorCode::InternalError); } @@ -168,6 +173,7 @@ void IBConnection::flush() { } } } + INFO(MSCCLPP_NET, "IBConnection flushing connection to remote rank %d", remoteRank()); // npkitCollectExitEvents(conn, NPKIT_EVENT_IB_SEND_EXIT); } diff --git a/src/fifo.cc b/src/fifo.cc index 0d0516b4d..5e21c9afa 100644 --- a/src/fifo.cc +++ b/src/fifo.cc @@ -1,20 +1,18 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. -#include - #include #include -#include #include "api.h" namespace mscclpp { -struct HostProxyFifo::Impl { +struct Fifo::Impl { UniqueCudaHostPtr triggers; UniqueCudaPtr head; UniqueCudaPtr tailReplica; + const int size; // allocated on the host. Only accessed by the host. This is a copy of the // value pointed to by fifoTailDev and the invariant is that @@ -28,28 +26,33 @@ struct HostProxyFifo::Impl { // for transferring fifo tail CudaStreamWithFlags stream; - Impl() - : triggers(makeUniqueCudaHost(MSCCLPP_PROXY_FIFO_SIZE)), + Impl(int size) + : triggers(makeUniqueCudaHost(size)), head(allocUniqueCuda()), tailReplica(allocUniqueCuda()), + size(size), hostTail(0), stream(cudaStreamNonBlocking) {} }; -MSCCLPP_API_CPP HostProxyFifo::HostProxyFifo() : pimpl(std::make_unique()) {} -MSCCLPP_API_CPP HostProxyFifo::~HostProxyFifo() = default; +MSCCLPP_API_CPP Fifo::Fifo(int size) : pimpl(std::make_unique(size)) {} +MSCCLPP_API_CPP Fifo::~Fifo() = default; -MSCCLPP_API_CPP void HostProxyFifo::poll(ProxyTrigger* trigger) { - __m128i xmm0 = _mm_load_si128((__m128i*)&pimpl->triggers.get()[pimpl->hostTail % MSCCLPP_PROXY_FIFO_SIZE]); - _mm_store_si128((__m128i*)trigger, xmm0); +MSCCLPP_API_CPP ProxyTrigger Fifo::poll() { + ProxyTrigger trigger; + volatile ProxyTrigger* ptr = + reinterpret_cast(&pimpl->triggers.get()[pimpl->hostTail % pimpl->size]); + trigger.fst = ptr->fst; + trigger.snd = ptr->snd; + return trigger; } -MSCCLPP_API_CPP void HostProxyFifo::pop() { - *(volatile uint64_t*)(&pimpl->triggers.get()[pimpl->hostTail % MSCCLPP_PROXY_FIFO_SIZE]) = 0; +MSCCLPP_API_CPP void Fifo::pop() { + *(volatile uint64_t*)(&pimpl->triggers.get()[pimpl->hostTail % pimpl->size]) = 0; (pimpl->hostTail)++; } -MSCCLPP_API_CPP void HostProxyFifo::flushTail(bool sync) { +MSCCLPP_API_CPP void Fifo::flushTail(bool sync) { // Flush the tail to device memory. This is either triggered every ProxyFlushPeriod to make sure that the fifo can // make progress even if there is no request mscclppSync. However, mscclppSync type is for flush request. MSCCLPP_CUDATHROW(cudaMemcpyAsync(pimpl->tailReplica.get(), &pimpl->hostTail, sizeof(uint64_t), @@ -59,12 +62,15 @@ MSCCLPP_API_CPP void HostProxyFifo::flushTail(bool sync) { } } -MSCCLPP_API_CPP DeviceProxyFifo HostProxyFifo::deviceFifo() { - DeviceProxyFifo deviceFifo; - deviceFifo.triggers = pimpl->triggers.get(); - deviceFifo.head = pimpl->head.get(); - deviceFifo.tailReplica = pimpl->tailReplica.get(); - return deviceFifo; +MSCCLPP_API_CPP int Fifo::size() const { return pimpl->size; } + +MSCCLPP_API_CPP FifoDeviceHandle Fifo::deviceHandle() { + FifoDeviceHandle deviceHandle; + deviceHandle.triggers = pimpl->triggers.get(); + deviceHandle.head = pimpl->head.get(); + deviceHandle.tailReplica = pimpl->tailReplica.get(); + deviceHandle.size = pimpl->size; + return deviceHandle; } } // namespace mscclpp diff --git a/src/ib.cc b/src/ib.cc index 34df977fa..7a93a650e 100644 --- a/src/ib.cc +++ b/src/ib.cc @@ -16,8 +16,6 @@ #include "api.h" #include "debug.h" -#define MAXCONNECTIONS 64 - namespace mscclpp { IbMr::IbMr(ibv_pd* pd, void* buff, std::size_t size) : buff(buff) { @@ -54,8 +52,10 @@ const void* IbMr::getBuff() const { return this->buff; } uint32_t IbMr::getLkey() const { return this->mr->lkey; } -IbQp::IbQp(ibv_context* ctx, ibv_pd* pd, int port) { - this->cq = ibv_create_cq(ctx, MSCCLPP_IB_CQ_SIZE, nullptr, nullptr, 0); +IbQp::IbQp(ibv_context* ctx, ibv_pd* pd, int port, int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, + int maxWrPerSend) + : maxCqPollNum(maxCqPollNum), maxWrPerSend(maxWrPerSend) { + this->cq = ibv_create_cq(ctx, maxCqSize, nullptr, nullptr, 0); if (this->cq == nullptr) { std::stringstream err; err << "ibv_create_cq failed (errno " << errno << ")"; @@ -68,8 +68,8 @@ IbQp::IbQp(ibv_context* ctx, ibv_pd* pd, int port) { qpInitAttr.send_cq = this->cq; qpInitAttr.recv_cq = this->cq; qpInitAttr.qp_type = IBV_QPT_RC; - qpInitAttr.cap.max_send_wr = MAXCONNECTIONS * MSCCLPP_PROXY_FIFO_SIZE; - qpInitAttr.cap.max_recv_wr = MAXCONNECTIONS * MSCCLPP_PROXY_FIFO_SIZE; + qpInitAttr.cap.max_send_wr = maxSendWr; + qpInitAttr.cap.max_recv_wr = maxRecvWr; qpInitAttr.cap.max_send_sge = 1; qpInitAttr.cap.max_recv_sge = 1; qpInitAttr.cap.max_inline_data = 0; @@ -118,9 +118,9 @@ IbQp::IbQp(ibv_context* ctx, ibv_pd* pd, int port) { } this->qp = _qp; this->wrn = 0; - this->wrs = std::make_unique(MSCCLPP_IB_MAX_SENDS); - this->sges = std::make_unique(MSCCLPP_IB_MAX_SENDS); - this->wcs = std::make_unique(MSCCLPP_IB_CQ_POLL_NUM); + this->wrs = std::make_unique(maxWrPerSend); + this->sges = std::make_unique(maxWrPerSend); + this->wcs = std::make_unique(maxCqPollNum); } IbQp::~IbQp() { @@ -182,9 +182,9 @@ void IbQp::rts() { } IbQp::WrInfo IbQp::getNewWrInfo() { - if (this->wrn >= MSCCLPP_IB_MAX_SENDS) { + if (this->wrn >= this->maxWrPerSend) { std::stringstream err; - err << "too many outstanding work requests. limit is " << MSCCLPP_IB_MAX_SENDS; + err << "too many outstanding work requests. limit is " << this->maxWrPerSend; throw mscclpp::Error(err.str(), ErrorCode::InvalidUsage); } int wrn = this->wrn; @@ -269,7 +269,7 @@ void IbQp::postRecv(uint64_t wrId) { } } -int IbQp::pollCq() { return ibv_poll_cq(this->cq, MSCCLPP_IB_CQ_POLL_NUM, this->wcs.get()); } +int IbQp::pollCq() { return ibv_poll_cq(this->cq, this->maxCqPollNum, this->wcs.get()); } IbQpInfo& IbQp::getInfo() { return this->info; } @@ -335,7 +335,8 @@ int IbCtx::getAnyActivePort() const { return -1; } -IbQp* IbCtx::createQp(int port /*=-1*/) { +IbQp* IbCtx::createQp(int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, int maxWrPerSend, + int port /*=-1*/) { if (port == -1) { port = this->getAnyActivePort(); if (port == -1) { @@ -344,7 +345,7 @@ IbQp* IbCtx::createQp(int port /*=-1*/) { } else if (!this->isPortUsable(port)) { throw mscclpp::Error("invalid IB port: " + std::to_string(port), ErrorCode::InternalError); } - qps.emplace_back(new IbQp(this->ctx, this->pd, port)); + qps.emplace_back(new IbQp(this->ctx, this->pd, port, maxCqSize, maxCqPollNum, maxSendWr, maxRecvWr, maxWrPerSend)); return qps.back().get(); } diff --git a/src/include/connection.hpp b/src/include/connection.hpp index f41161682..0475691c9 100644 --- a/src/include/connection.hpp +++ b/src/include/connection.hpp @@ -4,9 +4,6 @@ #ifndef MSCCLPP_CONNECTION_HPP_ #define MSCCLPP_CONNECTION_HPP_ -// TODO(saemal): make this configurable -#define MSCCLPP_POLLING_WAIT 3e7 // in microseconds - #include #include @@ -46,7 +43,7 @@ class CudaIpcConnection : public ConnectionBase { uint64_t size) override; void updateAndSync(RegisteredMemory dst, uint64_t dstOffset, uint64_t* src, uint64_t newValue) override; - void flush() override; + void flush(int64_t timeoutUsec) override; }; class IBConnection : public ConnectionBase { @@ -59,7 +56,8 @@ class IBConnection : public ConnectionBase { mscclpp::TransportInfo dstTransportInfo_; public: - IBConnection(int remoteRank, int tag, Transport transport, Communicator::Impl& commImpl); + IBConnection(int remoteRank, int tag, Transport transport, int maxCqSize, int maxCqPollNum, int maxSendWr, + int maxWrPerSend, Communicator::Impl& commImpl); Transport transport() override; @@ -69,7 +67,7 @@ class IBConnection : public ConnectionBase { uint64_t size) override; void updateAndSync(RegisteredMemory dst, uint64_t dstOffset, uint64_t* src, uint64_t newValue) override; - void flush() override; + void flush(int64_t timeoutUsec) override; void beginSetup(std::shared_ptr bootstrap) override; diff --git a/src/include/ib.hpp b/src/include/ib.hpp index 14da20474..1bec30b85 100644 --- a/src/include/ib.hpp +++ b/src/include/ib.hpp @@ -8,11 +8,6 @@ #include #include -#define MSCCLPP_IB_CQ_SIZE 1024 -#define MSCCLPP_IB_CQ_POLL_NUM 1 -#define MSCCLPP_IB_MAX_SENDS 64 -#define MSCCLPP_IB_MAX_DEVS 8 - // Forward declarations of IB structures struct ibv_context; struct ibv_pd; @@ -84,7 +79,8 @@ class IbQp { ibv_sge* sge; }; - IbQp(ibv_context* ctx, ibv_pd* pd, int port); + IbQp(ibv_context* ctx, ibv_pd* pd, int port, int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, + int maxWrPerSend); WrInfo getNewWrInfo(); IbQpInfo info; @@ -96,6 +92,9 @@ class IbQp { std::unique_ptr sges; int wrn; + const int maxCqPollNum; + const int maxWrPerSend; + friend class IbCtx; }; @@ -104,7 +103,7 @@ class IbCtx { IbCtx(const std::string& devName); ~IbCtx(); - IbQp* createQp(int port = -1); + IbQp* createQp(int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, int maxWrPerSend, int port = -1); const IbMr* registerMr(void* buff, std::size_t size); const std::string& getDevName() const; @@ -122,4 +121,4 @@ class IbCtx { } // namespace mscclpp -#endif // MSCCLPP_IB_HPP_ \ No newline at end of file +#endif // MSCCLPP_IB_HPP_ diff --git a/src/numa.cc b/src/numa.cc index c66256e38..a1d4129d1 100644 --- a/src/numa.cc +++ b/src/numa.cc @@ -6,6 +6,8 @@ #include #include +#include "api.h" + // Convert a logical cudaDev index to the NVML device minor number static const std::string getBusId(int cudaDev) { // On most systems, the PCI bus ID comes back as in the 0000:00:00.0 @@ -22,7 +24,7 @@ static const std::string getBusId(int cudaDev) { namespace mscclpp { -int getDeviceNumaNode(int cudaDev) { +MSCCLPP_API_CPP int getDeviceNumaNode(int cudaDev) { std::string busId = getBusId(cudaDev); std::string file_str = "/sys/bus/pci/devices/" + busId + "/numa_node"; std::ifstream file(file_str); @@ -37,7 +39,7 @@ int getDeviceNumaNode(int cudaDev) { return numaNode; } -void numaBind(int node) { +MSCCLPP_API_CPP void numaBind(int node) { int totalNumNumaNodes = numa_num_configured_nodes(); if (node < 0 || node >= totalNumNumaNodes) { throw Error( diff --git a/src/proxy.cc b/src/proxy.cc index e2aa85694..3fe3b1645 100644 --- a/src/proxy.cc +++ b/src/proxy.cc @@ -15,14 +15,13 @@ namespace mscclpp { const int ProxyStopCheckPeriod = 1000; // Unless explicitly requested, a flush of the tail to device memory is triggered for every ProxyFlushPeriod. -// As long as MSCCLPP_PROXY_FIFO_SIZE is large enough, having a stale tail is not a problem. +// As long as the FIFO size is large enough, having a stale tail is not a problem. const int ProxyFlushPeriod = 4; -static_assert(MSCCLPP_PROXY_FIFO_SIZE >= ProxyFlushPeriod, "MSCCLPP_PROXY_FIFO_SIZE is too small"); struct Proxy::Impl { ProxyHandler handler; std::function threadInit; - HostProxyFifo fifo; + Fifo fifo; std::thread service; std::atomic_bool running; @@ -53,10 +52,12 @@ MSCCLPP_API_CPP void Proxy::start() { pimpl->threadInit(); ProxyHandler handler = this->pimpl->handler; - HostProxyFifo& fifo = this->pimpl->fifo; + Fifo& fifo = this->pimpl->fifo; std::atomic_bool& running = this->pimpl->running; ProxyTrigger trigger; + int flushPeriod = std::min(fifo.size(), ProxyFlushPeriod); + int runCnt = ProxyStopCheckPeriod; uint64_t flushCnt = 0; for (;;) { @@ -67,19 +68,19 @@ MSCCLPP_API_CPP void Proxy::start() { } } // Poll to see if we are ready to send anything - fifo.poll(&trigger); - if (trigger.fst == 0) { // TODO: this check is a potential pitfall for custom triggers - continue; // there is one in progress + trigger = fifo.poll(); + if (trigger.fst == 0 || trigger.snd == 0) { // TODO: this check is a potential pitfall for custom triggers + continue; // there is one in progress } + trigger.snd ^= ((uint64_t)1 << (uint64_t)63); // this is where the last bit of snd is reverted. ProxyHandlerResult result = handler(trigger); // Send completion: reset only the high 64 bits fifo.pop(); - // Flush the tail to device memory. This is either triggered every ProxyFlushPeriod to make sure - // that the fifo can make progress even if there is no request mscclppSync. However, mscclppSync type is for flush - // request. - if ((++flushCnt % ProxyFlushPeriod) == 0 || result == ProxyHandlerResult::FlushFifoTailAndContinue) { + // Flush the tail to device memory. This is either triggered every flushPeriod to make sure that the fifo can make + // progress even if there is no request mscclppSync. However, mscclppSync type is for flush request. + if ((++flushCnt % flushPeriod) == 0 || result == ProxyHandlerResult::FlushFifoTailAndContinue) { // TODO: relocate this check: || (trigger.fields.type & mscclppSync) fifo.flushTail(); } @@ -107,6 +108,6 @@ MSCCLPP_API_CPP void Proxy::stop() { } } -MSCCLPP_API_CPP HostProxyFifo& Proxy::fifo() { return pimpl->fifo; } +MSCCLPP_API_CPP Fifo& Proxy::fifo() { return pimpl->fifo; } } // namespace mscclpp diff --git a/src/proxy_channel.cc b/src/proxy_channel.cc index 90615a119..c6a4e243d 100644 --- a/src/proxy_channel.cc +++ b/src/proxy_channel.cc @@ -1,31 +1,36 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. +#include #include #include "api.h" #include "debug.h" -#include "numa.hpp" namespace mscclpp { MSCCLPP_API_CPP ProxyChannel::ProxyChannel(SemaphoreId semaphoreId, Host2DeviceSemaphore::DeviceHandle semaphore, - DeviceProxyFifo fifo) + FifoDeviceHandle fifo) : semaphoreId_(semaphoreId), semaphore_(semaphore), fifo_(fifo) {} MSCCLPP_API_CPP SimpleProxyChannel::SimpleProxyChannel(ProxyChannel proxyChan, MemoryId dst, MemoryId src) : proxyChan_(proxyChan), dst_(dst), src_(src) {} -MSCCLPP_API_CPP ProxyService::ProxyService(Communicator& communicator) - : communicator_(communicator), - proxy_([&](ProxyTrigger triggerRaw) { return handleTrigger(triggerRaw); }, [&]() { bindThread(); }) { +MSCCLPP_API_CPP ProxyService::ProxyService() + : proxy_([&](ProxyTrigger triggerRaw) { return handleTrigger(triggerRaw); }, [&]() { bindThread(); }) { int cudaDevice; MSCCLPP_CUDATHROW(cudaGetDevice(&cudaDevice)); deviceNumaNode = getDeviceNumaNode(cudaDevice); } -MSCCLPP_API_CPP SemaphoreId ProxyService::addSemaphore(std::shared_ptr connection) { - semaphores_.push_back(std::make_shared(communicator_, connection)); +MSCCLPP_API_CPP SemaphoreId ProxyService::buildAndAddSemaphore(Communicator& communicator, + std::shared_ptr connection) { + semaphores_.push_back(std::make_shared(communicator, connection)); + return semaphores_.size() - 1; +} + +MSCCLPP_API_CPP SemaphoreId ProxyService::addSemaphore(std::shared_ptr semaphore) { + semaphores_.push_back(semaphore); return semaphores_.size() - 1; } @@ -38,8 +43,8 @@ MSCCLPP_API_CPP std::shared_ptr ProxyService::semaphore(Se return semaphores_[id]; } -MSCCLPP_API_CPP ProxyChannel ProxyService::deviceChannel(SemaphoreId id) { - return ProxyChannel(id, semaphores_[id]->deviceHandle(), proxy_.fifo().deviceFifo()); +MSCCLPP_API_CPP ProxyChannel ProxyService::proxyChannel(SemaphoreId id) { + return ProxyChannel(id, semaphores_[id]->deviceHandle(), proxy_.fifo().deviceHandle()); } MSCCLPP_API_CPP void ProxyService::startProxy() { proxy_.start(); } @@ -78,14 +83,12 @@ ProxyHandlerResult ProxyService::handleTrigger(ProxyTrigger triggerRaw) { return result; } -template <> -DeviceHandle deviceHandle(ProxyChannel&& proxyChannel) { - return proxyChannel; +MSCCLPP_API_CPP ProxyChannel::DeviceHandle ProxyChannel::deviceHandle() const { + return ProxyChannel::DeviceHandle{.semaphoreId_ = semaphoreId_, .semaphore_ = semaphore_, .fifo_ = fifo_}; } -template <> -DeviceHandle deviceHandle(SimpleProxyChannel&& simpleProxyChannel) { - return simpleProxyChannel; +MSCCLPP_API_CPP SimpleProxyChannel::DeviceHandle SimpleProxyChannel::deviceHandle() const { + return SimpleProxyChannel::DeviceHandle{.proxyChan_ = proxyChan_.deviceHandle(), .dst_ = dst_, .src_ = src_}; } } // namespace mscclpp diff --git a/test/allgather_test_cpp.cu b/test/allgather_test_cpp.cu index a45114ab4..9bf81b2b1 100644 --- a/test/allgather_test_cpp.cu +++ b/test/allgather_test_cpp.cu @@ -207,8 +207,8 @@ void initializeAndAllocateAllGatherData(int rank, int world_size, size_t dataSiz CUDACHECK(cudaMemcpy(*data_d, *data_h, dataSize, cudaMemcpyHostToDevice)); } -void setupMscclppConnections(int rank, int world_size, mscclpp::Communicator& comm, - mscclpp::ProxyService& channelService, int* data_d, size_t dataSize) { +void setupMscclppConnections(int rank, int world_size, mscclpp::Communicator& comm, mscclpp::ProxyService& proxyService, + int* data_d, size_t dataSize) { int thisNode = rankToNode(rank); int cudaNum = rankToLocalRank(rank); std::string ibDevStr = "mlx5_ib" + std::to_string(cudaNum); @@ -226,7 +226,7 @@ void setupMscclppConnections(int rank, int world_size, mscclpp::Communicator& co transport = ibTransport; } // Connect with all other ranks - semaphoreIds.push_back(channelService.addSemaphore(comm.connectOnSetup(r, 0, transport))); + semaphoreIds.push_back(proxyService.buildAndAddSemaphore(comm, comm.connectOnSetup(r, 0, transport))); auto memory = comm.registerMemory(data_d, dataSize, mscclpp::Transport::CudaIpc | ibTransport); localMemories.push_back(memory); comm.sendMemoryOnSetup(memory, r, 0); @@ -238,8 +238,8 @@ void setupMscclppConnections(int rank, int world_size, mscclpp::Communicator& co std::vector> proxyChannels; for (size_t i = 0; i < semaphoreIds.size(); ++i) { proxyChannels.push_back(mscclpp::deviceHandle(mscclpp::SimpleProxyChannel( - channelService.deviceChannel(semaphoreIds[i]), channelService.addMemory(remoteMemories[i].get()), - channelService.addMemory(localMemories[i])))); + proxyService.proxyChannel(semaphoreIds[i]), proxyService.addMemory(remoteMemories[i].get()), + proxyService.addMemory(localMemories[i])))); } assert(proxyChannels.size() < sizeof(constProxyChans) / sizeof(DeviceHandle)); @@ -396,16 +396,16 @@ int main(int argc, const char* argv[]) { auto bootstrap = std::make_shared(rank, world_size); bootstrap->initialize(ip_port); mscclpp::Communicator comm(bootstrap); - mscclpp::ProxyService channelService(comm); + mscclpp::ProxyService proxyService; if (rank == 0) printf("Initializing data for allgather test\n"); initializeAndAllocateAllGatherData(rank, world_size, dataSize, nelemsPerGPU, &data_h, &data_d); if (rank == 0) printf("Setting up the connection in MSCCL++\n"); - setupMscclppConnections(rank, world_size, comm, channelService, data_d, dataSize); + setupMscclppConnections(rank, world_size, comm, proxyService, data_d, dataSize); if (rank == 0) printf("Launching MSCCL++ proxy threads\n"); - channelService.startProxy(); + proxyService.startProxy(); if (rank == 0) printf("Testing the correctness of AllGather implementation\n"); cudaStream_t stream; @@ -480,7 +480,7 @@ int main(int argc, const char* argv[]) { bootstrap->allGather(tmp, sizeof(int)); if (rank == 0) printf("Stopping MSCCL++ proxy threads\n"); - channelService.stopProxy(); + proxyService.stopProxy(); } catch (std::exception& e) { // todo: throw exceptions in the implementation and process them here diff --git a/test/allgather_test_host_offloading.cu b/test/allgather_test_host_offloading.cu index ff407ba0b..d3e725f4c 100644 --- a/test/allgather_test_host_offloading.cu +++ b/test/allgather_test_host_offloading.cu @@ -4,9 +4,9 @@ #include #include #include +#include #include #include -#include #ifdef MSCCLPP_USE_MPI_FOR_TESTS #include "mpi.h" @@ -45,7 +45,7 @@ static double getTime(void) { return (tspec.tv_nsec / 1.0e9) + tspec.tv_sec; } -__global__ void kernel(int r, int nranks, mscclpp::DeviceProxyFifo fifo, +__global__ void kernel(int r, int nranks, mscclpp::FifoDeviceHandle fifo, mscclpp::Host2DeviceSemaphore::DeviceHandle* handles, int handleIndex) { int tid = threadIdx.x; __syncthreads(); @@ -188,7 +188,7 @@ class MyProxyService { void stop() { proxy_.stop(); } - mscclpp::HostProxyFifo& fifo() { return proxy_.fifo(); } + mscclpp::Fifo& fifo() { return proxy_.fifo(); } mscclpp::Host2DeviceSemaphore::DeviceHandle getDeviceHandle1(int r) { return deviceSemaphores1_[r]->deviceHandle(); } @@ -261,7 +261,7 @@ int main(int argc, char* argv[]) { if (rank == 0) printf("Launching MSCCL++ proxy threads\n"); proxyService.start(); - mscclpp::DeviceProxyFifo fifo = proxyService.fifo().deviceFifo(); + mscclpp::FifoDeviceHandle fifo = proxyService.fifo().deviceHandle(); if (rank == 0) printf("Testing the correctness of AllGather implementation\n"); cudaStream_t stream; CUCHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking)); diff --git a/test/mp_unit/bootstrap_tests.cc b/test/mp_unit/bootstrap_tests.cc index 095ddd8a6..f33834e2e 100644 --- a/test/mp_unit/bootstrap_tests.cc +++ b/test/mp_unit/bootstrap_tests.cc @@ -3,8 +3,6 @@ #include -#include - #include "mp_unit_tests.hpp" void BootstrapTest::bootstrapTestAllGather(std::shared_ptr bootstrap) { @@ -88,10 +86,6 @@ TEST_F(BootstrapTest, ExitBeforeConnect) { } TEST_F(BootstrapTest, TimeoutWithId) { - // Set bootstrap timeout to 1 second - mscclpp::Config* cfg = mscclpp::Config::getInstance(); - cfg->setBootstrapConnectionTimeoutConfig(1); - mscclpp::Timer timer; // All ranks initialize a bootstrap with their own id (will hang) @@ -99,7 +93,8 @@ TEST_F(BootstrapTest, TimeoutWithId) { mscclpp::UniqueId id = bootstrap->createUniqueId(); try { - bootstrap->initialize(id); + // Set bootstrap timeout to 1 second + bootstrap->initialize(id, 1); } catch (const mscclpp::Error& e) { ASSERT_EQ(e.getErrorCode(), mscclpp::ErrorCode::Timeout); } diff --git a/test/mp_unit/ib_tests.cu b/test/mp_unit/ib_tests.cu index 5a7526654..7ab892b51 100644 --- a/test/mp_unit/ib_tests.cu +++ b/test/mp_unit/ib_tests.cu @@ -36,7 +36,7 @@ void IbPeerToPeerTest::SetUp() { bootstrap->initialize(id); ibCtx = std::make_shared(ibDevName); - qp = ibCtx->createQp(); + qp = ibCtx->createQp(1024, 1, 8192, 0, 64); qpInfo[gEnv->rank] = qp->getInfo(); bootstrap->allGather(qpInfo.data(), sizeof(mscclpp::IbQpInfo)); diff --git a/test/mp_unit/mp_unit_tests.hpp b/test/mp_unit/mp_unit_tests.hpp index c3a09c841..393255638 100644 --- a/test/mp_unit/mp_unit_tests.hpp +++ b/test/mp_unit/mp_unit_tests.hpp @@ -124,6 +124,9 @@ class CommunicatorTest : public CommunicatorTestBase { std::vector> remoteMemory; }; +template +using DeviceHandle = mscclpp::DeviceHandle; + class ProxyChannelOneToOneTest : public CommunicatorTestBase { protected: void SetUp() override; @@ -134,7 +137,7 @@ class ProxyChannelOneToOneTest : public CommunicatorTestBase { void testPacketPingPong(bool useIbOnly); void testPacketPingPongPerf(bool useIbOnly); - std::shared_ptr channelService; + std::shared_ptr proxyService; }; class SmChannelOneToOneTest : public CommunicatorTestBase { diff --git a/test/mp_unit/proxy_channel_tests.cu b/test/mp_unit/proxy_channel_tests.cu index 71c683b14..5537fe017 100644 --- a/test/mp_unit/proxy_channel_tests.cu +++ b/test/mp_unit/proxy_channel_tests.cu @@ -5,21 +5,18 @@ #include "mp_unit_tests.hpp" -template -using DeviceHandle = mscclpp::DeviceHandle; - void ProxyChannelOneToOneTest::SetUp() { // Use only two ranks setNumRanksToUse(2); CommunicatorTestBase::SetUp(); - channelService = std::make_shared(*communicator.get()); + proxyService = std::make_shared(); } void ProxyChannelOneToOneTest::TearDown() { CommunicatorTestBase::TearDown(); } -void ProxyChannelOneToOneTest::setupMeshConnections( - std::vector>& proxyChannels, bool useIbOnly, void* sendBuff, - size_t sendBuffBytes, void* recvBuff, size_t recvBuffBytes) { +void ProxyChannelOneToOneTest::setupMeshConnections(std::vector& proxyChannels, + bool useIbOnly, void* sendBuff, size_t sendBuffBytes, + void* recvBuff, size_t recvBuffBytes) { const int rank = communicator->bootstrap()->getRank(); const int worldSize = communicator->bootstrap()->getNranks(); const bool isInPlace = (recvBuff == nullptr); @@ -52,12 +49,11 @@ void ProxyChannelOneToOneTest::setupMeshConnections( communicator->setup(); - mscclpp::SemaphoreId cid = channelService->addSemaphore(conn); + mscclpp::SemaphoreId cid = proxyService->buildAndAddSemaphore(*communicator, conn); communicator->setup(); - proxyChannels.emplace_back(mscclpp::deviceHandle( - mscclpp::SimpleProxyChannel(channelService->deviceChannel(cid), channelService->addMemory(remoteMemory.get()), - channelService->addMemory(sendBufRegMem)))); + proxyChannels.emplace_back(proxyService->proxyChannel(cid), proxyService->addMemory(remoteMemory.get()), + proxyService->addMemory(sendBufRegMem)); } } @@ -121,15 +117,18 @@ TEST_F(ProxyChannelOneToOneTest, PingPongIb) { const int nElem = 4 * 1024 * 1024; - std::vector> proxyChannels; + std::vector proxyChannels; std::shared_ptr buff = mscclpp::allocSharedCuda(nElem); setupMeshConnections(proxyChannels, true, buff.get(), nElem * sizeof(int)); + std::vector> proxyChannelHandles; + for (auto& ch : proxyChannels) proxyChannelHandles.push_back(ch.deviceHandle()); + ASSERT_EQ(proxyChannels.size(), 1); - MSCCLPP_CUDATHROW(cudaMemcpyToSymbol(gChannelOneToOneTestConstProxyChans, proxyChannels.data(), + MSCCLPP_CUDATHROW(cudaMemcpyToSymbol(gChannelOneToOneTestConstProxyChans, proxyChannelHandles.data(), sizeof(DeviceHandle))); - channelService->startProxy(); + proxyService->startProxy(); std::shared_ptr ret = mscclpp::makeSharedCudaHost(0); @@ -153,7 +152,7 @@ TEST_F(ProxyChannelOneToOneTest, PingPongIb) { EXPECT_EQ(*ret, 0); - channelService->stopProxy(); + proxyService->stopProxy(); } __device__ mscclpp::DeviceSyncer gChannelOneToOneTestProxyChansSyncer; @@ -227,7 +226,7 @@ void ProxyChannelOneToOneTest::testPacketPingPong(bool useIbOnly) { const int nElem = 4 * 1024 * 1024; - std::vector> proxyChannels; + std::vector proxyChannels; std::shared_ptr buff = mscclpp::allocSharedCuda(nElem); const size_t nPacket = (nElem * sizeof(int) + sizeof(uint64_t) - 1) / sizeof(uint64_t); @@ -238,13 +237,19 @@ void ProxyChannelOneToOneTest::testPacketPingPong(bool useIbOnly) { getPacketBuffer.get(), nPacket * sizeof(mscclpp::LLPacket)); ASSERT_EQ(proxyChannels.size(), 1); - MSCCLPP_CUDATHROW(cudaMemcpyToSymbol(gChannelOneToOneTestConstProxyChans, proxyChannels.data(), + + std::vector> proxyChannelHandles; + for (auto& proxyChannel : proxyChannels) { + proxyChannelHandles.push_back(proxyChannel.deviceHandle()); + } + + MSCCLPP_CUDATHROW(cudaMemcpyToSymbol(gChannelOneToOneTestConstProxyChans, proxyChannelHandles.data(), sizeof(DeviceHandle))); mscclpp::DeviceSyncer syncer = {}; MSCCLPP_CUDATHROW(cudaMemcpyToSymbol(gChannelOneToOneTestProxyChansSyncer, &syncer, sizeof(mscclpp::DeviceSyncer))); - channelService->startProxy(); + proxyService->startProxy(); std::shared_ptr ret = mscclpp::makeSharedCudaHost(0); @@ -280,7 +285,7 @@ void ProxyChannelOneToOneTest::testPacketPingPong(bool useIbOnly) { communicator->bootstrap()->barrier(); - channelService->stopProxy(); + proxyService->stopProxy(); } void ProxyChannelOneToOneTest::testPacketPingPongPerf(bool useIbOnly) { @@ -288,7 +293,7 @@ void ProxyChannelOneToOneTest::testPacketPingPongPerf(bool useIbOnly) { const int nElem = 4 * 1024 * 1024; - std::vector> proxyChannels; + std::vector proxyChannels; std::shared_ptr buff = mscclpp::allocSharedCuda(nElem); const size_t nPacket = (nElem * sizeof(int) + sizeof(uint64_t) - 1) / sizeof(uint64_t); @@ -299,13 +304,19 @@ void ProxyChannelOneToOneTest::testPacketPingPongPerf(bool useIbOnly) { getPacketBuffer.get(), nPacket * sizeof(mscclpp::LLPacket)); ASSERT_EQ(proxyChannels.size(), 1); - MSCCLPP_CUDATHROW(cudaMemcpyToSymbol(gChannelOneToOneTestConstProxyChans, proxyChannels.data(), + + std::vector> proxyChannelHandles; + for (auto& proxyChannel : proxyChannels) { + proxyChannelHandles.push_back(proxyChannel.deviceHandle()); + } + + MSCCLPP_CUDATHROW(cudaMemcpyToSymbol(gChannelOneToOneTestConstProxyChans, proxyChannelHandles.data(), sizeof(DeviceHandle))); mscclpp::DeviceSyncer syncer = {}; MSCCLPP_CUDATHROW(cudaMemcpyToSymbol(gChannelOneToOneTestProxyChansSyncer, &syncer, sizeof(mscclpp::DeviceSyncer))); - channelService->startProxy(); + proxyService->startProxy(); auto* testInfo = ::testing::UnitTest::GetInstance()->current_test_info(); const std::string testName = std::string(testInfo->test_suite_name()) + "." + std::string(testInfo->name()); @@ -330,7 +341,7 @@ void ProxyChannelOneToOneTest::testPacketPingPongPerf(bool useIbOnly) { std::cout << testName << ": " << std::setprecision(4) << (float)timer.elapsed() / (float)nTries << " us/iter\n"; } - channelService->stopProxy(); + proxyService->stopProxy(); } TEST_F(ProxyChannelOneToOneTest, PacketPingPong) { testPacketPingPong(false); } diff --git a/test/mp_unit/sm_channel_tests.cu b/test/mp_unit/sm_channel_tests.cu index 671cf8ba0..21d9571aa 100644 --- a/test/mp_unit/sm_channel_tests.cu +++ b/test/mp_unit/sm_channel_tests.cu @@ -5,8 +5,6 @@ #include "mp_unit_tests.hpp" -template -using DeviceHandle = mscclpp::DeviceHandle; void SmChannelOneToOneTest::SetUp() { // Need at least two ranks within a node if (gEnv->nRanksPerNode < 2) { diff --git a/test/mscclpp-test/allgather_test.cu b/test/mscclpp-test/allgather_test.cu index fff17cf9e..387c1f956 100644 --- a/test/mscclpp-test/allgather_test.cu +++ b/test/mscclpp-test/allgather_test.cu @@ -210,6 +210,7 @@ __global__ void allgather3(int rank, int worldSize) { if (tid == 0) { mscclpp::ProxyTrigger trigger; trigger.fst = MAGIC; + trigger.snd = 0; // offload all the work to the proxy uint64_t currentFifoHead = proxyChan.fifo_.push(trigger); // wait for the work to be done in cpu side @@ -278,23 +279,24 @@ __global__ void allgather4(int rank, int worldSize, int nRanksPerNode, size_t ne nBlocksForLocalAllGather); } -class AllGatherChannelService : public mscclpp::BaseProxyService { +class AllGatherProxyService : public mscclpp::BaseProxyService { public: - AllGatherChannelService(mscclpp::Communicator& communicator, int worldSize, int rank, int cudaDevice); + AllGatherProxyService(int worldSize, int rank, int cudaDevice); void startProxy() override { proxy_.start(); } void stopProxy() override { proxy_.stop(); } void setSendBytes(size_t sendBytes) { this->sendBytes_ = sendBytes; } void addRemoteMemory(mscclpp::RegisteredMemory memory) { remoteMemories_.push_back(memory); } void setLocalMemory(mscclpp::RegisteredMemory memory) { localMemory_ = memory; } - mscclpp::SemaphoreId addSemaphore(std::shared_ptr connection) { - semaphores_.push_back(std::make_shared(communicator_, connection)); + mscclpp::SemaphoreId buildAndAddSemaphore(mscclpp::Communicator& communicator, + std::shared_ptr connection) { + semaphores_.push_back(std::make_shared(communicator, connection)); return semaphores_.size() - 1; } - std::vector> deviceChannels() { + std::vector> proxyChannels() { std::vector> result; for (auto& semaphore : semaphores_) { result.push_back( - mscclpp::deviceHandle(mscclpp::ProxyChannel(0, semaphore->deviceHandle(), proxy_.fifo().deviceFifo()))); + mscclpp::deviceHandle(mscclpp::ProxyChannel(0, semaphore->deviceHandle(), proxy_.fifo().deviceHandle()))); } return result; } @@ -306,7 +308,6 @@ class AllGatherChannelService : public mscclpp::BaseProxyService { size_t sendBytes_; mscclpp::Proxy proxy_; - mscclpp::Communicator& communicator_; std::vector> semaphores_; std::vector remoteMemories_; mscclpp::RegisteredMemory localMemory_; @@ -314,10 +315,8 @@ class AllGatherChannelService : public mscclpp::BaseProxyService { mscclpp::ProxyHandlerResult handleTrigger(mscclpp::ProxyTrigger triggerRaw); }; -AllGatherChannelService::AllGatherChannelService(mscclpp::Communicator& communicator, int worldSize, int rank, - int cudaDevice) - : communicator_(communicator), - worldSize_(worldSize), +AllGatherProxyService::AllGatherProxyService(int worldSize, int rank, int cudaDevice) + : worldSize_(worldSize), sendBytes_(0), rank_(rank), cudaDevice_(cudaDevice), @@ -327,7 +326,7 @@ AllGatherChannelService::AllGatherChannelService(mscclpp::Communicator& communic numaBind(deviceNumaNode); }) {} -mscclpp::ProxyHandlerResult AllGatherChannelService::handleTrigger(mscclpp::ProxyTrigger triggerRaw) { +mscclpp::ProxyHandlerResult AllGatherProxyService::handleTrigger(mscclpp::ProxyTrigger triggerRaw) { size_t offset = rank_ * sendBytes_; if (triggerRaw.fst != MAGIC) { // this is not a valid trigger @@ -432,7 +431,7 @@ void AllGatherTestColl::setupCollTest(size_t size) { paramCount_ = base; expectedCount_ = recvCount_; if (isUsingHostOffload(kernelNum_)) { - auto service = std::dynamic_pointer_cast(chanService_); + auto service = std::dynamic_pointer_cast(chanService_); service->setSendBytes(sendCount_ * typeSize_); } mscclpp::DeviceSyncer syncer = {}; @@ -459,7 +458,7 @@ class AllGatherTestEngine : public BaseTestEngine { std::vector getSendBuff() override; void* getRecvBuff() override; void* getScratchBuff() override; - std::shared_ptr createChannelService() override; + std::shared_ptr createProxyService() override; private: void* getExpectedBuff() override; @@ -492,31 +491,31 @@ void AllGatherTestEngine::setupConnections() { CUDATHROW(cudaMemcpyToSymbol(constSmChans, smChannelHandles.data(), sizeof(DeviceHandle) * smChannelHandles.size())); } else { - auto service = std::dynamic_pointer_cast(chanService_); + auto service = std::dynamic_pointer_cast(chanService_); setupMeshConnections(devProxyChannels, sendBuff_.get(), args_.maxBytes, nullptr, 0, [&](std::vector> conns, std::vector>& remoteMemories, const mscclpp::RegisteredMemory& localMemory) { std::vector semaphoreIds; for (size_t i = 0; i < conns.size(); ++i) { - service->addSemaphore(conns[i]); + service->buildAndAddSemaphore(*comm_, conns[i]); service->addRemoteMemory(remoteMemories[i].get()); } service->setLocalMemory(localMemory); comm_->setup(); }); - auto proxyChannels = service->deviceChannels(); + auto proxyChannels = service->proxyChannels(); assert(proxyChannels.size() < sizeof(constRawProxyChan) / sizeof(DeviceHandle)); CUDATHROW(cudaMemcpyToSymbol(constRawProxyChan, proxyChannels.data(), sizeof(DeviceHandle) * proxyChannels.size())); } } -std::shared_ptr AllGatherTestEngine::createChannelService() { +std::shared_ptr AllGatherTestEngine::createProxyService() { if (isUsingHostOffload(args_.kernelNum)) { - return std::make_shared(*comm_, args_.totalRanks, args_.rank, args_.gpuNum); + return std::make_shared(args_.totalRanks, args_.rank, args_.gpuNum); } else { - return std::make_shared(*comm_); + return std::make_shared(); } } diff --git a/test/mscclpp-test/allreduce_test.cu b/test/mscclpp-test/allreduce_test.cu index 13a0f877a..12fee9285 100644 --- a/test/mscclpp-test/allreduce_test.cu +++ b/test/mscclpp-test/allreduce_test.cu @@ -109,7 +109,7 @@ __device__ void localReduceScatter(int* buff, int* scratch, int rank, int nRanks int prePeerRecvId = (preRemoteRecvFromRank < rank) ? preRemoteRecvFromRank : preRemoteRecvFromRank - 1; // overlap communication and computation - mscclpp::SimpleProxyChannel& preDevFstRecvChan = constDevFstRoundChans[prePeerRecvId]; + DeviceHandle& preDevFstRecvChan = constDevFstRoundChans[prePeerRecvId]; if (isComm) { preDevFstRecvChan.wait(); devFstSendChan.putWithSignal(dstOffset, srcOffset, nelems * sizeof(int)); @@ -563,7 +563,8 @@ __global__ void allreduce0(int* buff, int* scratch, int rank, int worldSize, siz } } -__global__ void allreduce1(int* buff, int* scratch, int rank, int worldSize, size_t nelems, size_t scratchDataCount) { +__global__ void __launch_bounds__(1024) + allreduce1(int* buff, int* scratch, int rank, int worldSize, size_t nelems, size_t scratchDataCount) { int isComm = (threadIdx.x == 0) && (blockIdx.x == 0); int remoteSendRank = (rank + 1) % worldSize; int remoteRecvRank = (rank + worldSize - 1) % worldSize; @@ -686,7 +687,7 @@ __global__ void allreduce2(int* buff, void* scratch, void* putPktBuf, void* getP // Channel to a remote peer that has the same local rank as me int localRank = rank % nRanksPerNode; - mscclpp::SimpleProxyChannel proxyChan = constDevFstRoundChans[localRank]; + DeviceHandle proxyChan = constDevFstRoundChans[localRank]; // Flag for packets. Initially 1 uint32_t flag = (uint32_t)globalFlag; @@ -779,8 +780,8 @@ __global__ void allreduce2(int* buff, void* scratch, void* putPktBuf, void* getP } } -__global__ void allreduce3(int* buff, int* scratch, void* result, int rank, int nRanksPerNode, int worldSize, - size_t nelems) { +__global__ void __launch_bounds__(1024) + allreduce3(int* buff, int* scratch, void* result, int rank, int nRanksPerNode, int worldSize, size_t nelems) { reduceScatter(buff, scratch, rank, nRanksPerNode, worldSize, nelems); if (threadIdx.x == 0 && blockIdx.x == 0) { allGather(rank, worldSize, nRanksPerNode, nelems / worldSize); diff --git a/test/mscclpp-test/alltoall_test.cu b/test/mscclpp-test/alltoall_test.cu index 064828338..2ee147bdb 100644 --- a/test/mscclpp-test/alltoall_test.cu +++ b/test/mscclpp-test/alltoall_test.cu @@ -16,7 +16,7 @@ void* localSendBuff; __device__ void localAlltoall(int rank, int nRanksPerNode, size_t nElements) { int remoteRank = (blockIdx.x < rank) ? blockIdx.x : blockIdx.x + 1; for (int i = 1; i < nRanksPerNode; i++) { - mscclpp::SimpleProxyChannel proxyChan = constProxyChans[blockIdx.x]; + DeviceHandle proxyChan = constProxyChans[blockIdx.x]; if (threadIdx.x == 0 && remoteRank % nRanksPerNode == (rank + i) % nRanksPerNode) { proxyChan.putWithSignalAndFlush(rank * nElements * sizeof(int), remoteRank * nElements * sizeof(int), nElements * sizeof(int)); diff --git a/test/mscclpp-test/check_perf_result.py b/test/mscclpp-test/check_perf_result.py index 1430526ec..d5c5469a4 100644 --- a/test/mscclpp-test/check_perf_result.py +++ b/test/mscclpp-test/check_perf_result.py @@ -16,9 +16,17 @@ def load_perf_file(perf_fine: str) -> dict: "time": data["time"], } if "target" in data: - res[(data["name"], data["kernel"], data["ranks"], data["ranksPerNode"], data["size"])]["target"] = data[ + res[ + ( + data["name"], + data["kernel"], + data["ranks"], + data["ranksPerNode"], + data["size"], + ) + ][ "target" - ] + ] = data["target"] return res diff --git a/test/mscclpp-test/common.cc b/test/mscclpp-test/common.cc index 6cd5932c5..e80531048 100644 --- a/test/mscclpp-test/common.cc +++ b/test/mscclpp-test/common.cc @@ -335,7 +335,7 @@ void BaseTestEngine::bootstrap() { } void BaseTestEngine::setupTest() { - this->chanService_ = this->createChannelService(); + this->chanService_ = this->createProxyService(); this->setupConnections(); this->chanService_->startProxy(); this->coll_->setChanService(this->chanService_); @@ -357,8 +357,8 @@ size_t BaseTestEngine::checkData() { return nErrors; } -std::shared_ptr BaseTestEngine::createChannelService() { - return std::make_shared(*comm_); +std::shared_ptr BaseTestEngine::createProxyService() { + return std::make_shared(); } void BaseTestEngine::setupMeshConnectionsInternal( @@ -416,8 +416,8 @@ void BaseTestEngine::setupMeshConnections(std::vector(chanService_); for (size_t i = 0; i < connections.size(); ++i) { proxyChannels.push_back(mscclpp::deviceHandle(mscclpp::SimpleProxyChannel( - service->deviceChannel(service->addSemaphore(connections[i])), service->addMemory(remoteRegMemories[i].get()), - service->addMemory(inputBufRegMem)))); + service->proxyChannel(service->buildAndAddSemaphore(*comm_, connections[i])), + service->addMemory(remoteRegMemories[i].get()), service->addMemory(inputBufRegMem)))); } } @@ -498,7 +498,7 @@ void BaseTestEngine::setupMeshConnections(std::vector& smCha if (connections[cid]->transport() == mscclpp::Transport::CudaIpc) { smSemaphores.emplace(cid, std::make_shared(*comm_, connections[cid])); } else { - connIdToSemId[cid] = service->addSemaphore(connections[cid]); + connIdToSemId[cid] = service->buildAndAddSemaphore(*comm_, connections[cid]); } } comm_->setup(); @@ -513,7 +513,7 @@ void BaseTestEngine::setupMeshConnections(std::vector& smCha throw std::runtime_error("IB transport requires putPacketBuff and getPacketBuff"); } proxyChannels.emplace_back(mscclpp::deviceHandle(mscclpp::SimpleProxyChannel( - service->deviceChannel(connIdToSemId[cid]), service->addMemory(remoteRegMemories[cid].get()), + service->proxyChannel(connIdToSemId[cid]), service->addMemory(remoteRegMemories[cid].get()), service->addMemory(putPacketBufRegMem)))); } } diff --git a/test/mscclpp-test/common.hpp b/test/mscclpp-test/common.hpp index f5e43da74..665ff9119 100644 --- a/test/mscclpp-test/common.hpp +++ b/test/mscclpp-test/common.hpp @@ -97,7 +97,7 @@ class BaseTestEngine { private: virtual void setupConnections() = 0; - virtual std::shared_ptr createChannelService(); + virtual std::shared_ptr createProxyService(); virtual void* getExpectedBuff() = 0; double benchTime(); diff --git a/test/unit/fifo_tests.cu b/test/unit/fifo_tests.cu index e6c6a9a14..567592117 100644 --- a/test/unit/fifo_tests.cu +++ b/test/unit/fifo_tests.cu @@ -5,40 +5,40 @@ #include #include +#include #include -#include "numa.hpp" +#define ITER 10000 // should be larger than the FIFO size for proper testing -#define FLUSH_PERIOD (MSCCLPP_PROXY_FIFO_SIZE) // should not exceed MSCCLPP_PROXY_FIFO_SIZE -#define ITER 10000 // should be larger than MSCCLPP_PROXY_FIFO_SIZE for proper testing - -__constant__ mscclpp::DeviceProxyFifo gFifoTestDeviceProxyFifo; +__constant__ mscclpp::FifoDeviceHandle gFifoTestFifoDeviceHandle; __global__ void kernelFifoTest() { if (threadIdx.x + blockIdx.x * blockDim.x != 0) return; - mscclpp::DeviceProxyFifo& fifo = gFifoTestDeviceProxyFifo; + mscclpp::FifoDeviceHandle& fifo = gFifoTestFifoDeviceHandle; mscclpp::ProxyTrigger trigger; for (uint64_t i = 1; i < ITER + 1; ++i) { trigger.fst = i; trigger.snd = i; uint64_t curFifoHead = fifo.push(trigger); - if (i % FLUSH_PERIOD == 0) { + if (i % fifo.size == 0) { fifo.sync(curFifoHead); } } } -TEST(FifoTest, HostProxyFifo) { - ASSERT_LE(FLUSH_PERIOD, MSCCLPP_PROXY_FIFO_SIZE); - +TEST(FifoTest, Fifo) { int cudaNum; MSCCLPP_CUDATHROW(cudaGetDevice(&cudaNum)); int numaNode = mscclpp::getDeviceNumaNode(cudaNum); mscclpp::numaBind(numaNode); - mscclpp::HostProxyFifo hostFifo; - mscclpp::DeviceProxyFifo devFifo = hostFifo.deviceFifo(); - MSCCLPP_CUDATHROW(cudaMemcpyToSymbol(gFifoTestDeviceProxyFifo, &devFifo, sizeof(devFifo))); + mscclpp::Fifo hostFifo; + if (hostFifo.size() >= ITER) { + FAIL() << "ITER is too small for proper testing."; + } + + mscclpp::FifoDeviceHandle devFifo = hostFifo.deviceHandle(); + MSCCLPP_CUDATHROW(cudaMemcpyToSymbol(gFifoTestFifoDeviceHandle, &devFifo, sizeof(devFifo))); kernelFifoTest<<<1, 1>>>(); MSCCLPP_CUDATHROW(cudaGetLastError()); @@ -51,17 +51,19 @@ TEST(FifoTest, HostProxyFifo) { uint64_t flushCnt = 0; mscclpp::Timer timer(3); for (uint64_t i = 0; i < ITER; ++i) { - while (trigger.fst == 0) { - hostFifo.poll(&trigger); + while (trigger.fst == 0 || trigger.snd == 0) { + trigger = hostFifo.poll(); if (spin++ > 1000000) { FAIL() << "Polling is stuck."; } } + // see `src/proxy.cc` for the reason of this line + trigger.snd ^= ((uint64_t)1 << (uint64_t)63); ASSERT_TRUE(trigger.fst == (i + 1)); ASSERT_TRUE(trigger.snd == (i + 1)); hostFifo.pop(); - if ((++flushCnt % FLUSH_PERIOD) == 0) { + if ((++flushCnt % hostFifo.size()) == 0) { hostFifo.flushTail(); } trigger.fst = 0; @@ -70,7 +72,7 @@ TEST(FifoTest, HostProxyFifo) { hostFifo.flushTail(true); std::stringstream ss; - ss << "FifoTest.HostProxyFifo: " << (float)timer.elapsed() / ITER << " us/iter\n"; + ss << "FifoTest.Fifo: " << (float)timer.elapsed() / ITER << " us/iter\n"; std::cout << ss.str(); MSCCLPP_CUDATHROW(cudaDeviceSynchronize()); diff --git a/test/unit/numa_tests.cc b/test/unit/numa_tests.cc index 16abfcb1b..af09261a2 100644 --- a/test/unit/numa_tests.cc +++ b/test/unit/numa_tests.cc @@ -4,8 +4,7 @@ #include #include - -#include "numa.hpp" +#include TEST(NumaTest, Basic) { int num; diff --git a/tools/npkit/npkit_trace_generator.py b/tools/npkit/npkit_trace_generator.py index ef7197cd5..4f2bc1b5f 100644 --- a/tools/npkit/npkit_trace_generator.py +++ b/tools/npkit/npkit_trace_generator.py @@ -2,10 +2,8 @@ # Licensed under the MIT License. import argparse -import os import json - -from queue import Queue +import os def parse_npkit_event_header(npkit_event_header_path): @@ -118,7 +116,10 @@ def parse_gpu_event_file(npkit_dump_dir, npkit_event_def, rank, buf_idx, gpu_clo ) event_type_to_seq[event_type] += 1 else: - gpu_events[-1]["args"] = {"size": parsed_gpu_event["size"], "rsvd": parsed_gpu_event["rsvd"]} + gpu_events[-1]["args"] = { + "size": parsed_gpu_event["size"], + "rsvd": parsed_gpu_event["rsvd"], + } delta_time = gpu_events[-1]["ts"] - gpu_events[-2]["ts"] gpu_events[-1]["args"]["bw (GB/s)"] = gpu_events[-1]["args"]["size"] / delta_time / 1e3 raw_content_idx += raw_event_size @@ -238,7 +239,12 @@ def convert_npkit_dump_to_trace(npkit_dump_dir, output_dir, npkit_event_def): if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--npkit_dump_dir", type=str, required=True, help="NPKit dump directory.") - parser.add_argument("--npkit_event_header_path", type=str, required=True, help="Path to npkit_event.h.") + parser.add_argument( + "--npkit_event_header_path", + type=str, + required=True, + help="Path to npkit_event.h.", + ) parser.add_argument("--output_dir", type=str, required=True, help="Path to output directory.") args = parser.parse_args()