Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ethernet support #284

Merged
merged 20 commits into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 21 additions & 15 deletions include/mscclpp/core.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,25 +127,26 @@ class TcpBootstrap : public Bootstrap {

/// Enumerates the available transport types.
enum class Transport {
Unknown, // Unknown transport type.
CudaIpc, // CUDA IPC transport type.
Nvls, // NVLS transport type.
IB0, // InfiniBand device 0 transport type.
IB1, // InfiniBand device 1 transport type.
IB2, // InfiniBand device 2 transport type.
IB3, // InfiniBand device 3 transport type.
IB4, // InfiniBand device 4 transport type.
IB5, // InfiniBand device 5 transport type.
IB6, // InfiniBand device 6 transport type.
IB7, // InfiniBand device 7 transport type.
NumTransports // The number of transports.
Unknown, // Unknown transport type.
CudaIpc, // CUDA IPC transport type.
Nvls, // NVLS transport type.
IB0, // InfiniBand device 0 transport type.
IB1, // InfiniBand device 1 transport type.
IB2, // InfiniBand device 2 transport type.
IB3, // InfiniBand device 3 transport type.
IB4, // InfiniBand device 4 transport type.
IB5, // InfiniBand device 5 transport type.
IB6, // InfiniBand device 6 transport type.
IB7, // InfiniBand device 7 transport type.
Ethernet, // Ethernet transport type.
NumTransports, // The number of transports.
};

const std::string TransportNames[] = {"UNK", "IPC", "NVLS", "IB0", "IB1", "IB2",
"IB3", "IB4", "IB5", "IB6", "IB7", "NUM"};
const std::string TransportNames[] = {"UNK", "IPC", "NVLS", "IB0", "IB1", "IB2", "IB3",
"IB4", "IB5", "IB6", "IB7", "ETH", "NUM"};

namespace detail {
const size_t TransportFlagsSize = 11;
const size_t TransportFlagsSize = 12;
static_assert(TransportFlagsSize == static_cast<size_t>(Transport::NumTransports),
"TransportFlagsSize must match the number of transports");
/// Bitset for storing transport flags.
Expand Down Expand Up @@ -333,6 +334,11 @@ class RegisteredMemory {
/// @return A pointer to the memory block.
void* data() const;

/// Get a pointer to the original memory block.
///
/// @return A pointer to the original memory block.
void* originalDataPtr() const;

/// Get the size of the memory block.
///
/// @return The size of the memory block.
Expand Down
32 changes: 32 additions & 0 deletions src/bootstrap/socket.cc
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,38 @@ void Socket::recv(void* ptr, int size) {
socketWait(MSCCLPP_SOCKET_RECV, ptr, size, &offset);
}

void Socket::recvUntilEnd(void* ptr, int size, int* closed) {
int offset = 0;
*closed = 0;
if (state_ != SocketStateReady) {
std::stringstream ss;
ss << "socket state (" << state_ << ") is not ready in recvUntilEnd";
throw Error(ss.str(), ErrorCode::InternalError);
}

int bytes = 0;
char* data = (char*)ptr;

do {
bytes = ::recv(fd_, data + (offset), size - (offset), 0);
if (bytes == 0) {
*closed = 1;
return;
}
if (bytes == -1) {
if (errno != EINTR && errno != EWOULDBLOCK && errno != EAGAIN && state_ != SocketStateClosed) {
throw SysError("recv until end failed", errno);
} else {
bytes = 0;
}
}
(offset) += bytes;
if (abortFlag_ && *abortFlag_ != 0) {
throw Error("aborted", ErrorCode::Aborted);
}
} while (bytes > 0 && (offset) < size);
}

void Socket::close() {
if (fd_ >= 0) ::close(fd_);
state_ = SocketStateClosed;
Expand Down
126 changes: 126 additions & 0 deletions src/connection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#include <mscclpp/utils.hpp>
#include <sstream>
#include <thread>

#include "debug.h"
#include "endpoint.hpp"
Expand Down Expand Up @@ -180,4 +181,129 @@ void IBConnection::flush(int64_t timeoutUsec) {
// npkitCollectExitEvents(conn, NPKIT_EVENT_IB_SEND_EXIT);
}

// EthernetConnection

EthernetConnection::EthernetConnection(Endpoint localEndpoint, Endpoint remoteEndpoint)
: stopRcvMessages_(false), abortFlag_(0) {
// Validating Transport Protocol
if (localEndpoint.transport() != Transport::Ethernet || remoteEndpoint.transport() != Transport::Ethernet) {
throw mscclpp::Error("Ethernet connection can only be made from Ethernet endpoints", ErrorCode::InvalidUsage);
}

// Instanciating Buffers
sendBuffer_ = new char[sendBufferSize_];
rcvBuffer_ = new char[rcvBufferSize_];
caiomcbr marked this conversation as resolved.
Show resolved Hide resolved

// Creating Thread to Accept the Connection
auto parameter = (getImpl(localEndpoint)->socket_).get();
std::thread t([this, parameter]() {
rcvSocket_ = std::make_unique<Socket>(nullptr, MSCCLPP_SOCKET_MAGIC, SocketTypeUnknown, abortFlag_);
rcvSocket_->accept(parameter);
});

// Starting Connection
sendSocket_ =
std::make_unique<Socket>(&(getImpl(remoteEndpoint)->socketAddress_), 0xdeadbeef, SocketTypeBootstrap, abortFlag_);
caiomcbr marked this conversation as resolved.
Show resolved Hide resolved
sendSocket_->connect();

// Ensure the Connection was Established
t.join();

// Starting Thread to Receive Messages
threadRcvMessages_ = std::thread(&EthernetConnection::rcvMessages, this);

INFO(MSCCLPP_NET, "Ethernet connection created");
}

EthernetConnection::~EthernetConnection() {
sendSocket_->close();
stopRcvMessages_ = true;
caiomcbr marked this conversation as resolved.
Show resolved Hide resolved
rcvSocket_->close();
caiomcbr marked this conversation as resolved.
Show resolved Hide resolved
threadRcvMessages_.join();
}

Transport EthernetConnection::transport() { return Transport::Ethernet; }

Transport EthernetConnection::remoteTransport() { return Transport::Ethernet; }

void EthernetConnection::write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset,
uint64_t size) {
// Validating Transport Protocol
validateTransport(dst, remoteTransport());
validateTransport(src, transport());

// Initializing Variables
char* srcPtr = reinterpret_cast<char*>(src.data()) + srcOffset / sizeof(char);
char* dstPtr = reinterpret_cast<char*>(dst.originalDataPtr()) + dstOffset / sizeof(char);
uint64_t sendSize = 0;

// Sending Info Data
sendSocket_->send(&dstPtr, sizeof(char*));
sendSocket_->send(&size, sizeof(uint64_t));
caiomcbr marked this conversation as resolved.
Show resolved Hide resolved

// Getting Data From GPU and Sending Data
while (sendSize < size) {
uint64_t messageSize = std::min(sendBufferSize_, (size - sendSize) / sizeof(char)) * sizeof(char);
mscclpp::memcpyCuda<char>(sendBuffer_, (char*)srcPtr + (sendSize / sizeof(char)), messageSize,
cudaMemcpyDeviceToHost);
sendSocket_->send(sendBuffer_, messageSize);
sendSize += messageSize;
}

INFO(MSCCLPP_NET, "EthernetConnection write: from %p to %p, size %lu", srcPtr, dstPtr, size);
}

void EthernetConnection::updateAndSync(RegisteredMemory dst, uint64_t dstOffset, uint64_t* src, uint64_t newValue) {
// Validating Transport Protocol
validateTransport(dst, remoteTransport());

// Initializing Variables
uint64_t oldValue = *src;
uint64_t* dstPtr = reinterpret_cast<uint64_t*>(reinterpret_cast<char*>(dst.originalDataPtr()) + dstOffset);
uint64_t size = sizeof(uint64_t);
*src = newValue;

// Sending Data
sendSocket_->send(&dstPtr, sizeof(char*));
sendSocket_->send(&size, sizeof(uint64_t));
sendSocket_->send(src, size);

caiomcbr marked this conversation as resolved.
Show resolved Hide resolved
INFO(MSCCLPP_NET, "EthernetConnection atomic write: from %p to %p, %lu -> %lu", src, dstPtr + dstOffset, oldValue,
newValue);
}

void EthernetConnection::flush(int64_t timeoutUsec) { INFO(MSCCLPP_NET, "EthernetConnection flushing connection"); }

void EthernetConnection::rcvMessages() {
// Receiving Messages Until Connection is Closed
while (!stopRcvMessages_) {
// Declarating Variables
char* ptr;
uint64_t size;
uint64_t rcvSize = 0;
int closed = 0;
bool received = true;

// Receiving Data Address
if (closed == 0) rcvSocket_->recvUntilEnd(&ptr, sizeof(char*), &closed);
received &= !closed;

// Receiving data size
if (closed == 0) rcvSocket_->recvUntilEnd(&size, sizeof(uint64_t), &closed);
received &= !closed;

// Receiving Data and Copying Data yo GPU
while (rcvSize < size && closed == 0) {
uint64_t messageSize = std::min(rcvBufferSize_, (size - rcvSize) / sizeof(char)) * sizeof(char);
rcvSocket_->recvUntilEnd(rcvBuffer_, messageSize, &closed);
received &= !closed;

if (received)
mscclpp::memcpyCuda<char>((char*)ptr + (rcvSize / sizeof(char)), rcvBuffer_, messageSize,
cudaMemcpyHostToDevice);
rcvSize += messageSize;
}
}
}

} // namespace mscclpp
6 changes: 6 additions & 0 deletions src/context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,15 @@ MSCCLPP_API_CPP std::shared_ptr<Connection> Context::connect(Endpoint localEndpo
throw mscclpp::Error("Local transport is IB but remote is not", ErrorCode::InvalidUsage);
}
conn = std::make_shared<IBConnection>(localEndpoint, remoteEndpoint, *this);
} else if (localEndpoint.transport() == Transport::Ethernet) {
if (remoteEndpoint.transport() != Transport::Ethernet) {
throw mscclpp::Error("Local transport is Ethernet but remote is not", ErrorCode::InvalidUsage);
}
conn = std::make_shared<EthernetConnection>(localEndpoint, remoteEndpoint);
} else {
throw mscclpp::Error("Unsupported transport", ErrorCode::InternalError);
}

pimpl_->connections_.push_back(conn);
return conn;
}
Expand Down
2 changes: 1 addition & 1 deletion src/core.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ const TransportFlags NoTransports = TransportFlags();
const TransportFlags AllIBTransports = Transport::IB0 | Transport::IB1 | Transport::IB2 | Transport::IB3 |
Transport::IB4 | Transport::IB5 | Transport::IB6 | Transport::IB7;

