Skip to content

Commit

Permalink
Implementation of API changes compiles
Browse files Browse the repository at this point in the history
Tests are broken still
  • Loading branch information
olsaarik committed Aug 22, 2023
1 parent 1d4ec5c commit 491dc72
Show file tree
Hide file tree
Showing 11 changed files with 127 additions and 143 deletions.
80 changes: 43 additions & 37 deletions include/mscclpp/core.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -301,15 +301,11 @@ std::string getIBDeviceName(Transport ibTransport);
/// @return The InfiniBand transport associated with the specified device name.
Transport getIBTransportByDeviceName(const std::string& ibDeviceName);

class Communicator;
class Connection;

/// Represents a block of memory that has been registered to a @ref Communicator.
class RegisteredMemory {
protected:
public:
struct Impl;

public:
/// Default constructor.
RegisteredMemory() = default;

Expand Down Expand Up @@ -347,14 +343,7 @@ class RegisteredMemory {
/// @return A deserialized RegisteredMemory object.
static RegisteredMemory deserialize(const std::vector<char>& data);

friend class Connection;
friend class Context;
friend class IBConnection;
friend class Communicator;

private:
// A shared_ptr is used since RegisteredMemory is functionally immutable, although internally some state is populated
// lazily.
/// Pointer to the internal implementation of the RegisteredMemory class. A shared_ptr is used since RegisteredMemory is immutable.
std::shared_ptr<Impl> pimpl;
};

Expand Down Expand Up @@ -391,21 +380,17 @@ class Connection {
///
/// @return The transport used by the remote process.
virtual Transport remoteTransport() = 0;

protected:
/// Get the implementation object associated with a @ref RegisteredMemory object.
///
/// @param memory The @ref RegisteredMemory object.
/// @return A shared pointer to the implementation object.
static std::shared_ptr<RegisteredMemory::Impl> getRegisteredMemoryImpl(RegisteredMemory& memory);
};

/// Represents one end of a connection.
class Endpoint {
protected:
struct Impl;

public:
/// Default constructor.
Endpoint() = default;

/// Destructor.
~Endpoint();

/// Get the transport used.
///
/// @return The transport used.
Expand All @@ -421,16 +406,20 @@ class Endpoint {
/// @param data A vector of characters representing a serialized Endpoint object.
/// @return A deserialized Endpoint object.
static Endpoint deserialize(const std::vector<char>& data);

/// The interal implementation of the Endpoint class.
struct Impl;

private:
// A shared_ptr is used since Endpoint is immutable.
/// Constructor that takes a shared pointer to an implementation object.
///
/// @param pimpl A shared pointer to an implementation object.
Endpoint(std::shared_ptr<Impl> pimpl);

/// Pointer to the internal implementation of the Endpoint class. A shared_ptr is used since Endpoint is immutable.
std::shared_ptr<Impl> pimpl;
};

class Context {
protected:
struct Impl;

public:
/// Create a context.
///
Expand Down Expand Up @@ -461,10 +450,10 @@ class Context {
/// @return std::shared_ptr<Connection> A shared pointer to the connection.
std::shared_ptr<Connection> connect(Endpoint localEndpoint, Endpoint remoteEndpoint);

friend class RegisteredMemory::Impl;
/// The interal implementation of the Context class.
struct Impl;

private:
/// Unique pointer to the implementation of the Communicator class.
/// Pointer to the internal implementation of the Context class.
std::unique_ptr<Impl> pimpl;
};

Expand Down Expand Up @@ -531,9 +520,6 @@ class NonblockingFuture {
/// 6. All done; use connections and registered memories to build channels.
///
class Communicator {
protected:
struct Impl;

public:
/// Initializes the communicator with a given bootstrap implementation.
///
Expand All @@ -553,6 +539,14 @@ class Communicator {
/// @return std::shared_ptr<Context> The context held by this communicator.
std::shared_ptr<Context> context();

/// Register a region of GPU memory for use in this communicator's context.
///
/// @param ptr Base pointer to the memory.
/// @param size Size of the memory region in bytes.
/// @param transports Transport flags.
/// @return RegisteredMemory A handle to the buffer.
RegisteredMemory registerMemory(void* ptr, size_t size, TransportFlags transports);

/// Send information of a registered memory to the remote side on setup.
///
/// This function registers a send to a remote process that will happen by a following call of @ref setup(). The send
Expand Down Expand Up @@ -587,6 +581,18 @@ class Communicator {
/// @return NonblockingFuture<std::shared_ptr<Connection>> A non-blocking future of shared pointer to the connection.
NonblockingFuture<std::shared_ptr<Connection>> connectOnSetup(int remoteRank, int tag, Transport transport);

/// Get the remote rank a connection is connected to.
///
/// @param connection The connection to get the remote rank for.
/// @return The remote rank the connection is connected to.
int remoteRankOf(const Connection& connection);

/// Get the tag a connection was made with.
///
/// @param connection The connection to get the tag for.
/// @return The tag the connection was made with.
int tagOf(const Connection& connection);

/// Add a custom Setuppable object to a list of objects to be setup later, when @ref setup() is called.
///
/// @param setuppable A shared pointer to the Setuppable object.
Expand All @@ -599,10 +605,10 @@ class Communicator {
/// that have been registered after the (n-1)-th call.
void setup();

friend class IBConnection;
/// The interal implementation of the Communicator class.
struct Impl;

private:
/// Unique pointer to the implementation of the Communicator class.
/// Pointer to the internal implementation of the Communicator class.
std::unique_ptr<Impl> pimpl;
};

Expand Down
5 changes: 2 additions & 3 deletions python/core_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@ void register_core(nb::module_& m) {
.def(nb::init<>())
.def("data", &RegisteredMemory::data)
.def("size", &RegisteredMemory::size)
.def("rank", &RegisteredMemory::rank)
.def("transports", &RegisteredMemory::transports)
.def("serialize", &RegisteredMemory::serialize)
.def_static("deserialize", &RegisteredMemory::deserialize, nb::arg("data"));
Expand All @@ -119,8 +118,6 @@ void register_core(nb::module_& m) {
},
nb::arg("dst"), nb::arg("dstOffset"), nb::arg("src"), nb::arg("newValue"))
.def("flush", &Connection::flush)
.def("remote_rank", &Connection::remoteRank)
.def("tag", &Connection::tag)
.def("transport", &Connection::transport)
.def("remote_transport", &Connection::remoteTransport);

Expand All @@ -140,6 +137,8 @@ void register_core(nb::module_& m) {
.def("recv_memory_on_setup", &Communicator::recvMemoryOnSetup, nb::arg("remoteRank"), nb::arg("tag"))
.def("connect_on_setup", &Communicator::connectOnSetup, nb::arg("remoteRank"), nb::arg("tag"),
nb::arg("transport"))
.def("remote_rank_of", &Communicator::remoteRankOf)
.def("tag_of", &Communicator::tagOf)
.def("setup", &Communicator::setup);
}

Expand Down
102 changes: 45 additions & 57 deletions src/communicator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,46 +15,27 @@

namespace mscclpp {

Communicator::Impl::Impl(std::shared_ptr<Bootstrap> bootstrap) : bootstrap_(bootstrap) {
rankToHash_.resize(bootstrap->getNranks());
auto hostHash = getHostHash();
INFO(MSCCLPP_INIT, "Host hash: %lx", hostHash);
rankToHash_[bootstrap->getRank()] = hostHash;
bootstrap->allGather(rankToHash_.data(), sizeof(uint64_t));

MSCCLPP_CUDATHROW(cudaStreamCreateWithFlags(&ipcStream_, cudaStreamNonBlocking));
}

Communicator::Impl::~Impl() {
ibContexts_.clear();

cudaStreamDestroy(ipcStream_);
}

IbCtx* Communicator::Impl::getIbContext(Transport ibTransport) {
// Find IB context or create it
auto it = ibContexts_.find(ibTransport);
if (it == ibContexts_.end()) {
auto ibDev = getIBDeviceName(ibTransport);
ibContexts_[ibTransport] = std::make_unique<IbCtx>(ibDev);
return ibContexts_[ibTransport].get();
Communicator::Impl::Impl(std::shared_ptr<Bootstrap> bootstrap, std::shared_ptr<Context> context) : bootstrap_(bootstrap) {
if (!context) {
context_ = std::make_shared<Context>();
} else {
return it->second.get();
context_ = context;
}
}

cudaStream_t Communicator::Impl::getIpcStream() { return ipcStream_; }
rankToHash_.resize(bootstrap->getNranks());
rankToHash_[bootstrap->getRank()] = context_->pimpl->hostHash_;
bootstrap->allGather(rankToHash_.data(), sizeof(uint64_t));
}

MSCCLPP_API_CPP Communicator::~Communicator() = default;

MSCCLPP_API_CPP Communicator::Communicator(std::shared_ptr<Bootstrap> bootstrap)
: pimpl(std::make_unique<Impl>(bootstrap)) {}
MSCCLPP_API_CPP Communicator::Communicator(std::shared_ptr<Bootstrap> bootstrap, std::shared_ptr<Context> context)
: pimpl(std::make_unique<Impl>(bootstrap, context)) {}

MSCCLPP_API_CPP std::shared_ptr<Bootstrap> Communicator::bootstrap() { return pimpl->bootstrap_; }

MSCCLPP_API_CPP RegisteredMemory Communicator::registerMemory(void* ptr, size_t size, TransportFlags transports) {
return RegisteredMemory(
std::make_shared<RegisteredMemory::Impl>(ptr, size, pimpl->bootstrap_->getRank(), transports, *pimpl));
return context()->registerMemory(ptr, size, transports);
}

struct MemorySender : public Setuppable {
Expand Down Expand Up @@ -94,34 +75,41 @@ MSCCLPP_API_CPP NonblockingFuture<RegisteredMemory> Communicator::recvMemoryOnSe
return NonblockingFuture<RegisteredMemory>(memoryReceiver->memoryPromise_.get_future());
}

MSCCLPP_API_CPP std::shared_ptr<Connection> Communicator::connectOnSetup(int remoteRank, int tag, Transport transport) {
std::shared_ptr<ConnectionBase> conn;
if (transport == Transport::CudaIpc) {
// sanity check: make sure the IPC connection is being made within a node
if (pimpl->rankToHash_[remoteRank] != pimpl->rankToHash_[pimpl->bootstrap_->getRank()]) {
std::stringstream ss;
ss << "Cuda IPC connection can only be made within a node: " << remoteRank << "(" << std::hex
<< pimpl->rankToHash_[remoteRank] << ") != " << pimpl->bootstrap_->getRank() << "(" << std::hex
<< pimpl->rankToHash_[pimpl->bootstrap_->getRank()] << ")";
throw mscclpp::Error(ss.str(), ErrorCode::InvalidUsage);
}
auto cudaIpcConn = std::make_shared<CudaIpcConnection>(remoteRank, tag, pimpl->getIpcStream());
conn = cudaIpcConn;
INFO(MSCCLPP_P2P, "Cuda IPC connection between rank %d(%lx) and remoteRank %d(%lx) created",
pimpl->bootstrap_->getRank(), pimpl->rankToHash_[pimpl->bootstrap_->getRank()], remoteRank,
pimpl->rankToHash_[remoteRank]);
} else if (AllIBTransports.has(transport)) {
auto ibConn = std::make_shared<IBConnection>(remoteRank, tag, transport, *pimpl);
conn = ibConn;
INFO(MSCCLPP_NET, "IB connection between rank %d(%lx) via %s and remoteRank %d(%lx) created",
pimpl->bootstrap_->getRank(), pimpl->rankToHash_[pimpl->bootstrap_->getRank()],
getIBDeviceName(transport).c_str(), remoteRank, pimpl->rankToHash_[remoteRank]);
} else {
throw mscclpp::Error("Unsupported transport", ErrorCode::InternalError);
struct Connector : public Setuppable {
Connector(Communicator& comm, int remoteRank, int tag, Transport transport) : comm_(comm), remoteRank_(remoteRank), tag_(tag), localEndpoint_(comm.context()->createEndpoint(transport)) {}

void beginSetup(std::shared_ptr<Bootstrap> bootstrap) override {
bootstrap->send(localEndpoint_.serialize(), remoteRank_, tag_);
}
pimpl->connections_.push_back(conn);
onSetup(conn);
return conn;

void endSetup(std::shared_ptr<Bootstrap> bootstrap) override {
std::vector<char> data;
bootstrap->recv(data, remoteRank_, tag_);
auto remoteEndpoint = Endpoint::deserialize(data);
auto connection = comm_.context()->connect(localEndpoint_, remoteEndpoint);
comm_.pimpl->connectionInfos_[connection.get()] = {remoteRank_, tag_};
connectionPromise_.set_value(connection);
}

std::promise<std::shared_ptr<Connection>> connectionPromise_;
Communicator& comm_;
int remoteRank_;
int tag_;
Endpoint localEndpoint_;
};

MSCCLPP_API_CPP NonblockingFuture<std::shared_ptr<Connection>> Communicator::connectOnSetup(int remoteRank, int tag, Transport transport) {
auto connector = std::make_shared<Connector>(*this, remoteRank, tag, transport);
onSetup(connector);
return NonblockingFuture<std::shared_ptr<Connection>>(connector->connectionPromise_.get_future());
}

MSCCLPP_API_CPP int Communicator::remoteRankOf(const Connection& connection) {
return pimpl->connectionInfos_.at(&connection).remoteRank;
}

MSCCLPP_API_CPP int Communicator::tagOf(const Connection& connection) {
return pimpl->connectionInfos_.at(&connection).tag;
}

MSCCLPP_API_CPP void Communicator::onSetup(std::shared_ptr<Setuppable> setuppable) {
Expand Down
23 changes: 9 additions & 14 deletions src/connection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@

#include "connection.hpp"

#include <sstream>
#include <mscclpp/utils.hpp>

#include "endpoint.hpp"
#include "debug.h"
#include "infiniband/verbs.h"
#include "npkit/npkit.h"
Expand All @@ -18,16 +20,10 @@ void validateTransport(RegisteredMemory mem, Transport transport) {
}
}

// Connection

std::shared_ptr<RegisteredMemory::Impl> Connection::getRegisteredMemoryImpl(RegisteredMemory& memory) {
return memory.pimpl;
}

// CudaIpcConnection

CudaIpcConnection::CudaIpcConnection(Endpoint localEndpoint, Endpoint remoteEndpoint, Context::Impl& contextImpl)
: stream_(contextImpl->getIpcStream()) {
CudaIpcConnection::CudaIpcConnection(Endpoint localEndpoint, Endpoint remoteEndpoint, cudaStream_t stream)
: stream_(stream) {
if (localEndpoint.transport() != Transport::CudaIpc) {
throw mscclpp::Error("Cuda IPC connection can only be made from a Cuda IPC endpoint", ErrorCode::InvalidUsage);
}
Expand Down Expand Up @@ -83,16 +79,15 @@ void CudaIpcConnection::flush() {

// IBConnection

IBConnection::IBConnection(Endpoint localEndpoint, Endpoint remoteEndpoint, Context::Impl& contextImpl)
IBConnection::IBConnection(Endpoint localEndpoint, Endpoint remoteEndpoint, Context& context)
: transport_(localEndpoint.transport()),
remoteTransport_(remoteEndpoint.transport()),
numSignaledSends(0),
dummyAtomicSource_(std::make_unique<uint64_t>(0)) {
qp = localEndpoint.pimpl->ibQp_;
qp->rtr(remoteEndpoint.pimpl->ibQpInfo_);
qp->rts();
dummyAtomicSourceMem_ = RegisteredMemory(std::make_shared<RegisteredMemory::Impl>(
dummyAtomicSource_.get(), sizeof(uint64_t), transport_, contextImpl));
dummyAtomicSourceMem_ = context.registerMemory(dummyAtomicSource_.get(), sizeof(uint64_t), transport_);
INFO(MSCCLPP_NET, "IB connection via %s created", getIBDeviceName(transport_).c_str());
}

Expand All @@ -105,11 +100,11 @@ void IBConnection::write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMem
validateTransport(dst, remoteTransport());
validateTransport(src, transport());

auto dstTransportInfo = getRegisteredMemoryImpl(dst)->getTransportInfo(remoteTransport());
auto dstTransportInfo = dst.pimpl->getTransportInfo(remoteTransport());
if (dstTransportInfo.ibLocal) {
throw Error("dst is local, which is not supported", ErrorCode::InvalidUsage);
}
auto srcTransportInfo = getRegisteredMemoryImpl(src)->getTransportInfo(transport());
auto srcTransportInfo = src.pimpl->getTransportInfo(transport());
if (!srcTransportInfo.ibLocal) {
throw Error("src is remote, which is not supported", ErrorCode::InvalidUsage);
}
Expand All @@ -129,7 +124,7 @@ void IBConnection::write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMem

void IBConnection::updateAndSync(RegisteredMemory dst, uint64_t dstOffset, uint64_t* src, uint64_t newValue) {
validateTransport(dst, remoteTransport());
auto dstTransportInfo = getRegisteredMemoryImpl(dst)->getTransportInfo(remoteTransport());
auto dstTransportInfo = dst.pimpl->getTransportInfo(remoteTransport());
if (dstTransportInfo.ibLocal) {
throw Error("dst is local, which is not supported", ErrorCode::InvalidUsage);
}
Expand Down
Loading

0 comments on commit 491dc72

Please sign in to comment.