Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ethernet support #284

Merged
merged 20 commits into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 21 additions & 15 deletions include/mscclpp/core.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,25 +130,26 @@ class TcpBootstrap : public Bootstrap {

/// Enumerates the available transport types.
enum class Transport {
Unknown, // Unknown transport type.
CudaIpc, // CUDA IPC transport type.
Nvls, // NVLS transport type.
IB0, // InfiniBand device 0 transport type.
IB1, // InfiniBand device 1 transport type.
IB2, // InfiniBand device 2 transport type.
IB3, // InfiniBand device 3 transport type.
IB4, // InfiniBand device 4 transport type.
IB5, // InfiniBand device 5 transport type.
IB6, // InfiniBand device 6 transport type.
IB7, // InfiniBand device 7 transport type.
NumTransports // The number of transports.
Unknown, // Unknown transport type.
CudaIpc, // CUDA IPC transport type.
Nvls, // NVLS transport type.
IB0, // InfiniBand device 0 transport type.
IB1, // InfiniBand device 1 transport type.
IB2, // InfiniBand device 2 transport type.
IB3, // InfiniBand device 3 transport type.
IB4, // InfiniBand device 4 transport type.
IB5, // InfiniBand device 5 transport type.
IB6, // InfiniBand device 6 transport type.
IB7, // InfiniBand device 7 transport type.
Ethernet, // Ethernet transport type.
NumTransports, // The number of transports.
};

const std::string TransportNames[] = {"UNK", "IPC", "NVLS", "IB0", "IB1", "IB2",
"IB3", "IB4", "IB5", "IB6", "IB7", "NUM"};
const std::string TransportNames[] = {"UNK", "IPC", "NVLS", "IB0", "IB1", "IB2", "IB3",
"IB4", "IB5", "IB6", "IB7", "ETH", "NUM"};

namespace detail {
const size_t TransportFlagsSize = 11;
const size_t TransportFlagsSize = 12;
static_assert(TransportFlagsSize == static_cast<size_t>(Transport::NumTransports),
"TransportFlagsSize must match the number of transports");
/// Bitset for storing transport flags.
Expand Down Expand Up @@ -336,6 +337,11 @@ class RegisteredMemory {
/// @return A pointer to the memory block.
void* data() const;

/// Get a pointer to the original memory block.
///
/// @return A pointer to the original memory block.
void* originalDataPtr() const;

/// Get the size of the memory block.
///
/// @return The size of the memory block.
Expand Down
32 changes: 32 additions & 0 deletions src/bootstrap/socket.cc
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,38 @@ void Socket::recv(void* ptr, int size) {
socketWait(MSCCLPP_SOCKET_RECV, ptr, size, &offset);
}

void Socket::recvUntilEnd(void* ptr, int size, int* closed) {
int offset = 0;
*closed = 0;
if (state_ != SocketStateReady) {
std::stringstream ss;
ss << "socket state (" << state_ << ") is not ready in recvUntilEnd";
throw Error(ss.str(), ErrorCode::InternalError);
}

int bytes = 0;
char* data = (char*)ptr;

do {
bytes = ::recv(fd_, data + (offset), size - (offset), 0);
if (bytes == 0) {
*closed = 1;
return;
}
if (bytes == -1) {
if (errno != EINTR && errno != EWOULDBLOCK && errno != EAGAIN && state_ != SocketStateClosed) {
throw SysError("recv until end failed", errno);
} else {
bytes = 0;
}
}
(offset) += bytes;
if (abortFlag_ && *abortFlag_ != 0) {
throw Error("aborted", ErrorCode::Aborted);
}
} while (bytes > 0 && (offset) < size);
}

void Socket::close() {
if (fd_ >= 0) ::close(fd_);
state_ = SocketStateClosed;
Expand Down
143 changes: 143 additions & 0 deletions src/connection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#include <mscclpp/utils.hpp>
#include <sstream>
#include <thread>

#include "debug.h"
#include "endpoint.hpp"
Expand Down Expand Up @@ -180,4 +181,146 @@ void IBConnection::flush(int64_t timeoutUsec) {
// npkitCollectExitEvents(conn, NPKIT_EVENT_IB_SEND_EXIT);
}

// EthernetConnection

EthernetConnection::EthernetConnection(Endpoint localEndpoint, Endpoint remoteEndpoint) : abortFlag_(0) {
// Validating Transport Protocol
if (localEndpoint.transport() != Transport::Ethernet || remoteEndpoint.transport() != Transport::Ethernet) {
throw mscclpp::Error("Ethernet connection can only be made from Ethernet endpoints", ErrorCode::InvalidUsage);
}

// Instanciating Buffers
sendBuffer_.resize(sendBufferSize_);
recvBuffer_.resize(rcvBufferSize_);

// Creating Thread to Accept the Connection
auto parameter = (getImpl(localEndpoint)->socket_).get();
std::thread t([this, parameter]() {
recvSocket_ = std::make_unique<Socket>(nullptr, MSCCLPP_SOCKET_MAGIC, SocketTypeUnknown, abortFlag_);
recvSocket_->accept(parameter);
});

// Starting Connection
sendSocket_ = std::make_unique<Socket>(&(getImpl(remoteEndpoint)->socketAddress_), MSCCLPP_SOCKET_MAGIC,
SocketTypeBootstrap, abortFlag_);
sendSocket_->connect();

// Ensure the Connection was Established
t.join();

// Starting Thread to Receive Messages
threadRecvMessages_ = std::thread(&EthernetConnection::recvMessages, this);

INFO(MSCCLPP_NET, "Ethernet connection created");
}

EthernetConnection::~EthernetConnection() {
sendSocket_->close();
recvSocket_->close();
threadRecvMessages_.join();
}

Transport EthernetConnection::transport() { return Transport::Ethernet; }

Transport EthernetConnection::remoteTransport() { return Transport::Ethernet; }

void EthernetConnection::write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset,
uint64_t size) {
// Validating Transport Protocol
validateTransport(dst, remoteTransport());
validateTransport(src, transport());

// Initializing Variables
char* srcPtr = reinterpret_cast<char*>(src.data()) + srcOffset / sizeof(char);
char* dstPtr = reinterpret_cast<char*>(dst.originalDataPtr()) + dstOffset / sizeof(char);
uint64_t sentDataSize = 0;
uint64_t headerSize = 0;

// Copying Meta Data to Send Buffer
char* dstPtrBytes = reinterpret_cast<char*>(&dstPtr);
std::copy(dstPtrBytes, dstPtrBytes + sizeof(dstPtr), sendBuffer_.data() + headerSize / sizeof(char));
headerSize += sizeof(dstPtr);
char* sizeBytes = reinterpret_cast<char*>(&size);
std::copy(sizeBytes, sizeBytes + sizeof(size), sendBuffer_.data() + headerSize / sizeof(char));
headerSize += sizeof(size);

// Getting Data From GPU and Sending Message
while (sentDataSize < size) {
uint64_t dataSize =
std::min(sendBufferSize_ - headerSize / sizeof(char), (size - sentDataSize) / sizeof(char)) * sizeof(char);
uint64_t messageSize = dataSize + headerSize;
mscclpp::memcpyCuda<char>(sendBuffer_.data() + headerSize / sizeof(char),
(char*)srcPtr + (sentDataSize / sizeof(char)), dataSize, cudaMemcpyDeviceToHost);
sendSocket_->send(sendBuffer_.data(), messageSize);
sentDataSize += messageSize;
headerSize = 0;
}

INFO(MSCCLPP_NET, "EthernetConnection write: from %p to %p, size %lu", srcPtr, dstPtr, size);
}

void EthernetConnection::updateAndSync(RegisteredMemory dst, uint64_t dstOffset, uint64_t* src, uint64_t newValue) {
// Validating Transport Protocol
validateTransport(dst, remoteTransport());

// Initializing Variables
uint64_t oldValue = *src;
uint64_t* dstPtr = reinterpret_cast<uint64_t*>(reinterpret_cast<char*>(dst.originalDataPtr()) + dstOffset);
uint64_t dataSize = sizeof(uint64_t);
uint64_t messageSize = 0;
*src = newValue;

// Copying Data to Send Buffer
char* dstPtrBytes = reinterpret_cast<char*>(&dstPtr);
std::copy(dstPtrBytes, dstPtrBytes + sizeof(dstPtr), sendBuffer_.data() + messageSize / sizeof(char));
messageSize += sizeof(dstPtr);
char* sizeBytes = reinterpret_cast<char*>(&dataSize);
std::copy(sizeBytes, sizeBytes + sizeof(dataSize), sendBuffer_.data() + messageSize / sizeof(char));
messageSize += sizeof(dataSize);
char* dataBytes = reinterpret_cast<char*>(src);
std::copy(dataBytes, dataBytes + dataSize, sendBuffer_.data() + messageSize / sizeof(char));
messageSize += dataSize;

// Sending Message
sendSocket_->send(sendBuffer_.data(), messageSize);

INFO(MSCCLPP_NET, "EthernetConnection atomic write: from %p to %p, %lu -> %lu", src, dstPtr + dstOffset, oldValue,
newValue);
}

void EthernetConnection::flush(int64_t timeoutUsec) { INFO(MSCCLPP_NET, "EthernetConnection flushing connection"); }

void EthernetConnection::recvMessages() {
// Declarating Variables
char* ptr;
uint64_t size;
uint64_t recvSize;
int closed = 0;
bool received = true;

// Receiving Messages Until Connection is Closed
while (recvSocket_->getState() != SocketStateClosed) {
// Receiving Data Address
if (closed == 0) recvSocket_->recvUntilEnd(&ptr, sizeof(char*), &closed);
received &= !closed;

// Receiving data size
if (closed == 0) recvSocket_->recvUntilEnd(&size, sizeof(uint64_t), &closed);
received &= !closed;

// Receiving Data and Copying Data yo GPU
recvSize = 0;
while (recvSize < size && closed == 0) {
uint64_t messageSize = std::min(rcvBufferSize_, (size - recvSize) / sizeof(char)) * sizeof(char);
recvSocket_->recvUntilEnd(recvBuffer_.data(), messageSize, &closed);
received &= !closed;

if (received)
mscclpp::memcpyCuda<char>((char*)ptr + (recvSize / sizeof(char)), recvBuffer_.data(), messageSize,
cudaMemcpyHostToDevice);
recvSize += messageSize;
}
}
}

} // namespace mscclpp
6 changes: 6 additions & 0 deletions src/context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,15 @@ MSCCLPP_API_CPP std::shared_ptr<Connection> Context::connect(Endpoint localEndpo
throw mscclpp::Error("Local transport is IB but remote is not", ErrorCode::InvalidUsage);
}
conn = std::make_shared<IBConnection>(localEndpoint, remoteEndpoint, *this);
} else if (localEndpoint.transport() == Transport::Ethernet) {
if (remoteEndpoint.transport() != Transport::Ethernet) {
throw mscclpp::Error("Local transport is Ethernet but remote is not", ErrorCode::InvalidUsage);
}
conn = std::make_shared<EthernetConnection>(localEndpoint, remoteEndpoint);
} else {
throw mscclpp::Error("Unsupported transport", ErrorCode::InternalError);
}

pimpl_->connections_.push_back(conn);
return conn;
}
Expand Down
2 changes: 1 addition & 1 deletion src/core.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ const TransportFlags NoTransports = TransportFlags();
const TransportFlags AllIBTransports = Transport::IB0 | Transport::IB1 | Transport::IB2 | Transport::IB3 |
Transport::IB4 | Transport::IB5 | Transport::IB6 | Transport::IB7;

