Skip to content

Commit

Permalink
Add binary format support for Faiss IVF
Browse files Browse the repository at this point in the history
Signed-off-by: Junqiu Lei <[email protected]>
  • Loading branch information
junqiu-lei committed Jul 9, 2024
1 parent e62fb76 commit 648f342
Show file tree
Hide file tree
Showing 16 changed files with 360 additions and 42 deletions.
19 changes: 19 additions & 0 deletions jni/include/faiss_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ namespace knn_jni {
jlong vectorsAddressJ, jint dimJ, jstring indexPathJ, jbyteArray templateIndexJ,
jobject parametersJ);

// Create an index with ids and vectors. Instead of creating a new index, this function creates the index
// based off of the template index passed in. The index is serialized to indexPathJ.
void CreateBinaryIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ,
jlong vectorsAddressJ, jint dimJ, jstring indexPathJ, jbyteArray templateIndexJ,
jobject parametersJ);

// Load an index from indexPathJ into memory.
//
// Return a pointer to the loaded index
Expand Down Expand Up @@ -80,6 +86,12 @@ namespace knn_jni {
jobjectArray QueryBinaryIndex_WithFilter(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ,
jbyteArray queryVectorJ, jint kJ, jobject methodParamsJ, jlongArray filterIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ);

// Execute a query against the binary index located in memory at indexPointerJ along with Filters
//
// Return an array of KNNQueryResults
jobjectArray QueryBinaryIndex_WithFilter(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ,
jbyteArray queryVectorJ, jint kJ, jlongArray filterIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ);

// Free the index located in memory at indexPointerJ
void Free(jlong indexPointer, jboolean isBinaryIndexJ);

Expand All @@ -96,6 +108,13 @@ namespace knn_jni {
jbyteArray TrainIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jobject parametersJ, jint dimension,
jlong trainVectorsPointerJ);

// Create an empty binary index defined by the values in the Java map, parametersJ. Train the index with
// the vector of floats located at trainVectorsPointerJ.
//
// Return the serialized representation
jbyteArray TrainBinaryIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jobject parametersJ, jint dimension,
jlong trainVectorsPointerJ);

