Skip to content

Commit

Permalink
Cleaned up addKNNBinaryField method. Will be less complex once initIn…
Browse files Browse the repository at this point in the history
…dexFromTemplate is added.

Signed-off-by: Andrew Klepchick <[email protected]>
  • Loading branch information
MrFlap committed Jul 17, 2024
1 parent 79e0fe3 commit c13925d
Showing 1 changed file with 24 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,7 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer,
return;
}
BinaryDocValues values = valuesProducer.getBinary(field);
long num_docs = KNNCodecUtil.getTotalLiveDocsCount(values);
long totalArraySize = 0;
long totalDocsIncrement = 0;
if (isMerge) {
KNNGraphValue.MERGE_CURRENT_OPERATIONS.increment();
}
long numDocs = KNNCodecUtil.getTotalLiveDocsCount(values);
// Increment counter for number of graph index requests
KNNCounter.GRAPH_INDEX_REQUESTS.increment();
final KNNEngine knnEngine = getKNNEngine(field);
Expand All @@ -137,44 +132,39 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer,
// engineFileName is added to the tracked files by Lucene's TrackingDirectoryWrapper. Otherwise, the file will
// not be marked as added to the directory.
state.directory.createOutput(engineFileName, state.context).close();
KNNCodecUtil.VectorBatch pair;
// Get first pair
if(field.attributes().containsKey(MODEL_ID) || KNNEngine.NMSLIB == knnEngine) {
pair = KNNCodecUtil.readAllVectors(values);
} else {
pair = KNNCodecUtil.readVectorBatch(values);
}
int dim = 0;
// This will be cleaner once support for initIndexFromTemplate is added.
if (field.attributes().containsKey(MODEL_ID)) {
String modelId = field.attributes().get(MODEL_ID);
Model model = ModelCache.getInstance().get(modelId);
if (model.getModelBlob() == null) {
throw new RuntimeException(String.format("There is no trained model with id \"%s\"", modelId));
}
KNNCodecUtil.VectorBatch pair = KNNCodecUtil.readAllVectors(values);
createKNNIndexFromTemplate(model.getModelBlob(), pair, knnEngine, indexPath);
totalDocsIncrement += pair.docs.length;
totalArraySize += calculateArraySize(pair.docs.length, pair.getDimension(), pair.serializationMode);
} else {
if (KNNEngine.FAISS == knnEngine) {
KNNCodecUtil.VectorBatch pair = KNNCodecUtil.readVectorBatch(values);
long indexAddress = initIndexFromScratch(field, num_docs, pair.getDimension(), knnEngine);
while (true) {
if (pair.docs.length != 0) insertToIndex(pair, knnEngine, indexAddress);
if (isMerge) {
totalArraySize += calculateArraySize(pair.docs.length, pair.getDimension(), pair.serializationMode);
totalDocsIncrement += pair.docs.length;
}
if (pair.finished) {
break;
}
pair = KNNCodecUtil.readVectorBatch(values);
}
writeIndex(indexAddress, indexPath, knnEngine);
} else {
// Note that iterative graph construction has not yet been implemented for nmslib
KNNCodecUtil.VectorBatch pair = KNNCodecUtil.readAllVectors(values);
createKNNIndexFromScratch(field, pair, knnEngine, indexPath);
totalDocsIncrement += pair.docs.length;
totalArraySize += calculateArraySize(pair.docs.length, pair.getDimension(), pair.serializationMode);
} else if (KNNEngine.FAISS == knnEngine) {
long indexAddress = initIndexFromScratch(field, numDocs, pair.getDimension(), knnEngine);
insertToIndex(pair, knnEngine, indexAddress);
while (!pair.finished) {
pair = KNNCodecUtil.readVectorBatch(values);
insertToIndex(pair, knnEngine, indexAddress);
}
writeIndex(indexAddress, indexPath, knnEngine);
} else {
// Note that iterative graph construction has not yet been implemented for nmslib
createKNNIndexFromScratch(field, pair, knnEngine, indexPath);
}
if (isMerge) {
KNNGraphValue.MERGE_CURRENT_DOCS.incrementBy(totalDocsIncrement);
KNNGraphValue.MERGE_CURRENT_SIZE_IN_BYTES.incrementBy(totalArraySize);
recordMergeStats((int) totalDocsIncrement, totalArraySize);
KNNGraphValue.MERGE_CURRENT_OPERATIONS.increment();
KNNGraphValue.MERGE_CURRENT_DOCS.incrementBy(numDocs);
KNNGraphValue.MERGE_CURRENT_SIZE_IN_BYTES.incrementBy(calculateArraySize((int)numDocs, dim, pair.serializationMode));
recordMergeStats((int)numDocs, calculateArraySize((int)numDocs, dim, pair.serializationMode));
}
if (isRefresh) {
recordRefreshStats();
Expand Down

0 comments on commit c13925d

Please sign in to comment.