Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Index Initialization Alloc Method #1933

1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Add functionality to iteratively insert vectors into a faiss index to improve the memory footprint during indexing. [#1840](https://github.com/opensearch-project/k-NN/pull/1840)
### Bug Fixes
* Corrected search logic for scenario with non-existent fields in filter [#1874](https://github.com/opensearch-project/k-NN/pull/1874)
* Fixed and abstracted functionality for allocating index memory [#1933](https://github.com/opensearch-project/k-NN/pull/1933)
### Infrastructure
### Documentation
### Maintenance
Expand Down
5 changes: 4 additions & 1 deletion jni/include/faiss_index_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class IndexService {
virtual void writeIndex(std::string indexPath, jlong idMapAddress);
virtual ~IndexService() = default;
protected:
virtual void allocIndex(faiss::Index * index, size_t dim, size_t numVectors);
std::unique_ptr<FaissMethods> faissMethods;
};

Expand Down Expand Up @@ -120,10 +121,12 @@ class BinaryIndexService : public IndexService {
*/
virtual void writeIndex(std::string indexPath, jlong idMapAddress) override;
virtual ~BinaryIndexService() = default;
protected:
virtual void allocIndex(faiss::Index * index, size_t dim, size_t numVectors) override;
};

}
}


#endif //OPENSEARCH_KNN_FAISS_INDEX_SERVICE_H
#endif //OPENSEARCH_KNN_FAISS_INDEX_SERVICE_H
74 changes: 26 additions & 48 deletions jni/src/faiss_index_service.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,21 @@ void SetExtraParameters(knn_jni::JNIUtilInterface * jniUtil, JNIEnv *env,

IndexService::IndexService(std::unique_ptr<FaissMethods> faissMethods) : faissMethods(std::move(faissMethods)) {}

void IndexService::allocIndex(faiss::Index * index, size_t dim, size_t numVectors) {
if(auto * indexHNSWSQ = dynamic_cast<faiss::IndexHNSWSQ *>(index)) {
if(auto * indexScalarQuantizer = dynamic_cast<faiss::IndexScalarQuantizer *>(indexHNSWSQ->storage)) {
indexScalarQuantizer->codes.reserve(indexScalarQuantizer->code_size * numVectors);
}
return;
}
if(auto * indexHNSW = dynamic_cast<faiss::IndexHNSW *>(index)) {
if(auto * indexFlat = dynamic_cast<faiss::IndexFlat *>(indexHNSW->storage)) {
indexFlat->codes.reserve(indexFlat->code_size * numVectors);
}
return;
}
}

jlong IndexService::initIndex(
knn_jni::JNIUtilInterface * jniUtil,
JNIEnv * env,
Expand All @@ -83,36 +98,9 @@ jlong IndexService::initIndex(
throw std::runtime_error("Index is not trained");
}

// Add vectors
std::unique_ptr<faiss::IndexIDMap> idMap (faissMethods->indexIdMap(indexWriter.get()));

/*
* NOTE: The process of memory allocation is currently only implemented for HNSW.
* This technique of checking the types of the index and subindices should be generalized into
* another function.
*/

// Check to see if the current index is HNSW
faiss::IndexHNSWFlat * hnsw = dynamic_cast<faiss::IndexHNSWFlat *>(idMap->index);
if(hnsw != NULL) {
// Check to see if the HNSW storage is IndexFlat
faiss::IndexFlat * storage = dynamic_cast<faiss::IndexFlat *>(hnsw->storage);
if(storage != NULL) {
// Allocate enough memory for all of the vectors we plan on inserting
// We do this to avoid unnecessary memory allocations during insert
storage->codes.reserve(dim * numVectors * 4);
}
}
faiss::IndexHNSWSQ * hnswSq = dynamic_cast<faiss::IndexHNSWSQ *>(idMap->index);
if(hnswSq != NULL) {
// Check to see if the HNSW storage is IndexFlat
faiss::IndexFlat * storage = dynamic_cast<faiss::IndexFlat *>(hnswSq->storage);
if(storage != NULL) {
// Allocate enough memory for all of the vectors we plan on inserting
// We do this to avoid unnecessary memory allocations during insert
storage->codes.reserve(dim * numVectors * 2);
}
}
allocIndex(dynamic_cast<faiss::Index *>(idMap->index), dim, numVectors);
indexWriter.release();
return reinterpret_cast<jlong>(idMap.release());
}
Expand Down Expand Up @@ -168,6 +156,14 @@ void IndexService::writeIndex(

BinaryIndexService::BinaryIndexService(std::unique_ptr<FaissMethods> faissMethods) : IndexService(std::move(faissMethods)) {}

void BinaryIndexService::allocIndex(faiss::Index * index, size_t dim, size_t numVectors) {
if(auto * indexBinaryHNSW = dynamic_cast<faiss::IndexBinaryHNSW *>(index)) {
auto * indexBinaryFlat = dynamic_cast<faiss::IndexBinaryFlat *>(indexBinaryHNSW->storage);
indexBinaryFlat->xb.reserve(dim * numVectors / 8);
return;
}
}

jlong BinaryIndexService::initIndex(
knn_jni::JNIUtilInterface * jniUtil,
JNIEnv * env,
Expand All @@ -194,27 +190,9 @@ jlong BinaryIndexService::initIndex(
throw std::runtime_error("Index is not trained");
}

// Add vectors
std::unique_ptr<faiss::IndexBinaryIDMap> idMap(faissMethods->indexBinaryIdMap(indexWriter.get()));

/*
* NOTE: The process of memory allocation is currently only implemented for HNSW.
* This technique of checking the types of the index and subindices should be generalized into
* another function.
*/

// Check to see if the current index is BinaryHNSW
faiss::IndexBinaryHNSW * hnsw = dynamic_cast<faiss::IndexBinaryHNSW *>(idMap->index);

if(hnsw != NULL) {
// Check to see if the HNSW storage is IndexBinaryFlat
faiss::IndexBinaryFlat * storage = dynamic_cast<faiss::IndexBinaryFlat *>(hnsw->storage);
if(storage != NULL) {
// Allocate enough memory for all of the vectors we plan on inserting
// We do this to avoid unnecessary memory allocations during insert
storage->xb.reserve(dim / 8 * numVectors);
}
}
allocIndex(dynamic_cast<faiss::Index *>(idMap->index), dim, numVectors);
indexWriter.release();
return reinterpret_cast<jlong>(idMap.release());
}
Expand Down Expand Up @@ -271,4 +249,4 @@ void BinaryIndexService::writeIndex(
}

} // namespace faiss_wrapper
} // namesapce knn_jni
} // namesapce knn_jni
Loading