const TransportFlags AllTransports = AllIBTransports | Transport::CudaIpc;
const TransportFlags AllTransports = AllIBTransports | Transport::CudaIpc | Transport::Ethernet;

void Setuppable::beginSetup(std::shared_ptr<Bootstrap>) {}

Expand Down
19 changes: 19 additions & 0 deletions src/endpoint.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include "api.h"
#include "context.hpp"
#include "socket.h"
#include "utils_internal.hpp"

namespace mscclpp {
Expand All @@ -15,6 +16,16 @@ Endpoint::Impl::Impl(EndpointConfig config, Context::Impl& contextImpl)
ibQp_ = contextImpl.getIbContext(transport_)
->createQp(config.ibMaxCqSize, config.ibMaxCqPollNum, config.ibMaxSendWr, 0, config.ibMaxWrPerSend);
ibQpInfo_ = ibQp_->getInfo();
} else if (transport_ == Transport::Ethernet) {
// Configuring Ethernet Interfaces
abortFlag_ = 0;
int ret = FindInterfaces(netIfName_, &socketAddress_, MAX_IF_NAME_SIZE, 1, "");
if (ret <= 0) throw Error("NET/Socket", ErrorCode::InternalError);

// Starting Server Socket
socket_ = std::make_unique<Socket>(&socketAddress_, MSCCLPP_SOCKET_MAGIC, SocketTypeBootstrap, abortFlag_);
socket_->bindAndListen();
socketAddress_ = socket_->getAddr();
}
}