/*
* Perform a range search with filter against the index located in memory at indexPointerJ.
*
Expand Down
16 changes: 16 additions & 0 deletions jni/include/org_opensearch_knn_jni_FaissService.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,14 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createBinaryInde
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndexFromTemplate
(JNIEnv *, jclass, jintArray, jlong, jint, jstring, jbyteArray, jobject);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: createBinaryIndexFromTemplate
* Signature: ([IJILjava/lang/String;[BLjava/util/Map;)V
*/
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createBinaryIndexFromTemplate
(JNIEnv *, jclass, jintArray, jlong, jint, jstring, jbyteArray, jobject);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: loadIndex
Expand Down Expand Up @@ -139,6 +147,14 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_initLibrary
JNIEXPORT jbyteArray JNICALL Java_org_opensearch_knn_jni_FaissService_trainIndex
(JNIEnv *, jclass, jobject, jint, jlong);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: trainBinaryIndex
* Signature: (Ljava/util/Map;IJ)[B
*/
JNIEXPORT jbyteArray JNICALL Java_org_opensearch_knn_jni_FaissService_trainBinaryIndex
(JNIEnv *, jclass, jobject, jint, jlong);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: transferVectors
Expand Down
133 changes: 133 additions & 0 deletions jni/src/faiss_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ void SetExtraParameters(knn_jni::JNIUtilInterface * jniUtil, JNIEnv *env,
// Train an index with data provided
void InternalTrainIndex(faiss::Index * index, faiss::idx_t n, const float* x);

// Train a binary index with data provided
void InternalTrainBinaryIndex(faiss::IndexBinary * index, faiss::idx_t n, const float* x);

// Converts the int FilterIds to Faiss ids type array.
void convertFilterIdsToFaissIdType(const int* filterIds, int filterIdsLength, faiss::idx_t* convertedFilterIds);

Expand Down Expand Up @@ -223,6 +226,76 @@ void knn_jni::faiss_wrapper::CreateIndexFromTemplate(knn_jni::JNIUtilInterface *
faiss::write_index(&idMap, indexPathCpp.c_str());
}

void knn_jni::faiss_wrapper::CreateBinaryIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ,
jlong vectorsAddressJ, jint dimJ, jstring indexPathJ,
jbyteArray templateIndexJ, jobject parametersJ) {
if (idsJ == nullptr) {
throw std::runtime_error("IDs cannot be null");
}

if (vectorsAddressJ <= 0) {
throw std::runtime_error("VectorsAddress cannot be less than 0");
}

if(dimJ <= 0) {
throw std::runtime_error("Vectors dimensions cannot be less than or equal to 0");
}

if (indexPathJ == nullptr) {
throw std::runtime_error("Index path cannot be null");
}

if (templateIndexJ == nullptr) {
throw std::runtime_error("Template index cannot be null");
}

// Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread
auto parametersCpp = jniUtil->ConvertJavaMapToCppMap(env, parametersJ);
if(parametersCpp.find(knn_jni::INDEX_THREAD_QUANTITY) != parametersCpp.end()) {
auto threadCount = jniUtil->ConvertJavaObjectToCppInteger(env, parametersCpp[knn_jni::INDEX_THREAD_QUANTITY]);
omp_set_num_threads(threadCount);
}
jniUtil->DeleteLocalRef(env, parametersJ);

// Read data set
// Read vectors from memory address
auto *inputVectors = reinterpret_cast<std::vector<uint8_t>*>(vectorsAddressJ);
int dim = (int)dimJ;
if (dim % 8 != 0) {
throw std::runtime_error("Dimensions should be multiply of 8");
}
int numVectors = (int) (inputVectors->size() / (uint64_t) (dim / 8));
int numIds = jniUtil->GetJavaIntArrayLength(env, idsJ);
if (numIds != numVectors) {
throw std::runtime_error("Number of IDs does not match number of vectors");
}

// Get vector of bytes from jbytearray
int indexBytesCount = jniUtil->GetJavaBytesArrayLength(env, templateIndexJ);
jbyte * indexBytesJ = jniUtil->GetByteArrayElements(env, templateIndexJ, nullptr);

faiss::VectorIOReader vectorIoReader;
for (int i = 0; i < indexBytesCount; i++) {
vectorIoReader.data.push_back((uint8_t) indexBytesJ[i]);
}
jniUtil->ReleaseByteArrayElements(env, templateIndexJ, indexBytesJ, JNI_ABORT);

// Create faiss index
std::unique_ptr<faiss::IndexBinary> indexWriter;
indexWriter.reset(faiss::read_index_binary(&vectorIoReader, 0));

auto idVector = jniUtil->ConvertJavaIntArrayToCppIntVector(env, idsJ);
faiss::IndexBinaryIDMap idMap = faiss::IndexBinaryIDMap(indexWriter.get());
idMap.add_with_ids(numVectors, reinterpret_cast<const uint8_t*>(inputVectors->data()), idVector.data());
// Releasing the vectorsAddressJ memory as that is not required once we have created the index.
// This is not the ideal approach, please refer this gh issue for long term solution:
// https://github.com/opensearch-project/k-NN/issues/1600
delete inputVectors;
// Write the index to disk
std::string indexPathCpp(jniUtil->ConvertJavaStringToCppString(env, indexPathJ));
faiss::write_index_binary(&idMap, indexPathCpp.c_str());
}

jlong knn_jni::faiss_wrapper::LoadIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jstring indexPathJ) {
if (indexPathJ == nullptr) {
throw std::runtime_error("Index path cannot be null");
Expand Down Expand Up @@ -624,6 +697,57 @@ jbyteArray knn_jni::faiss_wrapper::TrainIndex(knn_jni::JNIUtilInterface * jniUti
return ret;
}

jbyteArray knn_jni::faiss_wrapper::TrainBinaryIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jobject parametersJ,
jint dimensionJ, jlong trainVectorsPointerJ) {
// First, we need to build the index
if (parametersJ == nullptr) {
throw std::runtime_error("Parameters cannot be null");
}

auto parametersCpp = jniUtil->ConvertJavaMapToCppMap(env, parametersJ);

jobject spaceTypeJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::SPACE_TYPE);
std::string spaceTypeCpp(jniUtil->ConvertJavaObjectToCppString(env, spaceTypeJ));
faiss::MetricType metric = TranslateSpaceToMetric(spaceTypeCpp);

// Create faiss index
jobject indexDescriptionJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::INDEX_DESCRIPTION);
std::string indexDescriptionCpp(jniUtil->ConvertJavaObjectToCppString(env, indexDescriptionJ));

std::unique_ptr<faiss::IndexBinary> indexWriter;
indexWriter.reset(faiss::index_binary_factory((int) dimensionJ, indexDescriptionCpp.c_str()));

// Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread
if(parametersCpp.find(knn_jni::INDEX_THREAD_QUANTITY) != parametersCpp.end()) {
auto threadCount = jniUtil->ConvertJavaObjectToCppInteger(env, parametersCpp[knn_jni::INDEX_THREAD_QUANTITY]);
omp_set_num_threads(threadCount);
}

// Train index if needed
auto *trainingVectorsPointerCpp = reinterpret_cast<std::vector<float>*>(trainVectorsPointerJ);
int numVectors = trainingVectorsPointerCpp->size()/(int) dimensionJ;
if(!indexWriter->is_trained) {
InternalTrainBinaryIndex(indexWriter.get(), numVectors, trainingVectorsPointerCpp->data());
}
jniUtil->DeleteLocalRef(env, parametersJ);

