Skip to content

Commit

Permalink
Reduce heap usage for knn index writers (#13538)
Browse files Browse the repository at this point in the history
* Reduce heap usage for knn index writers

* iter

* fixing heap usage & adding changes

* javadocs
  • Loading branch information
benwtrent authored Jul 10, 2024
1 parent 026d661 commit 428fdb5
Show file tree
Hide file tree
Showing 6 changed files with 182 additions and 127 deletions.
2 changes: 2 additions & 0 deletions lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,8 @@ Optimizations

* GITHUB#13175: Stop double-checking priority queue inserts in some FacetCount classes. (Jakub Slowinski)

* GITHUB#13538: Slightly reduce heap usage for HNSW and scalar quantized vector writers. (Ben Trent)

Changes in runtime behavior
---------------------

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@

package org.apache.lucene.codecs.hnsw;

import java.io.IOException;
import java.util.List;
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
import org.apache.lucene.index.DocsWithFieldSet;

/**
* Vectors' writer for a field
Expand All @@ -26,20 +29,25 @@
* @lucene.experimental
*/
public abstract class FlatFieldVectorsWriter<T> extends KnnFieldVectorsWriter<T> {
/**
* @return a list of vectors to be written
*/
public abstract List<T> getVectors();

/**
* The delegate to write to, can be null When non-null, all vectors seen should be written to the
* delegate along with being written to the flat vectors.
* @return the docsWithFieldSet for the field writer
*/
protected final KnnFieldVectorsWriter<T> indexingDelegate;
public abstract DocsWithFieldSet getDocsWithFieldSet();

/**
* Sole constructor that expects some indexingDelegate. All vectors seen should be written to the
* delegate along with being written to the flat vectors.
* indicates that this writer is done and no new vectors are allowed to be added
*
* @param indexingDelegate the delegate to write to, can be null
* @throws IOException if an I/O error occurs
*/
public abstract void finish() throws IOException;

/**
* @return true if the writer is done and no new vectors are allowed to be added
*/
protected FlatFieldVectorsWriter(KnnFieldVectorsWriter<T> indexingDelegate) {
this.indexingDelegate = indexingDelegate;
}
public abstract boolean isFinished();
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
package org.apache.lucene.codecs.hnsw;

import java.io.IOException;
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.MergeState;
Expand Down Expand Up @@ -46,21 +45,14 @@ public FlatVectorsScorer getFlatVectorScorer() {
}

/**
* Add a new field for indexing, allowing the user to provide a writer that the flat vectors
* writer can delegate to if additional indexing logic is required.
* Add a new field for indexing
*
* @param fieldInfo fieldInfo of the field to add
* @param indexWriter the writer to delegate to, can be null
* @return a writer for the field
* @throws IOException if an I/O error occurs when adding the field
*/
public abstract FlatFieldVectorsWriter<?> addField(
FieldInfo fieldInfo, KnnFieldVectorsWriter<?> indexWriter) throws IOException;

@Override
public FlatFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOException {
return addField(fieldInfo, null);
}
public abstract FlatFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOException;

/**
* Write the field for merging, providing a scorer over the newly merged flat vectors. This way
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import java.util.ArrayList;
import java.util.List;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter;
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
Expand Down Expand Up @@ -111,18 +110,12 @@ public Lucene99FlatVectorsWriter(SegmentWriteState state, FlatVectorsScorer scor
}

@Override
public FlatFieldVectorsWriter<?> addField(
FieldInfo fieldInfo, KnnFieldVectorsWriter<?> indexWriter) throws IOException {
FieldWriter<?> newField = FieldWriter.create(fieldInfo, indexWriter);
public FlatFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOException {
FieldWriter<?> newField = FieldWriter.create(fieldInfo);
fields.add(newField);
return newField;
}

@Override
public FlatFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOException {
return addField(fieldInfo, null);
}

@Override
public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
for (FieldWriter<?> field : fields) {
Expand All @@ -131,6 +124,7 @@ public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
} else {
writeSortingField(field, maxDoc, sortMap);
}
field.finish();
}
}

Expand Down Expand Up @@ -403,22 +397,20 @@ private abstract static class FieldWriter<T> extends FlatFieldVectorsWriter<T> {
private final int dim;
private final DocsWithFieldSet docsWithField;
private final List<T> vectors;
private boolean finished;

private int lastDocID = -1;

@SuppressWarnings("unchecked")
static FieldWriter<?> create(FieldInfo fieldInfo, KnnFieldVectorsWriter<?> indexWriter) {
static FieldWriter<?> create(FieldInfo fieldInfo) {
int dim = fieldInfo.getVectorDimension();
return switch (fieldInfo.getVectorEncoding()) {
case BYTE -> new Lucene99FlatVectorsWriter.FieldWriter<>(
fieldInfo, (KnnFieldVectorsWriter<byte[]>) indexWriter) {
case BYTE -> new Lucene99FlatVectorsWriter.FieldWriter<byte[]>(fieldInfo) {
@Override
public byte[] copyValue(byte[] value) {
return ArrayUtil.copyOfSubArray(value, 0, dim);
}
};
case FLOAT32 -> new Lucene99FlatVectorsWriter.FieldWriter<>(
fieldInfo, (KnnFieldVectorsWriter<float[]>) indexWriter) {
case FLOAT32 -> new Lucene99FlatVectorsWriter.FieldWriter<float[]>(fieldInfo) {
@Override
public float[] copyValue(float[] value) {
return ArrayUtil.copyOfSubArray(value, 0, dim);
Expand All @@ -427,8 +419,8 @@ public float[] copyValue(float[] value) {
};
}

FieldWriter(FieldInfo fieldInfo, KnnFieldVectorsWriter<T> indexWriter) {
super(indexWriter);
FieldWriter(FieldInfo fieldInfo) {
super();
this.fieldInfo = fieldInfo;
this.dim = fieldInfo.getVectorDimension();
this.docsWithField = new DocsWithFieldSet();
Expand All @@ -437,6 +429,9 @@ public float[] copyValue(float[] value) {

@Override
public void addValue(int docID, T vectorValue) throws IOException {
if (finished) {
throw new IllegalStateException("already finished, cannot add more values");
}
if (docID == lastDocID) {
throw new IllegalArgumentException(
"VectorValuesField \""
Expand All @@ -448,17 +443,11 @@ public void addValue(int docID, T vectorValue) throws IOException {
docsWithField.add(docID);
vectors.add(copy);
lastDocID = docID;
if (indexingDelegate != null) {
indexingDelegate.addValue(docID, copy);
}
}

@Override
public long ramBytesUsed() {
long size = SHALLOW_RAM_BYTES_USED;
if (indexingDelegate != null) {
size += indexingDelegate.ramBytesUsed();
}
if (vectors.size() == 0) return size;
return size
+ docsWithField.ramBytesUsed()
Expand All @@ -468,6 +457,29 @@ public long ramBytesUsed() {
* fieldInfo.getVectorDimension()
* fieldInfo.getVectorEncoding().byteSize;
}

@Override
public List<T> getVectors() {
return vectors;
}

@Override
public DocsWithFieldSet getDocsWithFieldSet() {
return docsWithField;
}

@Override
public void finish() throws IOException {
if (finished) {
return;
}
this.finished = true;
}

@Override
public boolean isFinished() {
return finished;
}
}

static final class FlatCloseableRandomVectorScorerSupplier
Expand Down
Loading

0 comments on commit 428fdb5

Please sign in to comment.