Skip to content

Commit

Permalink
Support for specifying RDMA devices when multiple RDMA devices are pr…
Browse files Browse the repository at this point in the history
…esent.

Signed-off-by: vegetableysm <[email protected]>
  • Loading branch information
vegetableysm committed Oct 29, 2024
1 parent 728477c commit 78e19c8
Show file tree
Hide file tree
Showing 15 changed files with 231 additions and 111 deletions.
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,7 @@ if (${CMAKE_SYSTEM_NAME} STREQUAL "Linux")
--disable-perf
--disable-efa
--disable-mrail
--with-cuda=no
--enable-verbs > /dev/null
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/thirdparty/libfabric
)
Expand All @@ -719,6 +720,7 @@ if (${CMAKE_SYSTEM_NAME} STREQUAL "Linux")
--disable-perf
--disable-efa
--disable-mrail
--with-cuda=no
--disable-verbs > /dev/null
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/thirdparty/libfabric
)
Expand Down
51 changes: 34 additions & 17 deletions src/client/rpc_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,16 +94,18 @@ Status RPCClient::Connect(const std::string& rpc_endpoint) {
Status RPCClient::Connect(const std::string& rpc_endpoint,
std::string const& username,
std::string const& password,
const std::string& rdma_endpoint) {
const std::string& rdma_endpoint,
std::string src_rdma_ednpoint) {
return this->Connect(rpc_endpoint, RootSessionID(), username, password,
rdma_endpoint);
rdma_endpoint, src_rdma_ednpoint);
}

Status RPCClient::Connect(const std::string& rpc_endpoint,
const SessionID session_id,
std::string const& username,
std::string const& password,
const std::string& rdma_endpoint) {
const std::string& rdma_endpoint,
std::string src_rdma_ednpoint) {
size_t pos = rpc_endpoint.find(":");
std::string host, port;
if (pos == std::string::npos) {
Expand All @@ -125,28 +127,32 @@ Status RPCClient::Connect(const std::string& rpc_endpoint,

return this->Connect(host, static_cast<uint32_t>(std::stoul(port)),
session_id, username, password, rdma_host,
static_cast<uint32_t>(std::stoul(rdma_port)));
static_cast<uint32_t>(std::stoul(rdma_port)),
src_rdma_ednpoint);
}

Status RPCClient::Connect(const std::string& host, uint32_t port,
const std::string& rdma_host, uint32_t rdma_port) {
const std::string& rdma_host, uint32_t rdma_port,
std::string src_rdma_ednpoint) {
return this->Connect(host, port, RootSessionID(), "", "", rdma_host,
rdma_port);
rdma_port, src_rdma_ednpoint);
}

Status RPCClient::Connect(const std::string& host, uint32_t port,
std::string const& username,
std::string const& password,
const std::string& rdma_host, uint32_t rdma_port) {
const std::string& rdma_host, uint32_t rdma_port,
std::string src_rdma_ednpoint) {
return this->Connect(host, port, RootSessionID(), username, password,
rdma_host, rdma_port);
rdma_host, rdma_port, src_rdma_ednpoint);
}