const TransportFlags AllTransports = AllIBTransports | Transport::CudaIpc;
const TransportFlags AllTransports = AllIBTransports | Transport::CudaIpc | Transport::Ethernet;

void Setuppable::beginSetup(std::shared_ptr<Bootstrap>) {}

Expand Down
19 changes: 19 additions & 0 deletions src/endpoint.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include "api.h"
#include "context.hpp"
#include "socket.h"
#include "utils_internal.hpp"

namespace mscclpp {
Expand All @@ -15,6 +16,16 @@ Endpoint::Impl::Impl(EndpointConfig config, Context::Impl& contextImpl)
ibQp_ = contextImpl.getIbContext(transport_)
->createQp(config.ibMaxCqSize, config.ibMaxCqPollNum, config.ibMaxSendWr, 0, config.ibMaxWrPerSend);
ibQpInfo_ = ibQp_->getInfo();
} else if (transport_ == Transport::Ethernet) {
// Configuring Ethernet Interfaces
abortFlag_ = 0;
int ret = FindInterfaces(netIfName_, &socketAddress_, MAX_IF_NAME_SIZE, 1, "");
if (ret <= 0) throw Error("NET/Socket", ErrorCode::InternalError);

// Starting Server Socket
socket_ = std::make_unique<Socket>(&socketAddress_, 0xdeadbeef, SocketTypeBootstrap, abortFlag_);
socket_->bindAndListen();
socketAddress_ = socket_->getAddr();
}
}

