Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Write old format IVF_FLAT index for backward compatible #111

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions benchmark/hdf5/benchmark_knowhere.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,12 @@ class Benchmark_knowhere : public Benchmark_hdf5 {
}

// IVFFLAT_NM should load raw data
knowhere::BinaryPtr bin = std::make_shared<knowhere::Binary>();
bin->data = std::shared_ptr<uint8_t[]>((uint8_t*)xb_);
bin->size = dim_ * nb_ * sizeof(float);
binary_set.Append("RAW_DATA", bin);
if (index_type_ == knowhere::IndexEnum::INDEX_FAISS_IVFFLAT && binary_set.GetByName("RAW_DATA") == nullptr) {
knowhere::BinaryPtr bin = std::make_shared<knowhere::Binary>();
bin->data = std::shared_ptr<uint8_t[]>((uint8_t*)xb_);
bin->size = dim_ * nb_ * sizeof(float);
binary_set.Append("RAW_DATA", bin);
}

index.Deserialize(binary_set, conf);
}
Expand Down
16 changes: 16 additions & 0 deletions include/knowhere/index_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,24 @@
#include "knowhere/dataset.h"
#include "knowhere/expected.h"
#include "knowhere/object.h"
#include "knowhere/version.h"

namespace knowhere {

class IndexNode : public Object {
public:
IndexNode(const int32_t ver) : versoin_(ver) {
}

IndexNode() : versoin_(Version::GetDefaultVersion()) {
}

IndexNode(const IndexNode& other) : versoin_(other.versoin_) {
}

IndexNode(const IndexNode&& other) : versoin_(other.versoin_) {
}

virtual Status
Build(const DataSet& dataset, const Config& cfg) {
RETURN_IF_ERROR(Train(dataset, cfg));
Expand Down Expand Up @@ -92,6 +105,9 @@ class IndexNode : public Object {

virtual ~IndexNode() {
}

protected:
Version versoin_;
};

} // namespace knowhere
Expand Down
3 changes: 1 addition & 2 deletions include/knowhere/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,7 @@ round_down(const T value, const T align) {
}

extern void
ConvertIVFFlatIfNeeded(const BinarySet& binset, const MetricType metric_type, const uint8_t* raw_data,
const size_t raw_size);
ConvertIVFFlat(const BinarySet& binset, const MetricType metric_type, const uint8_t* raw_data, const size_t raw_size);

bool
UseDiskLoad(const std::string& index_type, const int32_t& /*version*/);
Expand Down
6 changes: 3 additions & 3 deletions include/knowhere/version.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ namespace knowhere {
namespace {
static constexpr int32_t default_version = 0;
static constexpr int32_t minimal_version = 0;
static constexpr int32_t current_version = 0;
static constexpr int32_t current_version = 1;
} // namespace

class Version {
public:
explicit Version(const IndexVersion& version) : version_(version) {
explicit Version(const IndexVersion version) : version_(version) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why move the reference

}

// used when version is not set
Expand Down Expand Up @@ -59,7 +59,7 @@ class Version {

// the version number
IndexVersion
VersionNumber() {
VersionNumber() const {
return version_;
}

Expand Down
42 changes: 13 additions & 29 deletions src/common/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,7 @@ CopyAndNormalizeVecs(const float* x, size_t rows, int32_t dim) {
}

void
ConvertIVFFlatIfNeeded(const BinarySet& binset, const MetricType metric_type, const uint8_t* raw_data,
const size_t raw_size) {
ConvertIVFFlat(const BinarySet& binset, const MetricType metric_type, const uint8_t* raw_data, const size_t raw_size) {
std::vector<std::string> names = {"IVF", // compatible with knowhere-1.x
knowhere::IndexEnum::INDEX_FAISS_IVFFLAT};
auto binary = binset.GetByNames(names);
Expand All @@ -85,38 +84,23 @@ ConvertIVFFlatIfNeeded(const BinarySet& binset, const MetricType metric_type, co
MemoryIOReader reader(binary->data.get(), binary->size);

try {
uint32_t h;
reader.read(&h, sizeof(h), 1);

// only read IVF_FLAT index header
std::unique_ptr<faiss::IndexIVFFlat> ivfl = std::make_unique<faiss::IndexIVFFlat>(faiss::IndexIVFFlat());
faiss::read_ivf_header(ivfl.get(), &reader);
ivfl->code_size = ivfl->d * sizeof(float);
std::unique_ptr<faiss::IndexIVFFlat> ivfl;
ivfl.reset(static_cast<faiss::IndexIVFFlat*>(faiss::read_index_nm(&reader)));

// is_cosine is not defined in IVF_FLAT_NM, so mark it from config
ivfl->is_cosine = IsMetricType(metric_type, knowhere::metric::COSINE);

auto remains = binary->size - reader.tellg() - sizeof(uint32_t) - sizeof(ivfl->invlists->nlist) -
sizeof(ivfl->invlists->code_size);
auto invlist_size = sizeof(uint32_t) + sizeof(size_t) + ivfl->nlist * sizeof(size_t);
auto ids_size = ivfl->ntotal * sizeof(faiss::Index::idx_t);
// auto codes_size = ivfl->d * ivfl->ntotal * sizeof(float);

// IVF_FLAT_NM format, need convert to new format
if (remains == invlist_size + ids_size) {
faiss::read_InvertedLists_nm(ivfl.get(), &reader);
ivfl->restore_codes(raw_data, raw_size);

// over-write IVF_FLAT_NM binary with native IVF_FLAT binary
MemoryIOWriter writer;
faiss::write_index(ivfl.get(), &writer);
std::shared_ptr<uint8_t[]> data(writer.data());
binary->data = data;
binary->size = writer.tellg();

LOG_KNOWHERE_INFO_ << "Convert IVF_FLAT_NM to native IVF_FLAT, rows " << ivfl->ntotal << ", dim "
<< ivfl->d;
}
ivfl->restore_codes(raw_data, raw_size);

// over-write IVF_FLAT_NM binary with native IVF_FLAT binary
MemoryIOWriter writer;
faiss::write_index(ivfl.get(), &writer);
std::shared_ptr<uint8_t[]> data(writer.data());
binary->data = data;
binary->size = writer.tellg();

LOG_KNOWHERE_INFO_ << "Convert IVF_FLAT_NM to native IVF_FLAT, rows " << ivfl->ntotal << ", dim " << ivfl->d;
} catch (...) {
// not IVF_FLAT_NM format, do nothing
return;
Expand Down
8 changes: 6 additions & 2 deletions src/index/flat/flat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ namespace knowhere {
template <typename T>
class FlatIndexNode : public IndexNode {
public:
FlatIndexNode(const int32_t& version, const Object& object) : index_(nullptr) {
FlatIndexNode(const int32_t version, const Object& object) : index_(nullptr) {
static_assert(std::is_same<T, faiss::IndexFlat>::value || std::is_same<T, faiss::IndexBinaryFlat>::value,
"not support");
search_pool_ = ThreadPool::GetGlobalSearchThreadPool();
Expand Down Expand Up @@ -237,7 +237,11 @@ class FlatIndexNode : public IndexNode {
bool
HasRawData(const std::string& metric_type) const override {
if constexpr (std::is_same<T, faiss::IndexFlat>::value) {
return true;
if (versoin_ <= Version::GetMinimalVersion()) {
return !IsMetricType(metric_type, metric::COSINE);
} else {
return true;
}
}
if constexpr (std::is_same<T, faiss::IndexBinaryFlat>::value) {
return true;
Expand Down
52 changes: 48 additions & 4 deletions src/index/ivf/ivf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ struct QuantizerT<faiss::IndexBinaryIVF> {
template <typename T>
class IvfIndexNode : public IndexNode {
public:
IvfIndexNode(const int32_t& /*version*/, const Object& object) : index_(nullptr) {
IvfIndexNode(const int32_t version, const Object& object) : IndexNode(version), index_(nullptr) {
static_assert(std::is_same<T, faiss::IndexIVFFlat>::value || std::is_same<T, faiss::IndexIVFFlatCC>::value ||
std::is_same<T, faiss::IndexIVFPQ>::value ||
std::is_same<T, faiss::IndexIVFScalarQuantizer>::value ||
Expand Down Expand Up @@ -675,6 +675,50 @@ IvfIndexNode<T>::Serialize(BinarySet& binset) const {
}
}

template <>
Status
IvfIndexNode<faiss::IndexIVFFlat>::Serialize(BinarySet& binset) const {
try {
MemoryIOWriter writer;
LOG_KNOWHERE_INFO_ << "request version " << versoin_.VersionNumber();
if (versoin_ <= Version::GetMinimalVersion()) {
faiss::write_index_nm(index_.get(), &writer);
LOG_KNOWHERE_INFO_ << "write IVF_FLAT_NM, file size " << writer.tellg();
} else {
faiss::write_index(index_.get(), &writer);
LOG_KNOWHERE_INFO_ << "write IVF_FLAT, file size " << writer.tellg();
}
std::shared_ptr<uint8_t[]> index_data_ptr(writer.data());
binset.Append(Type(), index_data_ptr, writer.tellg());

// append raw data for backward compatible
if (versoin_ <= Version::GetMinimalVersion()) {
size_t dim = index_->d;
size_t rows = index_->ntotal;
size_t raw_data_size = dim * rows * sizeof(float);
uint8_t* raw_data = new uint8_t[raw_data_size];
std::shared_ptr<uint8_t[]> raw_data_ptr(raw_data);
for (size_t i = 0; i < index_->nlist; i++) {
size_t list_size = index_->invlists->list_size(i);
const faiss::idx_t* ids = index_->invlists->get_ids(i);
const uint8_t* codes = index_->invlists->get_codes(i);
for (size_t j = 0; j < list_size; j++) {
faiss::idx_t id = ids[j];
const uint8_t* src = codes + j * dim * sizeof(float);
uint8_t* dst = raw_data + id * dim * sizeof(float);
memcpy(dst, src, dim * sizeof(float));
}
}
binset.Append("RAW_DATA", raw_data_ptr, raw_data_size);
LOG_KNOWHERE_INFO_ << "append raw data for IVF_FLAT_NM, size " << raw_data_size;
}
return Status::success;
} catch (const std::exception& e) {
LOG_KNOWHERE_WARNING_ << "faiss inner error: " << e.what();
return Status::faiss_inner_error;
}
}

template <typename T>
Status
IvfIndexNode<T>::Deserialize(const BinarySet& binset, const Config& config) {
Expand All @@ -690,10 +734,10 @@ IvfIndexNode<T>::Deserialize(const BinarySet& binset, const Config& config) {
MemoryIOReader reader(binary->data.get(), binary->size);
try {
if constexpr (std::is_same<T, faiss::IndexIVFFlat>::value) {
auto raw_binary = binset.GetByName("RAW_DATA");
if (raw_binary != nullptr) {
if (versoin_ <= Version::GetMinimalVersion()) {
auto raw_binary = binset.GetByName("RAW_DATA");
const BaseConfig& base_cfg = static_cast<const BaseConfig&>(config);
ConvertIVFFlatIfNeeded(binset, base_cfg.metric_type.value(), raw_binary->data.get(), raw_binary->size);
ConvertIVFFlat(binset, base_cfg.metric_type.value(), raw_binary->data.get(), raw_binary->size);
// after conversion, binary size and data will be updated
reader.data_ = binary->data.get();
reader.total_ = binary->size;
Expand Down
4 changes: 2 additions & 2 deletions thirdparty/faiss/faiss/impl/index_read.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -551,10 +551,10 @@ static void read_direct_map(DirectMap* dm, IOReader* f) {
}
}

void read_ivf_header(
static void read_ivf_header(
IndexIVF* ivf,
IOReader* f,
std::vector<std::vector<Index::idx_t>>* ids) {
std::vector<std::vector<Index::idx_t>>* ids = nullptr) {
read_index_header(ivf, f);
READ1(ivf->nlist);
READ1(ivf->nprobe);
Expand Down
67 changes: 67 additions & 0 deletions thirdparty/faiss/faiss/impl/index_write.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,60 @@ void write_InvertedLists(const InvertedLists* ils, IOWriter* f) {
}
}

// write inverted lists for offset-only index
void write_InvertedLists_nm(const InvertedLists *ils, IOWriter *f) {
if (ils == nullptr) {
uint32_t h = fourcc("il00");
WRITE1(h);
} else if (const auto & ails =
dynamic_cast<const ArrayInvertedLists *>(ils)) {
uint32_t h = fourcc("ilar");
WRITE1(h);
WRITE1(ails->nlist);
WRITE1(ails->code_size);
// here we store either as a full or a sparse data buffer
size_t n_non0 = 0;
for (size_t i = 0; i < ails->nlist; i++) {
if (ails->ids[i].size() > 0)
n_non0++;
}
if (n_non0 > ails->nlist / 2) {
uint32_t list_type = fourcc("full");
WRITE1(list_type);
std::vector<size_t> sizes;
for (size_t i = 0; i < ails->nlist; i++) {
sizes.push_back (ails->ids[i].size());
}
WRITEVECTOR(sizes);
} else {
int list_type = fourcc("sprs"); // sparse
WRITE1(list_type);
std::vector<size_t> sizes;
for (size_t i = 0; i < ails->nlist; i++) {
size_t n = ails->ids[i].size();
if (n > 0) {
sizes.push_back (i);
sizes.push_back (n);
}
}
WRITEVECTOR(sizes);
}
// make a single contiguous data buffer (useful for mmapping)
for (size_t i = 0; i < ails->nlist; i++) {
size_t n = ails->ids[i].size();
if (n > 0) {
// WRITEANDCHECK (ails->codes[i].data(), n * ails->code_size);
WRITEANDCHECK(ails->ids[i].data(), n);
}
}
} else {
fprintf(stderr, "WARN! write_InvertedLists: unsupported invlist type, "
"saving null invlist\n");
uint32_t h = fourcc("il00");
WRITE1(h);
}
}

void write_ProductQuantizer(const ProductQuantizer* pq, const char* fname) {
FileIOWriter writer(fname);
write_ProductQuantizer(pq, &writer);
Expand Down Expand Up @@ -729,6 +783,19 @@ void write_index(const Index* idx, const char* fname) {
write_index(idx, &writer);
}

// write index for offset-only index
void write_index_nm(const Index *idx, IOWriter *f) {
if(const IndexIVFFlat * ivfl =
dynamic_cast<const IndexIVFFlat *> (idx)) {
uint32_t h = fourcc("IwFl");
WRITE1(h);
write_ivf_header(ivfl, f);
write_InvertedLists_nm(ivfl->invlists, f);
} else {
FAISS_THROW_MSG("don't know how to serialize this type of index");
}
}

void write_VectorTransform(const VectorTransform* vt, const char* fname) {
FileIOWriter writer(fname);
write_VectorTransform(vt, &writer);
Expand Down
5 changes: 2 additions & 3 deletions thirdparty/faiss/faiss/index_io.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,8 @@ void write_InvertedLists(const InvertedLists* ils, IOWriter* f);
InvertedLists* read_InvertedLists(IOReader* reader, int io_flags = 0);

// for backward compatibility
void read_ivf_header(IndexIVF* ivf, IOReader* f,
std::vector<std::vector<Index::idx_t>>* ids = nullptr);
void read_InvertedLists_nm(IndexIVF *ivf, IOReader *f, int io_flags = 0);
Index *read_index_nm(IOReader *f, int io_flags = 0);
void write_index_nm(const Index* idx, IOWriter* writer);
} // namespace faiss

#endif
Loading