Skip to content

Commit

Permalink
Support IVF_FLAT backward compatible when cosine
Browse files Browse the repository at this point in the history
Signed-off-by: Yudong Cai <[email protected]>
  • Loading branch information
cydrain committed Sep 20, 2023
1 parent 8280cfc commit 1dc4a8a
Show file tree
Hide file tree
Showing 8 changed files with 43 additions and 11 deletions.
8 changes: 4 additions & 4 deletions benchmark/hdf5/benchmark_knowhere.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class Benchmark_knowhere : public Benchmark_hdf5 {
}

void
read_index(knowhere::Index<knowhere::IndexNode>& index, const std::string& filename) {
read_index(knowhere::Index<knowhere::IndexNode>& index, const std::string& filename, const knowhere::Json& conf) {
FileIOReader reader(filename);
int64_t file_size = reader.size();
if (file_size < 0) {
Expand Down Expand Up @@ -79,7 +79,7 @@ class Benchmark_knowhere : public Benchmark_hdf5 {
bin->size = dim_ * nb_ * sizeof(float);
binary_set.Append("RAW_DATA", bin);

index.Deserialize(binary_set);
index.Deserialize(binary_set, conf);
}

std::string
Expand All @@ -98,7 +98,7 @@ class Benchmark_knowhere : public Benchmark_hdf5 {

try {
printf("[%.3f s] Reading index file: %s\n", get_time_diff(), index_file_name.c_str());
read_index(index_, index_file_name);
read_index(index_, index_file_name, conf);
} catch (...) {
printf("[%.3f s] Building all on %d vectors\n", get_time_diff(), nb_);
knowhere::DataSetPtr ds_ptr = knowhere::GenDataSet(nb_, dim_, xb_);
Expand All @@ -120,7 +120,7 @@ class Benchmark_knowhere : public Benchmark_hdf5 {

try {
printf("[%.3f s] Reading golden index file: %s\n", get_time_diff(), golden_index_file_name.c_str());
read_index(golden_index_, golden_index_file_name);
read_index(golden_index_, golden_index_file_name, conf);
} catch (...) {
printf("[%.3f s] Building golden index on %d vectors\n", get_time_diff(), nb_);
knowhere::DataSetPtr ds_ptr = knowhere::GenDataSet(nb_, dim_, xb_);
Expand Down
6 changes: 5 additions & 1 deletion include/knowhere/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,11 @@ class BaseConfig : public Config {
CFG_BOOL enable_mmap;
CFG_BOOL for_tuning;
KNOHWERE_DECLARE_CONFIG(BaseConfig) {
KNOWHERE_CONFIG_DECLARE_FIELD(metric_type).set_default("L2").description("metric type").for_train_and_search();
KNOWHERE_CONFIG_DECLARE_FIELD(metric_type)
.set_default("L2")
.description("metric type")
.for_train_and_search()
.for_deserialize();
KNOWHERE_CONFIG_DECLARE_FIELD(k)
.set_default(10)
.description("search for top k similar vector.")
Expand Down
5 changes: 4 additions & 1 deletion include/knowhere/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ round_down(const T value, const T align) {
}

extern void
ConvertIVFFlatIfNeeded(const BinarySet& binset, const uint8_t* raw_data, const size_t raw_size);
ConvertIVFFlatIfNeeded(const BinarySet& binset,
const MetricType metric_type,
const uint8_t* raw_data,
const size_t raw_size);

} // namespace knowhere
8 changes: 7 additions & 1 deletion src/common/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,10 @@ CopyAndNormalizeVecs(const float* x, size_t rows, int32_t dim) {
}

void
ConvertIVFFlatIfNeeded(const BinarySet& binset, const uint8_t* raw_data, const size_t raw_size) {
ConvertIVFFlatIfNeeded(const BinarySet& binset,
const MetricType metric_type,
const uint8_t* raw_data,
const size_t raw_size) {
std::vector<std::string> names = {"IVF", // compatible with knowhere-1.x
knowhere::IndexEnum::INDEX_FAISS_IVFFLAT};
auto binary = binset.GetByNames(names);
Expand All @@ -92,6 +95,9 @@ ConvertIVFFlatIfNeeded(const BinarySet& binset, const uint8_t* raw_data, const s
faiss::read_ivf_header(ivfl.get(), &reader);
ivfl->code_size = ivfl->d * sizeof(float);

// is_cosine is not defined in IVF_FLAT_NM, so mark it from config
ivfl->is_cosine = IsMetricType(metric_type, knowhere::metric::COSINE);

auto remains = binary->size - reader.tellg() - sizeof(uint32_t) - sizeof(ivfl->invlists->nlist) -
sizeof(ivfl->invlists->code_size);
auto invlist_size = sizeof(uint32_t) + sizeof(size_t) + ivfl->nlist * sizeof(size_t);
Expand Down
3 changes: 2 additions & 1 deletion src/index/ivf/ivf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -692,7 +692,8 @@ IvfIndexNode<T>::Deserialize(const BinarySet& binset, const Config& config) {
if constexpr (std::is_same<T, faiss::IndexIVFFlat>::value) {
auto raw_binary = binset.GetByName("RAW_DATA");
if (raw_binary != nullptr) {
ConvertIVFFlatIfNeeded(binset, raw_binary->data.get(), raw_binary->size);
const BaseConfig& base_cfg = static_cast<const BaseConfig&>(config);
ConvertIVFFlatIfNeeded(binset, base_cfg.metric_type.value(), raw_binary->data.get(), raw_binary->size);
// after conversion, binary size and data will be updated
reader.data_ = binary->data.get();
reader.total_ = binary->size;
Expand Down
2 changes: 1 addition & 1 deletion thirdparty/faiss/faiss/IndexIVFFlat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ void IndexIVFFlat::restore_codes(
const uint8_t* raw_data,
const size_t raw_size) {
auto ails = dynamic_cast<faiss::ArrayInvertedLists*>(invlists);
ails->restore_codes(raw_data, raw_size);
ails->restore_codes(raw_data, raw_size, is_cosine);
}

void IndexIVFFlat::train(idx_t n, const float* x) {
Expand Down
18 changes: 17 additions & 1 deletion thirdparty/faiss/faiss/invlists/InvertedLists.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <numeric>

#include <faiss/impl/FaissAssert.h>
#include <faiss/utils/distances.h>
#include <faiss/utils/utils.h>

//TODO: refactor to decouple dependency between CPU and Cuda, or upgrade faiss
Expand Down Expand Up @@ -273,21 +274,36 @@ void ArrayInvertedLists::resize(size_t list_no, size_t new_size) {
codes[list_no].resize(new_size * code_size);
}

// temp code for IVF_FLAT_NM backward compatibility
void ArrayInvertedLists::restore_codes(
const uint8_t* raw_data,
const size_t raw_size) {
const size_t raw_size,
const bool is_cosine) {
size_t total = 0;
with_norm = is_cosine;
codes.resize(nlist);
if (is_cosine) {
code_norms.resize(nlist);
}
for (size_t i = 0; i < nlist; i++) {
auto list_size = ids[i].size();
total += list_size;
codes[i].resize(list_size * code_size);
if (is_cosine) {
code_norms[i].resize(list_size);
}
uint8_t* dst = codes[i].data();
for (size_t j = 0; j < list_size; j++) {
const uint8_t* src = raw_data + code_size * ids[i][j];
std::copy_n(src, code_size, dst);
dst += code_size;
}
if (is_cosine) {
fvec_norms_L2(code_norms[i].data(),
(const float*)codes[i].data(),
code_size / sizeof(float),
list_size);
}
}
assert(total * code_size == raw_size);
}
Expand Down
4 changes: 3 additions & 1 deletion thirdparty/faiss/faiss/invlists/InvertedLists.h
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,9 @@ struct ArrayInvertedLists : InvertedLists {
const uint8_t* get_codes(size_t list_no) const override;
const idx_t* get_ids(size_t list_no) const override;

void restore_codes(const uint8_t* raw_data, const size_t raw_size);
void restore_codes(const uint8_t* raw_data,
const size_t raw_size,
const bool is_cosine);

const float* get_code_norms(size_t list_no, size_t offset) const override;
void release_code_norms(size_t list_no, const float* codes) const override;
Expand Down

0 comments on commit 1dc4a8a

Please sign in to comment.