Expand All @@ -27,6 +38,10 @@ MSCCLPP_API_CPP std::vector<char> Endpoint::serialize() {
if (AllIBTransports.has(pimpl_->transport_)) {
std::copy_n(reinterpret_cast<char*>(&pimpl_->ibQpInfo_), sizeof(pimpl_->ibQpInfo_), std::back_inserter(data));
}
if ((pimpl_->transport_) == Transport::Ethernet) {
std::copy_n(reinterpret_cast<char*>(&pimpl_->socketAddress_), sizeof(pimpl_->socketAddress_),
std::back_inserter(data));
}
return data;
}

Expand All @@ -45,6 +60,10 @@ Endpoint::Impl::Impl(const std::vector<char>& serialization) {
std::copy_n(it, sizeof(ibQpInfo_), reinterpret_cast<char*>(&ibQpInfo_));
it += sizeof(ibQpInfo_);
}
if (transport_ == Transport::Ethernet) {
std::copy_n(it, sizeof(socketAddress_), reinterpret_cast<char*>(&socketAddress_));
it += sizeof(socketAddress_);
}
}

MSCCLPP_API_CPP Endpoint::Endpoint(std::shared_ptr<mscclpp::Endpoint::Impl> pimpl) : pimpl_(pimpl) {}
Expand Down
33 changes: 33 additions & 0 deletions src/include/connection.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "context.hpp"
#include "ib.hpp"
#include "registered_memory.hpp"
#include "socket.h"

