Skip to content

Commit

Permalink
Write old format IVF_FLAT index for backward compatible
Browse files Browse the repository at this point in the history
Signed-off-by: Yudong Cai <[email protected]>
  • Loading branch information
cydrain committed Sep 22, 2023
1 parent b724230 commit 82607d5
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 9 deletions.
16 changes: 16 additions & 0 deletions include/knowhere/index_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,24 @@
#include "knowhere/dataset.h"
#include "knowhere/expected.h"
#include "knowhere/object.h"
#include "knowhere/version.h"

namespace knowhere {

class IndexNode : public Object {
public:
IndexNode(const std::string& str) : versoin_(str) {
}

IndexNode() : versoin_(Version::GetDefaultVersion()) {
}

IndexNode(const IndexNode& other) : versoin_(other.versoin_) {
}

IndexNode(const IndexNode&& other) : versoin_(other.versoin_) {
}

virtual Status
Build(const DataSet& dataset, const Config& cfg) {
RETURN_IF_ERROR(Train(dataset, cfg));
Expand Down Expand Up @@ -92,6 +105,9 @@ class IndexNode : public Object {

virtual ~IndexNode() {
}

protected:
Version versoin_;
};

} // namespace knowhere
Expand Down
20 changes: 15 additions & 5 deletions include/knowhere/version.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ namespace {
static const std::regex version_regex(R"(^knowhere-v(\d+)$)");
static constexpr const char* default_version = "knowhere-v0";
static constexpr const char* minimal_vesion = "knowhere-v0";
static constexpr const char* current_version = "knowhere-v0";
static constexpr const char* current_version = "knowhere-v1";
} // namespace

class Version {
public:
explicit Version(const std::string& version_code_) : version_code(version_code_) {
explicit Version(const std::string& version_code) : version_code_(version_code) {
try {
std::smatch matches;
if (std::regex_match(version_code_, matches, version_regex)) {
Expand All @@ -39,14 +39,24 @@ class Version {
}
}

Version(const Version& other) {
version_code_ = other.version_code_;
version_ = other.version_;
}

Version(const Version&& other) {
version_code_ = other.version_code_;
version_ = other.version_;
}

bool
Valid() {
return version_ != unexpected_version_num;
};

const std::string&
VersionCode() const {
return version_code;
return version_code_;
}

static bool
Expand Down Expand Up @@ -79,7 +89,7 @@ class Version {

static inline bool
VersionSupport(const Version& version) {
return VersionCheck(version.version_code) && GetMinimalSupport() <= version && version <= GetCurrentVersion();
return VersionCheck(version.version_code_) && GetMinimalSupport() <= version && version <= GetCurrentVersion();
}

friend bool
Expand All @@ -89,7 +99,7 @@ class Version {

private:
static constexpr int32_t unexpected_version_num = -1;
const std::string version_code;
std::string version_code_;
int32_t version_ = unexpected_version_num;
};

Expand Down
8 changes: 6 additions & 2 deletions src/index/flat/flat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ namespace knowhere {
template <typename T>
class FlatIndexNode : public IndexNode {
public:
FlatIndexNode(const std::string& version, const Object& object) : index_(nullptr) {
FlatIndexNode(const std::string& version, const Object& object) : IndexNode(version), index_(nullptr) {
static_assert(std::is_same<T, faiss::IndexFlat>::value || std::is_same<T, faiss::IndexBinaryFlat>::value,
"not support");
search_pool_ = ThreadPool::GetGlobalSearchThreadPool();
Expand Down Expand Up @@ -237,7 +237,11 @@ class FlatIndexNode : public IndexNode {
bool
HasRawData(const std::string& metric_type) const override {
if constexpr (std::is_same<T, faiss::IndexFlat>::value) {
return true;
if (versoin_ <= Version::GetMinimalSupport()) {
return !IsMetricType(metric_type, metric::COSINE);
} else {
return true;
}
}
if constexpr (std::is_same<T, faiss::IndexBinaryFlat>::value) {
return true;
Expand Down
8 changes: 6 additions & 2 deletions src/index/ivf/ivf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ struct QuantizerT<faiss::IndexBinaryIVF> {
template <typename T>
class IvfIndexNode : public IndexNode {
public:
IvfIndexNode(const std::string& /*version*/, const Object& object) : index_(nullptr) {
IvfIndexNode(const std::string& version, const Object& object) : IndexNode(version), index_(nullptr) {
static_assert(std::is_same<T, faiss::IndexIVFFlat>::value || std::is_same<T, faiss::IndexIVFFlatCC>::value ||
std::is_same<T, faiss::IndexIVFPQ>::value ||
std::is_same<T, faiss::IndexIVFScalarQuantizer>::value ||
Expand Down Expand Up @@ -664,7 +664,11 @@ IvfIndexNode<T>::Serialize(BinarySet& binset) const {
if constexpr (std::is_same<T, faiss::IndexBinaryIVF>::value) {
faiss::write_index_binary(index_.get(), &writer);
} else {
faiss::write_index(index_.get(), &writer);
if (versoin_ <= Version::GetMinimalSupport()) {
faiss::write_index_nm(index_.get(), &writer);
} else {
faiss::write_index(index_.get(), &writer);
}
}
std::shared_ptr<uint8_t[]> data(writer.data());
binset.Append(Type(), data, writer.tellg());
Expand Down
67 changes: 67 additions & 0 deletions thirdparty/faiss/faiss/impl/index_write.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,60 @@ void write_InvertedLists(const InvertedLists* ils, IOWriter* f) {
}
}

// write inverted lists for offset-only index
void write_InvertedLists_nm(const InvertedLists *ils, IOWriter *f) {
if (ils == nullptr) {
uint32_t h = fourcc("il00");
WRITE1(h);
} else if (const auto & ails =
dynamic_cast<const ArrayInvertedLists *>(ils)) {
uint32_t h = fourcc("ilar");
WRITE1(h);
WRITE1(ails->nlist);
WRITE1(ails->code_size);
// here we store either as a full or a sparse data buffer
size_t n_non0 = 0;
for (size_t i = 0; i < ails->nlist; i++) {
if (ails->ids[i].size() > 0)
n_non0++;
}
if (n_non0 > ails->nlist / 2) {
uint32_t list_type = fourcc("full");
WRITE1(list_type);
std::vector<size_t> sizes;
for (size_t i = 0; i < ails->nlist; i++) {
sizes.push_back (ails->ids[i].size());
}
WRITEVECTOR(sizes);
} else {
int list_type = fourcc("sprs"); // sparse
WRITE1(list_type);
std::vector<size_t> sizes;
for (size_t i = 0; i < ails->nlist; i++) {
size_t n = ails->ids[i].size();
if (n > 0) {
sizes.push_back (i);
sizes.push_back (n);
}
}
WRITEVECTOR(sizes);
}
// make a single contiguous data buffer (useful for mmapping)
for (size_t i = 0; i < ails->nlist; i++) {
size_t n = ails->ids[i].size();
if (n > 0) {
// WRITEANDCHECK (ails->codes[i].data(), n * ails->code_size);
WRITEANDCHECK(ails->ids[i].data(), n);
}
}
} else {
fprintf(stderr, "WARN! write_InvertedLists: unsupported invlist type, "
"saving null invlist\n");
uint32_t h = fourcc("il00");
WRITE1(h);
}
}

void write_ProductQuantizer(const ProductQuantizer* pq, const char* fname) {
FileIOWriter writer(fname);
write_ProductQuantizer(pq, &writer);
Expand Down Expand Up @@ -729,6 +783,19 @@ void write_index(const Index* idx, const char* fname) {
write_index(idx, &writer);
}

// write index for offset-only index
void write_index_nm(const Index *idx, IOWriter *f) {
if(const IndexIVFFlat * ivfl =
dynamic_cast<const IndexIVFFlat *> (idx)) {
uint32_t h = fourcc("IwFl");
WRITE1(h);
write_ivf_header(ivfl, f);
write_InvertedLists_nm(ivfl->invlists, f);
} else {
FAISS_THROW_MSG("don't know how to serialize this type of index");
}
}

void write_VectorTransform(const VectorTransform* vt, const char* fname) {
FileIOWriter writer(fname);
write_VectorTransform(vt, &writer);
Expand Down
1 change: 1 addition & 0 deletions thirdparty/faiss/faiss/index_io.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ InvertedLists* read_InvertedLists(IOReader* reader, int io_flags = 0);
void read_ivf_header(IndexIVF* ivf, IOReader* f,
std::vector<std::vector<Index::idx_t>>* ids = nullptr);
void read_InvertedLists_nm(IndexIVF *ivf, IOReader *f, int io_flags = 0);
void write_index_nm(const Index* idx, IOWriter* writer);
} // namespace faiss

#endif

0 comments on commit 82607d5

Please sign in to comment.