Skip to content

Commit

Permalink
Remove rank from RegisteredMemory
Browse files Browse the repository at this point in the history
  • Loading branch information
olsaarik committed Aug 23, 2023
1 parent 07165ea commit 002c056
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 8 deletions.
2 changes: 1 addition & 1 deletion src/context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ MSCCLPP_API_CPP Context::Context() : pimpl(std::make_unique<Impl>()) {}
MSCCLPP_API_CPP Context::~Context() = default;

MSCCLPP_API_CPP RegisteredMemory Context::registerMemory(void* ptr, size_t size, TransportFlags transports) {
return RegisteredMemory(std::make_shared<RegisteredMemory::Impl>(ptr, size, -1, transports, *pimpl));
return RegisteredMemory(std::make_shared<RegisteredMemory::Impl>(ptr, size, transports, *pimpl));
}

MSCCLPP_API_CPP Endpoint Context::createEndpoint(Transport transport, int ibMaxCqSize, int ibMaxCqPollNum,
Expand Down
3 changes: 1 addition & 2 deletions src/include/registered_memory.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,11 @@ struct TransportInfo {
struct RegisteredMemory::Impl {
void* data;
size_t size;
int rank;
uint64_t hostHash;
TransportFlags transports;
std::vector<TransportInfo> transportInfos;

Impl(void* data, size_t size, int rank, TransportFlags transports, Context::Impl& contextImpl);
Impl(void* data, size_t size, TransportFlags transports, Context::Impl& contextImpl);
Impl(const std::vector<char>& data);

const TransportInfo& getTransportInfo(Transport transport) const;
Expand Down
7 changes: 2 additions & 5 deletions src/registered_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@

namespace mscclpp {

RegisteredMemory::Impl::Impl(void* data, size_t size, int rank, TransportFlags transports, Context::Impl& contextImpl)
: data(data), size(size), rank(rank), hostHash(contextImpl.hostHash_), transports(transports) {
RegisteredMemory::Impl::Impl(void* data, size_t size, TransportFlags transports, Context::Impl& contextImpl)
: data(data), size(size), hostHash(contextImpl.hostHash_), transports(transports) {
if (transports.has(Transport::CudaIpc)) {
TransportInfo transportInfo;
transportInfo.transport = Transport::CudaIpc;
Expand Down Expand Up @@ -66,7 +66,6 @@ MSCCLPP_API_CPP TransportFlags RegisteredMemory::transports() { return pimpl->tr
MSCCLPP_API_CPP std::vector<char> RegisteredMemory::serialize() {
std::vector<char> result;
std::copy_n(reinterpret_cast<char*>(&pimpl->size), sizeof(pimpl->size), std::back_inserter(result));
std::copy_n(reinterpret_cast<char*>(&pimpl->rank), sizeof(pimpl->rank), std::back_inserter(result));
std::copy_n(reinterpret_cast<char*>(&pimpl->hostHash), sizeof(pimpl->hostHash), std::back_inserter(result));
std::copy_n(reinterpret_cast<char*>(&pimpl->transports), sizeof(pimpl->transports), std::back_inserter(result));
if (pimpl->transportInfos.size() > static_cast<size_t>(std::numeric_limits<int8_t>::max())) {
Expand Down Expand Up @@ -98,8 +97,6 @@ RegisteredMemory::Impl::Impl(const std::vector<char>& serialization) {
auto it = serialization.begin();
std::copy_n(it, sizeof(this->size), reinterpret_cast<char*>(&this->size));
it += sizeof(this->size);
std::copy_n(it, sizeof(this->rank), reinterpret_cast<char*>(&this->rank));
it += sizeof(this->rank);
std::copy_n(it, sizeof(this->hostHash), reinterpret_cast<char*>(&this->hostHash));
it += sizeof(this->hostHash);
std::copy_n(it, sizeof(this->transports), reinterpret_cast<char*>(&this->transports));
Expand Down

0 comments on commit 002c056

Please sign in to comment.