Status RPCClient::Connect(const std::string& host, uint32_t port,
const SessionID session_id,
std::string const& username,
std::string const& password,
const std::string& rdma_host, uint32_t rdma_port) {
const std::string& rdma_host, uint32_t rdma_port,
std::string src_rdma_ednpoint) {
std::lock_guard<std::recursive_mutex> guard(client_mutex_);
std::string rpc_endpoint = host + ":" + std::to_string(port);
RETURN_ON_ASSERT(!connected_ || rpc_endpoint == rpc_endpoint_);
Expand Down Expand Up @@ -183,7 +189,8 @@ Status RPCClient::Connect(const std::string& host, uint32_t port,
instance_id_ = UnspecifiedInstanceID() - 1;

if (rdma_host.length() > 0) {
Status status = ConnectRDMA(rdma_host, rdma_port);
src_rdma_endpoint_ = src_rdma_ednpoint;
Status status = ConnectRDMA(rdma_host, rdma_port, src_rdma_ednpoint);
if (status.ok()) {
rdma_endpoint_ = rdma_host + ":" + std::to_string(rdma_port);
std::cout << "Connected to RPC server: " << rpc_endpoint
Expand All @@ -192,33 +199,38 @@ Status RPCClient::Connect(const std::string& host, uint32_t port,
} else {
std::cout << "Connect RDMA server failed! Fall back to RPC mode. Error:"
<< status.message() << std::endl;
std::cout << "Failed src_rdma_ednpoint: " << src_rdma_ednpoint
<< std::endl;
}
}

return Status::OK();
}

Status RPCClient::ConnectRDMA(const std::string& rdma_host,
uint32_t rdma_port) {
Status RPCClient::ConnectRDMA(const std::string& rdma_host, uint32_t rdma_port,
std::string src_rdma_endpoint) {
if (this->rdma_connected_) {
return Status::OK();
}

RETURN_ON_ERROR(RDMAClientCreator::Create(this->rdma_client_, rdma_host,
static_cast<int>(rdma_port)));
static_cast<int>(rdma_port),
src_rdma_endpoint));

int retry = 0;
do {
if (this->rdma_client_->Connect().ok()) {
Status status = this->rdma_client_->Connect();
if (status.ok()) {
break;
}
if (retry == 10) {
return Status::Invalid("Failed to connect to RDMA server.");
}
retry++;
usleep(300 * 1000);
std::cout << "Connect rdma server failed! retry: " << retry << " times."
<< std::endl;
std::cout << "Connect rdma server failed! Error:" + status.message() +
"retry: "
<< retry << " times." << std::endl;
} while (true);
this->rdma_connected_ = true;
return Status::OK();
Expand Down Expand Up @@ -272,6 +284,9 @@ Status RPCClient::RDMAReleaseMemInfo(RegisterMemInfo& remote_info) {

Status RPCClient::StopRDMA() {
if (!rdma_connected_) {
RETURN_ON_ERROR(
RDMAClientCreator::Release(RDMAClientCreator::buildConnectionKey(
rdma_endpoint_, src_rdma_endpoint_)));
return Status::OK();
}
rdma_connected_ = false;
Expand All @@ -285,7 +300,9 @@ Status RPCClient::StopRDMA() {

RETURN_ON_ERROR(rdma_client_->Stop());
RETURN_ON_ERROR(rdma_client_->Close());
RETURN_ON_ERROR(RDMAClientCreator::Release(rdma_endpoint_));
RETURN_ON_ERROR(
RDMAClientCreator::Release(RDMAClientCreator::buildConnectionKey(
rdma_endpoint_, src_rdma_endpoint_)));

return Status::OK();
}
Expand Down
19 changes: 13 additions & 6 deletions src/client/rpc_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ class RPCClient final : public ClientBase {
*/
Status Connect(const std::string& rpc_endpoint, std::string const& username,
std::string const& password,
const std::string& rdma_endpoint = "");
const std::string& rdma_endpoint = "",
std::string src_rdma_ednpoint = "");

/**
* @brief Connect to vineyardd using the given TCP endpoint `rpc_endpoint`.
Expand All @@ -104,7 +105,8 @@ class RPCClient final : public ClientBase {
Status Connect(const std::string& rpc_endpoint, const SessionID session_id,
std::string const& username = "",
std::string const& password = "",
const std::string& rdma_endpoint = "");
const std::string& rdma_endpoint = "",
std::string src_rdma_ednpoint = "");

/**
* @brief Connect to vineyardd using the given TCP `host` and `port`.
Expand All @@ -117,7 +119,8 @@ class RPCClient final : public ClientBase {
* @return Status that indicates whether the connect has succeeded.
*/
Status Connect(const std::string& host, uint32_t port,
const std::string& rdma_host = "", uint32_t rdma_port = -1);
const std::string& rdma_host = "", uint32_t rdma_port = -1,
std::string src_rdma_ednpoint = "");

/**
* @brief Connect to vineyardd using the given TCP `host` and `port`.
Expand All @@ -131,7 +134,8 @@ class RPCClient final : public ClientBase {
*/
Status Connect(const std::string& host, uint32_t port,
std::string const& username, std::string const& password,
const std::string& rdma_host = "", uint32_t rdma_port = -1);
const std::string& rdma_host = "", uint32_t rdma_port = -1,
std::string src_rdma_ednpoint = "");

/**
* @brief Connect to vineyardd using the given TCP `host` and `port`.
Expand All @@ -147,7 +151,8 @@ class RPCClient final : public ClientBase {
Status Connect(const std::string& host, uint32_t port,
const SessionID session_id, std::string const& username = "",
std::string const& password = "",
const std::string& rdma_host = "", uint32_t rdma_port = -1);
const std::string& rdma_host = "", uint32_t rdma_port = -1,
std::string src_rdma_ednpoint = "");

/**
* @brief Create a new client using self endpoint.
Expand Down Expand Up @@ -436,7 +441,8 @@ class RPCClient final : public ClientBase {
const std::string rdma_endpoint() { return rdma_endpoint_; }

private:
Status ConnectRDMA(const std::string& rdma_host, uint32_t rdma_port);
Status ConnectRDMA(const std::string& rdma_host, uint32_t rdma_port,
std::string src_rdma_endpoint = "");

Status StopRDMA();

Expand Down Expand Up @@ -479,6 +485,7 @@ class RPCClient final : public ClientBase {
std::string rdma_endpoint_;
std::shared_ptr<RDMAClient> rdma_client_;
mutable bool rdma_connected_ = false;
std::string src_rdma_endpoint_ = "";

friend class Client;
};
Expand Down
23 changes: 19 additions & 4 deletions src/common/rdma/rdma.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ Status IRDMA::RegisterMemory(fid_mr** mr, fid_domain* domain, void* address,
return Status::IOError("Failed to register memory region:" +
std::to_string(ret));
}
CHECK_ERROR(!ret, "Failed to register memory region:" + std::to_string(ret));
CHECK_ERROR(ret, "Failed to register memory region:" + std::to_string(ret));

mr_desc = fi_mr_desc(*mr);

Expand Down Expand Up @@ -177,10 +177,25 @@ int IRDMA::GetCompletion(fid_cq* cq, int timeout, void** context) {
return ret < 0 ? ret : 0;
}

void IRDMA::FreeInfo(fi_info* info) {
if (info) {
fi_freeinfo(info);
void IRDMA::FreeInfo(fi_info* info, bool is_hints) {
if (!info) {
return;
}

if (is_hints) {
if (info->src_addr) {
free(info->src_addr);
info->src_addr = nullptr;
info->src_addrlen = 0;
}
if (info->dest_addr) {
free(info->dest_addr);
info->dest_addr = nullptr;
info->dest_addrlen = 0;
}
}

fi_freeinfo(info);
}

} // namespace vineyard
Expand Down
2 changes: 1 addition & 1 deletion src/common/rdma/rdma.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class IRDMA {

static int GetCompletion(fid_cq* cq, int timeout, void** context);

static void FreeInfo(fi_info* info);
static void FreeInfo(fi_info* info, bool is_hints);

template <typename FIDType>
static Status CloseResource(FIDType* res, const char* resource_name) {
Expand Down
Loading

0 comments on commit 78e19c8

Please sign in to comment.