Expand All @@ -27,6 +38,10 @@ MSCCLPP_API_CPP std::vector<char> Endpoint::serialize() {
if (AllIBTransports.has(pimpl_->transport_)) {
std::copy_n(reinterpret_cast<char*>(&pimpl_->ibQpInfo_), sizeof(pimpl_->ibQpInfo_), std::back_inserter(data));
}
if ((pimpl_->transport_) == Transport::Ethernet) {
std::copy_n(reinterpret_cast<char*>(&pimpl_->socketAddress_), sizeof(pimpl_->socketAddress_),
std::back_inserter(data));
}
return data;
}

Expand All @@ -45,6 +60,10 @@ Endpoint::Impl::Impl(const std::vector<char>& serialization) {
std::copy_n(it, sizeof(ibQpInfo_), reinterpret_cast<char*>(&ibQpInfo_));
it += sizeof(ibQpInfo_);
}
if (transport_ == Transport::Ethernet) {
std::copy_n(it, sizeof(socketAddress_), reinterpret_cast<char*>(&socketAddress_));
it += sizeof(socketAddress_);
}
}

MSCCLPP_API_CPP Endpoint::Endpoint(std::shared_ptr<mscclpp::Endpoint::Impl> pimpl) : pimpl_(pimpl) {}
Expand Down
32 changes: 32 additions & 0 deletions src/include/connection.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "context.hpp"
#include "ib.hpp"
#include "registered_memory.hpp"
#include "socket.h"

