Skip to content

Commit

Permalink
Add sq for hnsw
Browse files Browse the repository at this point in the history
Signed-off-by: zh Wang <[email protected]>
  • Loading branch information
hhy3 committed Dec 19, 2023
1 parent ff90766 commit 218e029
Show file tree
Hide file tree
Showing 22 changed files with 447 additions and 51 deletions.
2 changes: 2 additions & 0 deletions include/knowhere/comp/index_param.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ constexpr const char* INDEX_RAFT_IVFPQ = "GPU_RAFT_IVF_PQ";
constexpr const char* INDEX_RAFT_CAGRA = "GPU_RAFT_CAGRA";

constexpr const char* INDEX_HNSW = "HNSW";
constexpr const char* INDEX_HNSW_SQ = "HNSW_SQ";
constexpr const char* INDEX_DISKANN = "DISKANN";

} // namespace IndexEnum
Expand Down Expand Up @@ -117,6 +118,7 @@ constexpr const char* HNSW_M = "M";
constexpr const char* EF = "ef";
constexpr const char* SEED_EF = "seed_ef";
constexpr const char* OVERVIEW_LEVELS = "overview_levels";
constexpr const char* USE_REFINE = "use_refine";
} // namespace indexparam

using MetricType = std::string;
Expand Down
1 change: 1 addition & 0 deletions src/common/config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ static const std::unordered_set<std::string> ext_legal_json_keys = {"metric_type
"M", // HNSW param
"efConstruction", // HNSW param
"ef", // HNSW param
"use_refine", // HNSW_SQ param
"seed_ef", // HNSW iterator param
"level",
"index_type",
Expand Down
25 changes: 18 additions & 7 deletions src/index/hnsw/hnsw.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "knowhere/utils.h"

namespace knowhere {
template <bool USE_SQ>
class HnswIndexNode : public IndexNode {
public:
HnswIndexNode(const int32_t& /*version*/, const Object& object) : index_(nullptr) {
Expand All @@ -53,8 +54,12 @@ class HnswIndexNode : public IndexNode {
LOG_KNOWHERE_WARNING_ << "metric type not support in hnsw: " << hnsw_cfg.metric_type.value();
return Status::invalid_metric_type;
}
auto index = new (std::nothrow)
hnswlib::HierarchicalNSW<float>(space, rows, hnsw_cfg.M.value(), hnsw_cfg.efConstruction.value());
bool use_sq_refine = false;
if constexpr (USE_SQ) {
use_sq_refine = static_cast<const HnswSQConfig&>(cfg).use_refine.value();
}
auto index = new (std::nothrow) hnswlib::HierarchicalNSW<float>(
space, rows, hnsw_cfg.M.value(), hnsw_cfg.efConstruction.value(), USE_SQ, use_sq_refine);
if (index == nullptr) {
LOG_KNOWHERE_WARNING_ << "memory malloc error.";
return Status::malloc_error;
Expand All @@ -64,6 +69,9 @@ class HnswIndexNode : public IndexNode {
LOG_KNOWHERE_WARNING_ << "index not empty, deleted old index";
}
this->index_ = index;
if constexpr (USE_SQ) {
this->index_->trainSQuant((const float*)dataset.GetTensor(), rows);
}
return Status::success;
}

Expand Down Expand Up @@ -352,7 +360,7 @@ class HnswIndexNode : public IndexNode {

bool
HasRawData(const std::string& metric_type) const override {
return true;
return !USE_SQ || index_->sq_refine_enabled_;
}

expected<DataSetPtr>
Expand Down Expand Up @@ -418,7 +426,7 @@ class HnswIndexNode : public IndexNode {

hnswlib::SpaceInterface<float>* space = nullptr;
index_ = new (std::nothrow) hnswlib::HierarchicalNSW<float>(space);
index_->loadIndex(reader);
index_->loadIndex(reader, USE_SQ);
LOG_KNOWHERE_INFO_ << "Loaded HNSW index. #points num:" << index_->max_elements_ << " #M:" << index_->M_
<< " #max level:" << index_->maxlevel_
<< " #ef_construction:" << index_->ef_construction_
Expand All @@ -438,7 +446,7 @@ class HnswIndexNode : public IndexNode {
try {
hnswlib::SpaceInterface<float>* space = nullptr;
index_ = new (std::nothrow) hnswlib::HierarchicalNSW<float>(space);
index_->loadIndex(filename, config);
index_->loadIndex(filename, config, USE_SQ);
} catch (std::exception& e) {
LOG_KNOWHERE_WARNING_ << "hnsw inner error: " << e.what();
return Status::hnsw_inner_error;
Expand Down Expand Up @@ -477,7 +485,7 @@ class HnswIndexNode : public IndexNode {

std::string
Type() const override {
return knowhere::IndexEnum::INDEX_HNSW;
return USE_SQ ? knowhere::IndexEnum::INDEX_HNSW_SQ : knowhere::IndexEnum::INDEX_HNSW;
}

~HnswIndexNode() override {
Expand Down Expand Up @@ -529,7 +537,10 @@ class HnswIndexNode : public IndexNode {
};

KNOWHERE_REGISTER_GLOBAL(HNSW, [](const int32_t& version, const Object& object) {
return Index<HnswIndexNode>::Create(version, object);
return Index<HnswIndexNode<false>>::Create(version, object);
});
KNOWHERE_REGISTER_GLOBAL(HNSW_SQ, [](const int32_t& version, const Object& object) {
return Index<HnswIndexNode<true>>::Create(version, object);
});

} // namespace knowhere
8 changes: 8 additions & 0 deletions src/index/hnsw/hnsw_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,14 @@ class HnswConfig : public BaseConfig {
}
};

class HnswSQConfig : public HnswConfig {
public:
CFG_BOOL use_refine;
KNOHWERE_DECLARE_CONFIG(HnswSQConfig) {
KNOWHERE_CONFIG_DECLARE_FIELD(use_refine).description("hnswsq use refine").set_default(false).for_train();
}
};

} // namespace knowhere

#endif /* HNSW_CONFIG_H */
23 changes: 23 additions & 0 deletions src/simd/distances_avx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -215,5 +215,28 @@ fvec_L2sqr_batch_4_avx(const float* x, const float* y0, const float* y1, const f
}
FAISS_PRAGMA_IMPRECISE_FUNCTION_END

// trust the compiler to unroll this properly
int32_t
ivec_inner_product_avx(const int8_t* x, const int8_t* y, size_t d) {
size_t i;
int32_t res = 0;
for (i = 0; i < d; i++) {
res += (int32_t)x[i] * y[i];
}
return res;
}

// trust the compiler to unroll this properly
int32_t
ivec_L2sqr_avx(const int8_t* x, const int8_t* y, size_t d) {
size_t i;
int32_t res = 0;
for (i = 0; i < d; i++) {
const int32_t tmp = (int32_t)x[i] - (int32_t)y[i];
res += tmp * tmp;
}
return res;
}

} // namespace faiss
#endif
6 changes: 6 additions & 0 deletions src/simd/distances_avx.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ void
fvec_L2sqr_batch_4_avx(const float* x, const float* y0, const float* y1, const float* y2, const float* y3,
const size_t d, float& dis0, float& dis1, float& dis2, float& dis3);

int32_t
ivec_inner_product_avx(const int8_t* x, const int8_t* y, size_t d);

int32_t
ivec_L2sqr_avx(const int8_t* x, const int8_t* y, size_t d);

} // namespace faiss

#endif /* DISTANCES_AVX_H */
23 changes: 23 additions & 0 deletions src/simd/distances_avx512.cc
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,29 @@ fvec_L2sqr_batch_4_avx512(const float* x, const float* y0, const float* y1, cons
}
FAISS_PRAGMA_IMPRECISE_FUNCTION_END

// trust the compiler to unroll this properly
int32_t
ivec_inner_product_avx512(const int8_t* x, const int8_t* y, size_t d) {
size_t i;
int32_t res = 0;
for (i = 0; i < d; i++) {
res += (int32_t)x[i] * y[i];
}
return res;
}

// trust the compiler to unroll this properly
int32_t
ivec_L2sqr_avx512(const int8_t* x, const int8_t* y, size_t d) {
size_t i;
int32_t res = 0;
for (i = 0; i < d; i++) {
const int32_t tmp = (int32_t)x[i] - (int32_t)y[i];
res += tmp * tmp;
}
return res;
}

} // namespace faiss

#endif
6 changes: 6 additions & 0 deletions src/simd/distances_avx512.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ void
fvec_L2sqr_batch_4_avx512(const float* x, const float* y0, const float* y1, const float* y2, const float* y3,
const size_t d, float& dis0, float& dis1, float& dis2, float& dis3);

int32_t
ivec_inner_product_avx512(const int8_t* x, const int8_t* y, size_t d);

int32_t
ivec_L2sqr_avx512(const int8_t* x, const int8_t* y, size_t d);

} // namespace faiss

#endif /* DISTANCES_AVX512_H */
23 changes: 23 additions & 0 deletions src/simd/distances_neon.cc
Original file line number Diff line number Diff line change
Expand Up @@ -635,5 +635,28 @@ fvec_madd_and_argmin_neon(size_t n, const float* a, float bf, const float* b, fl
return ans_;
}

// trust the compiler to unroll this properly
int32_t
ivec_inner_product_neon(const int8_t* x, const int8_t* y, size_t d) {
size_t i;
int32_t res = 0;
for (i = 0; i < d; i++) {
res += (int32_t)x[i] * y[i];
}
return res;
}

// trust the compiler to unroll this properly
int32_t
ivec_L2sqr_neon(const int8_t* x, const int8_t* y, size_t d) {
size_t i;
int32_t res = 0;
for (i = 0; i < d; i++) {
const int32_t tmp = (int32_t)x[i] - (int32_t)y[i];
res += tmp * tmp;
}
return res;
}

} // namespace faiss
#endif
7 changes: 7 additions & 0 deletions src/simd/distances_neon.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#ifndef DISTANCES_NEON_H
#define DISTANCES_NEON_H

#include <cstdint>
#include <cstdio>

namespace faiss {
Expand Down Expand Up @@ -50,6 +51,12 @@ fvec_madd_neon(size_t n, const float* a, float bf, const float* b, float* c);
int
fvec_madd_and_argmin_neon(size_t n, const float* a, float bf, const float* b, float* c);

int32_t
ivec_inner_product_neon(const int8_t* x, const int8_t* y, size_t d);

int32_t
ivec_L2sqr_neon(const int8_t* x, const int8_t* y, size_t d);

} // namespace faiss

#endif /* DISTANCES_NEON_H */
21 changes: 21 additions & 0 deletions src/simd/distances_ref.cc
Original file line number Diff line number Diff line change
Expand Up @@ -212,4 +212,25 @@ fvec_L2sqr_batch_4_ref(const float* x, const float* y0, const float* y1, const f
dis3 = d3;
}

int32_t
ivec_inner_product_ref(const int8_t* x, const int8_t* y, size_t d) {
size_t i;
int32_t res = 0;
for (i = 0; i < d; i++) {
res += (int32_t)x[i] * y[i];
}
return res;
}

int32_t
ivec_L2sqr_ref(const int8_t* x, const int8_t* y, size_t d) {
size_t i;
int32_t res = 0;
for (i = 0; i < d; i++) {
const int32_t tmp = (int32_t)x[i] - (int32_t)y[i];
res += tmp * tmp;
}
return res;
}

} // namespace faiss
7 changes: 7 additions & 0 deletions src/simd/distances_ref.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#ifndef DISTANCES_REF_H
#define DISTANCES_REF_H

#include <cstdint>
#include <cstdio>

namespace faiss {
Expand Down Expand Up @@ -71,6 +72,12 @@ void
fvec_L2sqr_batch_4_ref(const float* x, const float* y0, const float* y1, const float* y2, const float* y3,
const size_t d, float& dis0, float& dis1, float& dis2, float& dis3);

int32_t
ivec_inner_product_ref(const int8_t* x, const int8_t* y, size_t d);

int32_t
ivec_L2sqr_ref(const int8_t* x, const int8_t* y, size_t d);

} // namespace faiss

#endif /* DISTANCES_REF_H */
23 changes: 23 additions & 0 deletions src/simd/distances_sse.cc
Original file line number Diff line number Diff line change
Expand Up @@ -377,5 +377,28 @@ fvec_madd_and_argmin_sse(size_t n, const float* a, float bf, const float* b, flo
return _mm_cvtsi128_si32(imin4);
}

// trust the compiler to unroll this properly
int32_t
ivec_inner_product_sse(const int8_t* x, const int8_t* y, size_t d) {
size_t i;
int32_t res = 0;
for (i = 0; i < d; i++) {
res += (int32_t)x[i] * y[i];
}
return res;
}

// trust the compiler to unroll this properly
int32_t
ivec_L2sqr_sse(const int8_t* x, const int8_t* y, size_t d) {
size_t i;
int32_t res = 0;
for (i = 0; i < d; i++) {
const int32_t tmp = (int32_t)x[i] - (int32_t)y[i];
res += tmp * tmp;
}
return res;
}

} // namespace faiss
#endif
7 changes: 7 additions & 0 deletions src/simd/distances_sse.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#ifndef DISTANCES_SSE_H
#define DISTANCES_SSE_H

#include <cstdint>
#include <cstdio>
namespace faiss {

Expand Down Expand Up @@ -46,6 +47,12 @@ fvec_madd_sse(size_t n, const float* a, float bf, const float* b, float* c);
int
fvec_madd_and_argmin_sse(size_t n, const float* a, float bf, const float* b, float* c);

int32_t
ivec_inner_product_sse(const int8_t* x, const int8_t* y, size_t d);

int32_t
ivec_L2sqr_sse(const int8_t* x, const int8_t* y, size_t d);

} // namespace faiss

#endif /* DISTANCES_SSE_H */
Loading

0 comments on commit 218e029

Please sign in to comment.