Skip to content

Commit

Permalink
multi cuda gpu support (#83)
Browse files Browse the repository at this point in the history
Signed-off-by: Yusheng.Ma <[email protected]>
  • Loading branch information
Presburger authored Sep 18, 2023
1 parent 9aa3e21 commit 8280cfc
Show file tree
Hide file tree
Showing 7 changed files with 173 additions and 133 deletions.
5 changes: 3 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ if(WITH_COVERAGE)
endif()

knowhere_file_glob(GLOB_RECURSE KNOWHERE_SRCS src/common/*.cc src/index/*.cc
src/io/*.cc src/index/*.cu src/common/raft/*.cu)
src/io/*.cc src/index/*.cu src/common/raft/*.cu
src/common/raft/*.cc)

set(KNOWHERE_LINKER_LIBS "")

Expand All @@ -116,7 +117,7 @@ list(REMOVE_ITEM KNOWHERE_SRCS ${KNOWHERE_GPU_SRCS})
if(NOT WITH_RAFT)
knowhere_file_glob(GLOB_RECURSE KNOWHERE_RAFT_SRCS src/index/ivf_raft/*.cc
src/index/ivf_raft/*.cu src/index/cagra/*.cu
src/common/raft/*.cu)
src/common/raft/*.cu src/common/raft/*.cc)
list(REMOVE_ITEM KNOWHERE_SRCS ${KNOWHERE_RAFT_SRCS})
endif()

Expand Down
36 changes: 36 additions & 0 deletions src/common/raft/raft_utils.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#include "raft_utils.h"

namespace raft_utils {
int
gpu_device_manager::random_choose() const {
srand(time(NULL));
return rand() % memory_load_.size();
}

int
gpu_device_manager::choose_with_load(size_t load) {
std::lock_guard<std::mutex> lock(mtx_);

auto it = std::min_element(memory_load_.begin(), memory_load_.end());
*it += load;
return std::distance(memory_load_.begin(), it);
}

gpu_device_manager::gpu_device_manager() {
int device_counts;
try {
RAFT_CUDA_TRY(cudaGetDeviceCount(&device_counts));
} catch (const raft::exception& e) {
LOG_KNOWHERE_FATAL_ << e.what();
}
memory_load_.resize(device_counts);
std::fill(memory_load_.begin(), memory_load_.end(), 0);
}

gpu_device_manager&
gpu_device_manager::instance() {
static gpu_device_manager mgr;
return mgr;
}

} // namespace raft_utils
24 changes: 24 additions & 0 deletions src/common/raft/raft_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,21 @@ init_gpu_resources(std::optional<std::size_t> streams_per_device = std::nullopt,
get_gpu_resources(streams_per_device).init(device_id);
}

class gpu_device_manager {
public:
static gpu_device_manager&
instance();
int
random_choose() const;
int
choose_with_load(size_t load);

private:
gpu_device_manager();
std::vector<size_t> memory_load_;
mutable std::mutex mtx_;
};

inline auto&
get_raft_resources(int device_id = get_current_device()) {
thread_local auto all_resources = std::map<int, std::unique_ptr<raft::device_resources>>{};
Expand All @@ -168,3 +183,12 @@ set_mem_pool_size(size_t init_size, size_t max_size) {
}

}; // namespace raft_utils

#define RANDOM_CHOOSE_DEVICE_WITH_ASSIGN(x) \
do { \
x = raft_utils::gpu_device_manager::instance().random_choose(); \
} while (0)
#define MIN_LOAD_CHOOSE_DEVICE_WITH_ASSIGN(x, load) \
do { \
x = raft_utils::gpu_device_manager::instance().choose_with_load(load); \
} while (0)
2 changes: 1 addition & 1 deletion src/index/ivf_raft/ivf_raft.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
#include "knowhere/factory.h"
#include "knowhere/index_node_thread_pool_wrapper.h"

constexpr uint32_t cuda_concurrent_size = 16;
constexpr uint32_t cuda_concurrent_size = 32;

namespace knowhere {

Expand Down
222 changes: 109 additions & 113 deletions src/index/ivf_raft/ivf_raft.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ struct KnowhereConfigType<detail::raft_ivf_pq_index> {
template <typename T>
class RaftIvfIndexNode : public IndexNode {
public:
RaftIvfIndexNode(const Object& object) : devs_{}, gpu_index_{} {
RaftIvfIndexNode(const Object& object) : device_id_{-1}, gpu_index_{} {
}

Status
Expand All @@ -242,124 +242,115 @@ class RaftIvfIndexNode : public IndexNode {
if (gpu_index_) {
LOG_KNOWHERE_WARNING_ << "index is already trained";
return Status::index_already_trained;
} else if (ivf_raft_cfg.gpu_ids.value().size() == 1) {
try {
auto scoped_device = raft_utils::device_setter{*ivf_raft_cfg.gpu_ids.value().begin()};
raft_utils::init_gpu_resources();

auto metric = Str2RaftMetricType(ivf_raft_cfg.metric_type.value());
if (!metric.has_value()) {
LOG_KNOWHERE_WARNING_ << "please check metric value: " << ivf_raft_cfg.metric_type.value();
return metric.error();
}
if (metric.value() != raft::distance::DistanceType::L2Expanded &&
metric.value() != raft::distance::DistanceType::InnerProduct) {
LOG_KNOWHERE_WARNING_ << "selected metric not supported in RAFT IVF indexes: "
<< ivf_raft_cfg.metric_type.value();
return Status::invalid_metric_type;
}
devs_.insert(devs_.begin(), ivf_raft_cfg.gpu_ids.value().begin(), ivf_raft_cfg.gpu_ids.value().end());
auto& res = raft_utils::get_raft_resources();

auto rows = dataset.GetRows();
auto dim = dataset.GetDim();
auto* data = reinterpret_cast<float const*>(dataset.GetTensor());

auto data_gpu = raft::make_device_matrix<float, std::int64_t>(res, rows, dim);
RAFT_CUDA_TRY(cudaMemcpyAsync(data_gpu.data_handle(), data, data_gpu.size() * sizeof(float),
cudaMemcpyDefault, res.get_stream().value()));
if constexpr (std::is_same_v<detail::raft_ivf_flat_index, T>) {
auto build_params = raft::neighbors::ivf_flat::index_params{};
build_params.metric = metric.value();
build_params.n_lists = ivf_raft_cfg.nlist.value();
build_params.kmeans_n_iters = ivf_raft_cfg.kmeans_n_iters.value();
build_params.kmeans_trainset_fraction = ivf_raft_cfg.kmeans_trainset_fraction.value();
build_params.adaptive_centers = ivf_raft_cfg.adaptive_centers.value();
gpu_index_ =
raft::neighbors::ivf_flat::build<float, std::int64_t>(res, build_params, data_gpu.view());
} else if constexpr (std::is_same_v<detail::raft_ivf_pq_index, T>) {
auto build_params = raft::neighbors::ivf_pq::index_params{};
build_params.metric = metric.value();
build_params.n_lists = ivf_raft_cfg.nlist.value();
build_params.pq_bits = ivf_raft_cfg.nbits.value();
build_params.kmeans_n_iters = ivf_raft_cfg.kmeans_n_iters.value();
build_params.kmeans_trainset_fraction = ivf_raft_cfg.kmeans_trainset_fraction.value();
build_params.pq_dim = ivf_raft_cfg.m.value();
auto codebook_kind = detail::str_to_codebook_gen(ivf_raft_cfg.codebook_kind.value());
if (!codebook_kind.has_value()) {
LOG_KNOWHERE_WARNING_ << "please check codebook kind: " << ivf_raft_cfg.codebook_kind.value();
return codebook_kind.error();
}
build_params.codebook_kind = codebook_kind.value();
build_params.force_random_rotation = ivf_raft_cfg.force_random_rotation.value();
gpu_index_ =
raft::neighbors::ivf_pq::build<float, std::int64_t>(res, build_params, data_gpu.view());
} else {
static_assert(std::is_same_v<detail::raft_ivf_flat_index, T>);
}
dim_ = dim;
counts_ = rows;
res.sync_stream();
}
try {
RANDOM_CHOOSE_DEVICE_WITH_ASSIGN(this->device_id_);
raft_utils::device_setter with_this_device(this->device_id_);
raft_utils::init_gpu_resources();

auto metric = Str2RaftMetricType(ivf_raft_cfg.metric_type.value());
if (!metric.has_value()) {
LOG_KNOWHERE_WARNING_ << "please check metric value: " << ivf_raft_cfg.metric_type.value();
return metric.error();
}
if (metric.value() != raft::distance::DistanceType::L2Expanded &&
metric.value() != raft::distance::DistanceType::InnerProduct) {
LOG_KNOWHERE_WARNING_ << "selected metric not supported in RAFT IVF indexes: "
<< ivf_raft_cfg.metric_type.value();
return Status::invalid_metric_type;
}
auto& res = raft_utils::get_raft_resources();

} catch (std::exception& e) {
LOG_KNOWHERE_WARNING_ << "RAFT inner error, " << e.what();
return Status::raft_inner_error;
auto rows = dataset.GetRows();
auto dim = dataset.GetDim();
auto* data = reinterpret_cast<float const*>(dataset.GetTensor());

auto data_gpu = raft::make_device_matrix<float, std::int64_t>(res, rows, dim);
RAFT_CUDA_TRY(cudaMemcpyAsync(data_gpu.data_handle(), data, data_gpu.size() * sizeof(float),
cudaMemcpyDefault, res.get_stream().value()));
if constexpr (std::is_same_v<detail::raft_ivf_flat_index, T>) {
auto build_params = raft::neighbors::ivf_flat::index_params{};
build_params.metric = metric.value();
build_params.n_lists = ivf_raft_cfg.nlist.value();
build_params.kmeans_n_iters = ivf_raft_cfg.kmeans_n_iters.value();
build_params.kmeans_trainset_fraction = ivf_raft_cfg.kmeans_trainset_fraction.value();
build_params.adaptive_centers = ivf_raft_cfg.adaptive_centers.value();
gpu_index_ = raft::neighbors::ivf_flat::build<float, std::int64_t>(res, build_params, data_gpu.view());
} else if constexpr (std::is_same_v<detail::raft_ivf_pq_index, T>) {
auto build_params = raft::neighbors::ivf_pq::index_params{};
build_params.metric = metric.value();
build_params.n_lists = ivf_raft_cfg.nlist.value();
build_params.pq_bits = ivf_raft_cfg.nbits.value();
build_params.kmeans_n_iters = ivf_raft_cfg.kmeans_n_iters.value();
build_params.kmeans_trainset_fraction = ivf_raft_cfg.kmeans_trainset_fraction.value();
build_params.pq_dim = ivf_raft_cfg.m.value();
auto codebook_kind = detail::str_to_codebook_gen(ivf_raft_cfg.codebook_kind.value());
if (!codebook_kind.has_value()) {
LOG_KNOWHERE_WARNING_ << "please check codebook kind: " << ivf_raft_cfg.codebook_kind.value();
return codebook_kind.error();
}
build_params.codebook_kind = codebook_kind.value();
build_params.force_random_rotation = ivf_raft_cfg.force_random_rotation.value();
gpu_index_ = raft::neighbors::ivf_pq::build<float, std::int64_t>(res, build_params, data_gpu.view());
}
} else {
LOG_KNOWHERE_WARNING_ << "RAFT IVF implementation is single-GPU only";
dim_ = dim;
counts_ = rows;
res.sync_stream();

} catch (std::exception& e) {
LOG_KNOWHERE_WARNING_ << "RAFT inner error, " << e.what();
return Status::raft_inner_error;
}
return Status::success;
}

Status
Add(const DataSet& dataset, const Config& cfg) override {
auto result = Status::success;
if (!gpu_index_) {
result = Status::index_not_trained;
} else {
try {
auto scoped_device = raft_utils::device_setter{devs_[0]};
auto rows = dataset.GetRows();
auto dim = dataset.GetDim();
auto* data = reinterpret_cast<float const*>(dataset.GetTensor());

raft_utils::init_gpu_resources();
auto& res = raft_utils::get_raft_resources();

// TODO(wphicks): Clean up transfer with raft
// buffer objects when available
auto data_gpu = raft::make_device_matrix<float, std::int64_t>(res, rows, dim);
RAFT_CUDA_TRY(cudaMemcpyAsync(data_gpu.data_handle(), data, data_gpu.size() * sizeof(float),
cudaMemcpyDefault, res.get_stream().value()));

auto indices = rmm::device_uvector<std::int64_t>(rows, res.get_stream());
thrust::sequence(res.get_thrust_policy(), indices.begin(), indices.end(), gpu_index_->size());

if constexpr (std::is_same_v<detail::raft_ivf_flat_index, T>) {
raft::neighbors::ivf_flat::extend<float, std::int64_t>(
res, raft::make_const_mdspan(data_gpu.view()),
std::make_optional(
raft::make_device_vector_view<const std::int64_t, std::int64_t>(indices.data(), rows)),
gpu_index_.value());
} else if constexpr (std::is_same_v<detail::raft_ivf_pq_index, T>) {
raft::neighbors::ivf_pq::extend<float, std::int64_t>(
res, raft::make_const_mdspan(data_gpu.view()),
std::make_optional(
raft::make_device_matrix_view<const std::int64_t, std::int64_t>(indices.data(), rows, 1)),
gpu_index_.value());
} else {
static_assert(std::is_same_v<detail::raft_ivf_flat_index, T>);
}
dim_ = dim;
counts_ = rows;
} catch (std::exception& e) {
LOG_KNOWHERE_WARNING_ << "RAFT inner error, " << e.what();
result = Status::raft_inner_error;
return Status::index_not_trained;
}
try {
RAFT_EXPECTS(this->device_id_ != -1, "call data add before index train.");
raft_utils::device_setter with_this_device{this->device_id_};
auto rows = dataset.GetRows();
auto dim = dataset.GetDim();
auto* data = reinterpret_cast<float const*>(dataset.GetTensor());

raft_utils::init_gpu_resources();
auto& res = raft_utils::get_raft_resources();

// TODO(wphicks): Clean up transfer with raft
// buffer objects when available
auto data_gpu = raft::make_device_matrix<float, std::int64_t>(res, rows, dim);
RAFT_CUDA_TRY(cudaMemcpyAsync(data_gpu.data_handle(), data, data_gpu.size() * sizeof(float),
cudaMemcpyDefault, res.get_stream().value()));

auto indices = rmm::device_uvector<std::int64_t>(rows, res.get_stream());
thrust::sequence(res.get_thrust_policy(), indices.begin(), indices.end(), gpu_index_->size());

if constexpr (std::is_same_v<detail::raft_ivf_flat_index, T>) {
raft::neighbors::ivf_flat::extend<float, std::int64_t>(
res, raft::make_const_mdspan(data_gpu.view()),
std::make_optional(
raft::make_device_vector_view<const std::int64_t, std::int64_t>(indices.data(), rows)),
gpu_index_.value());
} else if constexpr (std::is_same_v<detail::raft_ivf_pq_index, T>) {
raft::neighbors::ivf_pq::extend<float, std::int64_t>(
res, raft::make_const_mdspan(data_gpu.view()),
std::make_optional(
raft::make_device_matrix_view<const std::int64_t, std::int64_t>(indices.data(), rows, 1)),
gpu_index_.value());
} else {
static_assert(std::is_same_v<detail::raft_ivf_flat_index, T>);
}
dim_ = dim;
counts_ = rows;
} catch (std::exception& e) {
LOG_KNOWHERE_WARNING_ << "RAFT inner error, " << e.what();
return Status::raft_inner_error;
}

return result;
return Status::success;
}

expected<DataSetPtr>
Expand All @@ -372,7 +363,8 @@ class RaftIvfIndexNode : public IndexNode {
auto ids = std::unique_ptr<std::int64_t[]>(new std::int64_t[output_size]);
auto dis = std::unique_ptr<float[]>(new float[output_size]);
try {
auto scoped_device = raft_utils::device_setter{devs_[0]};
RAFT_EXPECTS(this->device_id_ != -1, "device id is -1, when call search");
raft_utils::device_setter with_this_device{this->device_id_};
auto& res_ = raft_utils::get_raft_resources();

// TODO(wphicks): Clean up transfer with raft
Expand Down Expand Up @@ -480,15 +472,17 @@ class RaftIvfIndexNode : public IndexNode {
LOG_KNOWHERE_ERROR_ << "Can not serialize empty RaftIvfIndex.";
return Status::empty_index;
}

RAFT_EXPECTS(this->device_id_ != -1, "index serialize before trained.");
std::stringbuf buf;

std::ostream os(&buf);

os.write((char*)(&this->dim_), sizeof(this->dim_));
os.write((char*)(&this->counts_), sizeof(this->counts_));
os.write((char*)(&this->devs_[0]), sizeof(this->devs_[0]));
os.write((char*)(&this->device_id_), sizeof(this->device_id_));

auto scoped_device = raft_utils::device_setter{devs_[0]};
raft_utils::device_setter with_this_device{device_id_};
auto& res = raft_utils::get_raft_resources();

if constexpr (std::is_same_v<T, detail::raft_ivf_flat_index>) {
Expand Down Expand Up @@ -520,9 +514,11 @@ class RaftIvfIndexNode : public IndexNode {

is.read((char*)(&this->dim_), sizeof(this->dim_));
is.read((char*)(&this->counts_), sizeof(this->counts_));
this->devs_.resize(1);
is.read((char*)(&this->devs_[0]), sizeof(this->devs_[0]));
auto scoped_device = raft_utils::device_setter{devs_[0]};
// device_id from binset is useless, will gen device id from global
// status
is.read((char*)(&this->device_id_), sizeof(this->device_id_));
MIN_LOAD_CHOOSE_DEVICE_WITH_ASSIGN(this->device_id_, binary->size);
raft_utils::device_setter with_this_device{this->device_id_};

raft_utils::init_gpu_resources();
auto& res = raft_utils::get_raft_resources();
Expand Down Expand Up @@ -580,7 +576,7 @@ class RaftIvfIndexNode : public IndexNode {
}

private:
std::vector<int32_t> devs_;
int device_id_ = -1;
int64_t dim_ = 0;
int64_t counts_ = 0;
std::optional<T> gpu_index_;
Expand Down
Loading

0 comments on commit 8280cfc

Please sign in to comment.