namespace mscclpp {

Expand Down Expand Up @@ -53,6 +54,38 @@ class IBConnection : public Connection {
void flush(int64_t timeoutUsec) override;
};

class EthernetConnection : public Connection {
std::unique_ptr<Socket> sendSocket_;
std::unique_ptr<Socket> rcvSocket_;
std::thread threadRcvMessages_;
bool stopRcvMessages_;
volatile uint32_t* abortFlag_;
const uint64_t sendBufferSize_ = 256000000;
const uint64_t rcvBufferSize_ = 256000000;
caiomcbr marked this conversation as resolved.
Show resolved Hide resolved
char* sendBuffer_;
char* rcvBuffer_;

public:
EthernetConnection(Endpoint localEndpoint, Endpoint remoteEndpoint);

~EthernetConnection();

Transport transport() override;

Transport remoteTransport() override;

void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset,
uint64_t size) override;
void updateAndSync(RegisteredMemory dst, uint64_t dstOffset, uint64_t* src, uint64_t newValue) override;

void flush(int64_t timeoutUsec) override;

private:
void rcvMessages();
caiomcbr marked this conversation as resolved.
Show resolved Hide resolved

void sendMessage();
};

} // namespace mscclpp

#endif // MSCCLPP_CONNECTION_HPP_
9 changes: 9 additions & 0 deletions src/include/endpoint.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
#include <vector>

#include "ib.hpp"
#include "socket.h"

#define MAX_IF_NAME_SIZE 16

namespace mscclpp {

Expand All @@ -22,6 +25,12 @@ struct Endpoint::Impl {
bool ibLocal_;
IbQp* ibQp_;
IbQpInfo ibQpInfo_;

// The following are only used for Ethernet and are undefined for other transports.
std::unique_ptr<Socket> socket_;
SocketAddress socketAddress_;
volatile uint32_t* abortFlag_;
char netIfName_[MAX_IF_NAME_SIZE + 1];
};

} // namespace mscclpp
Expand Down
1 change: 1 addition & 0 deletions src/include/socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class Socket {
void accept(const Socket* listenSocket, int64_t timeout = -1);
void send(void* ptr, int size);
void recv(void* ptr, int size);
void recvUntilEnd(void* ptr, int size, int* closed);
void close();

int getFd() const { return fd_; }
Expand Down
4 changes: 3 additions & 1 deletion src/registered_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ MSCCLPP_API_CPP RegisteredMemory::~RegisteredMemory() = default;

MSCCLPP_API_CPP void* RegisteredMemory::data() const { return pimpl_->data; }

MSCCLPP_API_CPP void* RegisteredMemory::originalDataPtr() const { return pimpl_->originalDataPtr; }

MSCCLPP_API_CPP size_t RegisteredMemory::size() { return pimpl_->size; }

MSCCLPP_API_CPP TransportFlags RegisteredMemory::transports() { return pimpl_->transports; }
Expand Down Expand Up @@ -139,7 +141,7 @@ RegisteredMemory::Impl::Impl(const std::vector<char>& serialization) {
}

// Next decide how to set this->data
if (getHostHash() == this->hostHash && getPidHash() == this->pidHash) {
if ((getHostHash() == this->hostHash && getPidHash() == this->pidHash)) {
caiomcbr marked this conversation as resolved.
Show resolved Hide resolved
// The memory is local to the process, so originalDataPtr is valid as is
this->data = this->originalDataPtr;
} else if (transports.has(Transport::CudaIpc) && getHostHash() == this->hostHash) {
Expand Down
Loading
Loading