Skip to content

Commit

Permalink
Add sq for hnsw (zilliztech#311)
Browse files Browse the repository at this point in the history
Signed-off-by: zh Wang <[email protected]>
  • Loading branch information
hhy3 authored Mar 8, 2024
1 parent 51051c3 commit e44ba8d
Show file tree
Hide file tree
Showing 21 changed files with 452 additions and 56 deletions.
2 changes: 2 additions & 0 deletions include/knowhere/comp/index_param.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ constexpr const char* INDEX_RAFT_IVFPQ = "GPU_RAFT_IVF_PQ";
constexpr const char* INDEX_RAFT_CAGRA = "GPU_RAFT_CAGRA";

constexpr const char* INDEX_HNSW = "HNSW";
constexpr const char* INDEX_HNSW_SQ8 = "HNSW_SQ8";
constexpr const char* INDEX_HNSW_SQ8_REFINE = "HNSW_SQ8_REFINE";
constexpr const char* INDEX_DISKANN = "DISKANN";

constexpr const char* INDEX_SPARSE_INVERTED_INDEX = "SPARSE_INVERTED_INDEX";
Expand Down
9 changes: 9 additions & 0 deletions include/knowhere/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,15 @@ hash_vec(const float* x, size_t d) {
return h;
}

inline uint64_t
hash_u8_vec(const uint8_t* x, size_t d) {
uint64_t h = seed;
for (size_t i = 0; i < d; ++i) {
h = h * 13331 + *(x + i);
}
return h;
}

inline uint64_t
hash_binary_vec(const uint8_t* x, size_t d) {
size_t len = (d + 7) / 8;
Expand Down
43 changes: 31 additions & 12 deletions src/index/hnsw/hnsw.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@
#include "knowhere/utils.h"

namespace knowhere {
template <typename DataType>

using hnswlib::QuantType;

template <typename DataType, QuantType quant_type = QuantType::None>
class HnswIndexNode : public IndexNode {
static_assert(std::is_same_v<DataType, fp32> || std::is_same_v<DataType, bin1>,
"HnswIndexNode only support float/bianry");
Expand Down Expand Up @@ -60,8 +63,8 @@ class HnswIndexNode : public IndexNode {
LOG_KNOWHERE_WARNING_ << "metric type not support in hnsw: " << hnsw_cfg.metric_type.value();
return Status::invalid_metric_type;
}
auto index = new (std::nothrow)
hnswlib::HierarchicalNSW<DistType>(space, rows, hnsw_cfg.M.value(), hnsw_cfg.efConstruction.value());
auto index = new (std::nothrow) hnswlib::HierarchicalNSW<DistType, quant_type>(space, rows, hnsw_cfg.M.value(),
hnsw_cfg.efConstruction.value());
if (index == nullptr) {
LOG_KNOWHERE_WARNING_ << "memory malloc error.";
return Status::malloc_error;
Expand All @@ -71,6 +74,9 @@ class HnswIndexNode : public IndexNode {
LOG_KNOWHERE_WARNING_ << "index not empty, deleted old index";
}
this->index_ = index;
if constexpr (quant_type != QuantType::None) {
this->index_->trainSQuant((const float*)dataset.GetTensor(), rows);
}
return Status::success;
}

Expand Down Expand Up @@ -219,7 +225,7 @@ class HnswIndexNode : public IndexNode {
private:
class iterator : public IndexNode::iterator {
public:
iterator(const hnswlib::HierarchicalNSW<DistType>* index, const char* query, const bool transform,
iterator(const hnswlib::HierarchicalNSW<DistType, quant_type>* index, const char* query, const bool transform,
const BitsetView& bitset, const bool for_tuning = false, const size_t seed_ef = kIteratorSeedEf)
: index_(index),
transform_(transform),
Expand Down Expand Up @@ -252,7 +258,7 @@ class HnswIndexNode : public IndexNode {
has_next_ = false;
}
}
const hnswlib::HierarchicalNSW<DistType>* index_;
const hnswlib::HierarchicalNSW<DistType, quant_type>* index_;
const bool transform_;
std::unique_ptr<hnswlib::IteratorWorkspace> workspace_;
bool has_next_;
Expand Down Expand Up @@ -396,7 +402,7 @@ class HnswIndexNode : public IndexNode {

bool
HasRawData(const std::string& metric_type) const override {
return true;
return quant_type == QuantType::None || quant_type == QuantType::SQ8Refine;
}

expected<DataSetPtr>
Expand Down Expand Up @@ -460,8 +466,8 @@ class HnswIndexNode : public IndexNode {

MemoryIOReader reader(binary->data.get(), binary->size);

hnswlib::SpaceInterface<DistType>* space = nullptr;
index_ = new (std::nothrow) hnswlib::HierarchicalNSW<DistType>(space);
hnswlib::SpaceInterface<float>* space = nullptr;
index_ = new (std::nothrow) hnswlib::HierarchicalNSW<DistType, quant_type>(space);
index_->loadIndex(reader);
LOG_KNOWHERE_INFO_ << "Loaded HNSW index. #points num:" << index_->max_elements_ << " #M:" << index_->M_
<< " #max level:" << index_->maxlevel_
Expand All @@ -480,8 +486,8 @@ class HnswIndexNode : public IndexNode {
delete index_;
}
try {
hnswlib::SpaceInterface<DistType>* space = nullptr;
index_ = new (std::nothrow) hnswlib::HierarchicalNSW<DistType>(space);
hnswlib::SpaceInterface<float>* space = nullptr;
index_ = new (std::nothrow) hnswlib::HierarchicalNSW<DistType, quant_type>(space);
index_->loadIndex(filename, config);
} catch (std::exception& e) {
LOG_KNOWHERE_WARNING_ << "hnsw inner error: " << e.what();
Expand Down Expand Up @@ -521,7 +527,14 @@ class HnswIndexNode : public IndexNode {

std::string
Type() const override {
return knowhere::IndexEnum::INDEX_HNSW;
if constexpr (quant_type == QuantType::SQ8) {
return knowhere::IndexEnum::INDEX_HNSW_SQ8;
} else if constexpr (quant_type == QuantType::SQ8Refine) {
return knowhere::IndexEnum::INDEX_HNSW_SQ8_REFINE;

} else {
return knowhere::IndexEnum::INDEX_HNSW;
}
}

~HnswIndexNode() override {
Expand Down Expand Up @@ -568,12 +581,18 @@ class HnswIndexNode : public IndexNode {
}

private:
hnswlib::HierarchicalNSW<DistType>* index_;
hnswlib::HierarchicalNSW<DistType, quant_type>* index_;
std::shared_ptr<ThreadPool> search_pool_;
};

KNOWHERE_SIMPLE_REGISTER_GLOBAL(HNSW, HnswIndexNode, fp32);
KNOWHERE_SIMPLE_REGISTER_GLOBAL(HNSW_SQ8, HnswIndexNode, fp32, QuantType::SQ8);
KNOWHERE_SIMPLE_REGISTER_GLOBAL(HNSW_SQ8_REFINE, HnswIndexNode, fp32, QuantType::SQ8Refine);
KNOWHERE_SIMPLE_REGISTER_GLOBAL(HNSW, HnswIndexNode, bin1);
KNOWHERE_MOCK_REGISTER_GLOBAL(HNSW, HnswIndexNode, fp16);
KNOWHERE_MOCK_REGISTER_GLOBAL(HNSW_SQ8, HnswIndexNode, fp16, QuantType::SQ8);
KNOWHERE_MOCK_REGISTER_GLOBAL(HNSW_SQ8_REFINE, HnswIndexNode, fp16, QuantType::SQ8Refine);
KNOWHERE_MOCK_REGISTER_GLOBAL(HNSW, HnswIndexNode, bf16);
KNOWHERE_MOCK_REGISTER_GLOBAL(HNSW_SQ8, HnswIndexNode, bf16, QuantType::SQ8);
KNOWHERE_MOCK_REGISTER_GLOBAL(HNSW_SQ8_REFINE, HnswIndexNode, bf16, QuantType::SQ8Refine);
} // namespace knowhere
23 changes: 23 additions & 0 deletions src/simd/distances_avx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -215,5 +215,28 @@ fvec_L2sqr_batch_4_avx(const float* x, const float* y0, const float* y1, const f
}
FAISS_PRAGMA_IMPRECISE_FUNCTION_END

// trust the compiler to unroll this properly
int32_t
ivec_inner_product_avx(const int8_t* x, const int8_t* y, size_t d) {
size_t i;
int32_t res = 0;
for (i = 0; i < d; i++) {
res += (int32_t)x[i] * y[i];
}
return res;
}

// trust the compiler to unroll this properly
int32_t
ivec_L2sqr_avx(const int8_t* x, const int8_t* y, size_t d) {
size_t i;
int32_t res = 0;
for (i = 0; i < d; i++) {
const int32_t tmp = (int32_t)x[i] - (int32_t)y[i];
res += tmp * tmp;
}
return res;
}

} // namespace faiss
#endif
6 changes: 6 additions & 0 deletions src/simd/distances_avx.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ void
fvec_L2sqr_batch_4_avx(const float* x, const float* y0, const float* y1, const float* y2, const float* y3,
const size_t d, float& dis0, float& dis1, float& dis2, float& dis3);

int32_t
ivec_inner_product_avx(const int8_t* x, const int8_t* y, size_t d);

int32_t
ivec_L2sqr_avx(const int8_t* x, const int8_t* y, size_t d);

} // namespace faiss

#endif /* DISTANCES_AVX_H */
23 changes: 23 additions & 0 deletions src/simd/distances_avx512.cc
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,29 @@ fvec_L2sqr_batch_4_avx512(const float* x, const float* y0, const float* y1, cons
}
FAISS_PRAGMA_IMPRECISE_FUNCTION_END

// trust the compiler to unroll this properly
int32_t
ivec_inner_product_avx512(const int8_t* x, const int8_t* y, size_t d) {
size_t i;
int32_t res = 0;
for (i = 0; i < d; i++) {
res += (int32_t)x[i] * y[i];
}
return res;
}

// trust the compiler to unroll this properly
int32_t
ivec_L2sqr_avx512(const int8_t* x, const int8_t* y, size_t d) {
size_t i;
int32_t res = 0;
for (i = 0; i < d; i++) {
const int32_t tmp = (int32_t)x[i] - (int32_t)y[i];
res += tmp * tmp;
}
return res;
}

} // namespace faiss

#endif
6 changes: 6 additions & 0 deletions src/simd/distances_avx512.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ void
fvec_L2sqr_batch_4_avx512(const float* x, const float* y0, const float* y1, const float* y2, const float* y3,
const size_t d, float& dis0, float& dis1, float& dis2, float& dis3);

int32_t
ivec_inner_product_avx512(const int8_t* x, const int8_t* y, size_t d);

int32_t
ivec_L2sqr_avx512(const int8_t* x, const int8_t* y, size_t d);

} // namespace faiss

#endif /* DISTANCES_AVX512_H */
23 changes: 23 additions & 0 deletions src/simd/distances_neon.cc
Original file line number Diff line number Diff line change
Expand Up @@ -635,5 +635,28 @@ fvec_madd_and_argmin_neon(size_t n, const float* a, float bf, const float* b, fl
return ans_;
}

// trust the compiler to unroll this properly
int32_t
ivec_inner_product_neon(const int8_t* x, const int8_t* y, size_t d) {
size_t i;
int32_t res = 0;
for (i = 0; i < d; i++) {
res += (int32_t)x[i] * y[i];
}
return res;
}

// trust the compiler to unroll this properly
int32_t
ivec_L2sqr_neon(const int8_t* x, const int8_t* y, size_t d) {
size_t i;
int32_t res = 0;
for (i = 0; i < d; i++) {
const int32_t tmp = (int32_t)x[i] - (int32_t)y[i];
res += tmp * tmp;
}
return res;
}

} // namespace faiss
#endif
7 changes: 7 additions & 0 deletions src/simd/distances_neon.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#ifndef DISTANCES_NEON_H
#define DISTANCES_NEON_H

#include <cstdint>
#include <cstdio>

namespace faiss {
Expand Down Expand Up @@ -50,6 +51,12 @@ fvec_madd_neon(size_t n, const float* a, float bf, const float* b, float* c);
int
fvec_madd_and_argmin_neon(size_t n, const float* a, float bf, const float* b, float* c);

int32_t
ivec_inner_product_neon(const int8_t* x, const int8_t* y, size_t d);

int32_t
ivec_L2sqr_neon(const int8_t* x, const int8_t* y, size_t d);

} // namespace faiss

#endif /* DISTANCES_NEON_H */
21 changes: 21 additions & 0 deletions src/simd/distances_ref.cc
Original file line number Diff line number Diff line change
Expand Up @@ -212,4 +212,25 @@ fvec_L2sqr_batch_4_ref(const float* x, const float* y0, const float* y1, const f
dis3 = d3;
}

int32_t
ivec_inner_product_ref(const int8_t* x, const int8_t* y, size_t d) {
size_t i;
int32_t res = 0;
for (i = 0; i < d; i++) {
res += (int32_t)x[i] * y[i];
}
return res;
}

int32_t
ivec_L2sqr_ref(const int8_t* x, const int8_t* y, size_t d) {
size_t i;
int32_t res = 0;
for (i = 0; i < d; i++) {
const int32_t tmp = (int32_t)x[i] - (int32_t)y[i];
res += tmp * tmp;
}
return res;
}

} // namespace faiss
7 changes: 7 additions & 0 deletions src/simd/distances_ref.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#ifndef DISTANCES_REF_H
#define DISTANCES_REF_H

#include <cstdint>
#include <cstdio>

namespace faiss {
Expand Down Expand Up @@ -71,6 +72,12 @@ void
fvec_L2sqr_batch_4_ref(const float* x, const float* y0, const float* y1, const float* y2, const float* y3,
const size_t d, float& dis0, float& dis1, float& dis2, float& dis3);

int32_t
ivec_inner_product_ref(const int8_t* x, const int8_t* y, size_t d);

int32_t
ivec_L2sqr_ref(const int8_t* x, const int8_t* y, size_t d);

} // namespace faiss

#endif /* DISTANCES_REF_H */
23 changes: 23 additions & 0 deletions src/simd/distances_sse.cc
Original file line number Diff line number Diff line change
Expand Up @@ -377,5 +377,28 @@ fvec_madd_and_argmin_sse(size_t n, const float* a, float bf, const float* b, flo
return _mm_cvtsi128_si32(imin4);
}

// trust the compiler to unroll this properly
int32_t
ivec_inner_product_sse(const int8_t* x, const int8_t* y, size_t d) {
size_t i;
int32_t res = 0;
for (i = 0; i < d; i++) {
res += (int32_t)x[i] * y[i];
}
return res;
}

// trust the compiler to unroll this properly
int32_t
ivec_L2sqr_sse(const int8_t* x, const int8_t* y, size_t d) {
size_t i;
int32_t res = 0;
for (i = 0; i < d; i++) {
const int32_t tmp = (int32_t)x[i] - (int32_t)y[i];
res += tmp * tmp;
}
return res;
}

} // namespace faiss
#endif
7 changes: 7 additions & 0 deletions src/simd/distances_sse.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#ifndef DISTANCES_SSE_H
#define DISTANCES_SSE_H

#include <cstdint>
#include <cstdio>
namespace faiss {

Expand Down Expand Up @@ -46,6 +47,12 @@ fvec_madd_sse(size_t n, const float* a, float bf, const float* b, float* c);
int
fvec_madd_and_argmin_sse(size_t n, const float* a, float bf, const float* b, float* c);

int32_t
ivec_inner_product_sse(const int8_t* x, const int8_t* y, size_t d);

int32_t
ivec_L2sqr_sse(const int8_t* x, const int8_t* y, size_t d);

} // namespace faiss

#endif /* DISTANCES_SSE_H */
Loading

0 comments on commit e44ba8d

Please sign in to comment.