From 218e029a7575b9617871ee847e8179d7adfe23f3 Mon Sep 17 00:00:00 2001 From: zh Wang Date: Thu, 14 Dec 2023 14:31:18 +0800 Subject: [PATCH] Add sq for hnsw Signed-off-by: zh Wang --- include/knowhere/comp/index_param.h | 2 + src/common/config.cc | 1 + src/index/hnsw/hnsw.cc | 25 ++- src/index/hnsw/hnsw_config.h | 8 + src/simd/distances_avx.cc | 23 +++ src/simd/distances_avx.h | 6 + src/simd/distances_avx512.cc | 23 +++ src/simd/distances_avx512.h | 6 + src/simd/distances_neon.cc | 23 +++ src/simd/distances_neon.h | 7 + src/simd/distances_ref.cc | 21 ++ src/simd/distances_ref.h | 7 + src/simd/distances_sse.cc | 23 +++ src/simd/distances_sse.h | 7 + src/simd/hook.cc | 21 ++ src/simd/hook.h | 4 + tests/ut/test_search.cc | 27 ++- thirdparty/hnswlib/hnswlib/hnswalg.h | 223 ++++++++++++++++++---- thirdparty/hnswlib/hnswlib/hnswlib.h | 5 + thirdparty/hnswlib/hnswlib/space_cosine.h | 12 ++ thirdparty/hnswlib/hnswlib/space_ip.h | 12 ++ thirdparty/hnswlib/hnswlib/space_l2.h | 12 ++ 22 files changed, 447 insertions(+), 51 deletions(-) diff --git a/include/knowhere/comp/index_param.h b/include/knowhere/comp/index_param.h index 82f748145..2e1379681 100644 --- a/include/knowhere/comp/index_param.h +++ b/include/knowhere/comp/index_param.h @@ -42,6 +42,7 @@ constexpr const char* INDEX_RAFT_IVFPQ = "GPU_RAFT_IVF_PQ"; constexpr const char* INDEX_RAFT_CAGRA = "GPU_RAFT_CAGRA"; constexpr const char* INDEX_HNSW = "HNSW"; +constexpr const char* INDEX_HNSW_SQ = "HNSW_SQ"; constexpr const char* INDEX_DISKANN = "DISKANN"; } // namespace IndexEnum @@ -117,6 +118,7 @@ constexpr const char* HNSW_M = "M"; constexpr const char* EF = "ef"; constexpr const char* SEED_EF = "seed_ef"; constexpr const char* OVERVIEW_LEVELS = "overview_levels"; +constexpr const char* USE_REFINE = "use_refine"; } // namespace indexparam using MetricType = std::string; diff --git a/src/common/config.cc b/src/common/config.cc index 21444d9b3..3cfe54715 100644 --- a/src/common/config.cc +++ b/src/common/config.cc @@ -25,6 +25,7 @@ static const std::unordered_set ext_legal_json_keys = {"metric_type "M", // HNSW param "efConstruction", // HNSW param "ef", // HNSW param + "use_refine", // HNSW_SQ param "seed_ef", // HNSW iterator param "level", "index_type", diff --git a/src/index/hnsw/hnsw.cc b/src/index/hnsw/hnsw.cc index d82dcd581..dffeb2dda 100644 --- a/src/index/hnsw/hnsw.cc +++ b/src/index/hnsw/hnsw.cc @@ -27,6 +27,7 @@ #include "knowhere/utils.h" namespace knowhere { +template class HnswIndexNode : public IndexNode { public: HnswIndexNode(const int32_t& /*version*/, const Object& object) : index_(nullptr) { @@ -53,8 +54,12 @@ class HnswIndexNode : public IndexNode { LOG_KNOWHERE_WARNING_ << "metric type not support in hnsw: " << hnsw_cfg.metric_type.value(); return Status::invalid_metric_type; } - auto index = new (std::nothrow) - hnswlib::HierarchicalNSW(space, rows, hnsw_cfg.M.value(), hnsw_cfg.efConstruction.value()); + bool use_sq_refine = false; + if constexpr (USE_SQ) { + use_sq_refine = static_cast(cfg).use_refine.value(); + } + auto index = new (std::nothrow) hnswlib::HierarchicalNSW( + space, rows, hnsw_cfg.M.value(), hnsw_cfg.efConstruction.value(), USE_SQ, use_sq_refine); if (index == nullptr) { LOG_KNOWHERE_WARNING_ << "memory malloc error."; return Status::malloc_error; @@ -64,6 +69,9 @@ class HnswIndexNode : public IndexNode { LOG_KNOWHERE_WARNING_ << "index not empty, deleted old index"; } this->index_ = index; + if constexpr (USE_SQ) { + this->index_->trainSQuant((const float*)dataset.GetTensor(), rows); + } return Status::success; } @@ -352,7 +360,7 @@ class HnswIndexNode : public IndexNode { bool HasRawData(const std::string& metric_type) const override { - return true; + return !USE_SQ || index_->sq_refine_enabled_; } expected @@ -418,7 +426,7 @@ class HnswIndexNode : public IndexNode { hnswlib::SpaceInterface* space = nullptr; index_ = new (std::nothrow) hnswlib::HierarchicalNSW(space); - index_->loadIndex(reader); + index_->loadIndex(reader, USE_SQ); LOG_KNOWHERE_INFO_ << "Loaded HNSW index. #points num:" << index_->max_elements_ << " #M:" << index_->M_ << " #max level:" << index_->maxlevel_ << " #ef_construction:" << index_->ef_construction_ @@ -438,7 +446,7 @@ class HnswIndexNode : public IndexNode { try { hnswlib::SpaceInterface* space = nullptr; index_ = new (std::nothrow) hnswlib::HierarchicalNSW(space); - index_->loadIndex(filename, config); + index_->loadIndex(filename, config, USE_SQ); } catch (std::exception& e) { LOG_KNOWHERE_WARNING_ << "hnsw inner error: " << e.what(); return Status::hnsw_inner_error; @@ -477,7 +485,7 @@ class HnswIndexNode : public IndexNode { std::string Type() const override { - return knowhere::IndexEnum::INDEX_HNSW; + return USE_SQ ? knowhere::IndexEnum::INDEX_HNSW_SQ : knowhere::IndexEnum::INDEX_HNSW; } ~HnswIndexNode() override { @@ -529,7 +537,10 @@ class HnswIndexNode : public IndexNode { }; KNOWHERE_REGISTER_GLOBAL(HNSW, [](const int32_t& version, const Object& object) { - return Index::Create(version, object); + return Index>::Create(version, object); +}); +KNOWHERE_REGISTER_GLOBAL(HNSW_SQ, [](const int32_t& version, const Object& object) { + return Index>::Create(version, object); }); } // namespace knowhere diff --git a/src/index/hnsw/hnsw_config.h b/src/index/hnsw/hnsw_config.h index f69e612db..48b64ac02 100644 --- a/src/index/hnsw/hnsw_config.h +++ b/src/index/hnsw/hnsw_config.h @@ -85,6 +85,14 @@ class HnswConfig : public BaseConfig { } }; +class HnswSQConfig : public HnswConfig { + public: + CFG_BOOL use_refine; + KNOHWERE_DECLARE_CONFIG(HnswSQConfig) { + KNOWHERE_CONFIG_DECLARE_FIELD(use_refine).description("hnswsq use refine").set_default(false).for_train(); + } +}; + } // namespace knowhere #endif /* HNSW_CONFIG_H */ diff --git a/src/simd/distances_avx.cc b/src/simd/distances_avx.cc index 1e3ec0f04..bfb85f5f2 100644 --- a/src/simd/distances_avx.cc +++ b/src/simd/distances_avx.cc @@ -215,5 +215,28 @@ fvec_L2sqr_batch_4_avx(const float* x, const float* y0, const float* y1, const f } FAISS_PRAGMA_IMPRECISE_FUNCTION_END +// trust the compiler to unroll this properly +int32_t +ivec_inner_product_avx(const int8_t* x, const int8_t* y, size_t d) { + size_t i; + int32_t res = 0; + for (i = 0; i < d; i++) { + res += (int32_t)x[i] * y[i]; + } + return res; +} + +// trust the compiler to unroll this properly +int32_t +ivec_L2sqr_avx(const int8_t* x, const int8_t* y, size_t d) { + size_t i; + int32_t res = 0; + for (i = 0; i < d; i++) { + const int32_t tmp = (int32_t)x[i] - (int32_t)y[i]; + res += tmp * tmp; + } + return res; +} + } // namespace faiss #endif diff --git a/src/simd/distances_avx.h b/src/simd/distances_avx.h index c89b964d6..f786dc693 100644 --- a/src/simd/distances_avx.h +++ b/src/simd/distances_avx.h @@ -44,6 +44,12 @@ void fvec_L2sqr_batch_4_avx(const float* x, const float* y0, const float* y1, const float* y2, const float* y3, const size_t d, float& dis0, float& dis1, float& dis2, float& dis3); +int32_t +ivec_inner_product_avx(const int8_t* x, const int8_t* y, size_t d); + +int32_t +ivec_L2sqr_avx(const int8_t* x, const int8_t* y, size_t d); + } // namespace faiss #endif /* DISTANCES_AVX_H */ diff --git a/src/simd/distances_avx512.cc b/src/simd/distances_avx512.cc index bbf1ce80c..50cfbd153 100644 --- a/src/simd/distances_avx512.cc +++ b/src/simd/distances_avx512.cc @@ -242,6 +242,29 @@ fvec_L2sqr_batch_4_avx512(const float* x, const float* y0, const float* y1, cons } FAISS_PRAGMA_IMPRECISE_FUNCTION_END +// trust the compiler to unroll this properly +int32_t +ivec_inner_product_avx512(const int8_t* x, const int8_t* y, size_t d) { + size_t i; + int32_t res = 0; + for (i = 0; i < d; i++) { + res += (int32_t)x[i] * y[i]; + } + return res; +} + +// trust the compiler to unroll this properly +int32_t +ivec_L2sqr_avx512(const int8_t* x, const int8_t* y, size_t d) { + size_t i; + int32_t res = 0; + for (i = 0; i < d; i++) { + const int32_t tmp = (int32_t)x[i] - (int32_t)y[i]; + res += tmp * tmp; + } + return res; +} + } // namespace faiss #endif diff --git a/src/simd/distances_avx512.h b/src/simd/distances_avx512.h index bfde69af1..79654319f 100644 --- a/src/simd/distances_avx512.h +++ b/src/simd/distances_avx512.h @@ -43,6 +43,12 @@ void fvec_L2sqr_batch_4_avx512(const float* x, const float* y0, const float* y1, const float* y2, const float* y3, const size_t d, float& dis0, float& dis1, float& dis2, float& dis3); +int32_t +ivec_inner_product_avx512(const int8_t* x, const int8_t* y, size_t d); + +int32_t +ivec_L2sqr_avx512(const int8_t* x, const int8_t* y, size_t d); + } // namespace faiss #endif /* DISTANCES_AVX512_H */ diff --git a/src/simd/distances_neon.cc b/src/simd/distances_neon.cc index cb304927f..eb90c9ae7 100644 --- a/src/simd/distances_neon.cc +++ b/src/simd/distances_neon.cc @@ -635,5 +635,28 @@ fvec_madd_and_argmin_neon(size_t n, const float* a, float bf, const float* b, fl return ans_; } +// trust the compiler to unroll this properly +int32_t +ivec_inner_product_neon(const int8_t* x, const int8_t* y, size_t d) { + size_t i; + int32_t res = 0; + for (i = 0; i < d; i++) { + res += (int32_t)x[i] * y[i]; + } + return res; +} + +// trust the compiler to unroll this properly +int32_t +ivec_L2sqr_neon(const int8_t* x, const int8_t* y, size_t d) { + size_t i; + int32_t res = 0; + for (i = 0; i < d; i++) { + const int32_t tmp = (int32_t)x[i] - (int32_t)y[i]; + res += tmp * tmp; + } + return res; +} + } // namespace faiss #endif diff --git a/src/simd/distances_neon.h b/src/simd/distances_neon.h index fdfc79ad9..c3150d161 100644 --- a/src/simd/distances_neon.h +++ b/src/simd/distances_neon.h @@ -12,6 +12,7 @@ #ifndef DISTANCES_NEON_H #define DISTANCES_NEON_H +#include #include namespace faiss { @@ -50,6 +51,12 @@ fvec_madd_neon(size_t n, const float* a, float bf, const float* b, float* c); int fvec_madd_and_argmin_neon(size_t n, const float* a, float bf, const float* b, float* c); +int32_t +ivec_inner_product_neon(const int8_t* x, const int8_t* y, size_t d); + +int32_t +ivec_L2sqr_neon(const int8_t* x, const int8_t* y, size_t d); + } // namespace faiss #endif /* DISTANCES_NEON_H */ diff --git a/src/simd/distances_ref.cc b/src/simd/distances_ref.cc index 84fb049a9..2ff6e82f6 100644 --- a/src/simd/distances_ref.cc +++ b/src/simd/distances_ref.cc @@ -212,4 +212,25 @@ fvec_L2sqr_batch_4_ref(const float* x, const float* y0, const float* y1, const f dis3 = d3; } +int32_t +ivec_inner_product_ref(const int8_t* x, const int8_t* y, size_t d) { + size_t i; + int32_t res = 0; + for (i = 0; i < d; i++) { + res += (int32_t)x[i] * y[i]; + } + return res; +} + +int32_t +ivec_L2sqr_ref(const int8_t* x, const int8_t* y, size_t d) { + size_t i; + int32_t res = 0; + for (i = 0; i < d; i++) { + const int32_t tmp = (int32_t)x[i] - (int32_t)y[i]; + res += tmp * tmp; + } + return res; +} + } // namespace faiss diff --git a/src/simd/distances_ref.h b/src/simd/distances_ref.h index fefb04999..2ca812051 100644 --- a/src/simd/distances_ref.h +++ b/src/simd/distances_ref.h @@ -1,6 +1,7 @@ #ifndef DISTANCES_REF_H #define DISTANCES_REF_H +#include #include namespace faiss { @@ -71,6 +72,12 @@ void fvec_L2sqr_batch_4_ref(const float* x, const float* y0, const float* y1, const float* y2, const float* y3, const size_t d, float& dis0, float& dis1, float& dis2, float& dis3); +int32_t +ivec_inner_product_ref(const int8_t* x, const int8_t* y, size_t d); + +int32_t +ivec_L2sqr_ref(const int8_t* x, const int8_t* y, size_t d); + } // namespace faiss #endif /* DISTANCES_REF_H */ diff --git a/src/simd/distances_sse.cc b/src/simd/distances_sse.cc index 30e8fa3d1..600ec931b 100644 --- a/src/simd/distances_sse.cc +++ b/src/simd/distances_sse.cc @@ -377,5 +377,28 @@ fvec_madd_and_argmin_sse(size_t n, const float* a, float bf, const float* b, flo return _mm_cvtsi128_si32(imin4); } +// trust the compiler to unroll this properly +int32_t +ivec_inner_product_sse(const int8_t* x, const int8_t* y, size_t d) { + size_t i; + int32_t res = 0; + for (i = 0; i < d; i++) { + res += (int32_t)x[i] * y[i]; + } + return res; +} + +// trust the compiler to unroll this properly +int32_t +ivec_L2sqr_sse(const int8_t* x, const int8_t* y, size_t d) { + size_t i; + int32_t res = 0; + for (i = 0; i < d; i++) { + const int32_t tmp = (int32_t)x[i] - (int32_t)y[i]; + res += tmp * tmp; + } + return res; +} + } // namespace faiss #endif diff --git a/src/simd/distances_sse.h b/src/simd/distances_sse.h index 48f2f7653..a8089c119 100644 --- a/src/simd/distances_sse.h +++ b/src/simd/distances_sse.h @@ -12,6 +12,7 @@ #ifndef DISTANCES_SSE_H #define DISTANCES_SSE_H +#include #include namespace faiss { @@ -46,6 +47,12 @@ fvec_madd_sse(size_t n, const float* a, float bf, const float* b, float* c); int fvec_madd_and_argmin_sse(size_t n, const float* a, float bf, const float* b, float* c); +int32_t +ivec_inner_product_sse(const int8_t* x, const int8_t* y, size_t d); + +int32_t +ivec_L2sqr_sse(const int8_t* x, const int8_t* y, size_t d); + } // namespace faiss #endif /* DISTANCES_SSE_H */ diff --git a/src/simd/hook.cc b/src/simd/hook.cc index 20f887406..5e40b3f05 100644 --- a/src/simd/hook.cc +++ b/src/simd/hook.cc @@ -59,6 +59,9 @@ decltype(fvec_L2sqr_ny_transposed) fvec_L2sqr_ny_transposed = fvec_L2sqr_ny_tran decltype(fvec_inner_product_batch_4) fvec_inner_product_batch_4 = fvec_inner_product_batch_4_ref; decltype(fvec_L2sqr_batch_4) fvec_L2sqr_batch_4 = fvec_L2sqr_batch_4_ref; +decltype(ivec_inner_product) ivec_inner_product = ivec_inner_product_ref; +decltype(ivec_L2sqr) ivec_L2sqr = ivec_L2sqr_ref; + #if defined(__x86_64__) bool cpu_support_avx512() { @@ -99,6 +102,9 @@ fvec_hook(std::string& simd_type) { fvec_inner_product_batch_4 = fvec_inner_product_batch_4_avx512; fvec_L2sqr_batch_4 = fvec_L2sqr_batch_4_avx512; + ivec_inner_product = ivec_inner_product_avx512; + ivec_L2sqr = ivec_L2sqr_avx512; + simd_type = "AVX512"; support_pq_fast_scan = true; } else if (use_avx2 && cpu_support_avx2()) { @@ -116,6 +122,9 @@ fvec_hook(std::string& simd_type) { fvec_inner_product_batch_4 = fvec_inner_product_batch_4_avx; fvec_L2sqr_batch_4 = fvec_L2sqr_batch_4_avx; + ivec_inner_product = ivec_inner_product_avx; + ivec_L2sqr = ivec_L2sqr_avx; + simd_type = "AVX2"; support_pq_fast_scan = true; } else if (use_sse4_2 && cpu_support_sse4_2()) { @@ -133,6 +142,9 @@ fvec_hook(std::string& simd_type) { fvec_inner_product_batch_4 = fvec_inner_product_batch_4_ref; fvec_L2sqr_batch_4 = fvec_L2sqr_batch_4_ref; + ivec_inner_product = ivec_inner_product_sse; + ivec_L2sqr = ivec_L2sqr_sse; + simd_type = "SSE4_2"; support_pq_fast_scan = false; } else { @@ -150,6 +162,9 @@ fvec_hook(std::string& simd_type) { fvec_inner_product_batch_4 = fvec_inner_product_batch_4_ref; fvec_L2sqr_batch_4 = fvec_L2sqr_batch_4_ref; + ivec_inner_product = ivec_inner_product_ref; + ivec_L2sqr = ivec_L2sqr_ref; + simd_type = "GENERIC"; support_pq_fast_scan = false; } @@ -167,6 +182,9 @@ fvec_hook(std::string& simd_type) { fvec_madd = fvec_madd_neon; fvec_madd_and_argmin = fvec_madd_and_argmin_neon; + ivec_inner_product = ivec_inner_product_neon; + ivec_L2sqr = ivec_L2sqr_neon; + simd_type = "NEON"; support_pq_fast_scan = true; @@ -185,6 +203,9 @@ fvec_hook(std::string& simd_type) { fvec_madd = fvec_madd_ref; fvec_madd_and_argmin = fvec_madd_and_argmin_ref; + ivec_inner_product = ivec_inner_product_ref; + ivec_L2sqr = ivec_L2sqr_ref; + simd_type = "GENERIC"; support_pq_fast_scan = false; #endif diff --git a/src/simd/hook.h b/src/simd/hook.h index 07763ea8d..69d84e9b4 100644 --- a/src/simd/hook.h +++ b/src/simd/hook.h @@ -70,6 +70,10 @@ extern void (*fvec_inner_product_batch_4)(const float*, const float*, const floa extern void (*fvec_L2sqr_batch_4)(const float*, const float*, const float*, const float*, const float*, const size_t, float&, float&, float&, float&); +extern int32_t (*ivec_inner_product)(const int8_t*, const int8_t*, size_t); + +extern int32_t (*ivec_L2sqr)(const int8_t*, const int8_t*, size_t); + #if defined(__x86_64__) extern bool use_avx512; extern bool use_avx2; diff --git a/tests/ut/test_search.cc b/tests/ut/test_search.cc index 4d177376d..218ac1ccd 100644 --- a/tests/ut/test_search.cc +++ b/tests/ut/test_search.cc @@ -24,7 +24,7 @@ namespace { constexpr float kKnnRecallThreshold = 0.6f; -constexpr float kBruteForceRecallThreshold = 0.99f; +constexpr float kBruteForceRecallThreshold = 0.95f; } // namespace TEST_CASE("Test Mem Index With Float Vector", "[float metrics]") { @@ -93,6 +93,18 @@ TEST_CASE("Test Mem Index With Float Vector", "[float metrics]") { return json; }; + auto hnsw_sq_gen = [hnsw_gen]() { + knowhere::Json json = hnsw_gen(); + json[knowhere::indexparam::USE_REFINE] = false; + return json; + }; + + auto hnsw_sq_with_refine_gen = [hnsw_gen]() { + knowhere::Json json = hnsw_gen(); + json[knowhere::indexparam::USE_REFINE] = true; + return json; + }; + const auto train_ds = GenDataSet(nb, dim); const auto query_ds = GenDataSet(nq, dim); @@ -113,6 +125,8 @@ TEST_CASE("Test Mem Index With Float Vector", "[float metrics]") { make_tuple(knowhere::IndexEnum::INDEX_FAISS_SCANN, scann_gen), make_tuple(knowhere::IndexEnum::INDEX_FAISS_SCANN, scann_gen2), make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen), + make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ, hnsw_sq_gen), + make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ, hnsw_sq_with_refine_gen), })); auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); @@ -138,7 +152,7 @@ TEST_CASE("Test Mem Index With Float Vector", "[float metrics]") { if (metric == knowhere::metric::COSINE) { if (name != knowhere::IndexEnum::INDEX_FAISS_IVFSQ8 && name != knowhere::IndexEnum::INDEX_FAISS_IVFPQ && - !scann_without_raw_data) { + name != knowhere::IndexEnum::INDEX_HNSW_SQ && !scann_without_raw_data) { REQUIRE(CheckDistanceInScope(*results.value(), topk, -1.00001, 1.00001)); } } @@ -155,6 +169,8 @@ TEST_CASE("Test Mem Index With Float Vector", "[float metrics]") { make_tuple(knowhere::IndexEnum::INDEX_FAISS_SCANN, scann_gen), make_tuple(knowhere::IndexEnum::INDEX_FAISS_SCANN, scann_gen2), make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen), + make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ, hnsw_sq_gen), + make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ, hnsw_sq_with_refine_gen), })); auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); @@ -181,7 +197,7 @@ TEST_CASE("Test Mem Index With Float Vector", "[float metrics]") { if (metric == knowhere::metric::COSINE) { if (name != knowhere::IndexEnum::INDEX_FAISS_IVFSQ8 && name != knowhere::IndexEnum::INDEX_FAISS_IVFPQ && - !scann_without_raw_data) { + name != knowhere::IndexEnum::INDEX_HNSW_SQ && !scann_without_raw_data) { REQUIRE(CheckDistanceInScope(*results.value(), -1.00001, 1.00001)); } } @@ -216,6 +232,9 @@ TEST_CASE("Test Mem Index With Float Vector", "[float metrics]") { using std::make_tuple; auto [name, gen, threshold] = GENERATE_REF(table, float>({ make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen, hnswlib::kHnswSearchKnnBFFilterThreshold), + make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ, hnsw_sq_gen, hnswlib::kHnswSearchKnnBFFilterThreshold), + make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ, hnsw_sq_with_refine_gen, + hnswlib::kHnswSearchKnnBFFilterThreshold), })); auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); @@ -255,6 +274,8 @@ TEST_CASE("Test Mem Index With Float Vector", "[float metrics]") { make_tuple(knowhere::IndexEnum::INDEX_FAISS_SCANN, scann_gen), make_tuple(knowhere::IndexEnum::INDEX_FAISS_SCANN, scann_gen2), make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen), + make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ, hnsw_sq_gen), + make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ, hnsw_sq_with_refine_gen), })); auto idx = knowhere::IndexFactory::Instance().Create(name, version); diff --git a/thirdparty/hnswlib/hnswlib/hnswalg.h b/thirdparty/hnswlib/hnswlib/hnswalg.h index 203a341c1..ad9200430 100644 --- a/thirdparty/hnswlib/hnswlib/hnswalg.h +++ b/thirdparty/hnswlib/hnswlib/hnswalg.h @@ -27,8 +27,8 @@ #include "hnswlib.h" #include "io/memory_io.h" #include "knowhere/config.h" -#include "knowhere/utils.h" #include "knowhere/heap.h" +#include "knowhere/utils.h" #include "neighbor.h" #include "visited_list_pool.h" @@ -67,11 +67,13 @@ class HierarchicalNSW : public AlgorithmInterface { } HierarchicalNSW(SpaceInterface* s, size_t max_elements, size_t M = 16, size_t ef_construction = 200, - size_t random_seed = 100) + bool use_sq = false, bool use_sq_refine = false, size_t random_seed = 100) : link_list_locks_(max_elements), link_list_update_locks_(max_update_element_locks), element_levels_(max_elements) { space_ = s; + sq_enabled_ = use_sq; + sq_refine_enabled_ = use_sq_refine; if (auto x = dynamic_cast(s)) { metric_type_ = Metric::L2; } else if (auto x = dynamic_cast(s)) { @@ -91,6 +93,9 @@ class HierarchicalNSW : public AlgorithmInterface { num_deleted_ = 0; data_size_ = s->get_data_size(); fstdistfunc_ = s->get_dist_func(); + if (sq_enabled_) { + fstdistfunc_sq_ = space_->get_dist_func_sq(); + } dist_func_param_ = s->get_dist_func_param(); M_ = M; maxM_ = M_; @@ -102,8 +107,20 @@ class HierarchicalNSW : public AlgorithmInterface { update_probability_generator_.seed(random_seed + 1); size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint); - size_data_per_element_ = size_links_level0_ + data_size_; // + sizeof(labeltype); + size_data_per_element_ = size_links_level0_; // + sizeof(labeltype); + if (!sq_enabled_ || sq_refine_enabled_) { + size_data_per_element_ += data_size_; + } + if (sq_enabled_) { + size_data_per_element_ += *(size_t*)dist_func_param_ * sizeof(int8_t); + } offsetData_ = size_links_level0_; + if (use_sq) { + offsetSQData_ = offsetData_; + if (sq_refine_enabled_) { + offsetSQData_ += data_size_; + } + } // label_offset_ = size_links_level0_ + data_size_; offsetLevel0_ = 0; @@ -190,7 +207,7 @@ class HierarchicalNSW : public AlgorithmInterface { tableint enterpoint_node_; size_t size_links_level0_; - size_t offsetData_, offsetLevel0_; + size_t offsetData_, offsetSQData_, offsetLevel0_; char* data_level0_memory_; float* data_norm_l2_; // vector's l2 norm @@ -201,6 +218,7 @@ class HierarchicalNSW : public AlgorithmInterface { size_t label_offset_; DISTFUNC fstdistfunc_; + DISTFUNC fstdistfunc_sq_; void* dist_func_param_; std::default_random_engine level_generator_; @@ -210,8 +228,53 @@ class HierarchicalNSW : public AlgorithmInterface { char* map_; size_t map_size_; + bool sq_enabled_ = false; + bool sq_refine_enabled_ = false; + float alpha_ = 0.0f; + mutable knowhere::lru_cache lru_cache; + void + trainSQuant(const float* train_data, size_t ntrain) { + size_t dim = *(size_t*)dist_func_param_; + for (size_t i = 0; i < ntrain; ++i) { + const float* vec = train_data + i * dim; + std::unique_ptr vec_norm = nullptr; + if (metric_type_ == Metric::COSINE) { + vec_norm = knowhere::CopyAndNormalizeVecs(vec, 1, dim); + vec = vec_norm.get(); + } + for (size_t j = 0; j < dim; ++j) { + alpha_ = std::max(alpha_, std::abs(vec[j])); + } + } + } + + void + encodeSQuant(const float* from, int8_t* to) const { + size_t dim = *(size_t*)dist_func_param_; + std::unique_ptr data_norm = nullptr; + if (metric_type_ == Metric::COSINE) { + data_norm = knowhere::CopyAndNormalizeVecs(from, 1, dim); + from = data_norm.get(); + } + for (size_t i = 0; i < dim; ++i) { + float x = from[i] / alpha_; + if (x > 1.0f) { + x = 1.0f; + } + if (x < -1.0f) { + x = -1.0f; + } + to[i] = std::round(x * 127.0f); + } + } + + inline char* + getSQDataByInternalId(tableint internal_id) const { + return (data_level0_memory_ + internal_id * size_data_per_element_ + offsetSQData_); + } + inline char* getDataByInternalId(tableint internal_id) const { return (data_level0_memory_ + internal_id * size_data_per_element_ + offsetData_); @@ -226,15 +289,34 @@ class HierarchicalNSW : public AlgorithmInterface { inline dist_t calcDistance(const tableint id1, const tableint id2) const { - dist_t dist = fstdistfunc_(getDataByInternalId(id1), getDataByInternalId(id2), dist_func_param_); - if (metric_type_ == Metric::COSINE) { - dist /= (data_norm_l2_[id1] * data_norm_l2_[id2]); + if (sq_enabled_) { + return fstdistfunc_sq_(getSQDataByInternalId(id1), getSQDataByInternalId(id2), dist_func_param_) * alpha_ * + alpha_ / 127.0f / 127.0f; + } else { + dist_t dist = fstdistfunc_(getDataByInternalId(id1), getDataByInternalId(id2), dist_func_param_); + if (metric_type_ == Metric::COSINE) { + dist /= (data_norm_l2_[id1] * data_norm_l2_[id2]); + } + return dist; } - return dist; } inline dist_t calcDistance(const void* vec, const tableint id) const { + if (sq_enabled_) { + return fstdistfunc_sq_(vec, getSQDataByInternalId(id), dist_func_param_) * alpha_ * alpha_ / 127.0f / + 127.0f; + } else { + dist_t dist = fstdistfunc_(vec, getDataByInternalId(id), dist_func_param_); + if (metric_type_ == Metric::COSINE) { + dist /= data_norm_l2_[id]; + } + return dist; + } + } + + inline dist_t + calcRefineDistance(const void* vec, const tableint id) const { dist_t dist = fstdistfunc_(vec, getDataByInternalId(id), dist_func_param_); if (metric_type_ == Metric::COSINE) { dist /= data_norm_l2_[id]; @@ -242,6 +324,17 @@ class HierarchicalNSW : public AlgorithmInterface { return dist; } + void + prefetchData(const tableint id) const { +#if defined(USE_PREFETCH) + if (sq_enabled_) { + _mm_prefetch(getSQDataByInternalId(id), _MM_HINT_T0); + } else { + _mm_prefetch(getDataByInternalId(id), _MM_HINT_T0); + } +#endif + } + std::priority_queue, std::vector>, CompareByFirst> searchBaseLayer(tableint ep_id, tableint cur_c, int layer) { auto& visited = visited_list_pool_->getFreeVisitedList(); @@ -278,11 +371,9 @@ class HierarchicalNSW : public AlgorithmInterface { } size_t size = getListCount((linklistsizeint*)data); tableint* datal = (tableint*)(data + 1); -#if defined(USE_PREFETCH) for (size_t j = 0; j < size; ++j) { - _mm_prefetch(getDataByInternalId(datal[j]), _MM_HINT_T0); + prefetchData(datal[j]); } -#endif for (size_t j = 0; j < size; j++) { tableint candidate_id = *(datal + j); // if (candidate_id == 0) continue; @@ -294,9 +385,7 @@ class HierarchicalNSW : public AlgorithmInterface { dist_t dist1 = calcDistance(cur_c, candidate_id); if (top_candidates.size() < ef_construction_ || lowerBound > dist1) { candidateSet.emplace(-dist1, candidate_id); -#if defined(USE_PREFETCH) - _mm_prefetch(getDataByInternalId(candidateSet.top().second), _MM_HINT_T0); -#endif + prefetchData(candidateSet.top().second); top_candidates.emplace(dist1, candidate_id); @@ -329,11 +418,9 @@ class HierarchicalNSW : public AlgorithmInterface { metric_distance_computations += size; } for (size_t i = 1; i <= size; ++i) { -#if defined(USE_PREFETCH) if (i + 1 <= size) { - _mm_prefetch(getDataByInternalId(list[i + 1]), _MM_HINT_T0); + prefetchData(list[i + 1]); } -#endif tableint v = list[i]; if (visited[v]) { if (feder_result != nullptr) { @@ -444,8 +531,8 @@ class HierarchicalNSW : public AlgorithmInterface { } std::vector> - getNeighboursWithinRadius(NeighborSet& top_candidates, const void* data_point, - float radius, const knowhere::BitsetView& bitset) const { + getNeighboursWithinRadius(NeighborSet& top_candidates, const void* data_point, float radius, + const knowhere::BitsetView& bitset) const { std::vector> result; auto& visited = visited_list_pool_->getFreeVisitedList(); @@ -468,11 +555,9 @@ class HierarchicalNSW : public AlgorithmInterface { int* data = (int*)get_linklist0(current_id); size_t size = getListCount((linklistsizeint*)data); -#if defined(USE_PREFETCH) for (size_t j = 1; j <= size; ++j) { - _mm_prefetch(getDataByInternalId(data[j]), _MM_HINT_T0); + prefetchData(data[j]); } -#endif for (size_t j = 1; j <= size; j++) { int candidate_id = *(data + j); if (!visited[candidate_id]) { @@ -662,7 +747,7 @@ class HierarchicalNSW : public AlgorithmInterface { } void - loadIndex(const std::string& location, const knowhere::Config& config, size_t max_elements_i = 0) { + loadIndex(const std::string& location, const knowhere::Config& config, bool use_sq, size_t max_elements_i = 0) { using knowhere::readBinaryPOD; auto cfg = static_cast(config); @@ -694,6 +779,12 @@ class HierarchicalNSW : public AlgorithmInterface { } fstdistfunc_ = space_->get_dist_func(); dist_func_param_ = space_->get_dist_func_param(); + if (use_sq) { + sq_enabled_ = true; + readBinaryPOD(input, sq_refine_enabled_); + readBinaryPOD(input, alpha_); + fstdistfunc_sq_ = space_->get_dist_func_sq(); + } readBinaryPOD(input, offsetLevel0_); readBinaryPOD(input, max_elements_); @@ -707,6 +798,13 @@ class HierarchicalNSW : public AlgorithmInterface { readBinaryPOD(input, size_data_per_element_); readBinaryPOD(input, label_offset_); readBinaryPOD(input, offsetData_); + if (use_sq) { + offsetSQData_ = offsetData_; + if (sq_refine_enabled_) { + offsetSQData_ += data_size_; + } + } + readBinaryPOD(input, offsetSQData_); readBinaryPOD(input, maxlevel_); readBinaryPOD(input, enterpoint_node_); @@ -778,6 +876,10 @@ class HierarchicalNSW : public AlgorithmInterface { writeBinaryPOD(output, metric_type_); writeBinaryPOD(output, data_size_); writeBinaryPOD(output, *((size_t*)dist_func_param_)); + if (sq_enabled_) { + writeBinaryPOD(output, sq_refine_enabled_); + writeBinaryPOD(output, alpha_); + } writeBinaryPOD(output, offsetLevel0_); writeBinaryPOD(output, max_elements_); @@ -806,11 +908,12 @@ class HierarchicalNSW : public AlgorithmInterface { if (linkListSize) output.write(linkLists_[i], linkListSize); } + // output.close(); } void - loadIndex(knowhere::MemoryIOReader& input, size_t max_elements_i = 0) { + loadIndex(knowhere::MemoryIOReader& input, bool use_sq, size_t max_elements_i = 0) { using knowhere::readBinaryPOD; // linxj: init with metrictype size_t dim; @@ -832,6 +935,12 @@ class HierarchicalNSW : public AlgorithmInterface { } fstdistfunc_ = space_->get_dist_func(); dist_func_param_ = space_->get_dist_func_param(); + if (use_sq) { + sq_enabled_ = true; + readBinaryPOD(input, sq_refine_enabled_); + readBinaryPOD(input, alpha_); + fstdistfunc_sq_ = space_->get_dist_func_sq(); + } readBinaryPOD(input, offsetLevel0_); readBinaryPOD(input, max_elements_); @@ -845,6 +954,12 @@ class HierarchicalNSW : public AlgorithmInterface { readBinaryPOD(input, size_data_per_element_); readBinaryPOD(input, label_offset_); readBinaryPOD(input, offsetData_); + if (use_sq) { + offsetSQData_ = offsetData_; + if (sq_refine_enabled_) { + offsetSQData_ += data_size_; + } + } readBinaryPOD(input, maxlevel_); readBinaryPOD(input, enterpoint_node_); @@ -1008,11 +1123,9 @@ class HierarchicalNSW : public AlgorithmInterface { data = get_linklist_at_level(currObj, level); int size = getListCount(data); tableint* datal = (tableint*)(data + 1); -#if defined(USE_PREFETCH) for (int i = 0; i < size; ++i) { - _mm_prefetch(getDataByInternalId(datal[i]), _MM_HINT_T0); + prefetchData(datal[i]); } -#endif for (int i = 0; i < size; i++) { tableint cand = datal[i]; dist_t d = calcDistance(dataPoint, cand); @@ -1086,13 +1199,16 @@ class HierarchicalNSW : public AlgorithmInterface { tableint enterpoint_copy = enterpoint_node_; memset(data_level0_memory_ + cur_c * size_data_per_element_ + offsetLevel0_, 0, size_data_per_element_); - memcpy(getDataByInternalId(cur_c), data_point, data_size_); - - if (metric_type_ == Metric::COSINE) { - data_norm_l2_[cur_c] = - std::sqrt(faiss::fvec_norm_L2sqr((const float*)data_point, *(size_t*)(dist_func_param_))); + if (!(sq_enabled_ && !sq_refine_enabled_)) { + memcpy(getDataByInternalId(cur_c), data_point, data_size_); + if (metric_type_ == Metric::COSINE) { + data_norm_l2_[cur_c] = + std::sqrt(faiss::fvec_norm_L2sqr((const float*)data_point, *(size_t*)(dist_func_param_))); + } + } + if (sq_enabled_) { + encodeSQuant((const float*)data_point, (int8_t*)getSQDataByInternalId(cur_c)); } - if (curlevel) { linkLists_[cur_c] = (char*)malloc(size_links_per_element_ * curlevel + 1); if (linkLists_[cur_c] == nullptr) @@ -1198,11 +1314,9 @@ class HierarchicalNSW : public AlgorithmInterface { metric_hops++; metric_distance_computations += size; tableint* datal = (tableint*)(data + 1); -#if defined(USE_PREFETCH) for (int i = 0; i < size; ++i) { - _mm_prefetch(getDataByInternalId(datal[i]), _MM_HINT_T0); + prefetchData(datal[i]); } -#endif for (int i = 0; i < size; i++) { tableint cand = datal[i]; if (cand < 0 || cand > max_elements_) @@ -1238,6 +1352,13 @@ class HierarchicalNSW : public AlgorithmInterface { query_data_norm = knowhere::CopyAndNormalizeVecs((const float*)query_data, 1, *(size_t*)dist_func_param_); query_data = query_data_norm.get(); } + std::unique_ptr query_data_sq; + const float* raw_data = (const float*)query_data; + if (sq_enabled_) { + query_data_sq = std::make_unique(*(size_t*)dist_func_param_); + encodeSQuant((const float*)query_data, query_data_sq.get()); + query_data = query_data_sq.get(); + } // do bruteforce search when topk is super large if (k >= (cur_element_count * kHnswSearchBFTopkThreshold)) { @@ -1247,7 +1368,8 @@ class HierarchicalNSW : public AlgorithmInterface { // do bruteforce search when delete rate high if (!bitset.empty()) { const size_t filtered_out_num = bitset.count(); - if (filtered_out_num >= (cur_element_count * kHnswSearchKnnBFFilterThreshold) || k >= (cur_element_count - filtered_out_num) * kHnswSearchBFTopkThreshold) { + if (filtered_out_num >= (cur_element_count * kHnswSearchKnnBFFilterThreshold) || + k >= (cur_element_count - filtered_out_num) * kHnswSearchBFTopkThreshold) { return searchKnnBF(query_data, k, bitset); } } @@ -1265,8 +1387,19 @@ class HierarchicalNSW : public AlgorithmInterface { std::vector> result; size_t len = std::min(k, retset.size()); result.reserve(len); - for (int i = 0; i < len; ++i) { - result.emplace_back(retset[i].distance, (labeltype)retset[i].id); + if (sq_enabled_ && sq_refine_enabled_) { + knowhere::ResultMaxHeap max_heap(len); + for (int i = 0; i < retset.size(); ++i) { + max_heap.Push(calcRefineDistance(raw_data, retset[i].id), retset[i].id); + } + for (int64_t i = len - 1; i >= 0; --i) { + const auto op = max_heap.Pop(); + result.emplace_back(op.value()); + } + } else { + for (int i = 0; i < len; ++i) { + result.emplace_back(retset[i].distance, (labeltype)retset[i].id); + } } if (len > 0) { lru_cache.put(vec_hash, result[0].second); @@ -1376,6 +1509,13 @@ class HierarchicalNSW : public AlgorithmInterface { query_data = query_data_norm.get(); } + std::unique_ptr query_data_sq; + if (sq_enabled_) { + query_data_sq = std::make_unique(*(size_t*)dist_func_param_); + encodeSQuant((const float*)query_data, query_data_sq.get()); + query_data = query_data_sq.get(); + } + // do bruteforce range search when ef is super large size_t ef = param ? param->ef_ : this->ef_; if (ef >= (cur_element_count * kHnswSearchBFTopkThreshold)) { @@ -1385,7 +1525,8 @@ class HierarchicalNSW : public AlgorithmInterface { // do bruteforce range search when delete rate high if (!bitset.empty()) { const size_t filtered_out_num = bitset.count(); - if (filtered_out_num >= (cur_element_count * kHnswSearchRangeBFFilterThreshold) || ef >= (cur_element_count - filtered_out_num) * kHnswSearchBFTopkThreshold) { + if (filtered_out_num >= (cur_element_count * kHnswSearchRangeBFFilterThreshold) || + ef >= (cur_element_count - filtered_out_num) * kHnswSearchBFTopkThreshold) { return searchRangeBF(query_data, radius, bitset); } } @@ -1456,7 +1597,7 @@ class HierarchicalNSW : public AlgorithmInterface { } } if (metric_type_ == Metric::COSINE) { - ret += max_elements_ * sizeof(float); + ret += max_elements_ * sizeof(float); } return ret; } diff --git a/thirdparty/hnswlib/hnswlib/hnswlib.h b/thirdparty/hnswlib/hnswlib/hnswlib.h index 741043819..1b8daa6d2 100644 --- a/thirdparty/hnswlib/hnswlib/hnswlib.h +++ b/thirdparty/hnswlib/hnswlib/hnswlib.h @@ -161,6 +161,11 @@ class SpaceInterface { virtual DISTFUNC get_dist_func() = 0; + virtual DISTFUNC + get_dist_func_sq() { + throw std::runtime_error("Not implemented\n"); + } + virtual void* get_dist_func_param() = 0; diff --git a/thirdparty/hnswlib/hnswlib/space_cosine.h b/thirdparty/hnswlib/hnswlib/space_cosine.h index 237222c34..bd831eef7 100644 --- a/thirdparty/hnswlib/hnswlib/space_cosine.h +++ b/thirdparty/hnswlib/hnswlib/space_cosine.h @@ -15,14 +15,21 @@ CosineDistance(const void* pVect1, const void* pVect2, const void* qty_ptr) { return -1.0f * Cosine(pVect1, pVect2, qty_ptr); } +static float +CosineSQ8Distance(const void* pVect1, const void* pVect2, const void* qty_ptr) { + return -1.0f * faiss::ivec_inner_product((const int8_t*)pVect1, (const int8_t*)pVect2, *(size_t*)qty_ptr); +} + class CosineSpace : public SpaceInterface { DISTFUNC fstdistfunc_; + DISTFUNC fstdistfunc_sq_; size_t data_size_; size_t dim_; public: CosineSpace(size_t dim) { fstdistfunc_ = CosineDistance; + fstdistfunc_sq_ = CosineSQ8Distance; dim_ = dim; data_size_ = dim * sizeof(float); } @@ -37,6 +44,11 @@ class CosineSpace : public SpaceInterface { return fstdistfunc_; } + DISTFUNC + get_dist_func_sq() { + return fstdistfunc_sq_; + } + void* get_dist_func_param() { return &dim_; diff --git a/thirdparty/hnswlib/hnswlib/space_ip.h b/thirdparty/hnswlib/hnswlib/space_ip.h index 7383ff3a2..c3c8c33da 100644 --- a/thirdparty/hnswlib/hnswlib/space_ip.h +++ b/thirdparty/hnswlib/hnswlib/space_ip.h @@ -24,6 +24,11 @@ InnerProductDistance(const void* pVect1, const void* pVect2, const void* qty_ptr return -1.0f * InnerProduct(pVect1, pVect2, qty_ptr); } +static float +InnerProductSQ8Distance(const void* pVect1, const void* pVect2, const void* qty_ptr) { + return -1.0f * faiss::ivec_inner_product((const int8_t*)pVect1, (const int8_t*)pVect2, *(size_t*)qty_ptr); +} + #if defined(USE_AVX) // Favor using AVX if available. @@ -321,12 +326,14 @@ InnerProductDistanceSIMD4ExtResiduals(const void* pVect1v, const void* pVect2v, class InnerProductSpace : public SpaceInterface { DISTFUNC fstdistfunc_; + DISTFUNC fstdistfunc_sq_; size_t data_size_; size_t dim_; public: InnerProductSpace(size_t dim) { fstdistfunc_ = InnerProductDistance; + fstdistfunc_sq_ = InnerProductSQ8Distance; #if 0 /* use FAISS distance calculation algorithm instead */ #if defined(USE_AVX) || defined(USE_SSE) || defined(USE_AVX512) #if defined(USE_AVX512) @@ -374,6 +381,11 @@ class InnerProductSpace : public SpaceInterface { return fstdistfunc_; } + DISTFUNC + get_dist_func_sq() { + return fstdistfunc_sq_; + } + void* get_dist_func_param() { return &dim_; diff --git a/thirdparty/hnswlib/hnswlib/space_l2.h b/thirdparty/hnswlib/hnswlib/space_l2.h index aa255ecab..dfdc15504 100644 --- a/thirdparty/hnswlib/hnswlib/space_l2.h +++ b/thirdparty/hnswlib/hnswlib/space_l2.h @@ -25,6 +25,11 @@ L2Sqr(const void* pVect1v, const void* pVect2v, const void* qty_ptr) { #endif } +static float +L2SqrSQ8(const void* pVect1v, const void* pVect2v, const void* qty_ptr) { + return faiss::ivec_L2sqr((const int8_t*)pVect1v, (const int8_t*)pVect2v, *(size_t*)qty_ptr); +} + #if defined(USE_AVX512) // Favor using AVX512 if available. @@ -210,12 +215,14 @@ L2SqrSIMD4ExtResiduals(const void* pVect1v, const void* pVect2v, const void* qty class L2Space : public SpaceInterface { DISTFUNC fstdistfunc_; + DISTFUNC fstdistfunc_sq_; size_t data_size_; size_t dim_; public: L2Space(size_t dim) { fstdistfunc_ = L2Sqr; + fstdistfunc_sq_ = L2SqrSQ8; #if 0 /* use FAISS distance calculation algorithm instead */ #if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512) #if defined(USE_AVX512) @@ -252,6 +259,11 @@ class L2Space : public SpaceInterface { return fstdistfunc_; } + DISTFUNC + get_dist_func_sq() { + return fstdistfunc_sq_; + } + void* get_dist_func_param() { return &dim_;