Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduction of the raft::device_resources_snmg type #2487

Open
wants to merge 4 commits into
base: branch-24.12
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,7 @@
} \
} while (0);

namespace raft::comms {
void build_comms_nccl_only(raft::resources* handle, ncclComm_t nccl_comm, int num_ranks, int rank);
}

namespace raft::comms {
namespace raft::core {

struct nccl_clique {
using pool_mr = rmm::mr::pool_memory_resource<rmm::mr::device_memory_resource>;
Expand Down Expand Up @@ -114,10 +110,6 @@ struct nccl_clique {

// create a device resource handle for each device
device_resources_.emplace_back();

// add NCCL communications to the device resource handle
raft::comms::build_comms_nccl_only(
&device_resources_[rank], nccl_comms_[rank], num_ranks_, rank);
}

for (int rank = 0; rank < num_ranks_; rank++) {
Expand Down Expand Up @@ -153,4 +145,4 @@ struct nccl_clique {
std::vector<raft::device_resources> device_resources_;
viclafargue marked this conversation as resolved.
Show resolved Hide resolved
};

} // namespace raft::comms
} // namespace raft::core
80 changes: 72 additions & 8 deletions cpp/include/raft/core/resource/nccl_clique.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
*/
#pragma once

#include <raft/comms/nccl_clique.hpp>
#include <raft/core/nccl_clique.hpp>
#include <raft/core/resource/resource_types.hpp>
#include <raft/core/resources.hpp>

Expand All @@ -25,38 +25,102 @@ namespace raft::resource {

class nccl_clique_resource : public resource {
public:
nccl_clique_resource() : clique_(std::make_unique<raft::comms::nccl_clique>()) {}
nccl_clique_resource(std::optional<std::vector<int>>& device_ids, int percent_of_free_memory)
{
if (device_ids.has_value()) {
clique_ = std::make_unique<raft::core::nccl_clique>(*device_ids, percent_of_free_memory);
} else {
clique_ = std::make_unique<raft::core::nccl_clique>(percent_of_free_memory);
viclafargue marked this conversation as resolved.
Show resolved Hide resolved
}
}

~nccl_clique_resource() override {}
void* get_resource() override { return clique_.get(); }

private:
std::unique_ptr<raft::comms::nccl_clique> clique_;
std::unique_ptr<raft::core::nccl_clique> clique_;
};

/** Factory that knows how to construct a specific raft::resource to populate the res_t. */
class nccl_clique_resource_factory : public resource_factory {
public:
nccl_clique_resource_factory(const std::optional<std::vector<int>>& device_ids,
int percent_of_free_memory)
: device_ids(device_ids), percent_of_free_memory(percent_of_free_memory)
{
}

resource_type get_resource_type() override { return resource_type::NCCL_CLIQUE; }
resource* make_resource() override { return new nccl_clique_resource(); }
resource* make_resource() override
{
return new nccl_clique_resource(this->device_ids, this->percent_of_free_memory);
}

std::optional<std::vector<int>> device_ids;
int percent_of_free_memory;
};

inline const raft::core::nccl_clique& build_nccl_clique(
resources const& res,
const std::optional<std::vector<int>>& device_ids,
int percent_of_free_memory)
{
if (!res.has_resource_factory(resource_type::NCCL_CLIQUE)) {
res.add_resource_factory(
std::make_shared<nccl_clique_resource_factory>(device_ids, percent_of_free_memory));
} else {
RAFT_LOG_WARN("Attempted re-initialize the NCCL clique on a RAFT resource.");
}
return *res.get_resource<raft::core::nccl_clique>(resource_type::NCCL_CLIQUE);
}

/**
* @defgroup nccl_clique_resource resource functions
* @{
*/

/**
* Retrieves a NCCL clique from raft res if it exists, otherwise initializes it and return it.
* Initializes a NCCL clique and sets it into a raft resource instance
*
* @param[in] res the raft resources object
* @param[in] percent_of_free_memory percentage of device memory to pre-allocate as a memory pool on
* each GPU
* @return NCCL clique
*/
inline const raft::core::nccl_clique& initialize_nccl_clique(resources const& res,
int percent_of_free_memory = 80)
{
return build_nccl_clique(res, std::nullopt, percent_of_free_memory);
};

/**
* Initializes a NCCL clique and sets it into a raft resource instance
*
* @param[in] res the raft resources object
* @param[in] device_ids selection of GPUs initialize the clique on
* @param[in] percent_of_free_memory percentage of device memory to pre-allocate as a memory pool on
viclafargue marked this conversation as resolved.
Show resolved Hide resolved
* each GPU
* @return NCCL clique
*/
inline const raft::core::nccl_clique& initialize_nccl_clique(
resources const& res, std::optional<std::vector<int>> device_ids, int percent_of_free_memory = 80)
{
return build_nccl_clique(res, device_ids, percent_of_free_memory);
};

/**
* Retrieves a NCCL clique from raft resource instance, initializes one with default parameters if
* absent
*
* @param[in] res the raft resources object
* @return NCCL clique
*/
inline const raft::comms::nccl_clique& get_nccl_clique(resources const& res)
inline const raft::core::nccl_clique& get_nccl_clique(resources const& res)
{
if (!res.has_resource_factory(resource_type::NCCL_CLIQUE)) {
res.add_resource_factory(std::make_shared<nccl_clique_resource_factory>());
raft::resource::initialize_nccl_clique(res);
}
return *res.get_resource<raft::comms::nccl_clique>(resource_type::NCCL_CLIQUE);
return *res.get_resource<raft::core::nccl_clique>(resource_type::NCCL_CLIQUE);
};

/**
Expand Down
Loading