namespace mscclpp {

Expand Down Expand Up @@ -53,6 +54,37 @@ class IBConnection : public Connection {
void flush(int64_t timeoutUsec) override;
};

class EthernetConnection : public Connection {
std::unique_ptr<Socket> sendSocket_;
std::unique_ptr<Socket> recvSocket_;
std::thread threadRecvMessages_;
volatile uint32_t* abortFlag_;
const uint64_t sendBufferSize_ = 256000000;
const uint64_t rcvBufferSize_ = 256000000;
caiomcbr marked this conversation as resolved.
Show resolved Hide resolved
std::vector<char> sendBuffer_;
std::vector<char> recvBuffer_;

public:
EthernetConnection(Endpoint localEndpoint, Endpoint remoteEndpoint);

~EthernetConnection();

Transport transport() override;

Transport remoteTransport() override;

void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset,
uint64_t size) override;
void updateAndSync(RegisteredMemory dst, uint64_t dstOffset, uint64_t* src, uint64_t newValue) override;

void flush(int64_t timeoutUsec) override;

private:
void recvMessages();

void sendMessage();
};

} // namespace mscclpp

#endif // MSCCLPP_CONNECTION_HPP_
9 changes: 9 additions & 0 deletions src/include/endpoint.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
#include <vector>

#include "ib.hpp"
#include "socket.h"

#define MAX_IF_NAME_SIZE 16

namespace mscclpp {

Expand All @@ -22,6 +25,12 @@ struct Endpoint::Impl {
bool ibLocal_;
IbQp* ibQp_;
IbQpInfo ibQpInfo_;

// The following are only used for Ethernet and are undefined for other transports.
std::unique_ptr<Socket> socket_;
SocketAddress socketAddress_;
volatile uint32_t* abortFlag_;
char netIfName_[MAX_IF_NAME_SIZE + 1];
};

} // namespace mscclpp
Expand Down
1 change: 1 addition & 0 deletions src/include/socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class Socket {
void accept(const Socket* listenSocket, int64_t timeout = -1);
void send(void* ptr, int size);
void recv(void* ptr, int size);
void recvUntilEnd(void* ptr, int size, int* closed);
void close();

int getFd() const { return fd_; }
Expand Down
Loading
Loading