Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sq for hnsw #311

Merged
merged 1 commit into from
Mar 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/knowhere/comp/index_param.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ constexpr const char* INDEX_RAFT_IVFPQ = "GPU_RAFT_IVF_PQ";
constexpr const char* INDEX_RAFT_CAGRA = "GPU_RAFT_CAGRA";

constexpr const char* INDEX_HNSW = "HNSW";
constexpr const char* INDEX_HNSW_SQ8 = "HNSW_SQ8";
constexpr const char* INDEX_HNSW_SQ8_REFINE = "HNSW_SQ8_REFINE";
constexpr const char* INDEX_DISKANN = "DISKANN";

constexpr const char* INDEX_SPARSE_INVERTED_INDEX = "SPARSE_INVERTED_INDEX";
Expand Down
9 changes: 9 additions & 0 deletions include/knowhere/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,15 @@ hash_vec(const float* x, size_t d) {
return h;
}

inline uint64_t
hash_u8_vec(const uint8_t* x, size_t d) {
uint64_t h = seed;
for (size_t i = 0; i < d; ++i) {
h = h * 13331 + *(x + i);
}
return h;
}

inline uint64_t
hash_binary_vec(const uint8_t* x, size_t d) {
size_t len = (d + 7) / 8;
Expand Down
43 changes: 31 additions & 12 deletions src/index/hnsw/hnsw.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@
#include "knowhere/utils.h"

namespace knowhere {
template <typename DataType>

using hnswlib::QuantType;

template <typename DataType, QuantType quant_type = QuantType::None>
class HnswIndexNode : public IndexNode {
static_assert(std::is_same_v<DataType, fp32> || std::is_same_v<DataType, bin1>,
"HnswIndexNode only support float/bianry");
Expand Down Expand Up @@ -60,8 +63,8 @@ class HnswIndexNode : public IndexNode {
LOG_KNOWHERE_WARNING_ << "metric type not support in hnsw: " << hnsw_cfg.metric_type.value();
return Status::invalid_metric_type;
}
auto index = new (std::nothrow)
hnswlib::HierarchicalNSW<DistType>(space, rows, hnsw_cfg.M.value(), hnsw_cfg.efConstruction.value());
auto index = new (std::nothrow) hnswlib::HierarchicalNSW<DistType, quant_type>(space, rows, hnsw_cfg.M.value(),
hnsw_cfg.efConstruction.value());
if (index == nullptr) {
LOG_KNOWHERE_WARNING_ << "memory malloc error.";
return Status::malloc_error;
Expand All @@ -71,6 +74,9 @@ class HnswIndexNode : public IndexNode {
LOG_KNOWHERE_WARNING_ << "index not empty, deleted old index";
}
this->index_ = index;
if constexpr (quant_type != QuantType::None) {
this->index_->trainSQuant((const float*)dataset.GetTensor(), rows);
}
return Status::success;
}

Expand Down Expand Up @@ -219,7 +225,7 @@ class HnswIndexNode : public IndexNode {
private:
class iterator : public IndexNode::iterator {
public:
iterator(const hnswlib::HierarchicalNSW<DistType>* index, const char* query, const bool transform,
iterator(const hnswlib::HierarchicalNSW<DistType, quant_type>* index, const char* query, const bool transform,
const BitsetView& bitset, const bool for_tuning = false, const size_t seed_ef = kIteratorSeedEf)
: index_(index),
transform_(transform),
Expand Down Expand Up @@ -252,7 +258,7 @@ class HnswIndexNode : public IndexNode {
has_next_ = false;
}
}
const hnswlib::HierarchicalNSW<DistType>* index_;
const hnswlib::HierarchicalNSW<DistType, quant_type>* index_;
const bool transform_;
std::unique_ptr<hnswlib::IteratorWorkspace> workspace_;
bool has_next_;
Expand Down Expand Up @@ -396,7 +402,7 @@ class HnswIndexNode : public IndexNode {

bool
HasRawData(const std::string& metric_type) const override {
return true;
return quant_type == QuantType::None || quant_type == QuantType::SQ8Refine;
}

expected<DataSetPtr>
Expand Down Expand Up @@ -460,8 +466,8 @@ class HnswIndexNode : public IndexNode {

MemoryIOReader reader(binary->data.get(), binary->size);

hnswlib::SpaceInterface<DistType>* space = nullptr;
index_ = new (std::nothrow) hnswlib::HierarchicalNSW<DistType>(space);
hnswlib::SpaceInterface<float>* space = nullptr;
index_ = new (std::nothrow) hnswlib::HierarchicalNSW<DistType, quant_type>(space);
index_->loadIndex(reader);
LOG_KNOWHERE_INFO_ << "Loaded HNSW index. #points num:" << index_->max_elements_ << " #M:" << index_->M_
<< " #max level:" << index_->maxlevel_
Expand All @@ -480,8 +486,8 @@ class HnswIndexNode : public IndexNode {
delete index_;
}
try {
hnswlib::SpaceInterface<DistType>* space = nullptr;
index_ = new (std::nothrow) hnswlib::HierarchicalNSW<DistType>(space);
hnswlib::SpaceInterface<float>* space = nullptr;
index_ = new (std::nothrow) hnswlib::HierarchicalNSW<DistType, quant_type>(space);
index_->loadIndex(filename, config);
} catch (std::exception& e) {
LOG_KNOWHERE_WARNING_ << "hnsw inner error: " << e.what();
Expand Down Expand Up @@ -521,7 +527,14 @@ class HnswIndexNode : public IndexNode {

std::string
Type() const override {
return knowhere::IndexEnum::INDEX_HNSW;
if constexpr (quant_type == QuantType::SQ8) {
return knowhere::IndexEnum::INDEX_HNSW_SQ8;
} else if constexpr (quant_type == QuantType::SQ8Refine) {
return knowhere::IndexEnum::INDEX_HNSW_SQ8_REFINE;

} else {
return knowhere::IndexEnum::INDEX_HNSW;
}
}

~HnswIndexNode() override {
Expand Down Expand Up @@ -568,12 +581,18 @@ class HnswIndexNode : public IndexNode {
}

private:
hnswlib::HierarchicalNSW<DistType>* index_;
hnswlib::HierarchicalNSW<DistType, quant_type>* index_;
std::shared_ptr<ThreadPool> search_pool_;
};

KNOWHERE_SIMPLE_REGISTER_GLOBAL(HNSW, HnswIndexNode, fp32);
KNOWHERE_SIMPLE_REGISTER_GLOBAL(HNSW_SQ8, HnswIndexNode, fp32, QuantType::SQ8);
KNOWHERE_SIMPLE_REGISTER_GLOBAL(HNSW_SQ8_REFINE, HnswIndexNode, fp32, QuantType::SQ8Refine);
KNOWHERE_SIMPLE_REGISTER_GLOBAL(HNSW, HnswIndexNode, bin1);
KNOWHERE_MOCK_REGISTER_GLOBAL(HNSW, HnswIndexNode, fp16);
KNOWHERE_MOCK_REGISTER_GLOBAL(HNSW_SQ8, HnswIndexNode, fp16, QuantType::SQ8);
KNOWHERE_MOCK_REGISTER_GLOBAL(HNSW_SQ8_REFINE, HnswIndexNode, fp16, QuantType::SQ8Refine);
KNOWHERE_MOCK_REGISTER_GLOBAL(HNSW, HnswIndexNode, bf16);
KNOWHERE_MOCK_REGISTER_GLOBAL(HNSW_SQ8, HnswIndexNode, bf16, QuantType::SQ8);
KNOWHERE_MOCK_REGISTER_GLOBAL(HNSW_SQ8_REFINE, HnswIndexNode, bf16, QuantType::SQ8Refine);
} // namespace knowhere
23 changes: 23 additions & 0 deletions src/simd/distances_avx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -215,5 +215,28 @@ fvec_L2sqr_batch_4_avx(const float* x, const float* y0, const float* y1, const f
}
FAISS_PRAGMA_IMPRECISE_FUNCTION_END

// trust the compiler to unroll this properly
int32_t
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please confirm that it is int32_t, not uint32_t

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There seems no big difference between int32_t and uint32_t

ivec_inner_product_avx(const int8_t* x, const int8_t* y, size_t d) {
size_t i;
int32_t res = 0;
for (i = 0; i < d; i++) {
res += (int32_t)x[i] * y[i];
}
return res;
}

// trust the compiler to unroll this properly
int32_t
ivec_L2sqr_avx(const int8_t* x, const int8_t* y, size_t d) {
size_t i;
int32_t res = 0;
for (i = 0; i < d; i++) {
const int32_t tmp = (int32_t)x[i] - (int32_t)y[i];
res += tmp * tmp;
}
return res;
}

} // namespace faiss
#endif
6 changes: 6 additions & 0 deletions src/simd/distances_avx.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ void
fvec_L2sqr_batch_4_avx(const float* x, const float* y0, const float* y1, const float* y2, const float* y3,
const size_t d, float& dis0, float& dis1, float& dis2, float& dis3);

int32_t
ivec_inner_product_avx(const int8_t* x, const int8_t* y, size_t d);

int32_t
ivec_L2sqr_avx(const int8_t* x, const int8_t* y, size_t d);

} // namespace faiss

#endif /* DISTANCES_AVX_H */
23 changes: 23 additions & 0 deletions src/simd/distances_avx512.cc
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,29 @@ fvec_L2sqr_batch_4_avx512(const float* x, const float* y0, const float* y1, cons
}
FAISS_PRAGMA_IMPRECISE_FUNCTION_END

// trust the compiler to unroll this properly
int32_t
ivec_inner_product_avx512(const int8_t* x, const int8_t* y, size_t d) {
size_t i;
int32_t res = 0;
for (i = 0; i < d; i++) {
res += (int32_t)x[i] * y[i];
}
return res;
}

// trust the compiler to unroll this properly
int32_t
ivec_L2sqr_avx512(const int8_t* x, const int8_t* y, size_t d) {
size_t i;
int32_t res = 0;
for (i = 0; i < d; i++) {
const int32_t tmp = (int32_t)x[i] - (int32_t)y[i];
res += tmp * tmp;
}
return res;
}

} // namespace faiss

#endif
6 changes: 6 additions & 0 deletions src/simd/distances_avx512.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ void
fvec_L2sqr_batch_4_avx512(const float* x, const float* y0, const float* y1, const float* y2, const float* y3,
const size_t d, float& dis0, float& dis1, float& dis2, float& dis3);

int32_t
ivec_inner_product_avx512(const int8_t* x, const int8_t* y, size_t d);

int32_t
ivec_L2sqr_avx512(const int8_t* x, const int8_t* y, size_t d);

} // namespace faiss

#endif /* DISTANCES_AVX512_H */
23 changes: 23 additions & 0 deletions src/simd/distances_neon.cc
Original file line number Diff line number Diff line change
Expand Up @@ -635,5 +635,28 @@ fvec_madd_and_argmin_neon(size_t n, const float* a, float bf, const float* b, fl
return ans_;
}

// trust the compiler to unroll this properly
int32_t
ivec_inner_product_neon(const int8_t* x, const int8_t* y, size_t d) {
size_t i;
int32_t res = 0;
for (i = 0; i < d; i++) {
res += (int32_t)x[i] * y[i];
}
return res;
}

// trust the compiler to unroll this properly
int32_t
ivec_L2sqr_neon(const int8_t* x, const int8_t* y, size_t d) {
size_t i;
int32_t res = 0;
for (i = 0; i < d; i++) {
const int32_t tmp = (int32_t)x[i] - (int32_t)y[i];
res += tmp * tmp;
}
return res;
}

} // namespace faiss
#endif
7 changes: 7 additions & 0 deletions src/simd/distances_neon.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#ifndef DISTANCES_NEON_H
#define DISTANCES_NEON_H

#include <cstdint>
#include <cstdio>

namespace faiss {
Expand Down Expand Up @@ -50,6 +51,12 @@ fvec_madd_neon(size_t n, const float* a, float bf, const float* b, float* c);
int
fvec_madd_and_argmin_neon(size_t n, const float* a, float bf, const float* b, float* c);

int32_t
ivec_inner_product_neon(const int8_t* x, const int8_t* y, size_t d);

int32_t
ivec_L2sqr_neon(const int8_t* x, const int8_t* y, size_t d);

} // namespace faiss

#endif /* DISTANCES_NEON_H */
21 changes: 21 additions & 0 deletions src/simd/distances_ref.cc
Original file line number Diff line number Diff line change
Expand Up @@ -212,4 +212,25 @@ fvec_L2sqr_batch_4_ref(const float* x, const float* y0, const float* y1, const f
dis3 = d3;
}

int32_t
ivec_inner_product_ref(const int8_t* x, const int8_t* y, size_t d) {
size_t i;
int32_t res = 0;
for (i = 0; i < d; i++) {
res += (int32_t)x[i] * y[i];
}
return res;
}

int32_t
ivec_L2sqr_ref(const int8_t* x, const int8_t* y, size_t d) {
size_t i;
int32_t res = 0;
for (i = 0; i < d; i++) {
const int32_t tmp = (int32_t)x[i] - (int32_t)y[i];
res += tmp * tmp;
}
return res;
}

} // namespace faiss
7 changes: 7 additions & 0 deletions src/simd/distances_ref.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#ifndef DISTANCES_REF_H
#define DISTANCES_REF_H

#include <cstdint>
#include <cstdio>

namespace faiss {
Expand Down Expand Up @@ -71,6 +72,12 @@ void
fvec_L2sqr_batch_4_ref(const float* x, const float* y0, const float* y1, const float* y2, const float* y3,
const size_t d, float& dis0, float& dis1, float& dis2, float& dis3);

int32_t
ivec_inner_product_ref(const int8_t* x, const int8_t* y, size_t d);

int32_t
ivec_L2sqr_ref(const int8_t* x, const int8_t* y, size_t d);

} // namespace faiss

#endif /* DISTANCES_REF_H */
23 changes: 23 additions & 0 deletions src/simd/distances_sse.cc
Original file line number Diff line number Diff line change
Expand Up @@ -377,5 +377,28 @@ fvec_madd_and_argmin_sse(size_t n, const float* a, float bf, const float* b, flo
return _mm_cvtsi128_si32(imin4);
}

// trust the compiler to unroll this properly
int32_t
ivec_inner_product_sse(const int8_t* x, const int8_t* y, size_t d) {
size_t i;
int32_t res = 0;
for (i = 0; i < d; i++) {
res += (int32_t)x[i] * y[i];
}
return res;
}

// trust the compiler to unroll this properly
int32_t
ivec_L2sqr_sse(const int8_t* x, const int8_t* y, size_t d) {
size_t i;
int32_t res = 0;
for (i = 0; i < d; i++) {
const int32_t tmp = (int32_t)x[i] - (int32_t)y[i];
res += tmp * tmp;
}
return res;
}

} // namespace faiss
#endif
7 changes: 7 additions & 0 deletions src/simd/distances_sse.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#ifndef DISTANCES_SSE_H
#define DISTANCES_SSE_H

#include <cstdint>
#include <cstdio>
namespace faiss {

Expand Down Expand Up @@ -46,6 +47,12 @@ fvec_madd_sse(size_t n, const float* a, float bf, const float* b, float* c);
int
fvec_madd_and_argmin_sse(size_t n, const float* a, float bf, const float* b, float* c);

int32_t
ivec_inner_product_sse(const int8_t* x, const int8_t* y, size_t d);

int32_t
ivec_L2sqr_sse(const int8_t* x, const int8_t* y, size_t d);

} // namespace faiss

#endif /* DISTANCES_SSE_H */
Loading
Loading