Skip to content

Commit

Permalink
Use data type check when create index from template and fix jni errors
Browse files Browse the repository at this point in the history
Signed-off-by: Junqiu Lei <[email protected]>
  • Loading branch information
junqiu-lei committed Jul 9, 2024
1 parent 2609a68 commit 629757b
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 34 deletions.
8 changes: 0 additions & 8 deletions jni/include/org_opensearch_knn_jni_FaissService.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,14 +116,6 @@ JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryInd
JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryBinaryIndexWithFilter
(JNIEnv *, jclass, jlong, jbyteArray, jint, jobject, jlongArray, jint, jintArray);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: queryBIndexWithFilter
* Signature: (J[BI[JI[I)[Lorg/opensearch/knn/index/query/KNNQueryResult;
*/
JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryBinaryIndexWithFilter
(JNIEnv *, jclass, jlong, jbyteArray, jint, jlongArray, jint, jintArray);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: free
Expand Down
12 changes: 5 additions & 7 deletions jni/src/faiss_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,6 @@ void knn_jni::faiss_wrapper::CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JN
}
// end parameters to pass

std::cout << "Index description in CreateIndex: " << indexDescriptionCpp << std::endl;

// Create index
indexService->createIndex(jniUtil, env, metric, indexDescriptionCpp, dim, numIds, threadCount, vectorsAddress, ids, indexPathCpp, subParametersCpp);
}
Expand Down Expand Up @@ -263,7 +261,10 @@ void knn_jni::faiss_wrapper::CreateBinaryIndexFromTemplate(knn_jni::JNIUtilInter
// Read vectors from memory address
auto *inputVectors = reinterpret_cast<std::vector<uint8_t>*>(vectorsAddressJ);
int dim = (int)dimJ;
int numVectors = (int) (inputVectors->size() / (uint64_t) dim);
if (dim % 8 != 0) {
throw std::runtime_error("Dimensions should be multiply of 8");
}
int numVectors = (int) (inputVectors->size() / (uint64_t) (dim / 8));
int numIds = jniUtil->GetJavaIntArrayLength(env, idsJ);
if (numIds != numVectors) {
throw std::runtime_error("Number of IDs does not match number of vectors");
Expand All @@ -285,7 +286,7 @@ void knn_jni::faiss_wrapper::CreateBinaryIndexFromTemplate(knn_jni::JNIUtilInter

auto idVector = jniUtil->ConvertJavaIntArrayToCppIntVector(env, idsJ);
faiss::IndexBinaryIDMap idMap = faiss::IndexBinaryIDMap(indexWriter.get());
idMap.add_with_ids(numVectors, inputVectors->data(), idVector.data());
idMap.add_with_ids(numVectors, reinterpret_cast<const uint8_t*>(inputVectors->data()), idVector.data());
// Releasing the vectorsAddressJ memory as that is not required once we have created the index.
// This is not the ideal approach, please refer this gh issue for long term solution:
// https://github.com/opensearch-project/k-NN/issues/1600
Expand Down Expand Up @@ -647,7 +648,6 @@ jbyteArray knn_jni::faiss_wrapper::TrainIndex(knn_jni::JNIUtilInterface * jniUti
jobject indexDescriptionJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::INDEX_DESCRIPTION);
std::string indexDescriptionCpp(jniUtil->ConvertJavaObjectToCppString(env, indexDescriptionJ));

std::cout << "Index description in TrainIndex: " << indexDescriptionCpp << std::endl;
std::unique_ptr<faiss::Index> indexWriter;
indexWriter.reset(faiss::index_factory((int) dimensionJ, indexDescriptionCpp.c_str(), metric));

Expand Down Expand Up @@ -714,7 +714,6 @@ jbyteArray knn_jni::faiss_wrapper::TrainBinaryIndex(knn_jni::JNIUtilInterface *
jobject indexDescriptionJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::INDEX_DESCRIPTION);
std::string indexDescriptionCpp(jniUtil->ConvertJavaObjectToCppString(env, indexDescriptionJ));

std::cout << "Index description in TrainIndex: " << indexDescriptionCpp << std::endl;
std::unique_ptr<faiss::IndexBinary> indexWriter;
indexWriter.reset(faiss::index_binary_factory((int) dimensionJ, indexDescriptionCpp.c_str()));

Expand Down Expand Up @@ -809,7 +808,6 @@ void InternalTrainIndex(faiss::Index * index, faiss::idx_t n, const float* x) {

void InternalTrainBinaryIndex(faiss::IndexBinary * index, faiss::idx_t n, const float* x) {
if (auto * indexIvf = dynamic_cast<faiss::IndexBinaryIVF*>(index)) {
std::cout << "Index is IVFBinary" << std::endl;
indexIvf->make_direct_map();
}
if (!index->is_trained) {
Expand Down
14 changes: 1 addition & 13 deletions jni/src/org_opensearch_knn_jni_FaissService.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,19 +194,7 @@ JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryBin

}

JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryBinaryIndexWithFilter
(JNIEnv * env, jclass cls, jlong indexPointerJ, jbyteArray queryVectorJ, jint kJ, jlongArray filteredIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ) {

try {
return knn_jni::faiss_wrapper::QueryBinaryIndex_WithFilter(&jniUtil, env, indexPointerJ, queryVectorJ, kJ, filteredIdsJ, filterIdsTypeJ, parentIdsJ);
} catch (...) {
jniUtil.CatchCppExceptionAndThrowJava(env);
}
return nullptr;

}

JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_free(JNIEnv * env, jclass cls, jlong indexPointerJ)
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_free(JNIEnv * env, jclass cls, jlong indexPointerJ, jboolean isBinaryIndexJ)
{
try {
return knn_jni::faiss_wrapper::Free(indexPointerJ, isBinaryIndexJ);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,22 @@ private KNNEngine getKNNEngine(@NonNull FieldInfo field) {
}

private VectorTransfer getVectorTransfer(FieldInfo field) {
if (VectorDataType.BINARY.getValue().equalsIgnoreCase(field.attributes().get(KNNConstants.VECTOR_DATA_TYPE_FIELD))) {
boolean isBinary = false;

// Check if the field has a model ID and retrieve the model's vector data type
if (field.attributes().containsKey(MODEL_ID)) {
Model model = ModelCache.getInstance().get(field.attributes().get(MODEL_ID));
isBinary = model.getModelMetadata().getVectorDataType() == VectorDataType.BINARY;
} else if (VectorDataType.BINARY.getValue().equalsIgnoreCase(field.attributes().get(KNNConstants.VECTOR_DATA_TYPE_FIELD))) {
isBinary = true;
}

// Return the appropriate VectorTransfer instance based on the vector data type
if (isBinary) {
return new VectorTransferByte(KNNSettings.getVectorStreamingMemoryLimit().getBytes());
} else {
return new VectorTransferFloat(KNNSettings.getVectorStreamingMemoryLimit().getBytes());
}
return new VectorTransferFloat(KNNSettings.getVectorStreamingMemoryLimit().getBytes());
}

public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer, boolean isMerge, boolean isRefresh)
Expand Down Expand Up @@ -194,7 +206,7 @@ private void createKNNIndexFromTemplate(Model model, KNNCodecUtil.Pair pair, KNN
KNNSettings.state().getSettingValue(KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY));

// Update index description of Faiss for binary data type
if (KNNEngine.FAISS == knnEngine && SpaceType.HAMMING_BIT.equals(model.getModelMetadata().getSpaceType())) {
if (KNNEngine.FAISS == knnEngine && VectorDataType.BINARY.equals(model.getModelMetadata().getVectorDataType())) {
parameters.put(
KNNConstants.INDEX_DESCRIPTION_PARAMETER,
FAISS_BINARY_INDEX_DESCRIPTION_PREFIX + model.getModelMetadata().getMethodComponentContext().getName().toUpperCase()
Expand Down
6 changes: 3 additions & 3 deletions src/main/java/org/opensearch/knn/indices/ModelMetadata.java
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ public ModelMetadata(StreamInput in) throws IOException {
this.methodComponentContext = MethodComponentContext.EMPTY;
}

if (IndexUtil.isVersionOnOrAfterMinRequiredVersion(in.getVersion(), IndexUtil.MODEL_DATA_TYPE)) {
if (IndexUtil.isVersionOnOrAfterMinRequiredVersion(in.getVersion(), KNNConstants.METHOD_PARAMETER)) {
this.vectorDataType = VectorDataType.get(in.readOptionalString());
} else {
this.vectorDataType = VectorDataType.FLOAT;
Expand Down Expand Up @@ -461,7 +461,7 @@ public void writeTo(StreamOutput out) throws IOException {
if (IndexUtil.isVersionOnOrAfterMinRequiredVersion(out.getVersion(), IndexUtil.MODEL_METHOD_COMPONENT_CONTEXT_KEY)) {
getMethodComponentContext().writeTo(out);
}
if (IndexUtil.isVersionOnOrAfterMinRequiredVersion(out.getVersion(), IndexUtil.MODEL_DATA_TYPE)) {
if (IndexUtil.isVersionOnOrAfterMinRequiredVersion(out.getVersion(), KNNConstants.METHOD_PARAMETER)) {
out.writeOptionalString(vectorDataType.getValue());
}
}
Expand All @@ -484,7 +484,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
getMethodComponentContext().toXContent(builder, params);
builder.endObject();
}
if (IndexUtil.isClusterOnOrAfterMinRequiredVersion(IndexUtil.MODEL_DATA_TYPE)) {
if (IndexUtil.isClusterOnOrAfterMinRequiredVersion(KNNConstants.METHOD_PARAMETER)) {
builder.field(KNNConstants.VECTOR_DATA_TYPE_FIELD, vectorDataType.getValue());
}
return builder;
Expand Down
2 changes: 2 additions & 0 deletions src/main/java/org/opensearch/knn/jni/JNIService.java
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,10 @@ public static void createIndexFromTemplate(
if (KNNEngine.FAISS == knnEngine) {
if (faissUtil.isBinaryIndex(parameters)) {
FaissService.createBinaryIndexFromTemplate(ids, vectorsAddress, dim, indexPath, templateIndex, parameters);
return;
} else {
FaissService.createIndexFromTemplate(ids, vectorsAddress, dim, indexPath, templateIndex, parameters);
return;
}
}

Expand Down

0 comments on commit 629757b

Please sign in to comment.