// Now that indexWriter is trained, we just load the bytes into an array and return
faiss::VectorIOWriter vectorIoWriter;
faiss::write_index_binary(indexWriter.get(), &vectorIoWriter);

// Wrap in smart pointer
std::unique_ptr<jbyte[]> jbytesBuffer;
jbytesBuffer.reset(new jbyte[vectorIoWriter.data.size()]);
int c = 0;
for (auto b : vectorIoWriter.data) {
jbytesBuffer[c++] = (jbyte) b;
}

jbyteArray ret = jniUtil->NewByteArray(env, vectorIoWriter.data.size());
jniUtil->SetByteArrayRegion(env, ret, 0, vectorIoWriter.data.size(), jbytesBuffer.get());
return ret;
}

faiss::MetricType TranslateSpaceToMetric(const std::string& spaceType) {
if (spaceType == knn_jni::L2) {
return faiss::METRIC_L2;
Expand Down Expand Up @@ -682,6 +806,15 @@ void InternalTrainIndex(faiss::Index * index, faiss::idx_t n, const float* x) {
}
}

void InternalTrainBinaryIndex(faiss::IndexBinary * index, faiss::idx_t n, const float* x) {
if (auto * indexIvf = dynamic_cast<faiss::IndexBinaryIVF*>(index)) {
indexIvf->make_direct_map();
}
if (!index->is_trained) {
index->train(n, reinterpret_cast<const uint8_t*>(x));
}
}

std::unique_ptr<faiss::IDGrouperBitmap> buildIDGrouperBitmap(knn_jni::JNIUtilInterface * jniUtil, JNIEnv *env, jintArray parentIdsJ, std::vector<uint64_t>* bitmap) {
int *parentIdsArray = jniUtil->GetIntArrayElements(env, parentIdsJ, nullptr);
int parentIdsLength = jniUtil->GetJavaIntArrayLength(env, parentIdsJ);
Expand Down
28 changes: 28 additions & 0 deletions jni/src/org_opensearch_knn_jni_FaissService.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,21 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndexFromT
}
}

JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createBinaryIndexFromTemplate(JNIEnv * env, jclass cls,
jintArray idsJ,
jlong vectorsAddressJ,
jint dimJ,
jstring indexPathJ,
jbyteArray templateIndexJ,
jobject parametersJ)
{
try {
knn_jni::faiss_wrapper::CreateBinaryIndexFromTemplate(&jniUtil, env, idsJ, vectorsAddressJ, dimJ, indexPathJ, templateIndexJ, parametersJ);
} catch (...) {
jniUtil.CatchCppExceptionAndThrowJava(env);
}
}

JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_loadIndex(JNIEnv * env, jclass cls, jstring indexPathJ)
{
try {
Expand Down Expand Up @@ -220,6 +235,19 @@ JNIEXPORT jbyteArray JNICALL Java_org_opensearch_knn_jni_FaissService_trainIndex
return nullptr;
}

JNIEXPORT jbyteArray JNICALL Java_org_opensearch_knn_jni_FaissService_trainBinaryIndex(JNIEnv * env, jclass cls,
jobject parametersJ,
jint dimensionJ,
jlong trainVectorsPointerJ)
{
try {
return knn_jni::faiss_wrapper::TrainBinaryIndex(&jniUtil, env, parametersJ, dimensionJ, trainVectorsPointerJ);
} catch (...) {
jniUtil.CatchCppExceptionAndThrowJava(env);
}
return nullptr;
}

JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_transferVectors(JNIEnv * env, jclass cls,
jlong vectorsPointerJ,
jobjectArray vectorsJ)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

package org.opensearch.knn.index.codec.KNN80Codec;

import com.google.common.collect.ImmutableMap;
import lombok.NonNull;
import lombok.extern.log4j.Log4j2;
import org.apache.lucene.store.ChecksumIndexInput;
Expand Down Expand Up @@ -112,10 +111,22 @@ private KNNEngine getKNNEngine(@NonNull FieldInfo field) {
}

private VectorTransfer getVectorTransfer(FieldInfo field) {
if (VectorDataType.BINARY.getValue().equalsIgnoreCase(field.attributes().get(KNNConstants.VECTOR_DATA_TYPE_FIELD))) {
boolean isBinary = false;

// Check if the field has a model ID and retrieve the model's vector data type
if (field.attributes().containsKey(MODEL_ID)) {
Model model = ModelCache.getInstance().get(field.attributes().get(MODEL_ID));
isBinary = model.getModelMetadata().getVectorDataType() == VectorDataType.BINARY;
} else if (VectorDataType.BINARY.getValue().equalsIgnoreCase(field.attributes().get(KNNConstants.VECTOR_DATA_TYPE_FIELD))) {
isBinary = true;
}

// Return the appropriate VectorTransfer instance based on the vector data type
if (isBinary) {
return new VectorTransferByte(KNNSettings.getVectorStreamingMemoryLimit().getBytes());
} else {
return new VectorTransferFloat(KNNSettings.getVectorStreamingMemoryLimit().getBytes());
}
return new VectorTransferFloat(KNNSettings.getVectorStreamingMemoryLimit().getBytes());
}

public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer, boolean isMerge, boolean isRefresh)
Expand Down Expand Up @@ -154,7 +165,7 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer,
if (model.getModelBlob() == null) {
throw new RuntimeException(String.format("There is no trained model with id \"%s\"", modelId));
}
indexCreator = () -> createKNNIndexFromTemplate(model.getModelBlob(), pair, knnEngine, indexPath);
indexCreator = () -> createKNNIndexFromTemplate(model, pair, knnEngine, indexPath);
} else {
indexCreator = () -> createKNNIndexFromScratch(field, pair, knnEngine, indexPath);
}
Expand Down Expand Up @@ -188,18 +199,25 @@ private void recordRefreshStats() {
KNNGraphValue.REFRESH_TOTAL_OPERATIONS.increment();
}

