diff --git a/jni/include/faiss_index_service.h b/jni/include/faiss_index_service.h index a1d319fe47..9a357dd452 100644 --- a/jni/include/faiss_index_service.h +++ b/jni/include/faiss_index_service.h @@ -92,7 +92,7 @@ class IndexService { int numIds, int threadCount, int64_t vectorsAddress, - std::vector & ids, + std::vector ids, std::string indexPath, std::unordered_map parameters); virtual ~IndexService() = default; @@ -167,7 +167,7 @@ class BinaryIndexService : public IndexService { int numIds, int threadCount, int64_t vectorsAddress, - std::vector & ids, + std::vector ids, std::string indexPath, std::unordered_map parameters ) override; diff --git a/jni/src/faiss_index_service.cpp b/jni/src/faiss_index_service.cpp index 8148c599cf..4d410f6458 100644 --- a/jni/src/faiss_index_service.cpp +++ b/jni/src/faiss_index_service.cpp @@ -171,7 +171,7 @@ void IndexService::createIndex( int numIds, int threadCount, int64_t vectorsAddress, - std::vector & ids, + std::vector ids, std::string indexPath, std::unordered_map parameters ) { @@ -327,7 +327,7 @@ void BinaryIndexService::createIndex( int numIds, int threadCount, int64_t vectorsAddress, - std::vector & ids, + std::vector ids, std::string indexPath, std::unordered_map parameters ) { diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java index b9c9e733b9..77e00ea3ee 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java @@ -114,12 +114,7 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer, return; } BinaryDocValues values = valuesProducer.getBinary(field); - long num_docs = KNNCodecUtil.getTotalLiveDocsCount(values); - long totalArraySize = 0; - long totalDocsIncrement = 0; - if (isMerge) { - KNNGraphValue.MERGE_CURRENT_OPERATIONS.increment(); - } + long numDocs = KNNCodecUtil.getTotalLiveDocsCount(values); // Increment counter for number of graph index requests KNNCounter.GRAPH_INDEX_REQUESTS.increment(); final KNNEngine knnEngine = getKNNEngine(field); @@ -137,44 +132,39 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer, // engineFileName is added to the tracked files by Lucene's TrackingDirectoryWrapper. Otherwise, the file will // not be marked as added to the directory. state.directory.createOutput(engineFileName, state.context).close(); + KNNCodecUtil.VectorBatch pair; + // Get first pair + if(field.attributes().containsKey(MODEL_ID) || KNNEngine.NMSLIB == knnEngine) { + pair = KNNCodecUtil.readAllVectors(values); + } else { + pair = KNNCodecUtil.readVectorBatch(values); + } + int dim = 0; + // This will be cleaner once support for initIndexFromTemplate is added. if (field.attributes().containsKey(MODEL_ID)) { String modelId = field.attributes().get(MODEL_ID); Model model = ModelCache.getInstance().get(modelId); if (model.getModelBlob() == null) { throw new RuntimeException(String.format("There is no trained model with id \"%s\"", modelId)); } - KNNCodecUtil.VectorBatch pair = KNNCodecUtil.readAllVectors(values); createKNNIndexFromTemplate(model.getModelBlob(), pair, knnEngine, indexPath); - totalDocsIncrement += pair.docs.length; - totalArraySize += calculateArraySize(pair.docs.length, pair.getDimension(), pair.serializationMode); - } else { - if (KNNEngine.FAISS == knnEngine) { - KNNCodecUtil.VectorBatch pair = KNNCodecUtil.readVectorBatch(values); - long indexAddress = initIndexFromScratch(field, num_docs, pair.getDimension(), knnEngine); - while (true) { - if (pair.docs.length != 0) insertToIndex(pair, knnEngine, indexAddress); - if (isMerge) { - totalArraySize += calculateArraySize(pair.docs.length, pair.getDimension(), pair.serializationMode); - totalDocsIncrement += pair.docs.length; - } - if (pair.finished) { - break; - } - pair = KNNCodecUtil.readVectorBatch(values); - } - writeIndex(indexAddress, indexPath, knnEngine); - } else { - // Note that iterative graph construction has not yet been implemented for nmslib - KNNCodecUtil.VectorBatch pair = KNNCodecUtil.readAllVectors(values); - createKNNIndexFromScratch(field, pair, knnEngine, indexPath); - totalDocsIncrement += pair.docs.length; - totalArraySize += calculateArraySize(pair.docs.length, pair.getDimension(), pair.serializationMode); + } else if (KNNEngine.FAISS == knnEngine) { + long indexAddress = initIndexFromScratch(field, numDocs, pair.getDimension(), knnEngine); + insertToIndex(pair, knnEngine, indexAddress); + while (!pair.finished) { + pair = KNNCodecUtil.readVectorBatch(values); + insertToIndex(pair, knnEngine, indexAddress); } + writeIndex(indexAddress, indexPath, knnEngine); + } else { + // Note that iterative graph construction has not yet been implemented for nmslib + createKNNIndexFromScratch(field, pair, knnEngine, indexPath); } if (isMerge) { - KNNGraphValue.MERGE_CURRENT_DOCS.incrementBy(totalDocsIncrement); - KNNGraphValue.MERGE_CURRENT_SIZE_IN_BYTES.incrementBy(totalArraySize); - recordMergeStats((int) totalDocsIncrement, totalArraySize); + KNNGraphValue.MERGE_CURRENT_OPERATIONS.increment(); + KNNGraphValue.MERGE_CURRENT_DOCS.incrementBy(numDocs); + KNNGraphValue.MERGE_CURRENT_SIZE_IN_BYTES.incrementBy(calculateArraySize((int)numDocs, dim, pair.serializationMode)); + recordMergeStats((int)numDocs, calculateArraySize((int)numDocs, dim, pair.serializationMode)); } if (isRefresh) { recordRefreshStats(); diff --git a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java index fe8200375c..edae3de0d6 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java @@ -237,7 +237,7 @@ public int advance(int target) throws IOException { @Override public long cost() { - return 0; + return count; } }