diff --git a/include/knowhere/comp/index_param.h b/include/knowhere/comp/index_param.h index 13c6f6e03..00eac2c92 100644 --- a/include/knowhere/comp/index_param.h +++ b/include/knowhere/comp/index_param.h @@ -43,6 +43,8 @@ constexpr const char* INDEX_RAFT_IVFPQ = "GPU_RAFT_IVF_PQ"; constexpr const char* INDEX_RAFT_CAGRA = "GPU_RAFT_CAGRA"; constexpr const char* INDEX_HNSW = "HNSW"; +constexpr const char* INDEX_HNSW_SQ8 = "HNSW_SQ8"; +constexpr const char* INDEX_HNSW_SQ8_REFINE = "HNSW_SQ8_REFINE"; constexpr const char* INDEX_DISKANN = "DISKANN"; constexpr const char* INDEX_SPARSE_INVERTED_INDEX = "SPARSE_INVERTED_INDEX"; diff --git a/include/knowhere/utils.h b/include/knowhere/utils.h index b14beb5a9..783242137 100644 --- a/include/knowhere/utils.h +++ b/include/knowhere/utils.h @@ -59,6 +59,15 @@ hash_vec(const float* x, size_t d) { return h; } +inline uint64_t +hash_u8_vec(const uint8_t* x, size_t d) { + uint64_t h = seed; + for (size_t i = 0; i < d; ++i) { + h = h * 13331 + *(x + i); + } + return h; +} + inline uint64_t hash_binary_vec(const uint8_t* x, size_t d) { size_t len = (d + 7) / 8; diff --git a/src/index/hnsw/hnsw.cc b/src/index/hnsw/hnsw.cc index cb3f8ac5c..cd1d46fbe 100644 --- a/src/index/hnsw/hnsw.cc +++ b/src/index/hnsw/hnsw.cc @@ -29,7 +29,10 @@ #include "knowhere/utils.h" namespace knowhere { -template + +using hnswlib::QuantType; + +template class HnswIndexNode : public IndexNode { static_assert(std::is_same_v || std::is_same_v, "HnswIndexNode only support float/bianry"); @@ -60,8 +63,8 @@ class HnswIndexNode : public IndexNode { LOG_KNOWHERE_WARNING_ << "metric type not support in hnsw: " << hnsw_cfg.metric_type.value(); return Status::invalid_metric_type; } - auto index = new (std::nothrow) - hnswlib::HierarchicalNSW(space, rows, hnsw_cfg.M.value(), hnsw_cfg.efConstruction.value()); + auto index = new (std::nothrow) hnswlib::HierarchicalNSW(space, rows, hnsw_cfg.M.value(), + hnsw_cfg.efConstruction.value()); if (index == nullptr) { LOG_KNOWHERE_WARNING_ << "memory malloc error."; return Status::malloc_error; @@ -71,6 +74,9 @@ class HnswIndexNode : public IndexNode { LOG_KNOWHERE_WARNING_ << "index not empty, deleted old index"; } this->index_ = index; + if constexpr (quant_type != QuantType::None) { + this->index_->trainSQuant((const float*)dataset.GetTensor(), rows); + } return Status::success; } @@ -219,7 +225,7 @@ class HnswIndexNode : public IndexNode { private: class iterator : public IndexNode::iterator { public: - iterator(const hnswlib::HierarchicalNSW* index, const char* query, const bool transform, + iterator(const hnswlib::HierarchicalNSW* index, const char* query, const bool transform, const BitsetView& bitset, const bool for_tuning = false, const size_t seed_ef = kIteratorSeedEf) : index_(index), transform_(transform), @@ -252,7 +258,7 @@ class HnswIndexNode : public IndexNode { has_next_ = false; } } - const hnswlib::HierarchicalNSW* index_; + const hnswlib::HierarchicalNSW* index_; const bool transform_; std::unique_ptr workspace_; bool has_next_; @@ -396,7 +402,7 @@ class HnswIndexNode : public IndexNode { bool HasRawData(const std::string& metric_type) const override { - return true; + return quant_type == QuantType::None || quant_type == QuantType::SQ8Refine; } expected @@ -460,8 +466,8 @@ class HnswIndexNode : public IndexNode { MemoryIOReader reader(binary->data.get(), binary->size); - hnswlib::SpaceInterface* space = nullptr; - index_ = new (std::nothrow) hnswlib::HierarchicalNSW(space); + hnswlib::SpaceInterface* space = nullptr; + index_ = new (std::nothrow) hnswlib::HierarchicalNSW(space); index_->loadIndex(reader); LOG_KNOWHERE_INFO_ << "Loaded HNSW index. #points num:" << index_->max_elements_ << " #M:" << index_->M_ << " #max level:" << index_->maxlevel_ @@ -480,8 +486,8 @@ class HnswIndexNode : public IndexNode { delete index_; } try { - hnswlib::SpaceInterface* space = nullptr; - index_ = new (std::nothrow) hnswlib::HierarchicalNSW(space); + hnswlib::SpaceInterface* space = nullptr; + index_ = new (std::nothrow) hnswlib::HierarchicalNSW(space); index_->loadIndex(filename, config); } catch (std::exception& e) { LOG_KNOWHERE_WARNING_ << "hnsw inner error: " << e.what(); @@ -521,7 +527,14 @@ class HnswIndexNode : public IndexNode { std::string Type() const override { - return knowhere::IndexEnum::INDEX_HNSW; + if constexpr (quant_type == QuantType::SQ8) { + return knowhere::IndexEnum::INDEX_HNSW_SQ8; + } else if constexpr (quant_type == QuantType::SQ8Refine) { + return knowhere::IndexEnum::INDEX_HNSW_SQ8_REFINE; + + } else { + return knowhere::IndexEnum::INDEX_HNSW; + } } ~HnswIndexNode() override { @@ -568,12 +581,18 @@ class HnswIndexNode : public IndexNode { } private: - hnswlib::HierarchicalNSW* index_; + hnswlib::HierarchicalNSW* index_; std::shared_ptr search_pool_; }; KNOWHERE_SIMPLE_REGISTER_GLOBAL(HNSW, HnswIndexNode, fp32); +KNOWHERE_SIMPLE_REGISTER_GLOBAL(HNSW_SQ8, HnswIndexNode, fp32, QuantType::SQ8); +KNOWHERE_SIMPLE_REGISTER_GLOBAL(HNSW_SQ8_REFINE, HnswIndexNode, fp32, QuantType::SQ8Refine); KNOWHERE_SIMPLE_REGISTER_GLOBAL(HNSW, HnswIndexNode, bin1); KNOWHERE_MOCK_REGISTER_GLOBAL(HNSW, HnswIndexNode, fp16); +KNOWHERE_MOCK_REGISTER_GLOBAL(HNSW_SQ8, HnswIndexNode, fp16, QuantType::SQ8); +KNOWHERE_MOCK_REGISTER_GLOBAL(HNSW_SQ8_REFINE, HnswIndexNode, fp16, QuantType::SQ8Refine); KNOWHERE_MOCK_REGISTER_GLOBAL(HNSW, HnswIndexNode, bf16); +KNOWHERE_MOCK_REGISTER_GLOBAL(HNSW_SQ8, HnswIndexNode, bf16, QuantType::SQ8); +KNOWHERE_MOCK_REGISTER_GLOBAL(HNSW_SQ8_REFINE, HnswIndexNode, bf16, QuantType::SQ8Refine); } // namespace knowhere diff --git a/src/simd/distances_avx.cc b/src/simd/distances_avx.cc index 1e3ec0f04..bfb85f5f2 100644 --- a/src/simd/distances_avx.cc +++ b/src/simd/distances_avx.cc @@ -215,5 +215,28 @@ fvec_L2sqr_batch_4_avx(const float* x, const float* y0, const float* y1, const f } FAISS_PRAGMA_IMPRECISE_FUNCTION_END +// trust the compiler to unroll this properly +int32_t +ivec_inner_product_avx(const int8_t* x, const int8_t* y, size_t d) { + size_t i; + int32_t res = 0; + for (i = 0; i < d; i++) { + res += (int32_t)x[i] * y[i]; + } + return res; +} + +// trust the compiler to unroll this properly +int32_t +ivec_L2sqr_avx(const int8_t* x, const int8_t* y, size_t d) { + size_t i; + int32_t res = 0; + for (i = 0; i < d; i++) { + const int32_t tmp = (int32_t)x[i] - (int32_t)y[i]; + res += tmp * tmp; + } + return res; +} + } // namespace faiss #endif diff --git a/src/simd/distances_avx.h b/src/simd/distances_avx.h index c89b964d6..f786dc693 100644 --- a/src/simd/distances_avx.h +++ b/src/simd/distances_avx.h @@ -44,6 +44,12 @@ void fvec_L2sqr_batch_4_avx(const float* x, const float* y0, const float* y1, const float* y2, const float* y3, const size_t d, float& dis0, float& dis1, float& dis2, float& dis3); +int32_t +ivec_inner_product_avx(const int8_t* x, const int8_t* y, size_t d); + +int32_t +ivec_L2sqr_avx(const int8_t* x, const int8_t* y, size_t d); + } // namespace faiss #endif /* DISTANCES_AVX_H */ diff --git a/src/simd/distances_avx512.cc b/src/simd/distances_avx512.cc index bbf1ce80c..50cfbd153 100644 --- a/src/simd/distances_avx512.cc +++ b/src/simd/distances_avx512.cc @@ -242,6 +242,29 @@ fvec_L2sqr_batch_4_avx512(const float* x, const float* y0, const float* y1, cons } FAISS_PRAGMA_IMPRECISE_FUNCTION_END +// trust the compiler to unroll this properly +int32_t +ivec_inner_product_avx512(const int8_t* x, const int8_t* y, size_t d) { + size_t i; + int32_t res = 0; + for (i = 0; i < d; i++) { + res += (int32_t)x[i] * y[i]; + } + return res; +} + +// trust the compiler to unroll this properly +int32_t +ivec_L2sqr_avx512(const int8_t* x, const int8_t* y, size_t d) { + size_t i; + int32_t res = 0; + for (i = 0; i < d; i++) { + const int32_t tmp = (int32_t)x[i] - (int32_t)y[i]; + res += tmp * tmp; + } + return res; +} + } // namespace faiss #endif diff --git a/src/simd/distances_avx512.h b/src/simd/distances_avx512.h index bfde69af1..79654319f 100644 --- a/src/simd/distances_avx512.h +++ b/src/simd/distances_avx512.h @@ -43,6 +43,12 @@ void fvec_L2sqr_batch_4_avx512(const float* x, const float* y0, const float* y1, const float* y2, const float* y3, const size_t d, float& dis0, float& dis1, float& dis2, float& dis3); +int32_t +ivec_inner_product_avx512(const int8_t* x, const int8_t* y, size_t d); + +int32_t +ivec_L2sqr_avx512(const int8_t* x, const int8_t* y, size_t d); + } // namespace faiss #endif /* DISTANCES_AVX512_H */ diff --git a/src/simd/distances_neon.cc b/src/simd/distances_neon.cc index cb304927f..eb90c9ae7 100644 --- a/src/simd/distances_neon.cc +++ b/src/simd/distances_neon.cc @@ -635,5 +635,28 @@ fvec_madd_and_argmin_neon(size_t n, const float* a, float bf, const float* b, fl return ans_; } +// trust the compiler to unroll this properly +int32_t +ivec_inner_product_neon(const int8_t* x, const int8_t* y, size_t d) { + size_t i; + int32_t res = 0; + for (i = 0; i < d; i++) { + res += (int32_t)x[i] * y[i]; + } + return res; +} + +// trust the compiler to unroll this properly +int32_t +ivec_L2sqr_neon(const int8_t* x, const int8_t* y, size_t d) { + size_t i; + int32_t res = 0; + for (i = 0; i < d; i++) { + const int32_t tmp = (int32_t)x[i] - (int32_t)y[i]; + res += tmp * tmp; + } + return res; +} + } // namespace faiss #endif diff --git a/src/simd/distances_neon.h b/src/simd/distances_neon.h index fdfc79ad9..c3150d161 100644 --- a/src/simd/distances_neon.h +++ b/src/simd/distances_neon.h @@ -12,6 +12,7 @@ #ifndef DISTANCES_NEON_H #define DISTANCES_NEON_H +#include #include namespace faiss { @@ -50,6 +51,12 @@ fvec_madd_neon(size_t n, const float* a, float bf, const float* b, float* c); int fvec_madd_and_argmin_neon(size_t n, const float* a, float bf, const float* b, float* c); +int32_t +ivec_inner_product_neon(const int8_t* x, const int8_t* y, size_t d); + +int32_t +ivec_L2sqr_neon(const int8_t* x, const int8_t* y, size_t d); + } // namespace faiss #endif /* DISTANCES_NEON_H */ diff --git a/src/simd/distances_ref.cc b/src/simd/distances_ref.cc index 84fb049a9..2ff6e82f6 100644 --- a/src/simd/distances_ref.cc +++ b/src/simd/distances_ref.cc @@ -212,4 +212,25 @@ fvec_L2sqr_batch_4_ref(const float* x, const float* y0, const float* y1, const f dis3 = d3; } +int32_t +ivec_inner_product_ref(const int8_t* x, const int8_t* y, size_t d) { + size_t i; + int32_t res = 0; + for (i = 0; i < d; i++) { + res += (int32_t)x[i] * y[i]; + } + return res; +} + +int32_t +ivec_L2sqr_ref(const int8_t* x, const int8_t* y, size_t d) { + size_t i; + int32_t res = 0; + for (i = 0; i < d; i++) { + const int32_t tmp = (int32_t)x[i] - (int32_t)y[i]; + res += tmp * tmp; + } + return res; +} + } // namespace faiss diff --git a/src/simd/distances_ref.h b/src/simd/distances_ref.h index fefb04999..2ca812051 100644 --- a/src/simd/distances_ref.h +++ b/src/simd/distances_ref.h @@ -1,6 +1,7 @@ #ifndef DISTANCES_REF_H #define DISTANCES_REF_H +#include #include namespace faiss { @@ -71,6 +72,12 @@ void fvec_L2sqr_batch_4_ref(const float* x, const float* y0, const float* y1, const float* y2, const float* y3, const size_t d, float& dis0, float& dis1, float& dis2, float& dis3); +int32_t +ivec_inner_product_ref(const int8_t* x, const int8_t* y, size_t d); + +int32_t +ivec_L2sqr_ref(const int8_t* x, const int8_t* y, size_t d); + } // namespace faiss #endif /* DISTANCES_REF_H */ diff --git a/src/simd/distances_sse.cc b/src/simd/distances_sse.cc index 30e8fa3d1..600ec931b 100644 --- a/src/simd/distances_sse.cc +++ b/src/simd/distances_sse.cc @@ -377,5 +377,28 @@ fvec_madd_and_argmin_sse(size_t n, const float* a, float bf, const float* b, flo return _mm_cvtsi128_si32(imin4); } +// trust the compiler to unroll this properly +int32_t +ivec_inner_product_sse(const int8_t* x, const int8_t* y, size_t d) { + size_t i; + int32_t res = 0; + for (i = 0; i < d; i++) { + res += (int32_t)x[i] * y[i]; + } + return res; +} + +// trust the compiler to unroll this properly +int32_t +ivec_L2sqr_sse(const int8_t* x, const int8_t* y, size_t d) { + size_t i; + int32_t res = 0; + for (i = 0; i < d; i++) { + const int32_t tmp = (int32_t)x[i] - (int32_t)y[i]; + res += tmp * tmp; + } + return res; +} + } // namespace faiss #endif diff --git a/src/simd/distances_sse.h b/src/simd/distances_sse.h index 48f2f7653..a8089c119 100644 --- a/src/simd/distances_sse.h +++ b/src/simd/distances_sse.h @@ -12,6 +12,7 @@ #ifndef DISTANCES_SSE_H #define DISTANCES_SSE_H +#include #include namespace faiss { @@ -46,6 +47,12 @@ fvec_madd_sse(size_t n, const float* a, float bf, const float* b, float* c); int fvec_madd_and_argmin_sse(size_t n, const float* a, float bf, const float* b, float* c); +int32_t +ivec_inner_product_sse(const int8_t* x, const int8_t* y, size_t d); + +int32_t +ivec_L2sqr_sse(const int8_t* x, const int8_t* y, size_t d); + } // namespace faiss #endif /* DISTANCES_SSE_H */ diff --git a/src/simd/hook.cc b/src/simd/hook.cc index 20f887406..5e40b3f05 100644 --- a/src/simd/hook.cc +++ b/src/simd/hook.cc @@ -59,6 +59,9 @@ decltype(fvec_L2sqr_ny_transposed) fvec_L2sqr_ny_transposed = fvec_L2sqr_ny_tran decltype(fvec_inner_product_batch_4) fvec_inner_product_batch_4 = fvec_inner_product_batch_4_ref; decltype(fvec_L2sqr_batch_4) fvec_L2sqr_batch_4 = fvec_L2sqr_batch_4_ref; +decltype(ivec_inner_product) ivec_inner_product = ivec_inner_product_ref; +decltype(ivec_L2sqr) ivec_L2sqr = ivec_L2sqr_ref; + #if defined(__x86_64__) bool cpu_support_avx512() { @@ -99,6 +102,9 @@ fvec_hook(std::string& simd_type) { fvec_inner_product_batch_4 = fvec_inner_product_batch_4_avx512; fvec_L2sqr_batch_4 = fvec_L2sqr_batch_4_avx512; + ivec_inner_product = ivec_inner_product_avx512; + ivec_L2sqr = ivec_L2sqr_avx512; + simd_type = "AVX512"; support_pq_fast_scan = true; } else if (use_avx2 && cpu_support_avx2()) { @@ -116,6 +122,9 @@ fvec_hook(std::string& simd_type) { fvec_inner_product_batch_4 = fvec_inner_product_batch_4_avx; fvec_L2sqr_batch_4 = fvec_L2sqr_batch_4_avx; + ivec_inner_product = ivec_inner_product_avx; + ivec_L2sqr = ivec_L2sqr_avx; + simd_type = "AVX2"; support_pq_fast_scan = true; } else if (use_sse4_2 && cpu_support_sse4_2()) { @@ -133,6 +142,9 @@ fvec_hook(std::string& simd_type) { fvec_inner_product_batch_4 = fvec_inner_product_batch_4_ref; fvec_L2sqr_batch_4 = fvec_L2sqr_batch_4_ref; + ivec_inner_product = ivec_inner_product_sse; + ivec_L2sqr = ivec_L2sqr_sse; + simd_type = "SSE4_2"; support_pq_fast_scan = false; } else { @@ -150,6 +162,9 @@ fvec_hook(std::string& simd_type) { fvec_inner_product_batch_4 = fvec_inner_product_batch_4_ref; fvec_L2sqr_batch_4 = fvec_L2sqr_batch_4_ref; + ivec_inner_product = ivec_inner_product_ref; + ivec_L2sqr = ivec_L2sqr_ref; + simd_type = "GENERIC"; support_pq_fast_scan = false; } @@ -167,6 +182,9 @@ fvec_hook(std::string& simd_type) { fvec_madd = fvec_madd_neon; fvec_madd_and_argmin = fvec_madd_and_argmin_neon; + ivec_inner_product = ivec_inner_product_neon; + ivec_L2sqr = ivec_L2sqr_neon; + simd_type = "NEON"; support_pq_fast_scan = true; @@ -185,6 +203,9 @@ fvec_hook(std::string& simd_type) { fvec_madd = fvec_madd_ref; fvec_madd_and_argmin = fvec_madd_and_argmin_ref; + ivec_inner_product = ivec_inner_product_ref; + ivec_L2sqr = ivec_L2sqr_ref; + simd_type = "GENERIC"; support_pq_fast_scan = false; #endif diff --git a/src/simd/hook.h b/src/simd/hook.h index 07763ea8d..69d84e9b4 100644 --- a/src/simd/hook.h +++ b/src/simd/hook.h @@ -70,6 +70,10 @@ extern void (*fvec_inner_product_batch_4)(const float*, const float*, const floa extern void (*fvec_L2sqr_batch_4)(const float*, const float*, const float*, const float*, const float*, const size_t, float&, float&, float&, float&); +extern int32_t (*ivec_inner_product)(const int8_t*, const int8_t*, size_t); + +extern int32_t (*ivec_L2sqr)(const int8_t*, const int8_t*, size_t); + #if defined(__x86_64__) extern bool use_avx512; extern bool use_avx2; diff --git a/tests/ut/test_search.cc b/tests/ut/test_search.cc index a6f6cf6c5..136285f45 100644 --- a/tests/ut/test_search.cc +++ b/tests/ut/test_search.cc @@ -24,7 +24,7 @@ namespace { constexpr float kKnnRecallThreshold = 0.6f; -constexpr float kBruteForceRecallThreshold = 0.99f; +constexpr float kBruteForceRecallThreshold = 0.95f; } // namespace TEST_CASE("Test Mem Index With Float Vector", "[float metrics]") { @@ -113,6 +113,8 @@ TEST_CASE("Test Mem Index With Float Vector", "[float metrics]") { make_tuple(knowhere::IndexEnum::INDEX_FAISS_SCANN, scann_gen), make_tuple(knowhere::IndexEnum::INDEX_FAISS_SCANN, scann_gen2), make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen), + make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ8, hnsw_gen), + make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ8_REFINE, hnsw_gen), })); auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); @@ -138,6 +140,7 @@ TEST_CASE("Test Mem Index With Float Vector", "[float metrics]") { if (metric == knowhere::metric::COSINE) { if (name != knowhere::IndexEnum::INDEX_FAISS_IVFSQ8 && name != knowhere::IndexEnum::INDEX_FAISS_IVFPQ && + name != knowhere::IndexEnum::INDEX_HNSW_SQ8 && name != knowhere::IndexEnum::INDEX_HNSW_SQ8_REFINE && !scann_without_raw_data) { REQUIRE(CheckDistanceInScope(*results.value(), topk, -1.00001, 1.00001)); } @@ -155,6 +158,8 @@ TEST_CASE("Test Mem Index With Float Vector", "[float metrics]") { make_tuple(knowhere::IndexEnum::INDEX_FAISS_SCANN, scann_gen), make_tuple(knowhere::IndexEnum::INDEX_FAISS_SCANN, scann_gen2), make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen), + make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ8, hnsw_gen), + make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ8_REFINE, hnsw_gen), })); auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); @@ -181,6 +186,7 @@ TEST_CASE("Test Mem Index With Float Vector", "[float metrics]") { if (metric == knowhere::metric::COSINE) { if (name != knowhere::IndexEnum::INDEX_FAISS_IVFSQ8 && name != knowhere::IndexEnum::INDEX_FAISS_IVFPQ && + name != knowhere::IndexEnum::INDEX_HNSW_SQ8 && name != knowhere::IndexEnum::INDEX_HNSW_SQ8_REFINE && !scann_without_raw_data) { REQUIRE(CheckDistanceInScope(*results.value(), -1.00001, 1.00001)); } @@ -268,6 +274,8 @@ TEST_CASE("Test Mem Index With Float Vector", "[float metrics]") { using std::make_tuple; auto [name, gen, threshold] = GENERATE_REF(table, float>({ make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen, hnswlib::kHnswSearchKnnBFFilterThreshold), + make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ8, hnsw_gen, hnswlib::kHnswSearchKnnBFFilterThreshold), + make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ8_REFINE, hnsw_gen, hnswlib::kHnswSearchKnnBFFilterThreshold), })); auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); @@ -307,6 +315,8 @@ TEST_CASE("Test Mem Index With Float Vector", "[float metrics]") { make_tuple(knowhere::IndexEnum::INDEX_FAISS_SCANN, scann_gen), make_tuple(knowhere::IndexEnum::INDEX_FAISS_SCANN, scann_gen2), make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen), + make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ8, hnsw_gen), + make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ8_REFINE, hnsw_gen), })); auto idx = knowhere::IndexFactory::Instance().Create(name, version); diff --git a/thirdparty/hnswlib/hnswlib/hnswalg.h b/thirdparty/hnswlib/hnswlib/hnswalg.h index 48c1a572d..631fa4cea 100644 --- a/thirdparty/hnswlib/hnswlib/hnswalg.h +++ b/thirdparty/hnswlib/hnswlib/hnswalg.h @@ -54,10 +54,16 @@ enum Metric { UNKNOWN = 100, }; -template +enum QuantType { None = 0, SQ8 = 1, SQ8Refine = 2 }; + +template class HierarchicalNSW : public AlgorithmInterface { public: static const tableint max_update_element_locks = 65536; + + static constexpr bool sq_enabled = quant_type != QuantType::None; + static constexpr bool has_raw_data = quant_type == QuantType::None || quant_type == QuantType::SQ8Refine; + HierarchicalNSW(SpaceInterface* s) { } @@ -91,6 +97,9 @@ class HierarchicalNSW : public AlgorithmInterface { num_deleted_ = 0; data_size_ = s->get_data_size(); fstdistfunc_ = s->get_dist_func(); + if constexpr (sq_enabled) { + fstdistfunc_sq_ = space_->get_dist_func_sq(); + } dist_func_param_ = s->get_dist_func_param(); M_ = M; maxM_ = M_; @@ -102,8 +111,20 @@ class HierarchicalNSW : public AlgorithmInterface { update_probability_generator_.seed(random_seed + 1); size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint); - size_data_per_element_ = size_links_level0_ + data_size_; // + sizeof(labeltype); + size_data_per_element_ = size_links_level0_; // + sizeof(labeltype); + if constexpr (has_raw_data) { + size_data_per_element_ += data_size_; + } + if constexpr (sq_enabled) { + size_data_per_element_ += *(size_t*)dist_func_param_ * sizeof(int8_t); + } offsetData_ = size_links_level0_; + if constexpr (sq_enabled) { + offsetSQData_ = offsetData_; + if constexpr (has_raw_data) { + offsetSQData_ += data_size_; + } + } // label_offset_ = size_links_level0_ + data_size_; offsetLevel0_ = 0; @@ -190,7 +211,7 @@ class HierarchicalNSW : public AlgorithmInterface { tableint enterpoint_node_; size_t size_links_level0_; - size_t offsetData_, offsetLevel0_; + size_t offsetData_, offsetSQData_, offsetLevel0_; char* data_level0_memory_; float* data_norm_l2_; // vector's l2 norm @@ -201,6 +222,7 @@ class HierarchicalNSW : public AlgorithmInterface { size_t label_offset_; DISTFUNC fstdistfunc_; + DISTFUNC fstdistfunc_sq_; void* dist_func_param_; std::default_random_engine level_generator_; @@ -210,8 +232,53 @@ class HierarchicalNSW : public AlgorithmInterface { char* map_; size_t map_size_; + float alpha_ = 0.0f; + mutable knowhere::lru_cache lru_cache; + // Symmetric quantization to encode float value from [-alpha, alpha] to [-127, 127] + void + trainSQuant(const float* train_data, size_t ntrain) { + alpha_ = 0.0f; + size_t dim = *(size_t*)dist_func_param_; + for (size_t i = 0; i < ntrain; ++i) { + const float* vec = train_data + i * dim; + std::unique_ptr vec_norm = nullptr; + if (metric_type_ == Metric::COSINE) { + vec_norm = knowhere::CopyAndNormalizeVecs(vec, 1, dim); + vec = vec_norm.get(); + } + for (size_t j = 0; j < dim; ++j) { + alpha_ = std::max(alpha_, std::abs(vec[j])); + } + } + } + + void + encodeSQuant(const float* from, int8_t* to) const { + size_t dim = *(size_t*)dist_func_param_; + std::unique_ptr data_norm = nullptr; + if (metric_type_ == Metric::COSINE) { + data_norm = knowhere::CopyAndNormalizeVecs(from, 1, dim); + from = data_norm.get(); + } + for (size_t i = 0; i < dim; ++i) { + float x = from[i] / alpha_; + if (x > 1.0f) { + x = 1.0f; + } + if (x < -1.0f) { + x = -1.0f; + } + to[i] = std::round(x * 127.0f); + } + } + + inline char* + getSQDataByInternalId(tableint internal_id) const { + return (data_level0_memory_ + internal_id * size_data_per_element_ + offsetSQData_); + } + inline char* getDataByInternalId(tableint internal_id) const { return (data_level0_memory_ + internal_id * size_data_per_element_ + offsetData_); @@ -226,15 +293,34 @@ class HierarchicalNSW : public AlgorithmInterface { inline dist_t calcDistance(const tableint id1, const tableint id2) const { - dist_t dist = fstdistfunc_(getDataByInternalId(id1), getDataByInternalId(id2), dist_func_param_); - if (metric_type_ == Metric::COSINE) { - dist /= (data_norm_l2_[id1] * data_norm_l2_[id2]); + if constexpr (sq_enabled) { + return fstdistfunc_sq_(getSQDataByInternalId(id1), getSQDataByInternalId(id2), dist_func_param_) * alpha_ * + alpha_ / 127.0f / 127.0f; + } else { + dist_t dist = fstdistfunc_(getDataByInternalId(id1), getDataByInternalId(id2), dist_func_param_); + if (metric_type_ == Metric::COSINE) { + dist /= (data_norm_l2_[id1] * data_norm_l2_[id2]); + } + return dist; } - return dist; } inline dist_t calcDistance(const void* vec, const tableint id) const { + if constexpr (sq_enabled) { + return fstdistfunc_sq_(vec, getSQDataByInternalId(id), dist_func_param_) * alpha_ * alpha_ / 127.0f / + 127.0f; + } else { + dist_t dist = fstdistfunc_(vec, getDataByInternalId(id), dist_func_param_); + if (metric_type_ == Metric::COSINE) { + dist /= data_norm_l2_[id]; + } + return dist; + } + } + + inline dist_t + calcRefineDistance(const void* vec, const tableint id) const { dist_t dist = fstdistfunc_(vec, getDataByInternalId(id), dist_func_param_); if (metric_type_ == Metric::COSINE) { dist /= data_norm_l2_[id]; @@ -242,6 +328,17 @@ class HierarchicalNSW : public AlgorithmInterface { return dist; } + void + prefetchData(const tableint id) const { +#if defined(USE_PREFETCH) + if constexpr (sq_enabled) { + _mm_prefetch(getSQDataByInternalId(id), _MM_HINT_T0); + } else { + _mm_prefetch(getDataByInternalId(id), _MM_HINT_T0); + } +#endif + } + std::priority_queue, std::vector>, CompareByFirst> searchBaseLayer(tableint ep_id, tableint cur_c, int layer) { auto& visited = visited_list_pool_->getFreeVisitedList(); @@ -278,11 +375,9 @@ class HierarchicalNSW : public AlgorithmInterface { } size_t size = getListCount((linklistsizeint*)data); tableint* datal = (tableint*)(data + 1); -#if defined(USE_PREFETCH) for (size_t j = 0; j < size; ++j) { - _mm_prefetch(getDataByInternalId(datal[j]), _MM_HINT_T0); + prefetchData(datal[j]); } -#endif for (size_t j = 0; j < size; j++) { tableint candidate_id = *(datal + j); // if (candidate_id == 0) continue; @@ -294,9 +389,7 @@ class HierarchicalNSW : public AlgorithmInterface { dist_t dist1 = calcDistance(cur_c, candidate_id); if (top_candidates.size() < ef_construction_ || lowerBound > dist1) { candidateSet.emplace(-dist1, candidate_id); -#if defined(USE_PREFETCH) - _mm_prefetch(getDataByInternalId(candidateSet.top().second), _MM_HINT_T0); -#endif + prefetchData(candidateSet.top().second); top_candidates.emplace(dist1, candidate_id); @@ -330,11 +423,9 @@ class HierarchicalNSW : public AlgorithmInterface { } float kAlpha = bitset.filter_ratio() / 2.0f; for (size_t i = 1; i <= size; ++i) { -#if defined(USE_PREFETCH) if (i + 1 <= size) { - _mm_prefetch(getDataByInternalId(list[i + 1]), _MM_HINT_T0); + prefetchData(list[i + 1]); } -#endif tableint v = list[i]; if (visited[v]) { if (feder_result != nullptr) { @@ -473,11 +564,9 @@ class HierarchicalNSW : public AlgorithmInterface { int* data = (int*)get_linklist0(current_id); size_t size = getListCount((linklistsizeint*)data); -#if defined(USE_PREFETCH) for (size_t j = 1; j <= size; ++j) { - _mm_prefetch(getDataByInternalId(data[j]), _MM_HINT_T0); + prefetchData(data[j]); } -#endif for (size_t j = 1; j <= size; j++) { int candidate_id = *(data + j); if (!visited[candidate_id]) { @@ -699,6 +788,10 @@ class HierarchicalNSW : public AlgorithmInterface { } fstdistfunc_ = space_->get_dist_func(); dist_func_param_ = space_->get_dist_func_param(); + if constexpr (sq_enabled) { + readBinaryPOD(input, alpha_); + fstdistfunc_sq_ = space_->get_dist_func_sq(); + } readBinaryPOD(input, offsetLevel0_); readBinaryPOD(input, max_elements_); @@ -712,6 +805,12 @@ class HierarchicalNSW : public AlgorithmInterface { readBinaryPOD(input, size_data_per_element_); readBinaryPOD(input, label_offset_); readBinaryPOD(input, offsetData_); + if constexpr (sq_enabled) { + offsetSQData_ = offsetData_; + if constexpr (has_raw_data) { + offsetSQData_ += data_size_; + } + } readBinaryPOD(input, maxlevel_); readBinaryPOD(input, enterpoint_node_); @@ -783,6 +882,9 @@ class HierarchicalNSW : public AlgorithmInterface { writeBinaryPOD(output, metric_type_); writeBinaryPOD(output, data_size_); writeBinaryPOD(output, *((size_t*)dist_func_param_)); + if constexpr (sq_enabled) { + writeBinaryPOD(output, alpha_); + } writeBinaryPOD(output, offsetLevel0_); writeBinaryPOD(output, max_elements_); @@ -811,6 +913,7 @@ class HierarchicalNSW : public AlgorithmInterface { if (linkListSize) output.write(linkLists_[i], linkListSize); } + // output.close(); } @@ -837,6 +940,10 @@ class HierarchicalNSW : public AlgorithmInterface { } fstdistfunc_ = space_->get_dist_func(); dist_func_param_ = space_->get_dist_func_param(); + if constexpr (sq_enabled) { + readBinaryPOD(input, alpha_); + fstdistfunc_sq_ = space_->get_dist_func_sq(); + } readBinaryPOD(input, offsetLevel0_); readBinaryPOD(input, max_elements_); @@ -850,6 +957,12 @@ class HierarchicalNSW : public AlgorithmInterface { readBinaryPOD(input, size_data_per_element_); readBinaryPOD(input, label_offset_); readBinaryPOD(input, offsetData_); + if constexpr (sq_enabled) { + offsetSQData_ = offsetData_; + if constexpr (has_raw_data) { + offsetSQData_ += data_size_; + } + } readBinaryPOD(input, maxlevel_); readBinaryPOD(input, enterpoint_node_); @@ -1013,11 +1126,9 @@ class HierarchicalNSW : public AlgorithmInterface { data = get_linklist_at_level(currObj, level); int size = getListCount(data); tableint* datal = (tableint*)(data + 1); -#if defined(USE_PREFETCH) for (int i = 0; i < size; ++i) { - _mm_prefetch(getDataByInternalId(datal[i]), _MM_HINT_T0); + prefetchData(datal[i]); } -#endif for (int i = 0; i < size; i++) { tableint cand = datal[i]; dist_t d = calcDistance(dataPoint, cand); @@ -1091,13 +1202,16 @@ class HierarchicalNSW : public AlgorithmInterface { tableint enterpoint_copy = enterpoint_node_; memset(data_level0_memory_ + cur_c * size_data_per_element_ + offsetLevel0_, 0, size_data_per_element_); - memcpy(getDataByInternalId(cur_c), data_point, data_size_); - - if (metric_type_ == Metric::COSINE) { - data_norm_l2_[cur_c] = - std::sqrt(faiss::fvec_norm_L2sqr((const float*)data_point, *(size_t*)(dist_func_param_))); + if constexpr (has_raw_data) { + memcpy(getDataByInternalId(cur_c), data_point, data_size_); + if (metric_type_ == Metric::COSINE) { + data_norm_l2_[cur_c] = + std::sqrt(faiss::fvec_norm_L2sqr((const float*)data_point, *(size_t*)(dist_func_param_))); + } + } + if constexpr (sq_enabled) { + encodeSQuant((const float*)data_point, (int8_t*)getSQDataByInternalId(cur_c)); } - if (curlevel) { linkLists_[cur_c] = (char*)malloc(size_links_per_element_ * curlevel + 1); if (linkLists_[cur_c] == nullptr) @@ -1183,7 +1297,11 @@ class HierarchicalNSW : public AlgorithmInterface { if (metric_type_ == Metric::HAMMING || metric_type_ == Metric::JACCARD) { vec_hash = knowhere::hash_binary_vec((const uint8_t*)query_data, *(size_t*)dist_func_param_); } else { - vec_hash = knowhere::hash_vec((const float*)query_data, *(size_t*)dist_func_param_); + if constexpr (sq_enabled) { + vec_hash = knowhere::hash_u8_vec((const uint8_t*)query_data, *(size_t*)dist_func_param_); + } else { + vec_hash = knowhere::hash_vec((const float*)query_data, *(size_t*)dist_func_param_); + } } // for tuning, do not use cache if ((param && param->for_tuning) || !lru_cache.try_get(vec_hash, currObj)) { @@ -1203,11 +1321,9 @@ class HierarchicalNSW : public AlgorithmInterface { metric_hops++; metric_distance_computations += size; tableint* datal = (tableint*)(data + 1); -#if defined(USE_PREFETCH) for (int i = 0; i < size; ++i) { - _mm_prefetch(getDataByInternalId(datal[i]), _MM_HINT_T0); + prefetchData(datal[i]); } -#endif for (int i = 0; i < size; i++) { tableint cand = datal[i]; if (cand < 0 || cand > max_elements_) @@ -1243,6 +1359,13 @@ class HierarchicalNSW : public AlgorithmInterface { query_data_norm = knowhere::CopyAndNormalizeVecs((const float*)query_data, 1, *(size_t*)dist_func_param_); query_data = query_data_norm.get(); } + std::unique_ptr query_data_sq; + const float* raw_data = (const float*)query_data; + if constexpr (sq_enabled) { + query_data_sq = std::make_unique(*(size_t*)dist_func_param_); + encodeSQuant((const float*)query_data, query_data_sq.get()); + query_data = query_data_sq.get(); + } // do bruteforce search when topk is super large if (k >= (cur_element_count * kHnswSearchBFTopkThreshold)) { @@ -1256,7 +1379,8 @@ class HierarchicalNSW : public AlgorithmInterface { double ratio = ((double)filtered_out_num) / bitset.size(); knowhere::knowhere_hnsw_bitset_ratio.Observe(ratio); #endif - if (filtered_out_num >= (cur_element_count * kHnswSearchKnnBFFilterThreshold) || k >= (cur_element_count - filtered_out_num) * kHnswSearchBFTopkThreshold) { + if (filtered_out_num >= (cur_element_count * kHnswSearchKnnBFFilterThreshold) || + k >= (cur_element_count - filtered_out_num) * kHnswSearchBFTopkThreshold) { return searchKnnBF(query_data, k, bitset); } } @@ -1274,8 +1398,19 @@ class HierarchicalNSW : public AlgorithmInterface { std::vector> result; size_t len = std::min(k, retset.size()); result.reserve(len); - for (int i = 0; i < len; ++i) { - result.emplace_back(retset[i].distance, (labeltype)retset[i].id); + if constexpr (sq_enabled && has_raw_data) { + knowhere::ResultMaxHeap max_heap(len); + for (int i = 0; i < retset.size(); ++i) { + max_heap.Push(calcRefineDistance(raw_data, retset[i].id), retset[i].id); + } + for (int64_t i = len - 1; i >= 0; --i) { + const auto op = max_heap.Pop(); + result.emplace_back(op.value()); + } + } else { + for (int i = 0; i < len; ++i) { + result.emplace_back(retset[i].distance, (labeltype)retset[i].id); + } } if (len > 0) { lru_cache.put(vec_hash, result[0].second); @@ -1385,6 +1520,13 @@ class HierarchicalNSW : public AlgorithmInterface { query_data = query_data_norm.get(); } + std::unique_ptr query_data_sq; + if constexpr (sq_enabled) { + query_data_sq = std::make_unique(*(size_t*)dist_func_param_); + encodeSQuant((const float*)query_data, query_data_sq.get()); + query_data = query_data_sq.get(); + } + // do bruteforce range search when ef is super large size_t ef = param ? param->ef_ : this->ef_; if (ef >= (cur_element_count * kHnswSearchBFTopkThreshold)) { @@ -1398,7 +1540,8 @@ class HierarchicalNSW : public AlgorithmInterface { double ratio = ((double)filtered_out_num) / bitset.size(); knowhere::knowhere_hnsw_bitset_ratio.Observe(ratio); #endif - if (filtered_out_num >= (cur_element_count * kHnswSearchRangeBFFilterThreshold) || ef >= (cur_element_count - filtered_out_num) * kHnswSearchBFTopkThreshold) { + if (filtered_out_num >= (cur_element_count * kHnswSearchRangeBFFilterThreshold) || + ef >= (cur_element_count - filtered_out_num) * kHnswSearchBFTopkThreshold) { return searchRangeBF(query_data, radius, bitset); } } @@ -1464,9 +1607,10 @@ class HierarchicalNSW : public AlgorithmInterface { for (tableint i = 0; i < cur_element_count; ++i) { if (element_levels_[i] >= level) { if (!visited[i]) { - if (level > 0) { // for upper level, directly add edges since nodes num is usually small and fast to search its neighbors + if (level > 0) { // for upper level, directly add edges since nodes num is usually small and + // fast to search its neighbors repairGraphConnectivity(i, level); - } else { // for base level, collect the unreachable nodes and repair them concurrently + } else { // for base level, collect the unreachable nodes and repair them concurrently unreached.push_back(i); } } @@ -1513,8 +1657,8 @@ class HierarchicalNSW : public AlgorithmInterface { } } } - std::priority_queue, std::vector>, CompareByFirst> candidates = searchBaseLayer( - currObj, cur_c, level); + std::priority_queue, std::vector>, CompareByFirst> + candidates = searchBaseLayer(currObj, cur_c, level); // get sorted id std::vector top_candidate_ids(candidates.size()); @@ -1531,10 +1675,10 @@ class HierarchicalNSW : public AlgorithmInterface { // try to connect candidate to the element // add an edge if there is space - std::unique_lock lock(link_list_locks_[cand_id]); - linklistsizeint *ll_cand = get_linklist_at_level(cand_id, level); + std::unique_lock lock(link_list_locks_[cand_id]); + linklistsizeint* ll_cand = get_linklist_at_level(cand_id, level); size_t size = getListCount(ll_cand); - tableint *data_cand = (tableint *) (ll_cand + 1); + tableint* data_cand = (tableint*)(ll_cand + 1); if (size < m_max) { data_cand[size] = cur_c; setListCount(ll_cand, size + 1); diff --git a/thirdparty/hnswlib/hnswlib/hnswlib.h b/thirdparty/hnswlib/hnswlib/hnswlib.h index 741043819..1b8daa6d2 100644 --- a/thirdparty/hnswlib/hnswlib/hnswlib.h +++ b/thirdparty/hnswlib/hnswlib/hnswlib.h @@ -161,6 +161,11 @@ class SpaceInterface { virtual DISTFUNC get_dist_func() = 0; + virtual DISTFUNC + get_dist_func_sq() { + throw std::runtime_error("Not implemented\n"); + } + virtual void* get_dist_func_param() = 0; diff --git a/thirdparty/hnswlib/hnswlib/space_cosine.h b/thirdparty/hnswlib/hnswlib/space_cosine.h index 237222c34..bd831eef7 100644 --- a/thirdparty/hnswlib/hnswlib/space_cosine.h +++ b/thirdparty/hnswlib/hnswlib/space_cosine.h @@ -15,14 +15,21 @@ CosineDistance(const void* pVect1, const void* pVect2, const void* qty_ptr) { return -1.0f * Cosine(pVect1, pVect2, qty_ptr); } +static float +CosineSQ8Distance(const void* pVect1, const void* pVect2, const void* qty_ptr) { + return -1.0f * faiss::ivec_inner_product((const int8_t*)pVect1, (const int8_t*)pVect2, *(size_t*)qty_ptr); +} + class CosineSpace : public SpaceInterface { DISTFUNC fstdistfunc_; + DISTFUNC fstdistfunc_sq_; size_t data_size_; size_t dim_; public: CosineSpace(size_t dim) { fstdistfunc_ = CosineDistance; + fstdistfunc_sq_ = CosineSQ8Distance; dim_ = dim; data_size_ = dim * sizeof(float); } @@ -37,6 +44,11 @@ class CosineSpace : public SpaceInterface { return fstdistfunc_; } + DISTFUNC + get_dist_func_sq() { + return fstdistfunc_sq_; + } + void* get_dist_func_param() { return &dim_; diff --git a/thirdparty/hnswlib/hnswlib/space_ip.h b/thirdparty/hnswlib/hnswlib/space_ip.h index 7383ff3a2..c3c8c33da 100644 --- a/thirdparty/hnswlib/hnswlib/space_ip.h +++ b/thirdparty/hnswlib/hnswlib/space_ip.h @@ -24,6 +24,11 @@ InnerProductDistance(const void* pVect1, const void* pVect2, const void* qty_ptr return -1.0f * InnerProduct(pVect1, pVect2, qty_ptr); } +static float +InnerProductSQ8Distance(const void* pVect1, const void* pVect2, const void* qty_ptr) { + return -1.0f * faiss::ivec_inner_product((const int8_t*)pVect1, (const int8_t*)pVect2, *(size_t*)qty_ptr); +} + #if defined(USE_AVX) // Favor using AVX if available. @@ -321,12 +326,14 @@ InnerProductDistanceSIMD4ExtResiduals(const void* pVect1v, const void* pVect2v, class InnerProductSpace : public SpaceInterface { DISTFUNC fstdistfunc_; + DISTFUNC fstdistfunc_sq_; size_t data_size_; size_t dim_; public: InnerProductSpace(size_t dim) { fstdistfunc_ = InnerProductDistance; + fstdistfunc_sq_ = InnerProductSQ8Distance; #if 0 /* use FAISS distance calculation algorithm instead */ #if defined(USE_AVX) || defined(USE_SSE) || defined(USE_AVX512) #if defined(USE_AVX512) @@ -374,6 +381,11 @@ class InnerProductSpace : public SpaceInterface { return fstdistfunc_; } + DISTFUNC + get_dist_func_sq() { + return fstdistfunc_sq_; + } + void* get_dist_func_param() { return &dim_; diff --git a/thirdparty/hnswlib/hnswlib/space_l2.h b/thirdparty/hnswlib/hnswlib/space_l2.h index aa255ecab..dfdc15504 100644 --- a/thirdparty/hnswlib/hnswlib/space_l2.h +++ b/thirdparty/hnswlib/hnswlib/space_l2.h @@ -25,6 +25,11 @@ L2Sqr(const void* pVect1v, const void* pVect2v, const void* qty_ptr) { #endif } +static float +L2SqrSQ8(const void* pVect1v, const void* pVect2v, const void* qty_ptr) { + return faiss::ivec_L2sqr((const int8_t*)pVect1v, (const int8_t*)pVect2v, *(size_t*)qty_ptr); +} + #if defined(USE_AVX512) // Favor using AVX512 if available. @@ -210,12 +215,14 @@ L2SqrSIMD4ExtResiduals(const void* pVect1v, const void* pVect2v, const void* qty class L2Space : public SpaceInterface { DISTFUNC fstdistfunc_; + DISTFUNC fstdistfunc_sq_; size_t data_size_; size_t dim_; public: L2Space(size_t dim) { fstdistfunc_ = L2Sqr; + fstdistfunc_sq_ = L2SqrSQ8; #if 0 /* use FAISS distance calculation algorithm instead */ #if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512) #if defined(USE_AVX512) @@ -252,6 +259,11 @@ class L2Space : public SpaceInterface { return fstdistfunc_; } + DISTFUNC + get_dist_func_sq() { + return fstdistfunc_sq_; + } + void* get_dist_func_param() { return &dim_;