private void createKNNIndexFromTemplate(byte[] model, KNNCodecUtil.Pair pair, KNNEngine knnEngine, String indexPath) {
Map<String, Object> parameters = ImmutableMap.of(
KNNConstants.INDEX_THREAD_QTY,
KNNSettings.state().getSettingValue(KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY)
);
private void createKNNIndexFromTemplate(Model model, KNNCodecUtil.Pair pair, KNNEngine knnEngine, String indexPath) {
Map<String, Object> parameters = new HashMap<>();
parameters.put(KNNConstants.INDEX_THREAD_QTY, KNNSettings.state().getSettingValue(KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY));

// Update index description of Faiss for binary data type
if (KNNEngine.FAISS == knnEngine && VectorDataType.BINARY.equals(model.getModelMetadata().getVectorDataType())) {
parameters.put(
KNNConstants.INDEX_DESCRIPTION_PARAMETER,
FAISS_BINARY_INDEX_DESCRIPTION_PREFIX + model.getModelMetadata().getMethodComponentContext().getName().toUpperCase()
);
}

AccessController.doPrivileged((PrivilegedAction<Void>) () -> {
JNIService.createIndexFromTemplate(
pair.docs,
pair.getVectorAddress(),
pair.getDimension(),
indexPath,
model,
model.getModelBlob(),
parameters,
knnEngine
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,8 @@ protected void parseCreateField(ParseContext context) throws IOException {
context,
fieldType().getDimension(),
fieldType().getSpaceType(),
getMethodComponentContext(fieldType().getKnnMethodContext())
getMethodComponentContext(fieldType().getKnnMethodContext()),
fieldType().getVectorDataType()
);
}

Expand Down Expand Up @@ -605,8 +606,13 @@ protected List<Field> getFieldsForByteVector(final byte[] array, final FieldType
return fields;
}

protected void parseCreateField(ParseContext context, int dimension, SpaceType spaceType, MethodComponentContext methodComponentContext)
throws IOException {
protected void parseCreateField(
ParseContext context,
int dimension,
SpaceType spaceType,
MethodComponentContext methodComponentContext,
VectorDataType vectorDataType
) throws IOException {

validateIfKNNPluginEnabled();
validateIfCircuitBreakerIsNotTriggered();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,12 @@ protected void parseCreateField(ParseContext context) throws IOException {
);
}

parseCreateField(context, modelMetadata.getDimension(), modelMetadata.getSpaceType(), modelMetadata.getMethodComponentContext());
parseCreateField(
context,
modelMetadata.getDimension(),
modelMetadata.getSpaceType(),
modelMetadata.getMethodComponentContext(),
modelMetadata.getVectorDataType()
);
}
}
2 changes: 1 addition & 1 deletion src/main/java/org/opensearch/knn/index/util/Faiss.java
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ public class Faiss extends NativeLibrary {
return ((4L * centroids * dimension) / BYTES_PER_KILOBYTES) + 1;
})
.build()
).addSpaces(SpaceType.UNDEFINED, SpaceType.L2, SpaceType.INNER_PRODUCT).build()
).addSpaces(SpaceType.UNDEFINED, SpaceType.L2, SpaceType.INNER_PRODUCT, SpaceType.HAMMING_BIT).build()
);

final static Faiss INSTANCE = new Faiss(
Expand Down
Loading

0 comments on commit 648f342

Please sign in to comment.