Skip to content

Commit

Permalink
Append raw_data for IVF_FLAT_NM
Browse files Browse the repository at this point in the history
Signed-off-by: Yudong Cai <[email protected]>
  • Loading branch information
cydrain committed Sep 25, 2023
1 parent 3b99047 commit 989a164
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 52 deletions.
10 changes: 6 additions & 4 deletions benchmark/hdf5/benchmark_knowhere.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,12 @@ class Benchmark_knowhere : public Benchmark_hdf5 {
}

// IVFFLAT_NM should load raw data
knowhere::BinaryPtr bin = std::make_shared<knowhere::Binary>();
bin->data = std::shared_ptr<uint8_t[]>((uint8_t*)xb_);
bin->size = dim_ * nb_ * sizeof(float);
binary_set.Append("RAW_DATA", bin);
if (index_type_ == knowhere::IndexEnum::INDEX_FAISS_IVFFLAT && binary_set.GetByName("RAW_DATA") == nullptr) {
knowhere::BinaryPtr bin = std::make_shared<knowhere::Binary>();
bin->data = std::shared_ptr<uint8_t[]>((uint8_t*)xb_);
bin->size = dim_ * nb_ * sizeof(float);
binary_set.Append("RAW_DATA", bin);
}

index.Deserialize(binary_set, conf);
}
Expand Down
3 changes: 1 addition & 2 deletions include/knowhere/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,7 @@ round_down(const T value, const T align) {
}

extern void
ConvertIVFFlatIfNeeded(const BinarySet& binset, const MetricType metric_type, const uint8_t* raw_data,
const size_t raw_size);
ConvertIVFFlat(const BinarySet& binset, const MetricType metric_type, const uint8_t* raw_data, const size_t raw_size);

bool
UseDiskLoad(const std::string& index_type, const std::string& /*version*/);
Expand Down
42 changes: 13 additions & 29 deletions src/common/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,7 @@ CopyAndNormalizeVecs(const float* x, size_t rows, int32_t dim) {
}

void
ConvertIVFFlatIfNeeded(const BinarySet& binset, const MetricType metric_type, const uint8_t* raw_data,
const size_t raw_size) {
ConvertIVFFlat(const BinarySet& binset, const MetricType metric_type, const uint8_t* raw_data, const size_t raw_size) {
std::vector<std::string> names = {"IVF", // compatible with knowhere-1.x
knowhere::IndexEnum::INDEX_FAISS_IVFFLAT};
auto binary = binset.GetByNames(names);
Expand All @@ -85,38 +84,23 @@ ConvertIVFFlatIfNeeded(const BinarySet& binset, const MetricType metric_type, co
MemoryIOReader reader(binary->data.get(), binary->size);

try {
uint32_t h;
reader.read(&h, sizeof(h), 1);

// only read IVF_FLAT index header
std::unique_ptr<faiss::IndexIVFFlat> ivfl = std::make_unique<faiss::IndexIVFFlat>(faiss::IndexIVFFlat());
faiss::read_ivf_header(ivfl.get(), &reader);
ivfl->code_size = ivfl->d * sizeof(float);
std::unique_ptr<faiss::IndexIVFFlat> ivfl;
ivfl.reset(static_cast<faiss::IndexIVFFlat*>(faiss::read_index_nm(&reader)));

// is_cosine is not defined in IVF_FLAT_NM, so mark it from config
ivfl->is_cosine = IsMetricType(metric_type, knowhere::metric::COSINE);

auto remains = binary->size - reader.tellg() - sizeof(uint32_t) - sizeof(ivfl->invlists->nlist) -
sizeof(ivfl->invlists->code_size);
auto invlist_size = sizeof(uint32_t) + sizeof(size_t) + ivfl->nlist * sizeof(size_t);
auto ids_size = ivfl->ntotal * sizeof(faiss::Index::idx_t);
// auto codes_size = ivfl->d * ivfl->ntotal * sizeof(float);

// IVF_FLAT_NM format, need convert to new format
if (remains == invlist_size + ids_size) {
faiss::read_InvertedLists_nm(ivfl.get(), &reader);
ivfl->restore_codes(raw_data, raw_size);

// over-write IVF_FLAT_NM binary with native IVF_FLAT binary
MemoryIOWriter writer;
faiss::write_index(ivfl.get(), &writer);
std::shared_ptr<uint8_t[]> data(writer.data());
binary->data = data;
binary->size = writer.tellg();

LOG_KNOWHERE_INFO_ << "Convert IVF_FLAT_NM to native IVF_FLAT, rows " << ivfl->ntotal << ", dim "
<< ivfl->d;
}
ivfl->restore_codes(raw_data, raw_size);

// over-write IVF_FLAT_NM binary with native IVF_FLAT binary
MemoryIOWriter writer;
faiss::write_index(ivfl.get(), &writer);
std::shared_ptr<uint8_t[]> data(writer.data());
binary->data = data;
binary->size = writer.tellg();

LOG_KNOWHERE_INFO_ << "Convert IVF_FLAT_NM to native IVF_FLAT, rows " << ivfl->ntotal << ", dim " << ivfl->d;
} catch (...) {
// not IVF_FLAT_NM format, do nothing
return;
Expand Down
59 changes: 47 additions & 12 deletions src/index/ivf/ivf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -663,15 +663,6 @@ IvfIndexNode<T>::Serialize(BinarySet& binset) const {
MemoryIOWriter writer;
if constexpr (std::is_same<T, faiss::IndexBinaryIVF>::value) {
faiss::write_index_binary(index_.get(), &writer);
} else if constexpr (std::is_same<T, faiss::IndexIVFFlat>::value) {
LOG_KNOWHERE_INFO_ << "request version " << versoin_.VersionCode();
if (versoin_ <= Version::GetMinimalSupport()) {
faiss::write_index_nm(index_.get(), &writer);
LOG_KNOWHERE_INFO_ << "write IVF_FLAT_NM, file size " << writer.tellg();
} else {
faiss::write_index(index_.get(), &writer);
LOG_KNOWHERE_INFO_ << "write IVF_FLAT, file size " << writer.tellg();
}
} else {
faiss::write_index(index_.get(), &writer);
}
Expand All @@ -684,6 +675,50 @@ IvfIndexNode<T>::Serialize(BinarySet& binset) const {
}
}

template <>
Status
IvfIndexNode<faiss::IndexIVFFlat>::Serialize(BinarySet& binset) const {
try {
MemoryIOWriter writer;
LOG_KNOWHERE_INFO_ << "request version " << versoin_.VersionCode();
if (versoin_ <= Version::GetMinimalSupport()) {
faiss::write_index_nm(index_.get(), &writer);
LOG_KNOWHERE_INFO_ << "write IVF_FLAT_NM, file size " << writer.tellg();
} else {
faiss::write_index(index_.get(), &writer);
LOG_KNOWHERE_INFO_ << "write IVF_FLAT, file size " << writer.tellg();
}
std::shared_ptr<uint8_t[]> index_data_ptr(writer.data());
binset.Append(Type(), index_data_ptr, writer.tellg());

// append raw data for backward compatible
if (versoin_ <= Version::GetMinimalSupport()) {
size_t dim = index_->d;
size_t rows = index_->ntotal;
size_t raw_data_size = dim * rows * sizeof(float);
uint8_t* raw_data = new uint8_t[raw_data_size];
std::shared_ptr<uint8_t[]> raw_data_ptr(raw_data);
for (size_t i = 0; i < index_->nlist; i++) {
size_t list_size = index_->invlists->list_size(i);
const faiss::idx_t* ids = index_->invlists->get_ids(i);
const uint8_t* codes = index_->invlists->get_codes(i);
for (size_t j = 0; j < list_size; j++) {
faiss::idx_t id = ids[j];
const uint8_t* src = codes + j * dim * sizeof(float);
uint8_t* dst = raw_data + id * dim * sizeof(float);
memcpy(dst, src, dim * sizeof(float));
}
}
binset.Append("RAW_DATA", raw_data_ptr, raw_data_size);
LOG_KNOWHERE_INFO_ << "append raw data for IVF_FLAT_NM, size " << raw_data_size;
}
return Status::success;
} catch (const std::exception& e) {
LOG_KNOWHERE_WARNING_ << "faiss inner error: " << e.what();
return Status::faiss_inner_error;
}
}

template <typename T>
Status
IvfIndexNode<T>::Deserialize(const BinarySet& binset, const Config& config) {
Expand All @@ -699,10 +734,10 @@ IvfIndexNode<T>::Deserialize(const BinarySet& binset, const Config& config) {
MemoryIOReader reader(binary->data.get(), binary->size);
try {
if constexpr (std::is_same<T, faiss::IndexIVFFlat>::value) {
auto raw_binary = binset.GetByName("RAW_DATA");
if (raw_binary != nullptr) {
if (versoin_ <= Version::GetMinimalSupport()) {
auto raw_binary = binset.GetByName("RAW_DATA");
const BaseConfig& base_cfg = static_cast<const BaseConfig&>(config);
ConvertIVFFlatIfNeeded(binset, base_cfg.metric_type.value(), raw_binary->data.get(), raw_binary->size);
ConvertIVFFlat(binset, base_cfg.metric_type.value(), raw_binary->data.get(), raw_binary->size);
// after conversion, binary size and data will be updated
reader.data_ = binary->data.get();
reader.total_ = binary->size;
Expand Down
4 changes: 2 additions & 2 deletions thirdparty/faiss/faiss/impl/index_read.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -551,10 +551,10 @@ static void read_direct_map(DirectMap* dm, IOReader* f) {
}
}

void read_ivf_header(
static void read_ivf_header(
IndexIVF* ivf,
IOReader* f,
std::vector<std::vector<Index::idx_t>>* ids) {
std::vector<std::vector<Index::idx_t>>* ids = nullptr) {
read_index_header(ivf, f);
READ1(ivf->nlist);
READ1(ivf->nprobe);
Expand Down
4 changes: 1 addition & 3 deletions thirdparty/faiss/faiss/index_io.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,7 @@ void write_InvertedLists(const InvertedLists* ils, IOWriter* f);
InvertedLists* read_InvertedLists(IOReader* reader, int io_flags = 0);

// for backward compatibility
void read_ivf_header(IndexIVF* ivf, IOReader* f,
std::vector<std::vector<Index::idx_t>>* ids = nullptr);
void read_InvertedLists_nm(IndexIVF *ivf, IOReader *f, int io_flags = 0);
Index *read_index_nm(IOReader *f, int io_flags = 0);
void write_index_nm(const Index* idx, IOWriter* writer);
} // namespace faiss

Expand Down

0 comments on commit 989a164

Please sign in to comment.