Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce heap usage for knn index writers #13538

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@

package org.apache.lucene.codecs.hnsw;

import java.io.IOException;
import java.util.List;
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
import org.apache.lucene.index.DocsWithFieldSet;

/**
* Vectors' writer for a field
Expand All @@ -26,20 +29,11 @@
* @lucene.experimental
*/
public abstract class FlatFieldVectorsWriter<T> extends KnnFieldVectorsWriter<T> {
public abstract List<T> getVectors();

/**
* The delegate to write to, can be null When non-null, all vectors seen should be written to the
* delegate along with being written to the flat vectors.
*/
protected final KnnFieldVectorsWriter<T> indexingDelegate;
public abstract DocsWithFieldSet getDocsWithFieldSet();

/**
* Sole constructor that expects some indexingDelegate. All vectors seen should be written to the
* delegate along with being written to the flat vectors.
*
* @param indexingDelegate the delegate to write to, can be null
*/
protected FlatFieldVectorsWriter(KnnFieldVectorsWriter<T> indexingDelegate) {
this.indexingDelegate = indexingDelegate;
}
public abstract void finish() throws IOException;

public abstract boolean isFinished();
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
package org.apache.lucene.codecs.hnsw;

import java.io.IOException;
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.MergeState;
Expand Down Expand Up @@ -50,17 +49,11 @@ public FlatVectorsScorer getFlatVectorScorer() {
* writer can delegate to if additional indexing logic is required.
*
* @param fieldInfo fieldInfo of the field to add
* @param indexWriter the writer to delegate to, can be null
* @return a writer for the field
* @throws IOException if an I/O error occurs when adding the field
*/
public abstract FlatFieldVectorsWriter<?> addField(
FieldInfo fieldInfo, KnnFieldVectorsWriter<?> indexWriter) throws IOException;

@Override
public FlatFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOException {
return addField(fieldInfo, null);
}
public abstract FlatFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOException;

/**
* Write the field for merging, providing a scorer over the newly merged flat vectors. This way
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import java.util.ArrayList;
import java.util.List;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter;
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
Expand Down Expand Up @@ -112,18 +111,12 @@ public Lucene99FlatVectorsWriter(SegmentWriteState state, FlatVectorsScorer scor
}

@Override
public FlatFieldVectorsWriter<?> addField(
FieldInfo fieldInfo, KnnFieldVectorsWriter<?> indexWriter) throws IOException {
FieldWriter<?> newField = FieldWriter.create(fieldInfo, indexWriter);
public FlatFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOException {
FieldWriter<?> newField = FieldWriter.create(fieldInfo);
fields.add(newField);
return newField;
}

@Override
public FlatFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOException {
return addField(fieldInfo, null);
}

@Override
public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
for (FieldWriter<?> field : fields) {
Expand All @@ -132,6 +125,7 @@ public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
} else {
writeSortingField(field, maxDoc, sortMap);
}
field.finish();
}
}

Expand Down Expand Up @@ -421,22 +415,20 @@ private abstract static class FieldWriter<T> extends FlatFieldVectorsWriter<T> {
private final int dim;
private final DocsWithFieldSet docsWithField;
private final List<T> vectors;
private boolean finished;

private int lastDocID = -1;

@SuppressWarnings("unchecked")
static FieldWriter<?> create(FieldInfo fieldInfo, KnnFieldVectorsWriter<?> indexWriter) {
static FieldWriter<?> create(FieldInfo fieldInfo) {
int dim = fieldInfo.getVectorDimension();
return switch (fieldInfo.getVectorEncoding()) {
case BYTE -> new Lucene99FlatVectorsWriter.FieldWriter<>(
fieldInfo, (KnnFieldVectorsWriter<byte[]>) indexWriter) {
case BYTE -> new Lucene99FlatVectorsWriter.FieldWriter<byte[]>(fieldInfo) {
@Override
public byte[] copyValue(byte[] value) {
return ArrayUtil.copyOfSubArray(value, 0, dim);
}
};
case FLOAT32 -> new Lucene99FlatVectorsWriter.FieldWriter<>(
fieldInfo, (KnnFieldVectorsWriter<float[]>) indexWriter) {
case FLOAT32 -> new Lucene99FlatVectorsWriter.FieldWriter<float[]>(fieldInfo) {
@Override
public float[] copyValue(float[] value) {
return ArrayUtil.copyOfSubArray(value, 0, dim);
Expand All @@ -445,8 +437,8 @@ public float[] copyValue(float[] value) {
};
}

FieldWriter(FieldInfo fieldInfo, KnnFieldVectorsWriter<T> indexWriter) {
super(indexWriter);
FieldWriter(FieldInfo fieldInfo) {
super();
this.fieldInfo = fieldInfo;
this.dim = fieldInfo.getVectorDimension();
this.docsWithField = new DocsWithFieldSet();
Expand All @@ -455,6 +447,9 @@ public float[] copyValue(float[] value) {

@Override
public void addValue(int docID, T vectorValue) throws IOException {
if (finished) {
throw new IllegalStateException("already finished, cannot add more values");
}
if (docID == lastDocID) {
throw new IllegalArgumentException(
"VectorValuesField \""
Expand All @@ -466,17 +461,11 @@ public void addValue(int docID, T vectorValue) throws IOException {
docsWithField.add(docID);
vectors.add(copy);
lastDocID = docID;
if (indexingDelegate != null) {
indexingDelegate.addValue(docID, copy);
}
}

@Override
public long ramBytesUsed() {
long size = SHALLOW_RAM_BYTES_USED;
if (indexingDelegate != null) {
size += indexingDelegate.ramBytesUsed();
}
if (vectors.size() == 0) return size;
return size
+ docsWithField.ramBytesUsed()
Expand All @@ -486,6 +475,29 @@ public long ramBytesUsed() {
* fieldInfo.getVectorDimension()
* fieldInfo.getVectorEncoding().byteSize;
}

@Override
public List<T> getVectors() {
return vectors;
}

@Override
public DocsWithFieldSet getDocsWithFieldSet() {
return docsWithField;
}

@Override
public void finish() throws IOException {
if (finished) {
return;
}
this.finished = true;
}

@Override
public boolean isFinished() {
return finished;
}
}

static final class FlatCloseableRandomVectorScorerSupplier
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,11 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter;
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
import org.apache.lucene.index.DocsWithFieldSet;
Expand Down Expand Up @@ -130,12 +132,13 @@ public KnnFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOException
FieldWriter<?> newField =
FieldWriter.create(
flatVectorWriter.getFlatVectorScorer(),
flatVectorWriter.addField(fieldInfo),
fieldInfo,
M,
beamWidth,
segmentWriteState.infoStream);
fields.add(newField);
return flatVectorWriter.addField(fieldInfo, newField);
return newField;
}

@Override
Expand Down Expand Up @@ -189,7 +192,7 @@ private void writeField(FieldWriter<?> fieldData) throws IOException {
fieldData.fieldInfo,
vectorIndexOffset,
vectorIndexLength,
fieldData.docsWithField.cardinality(),
fieldData.getDocsWithFieldSet().cardinality(),
graph,
graphLevelNodeOffsets);
}
Expand All @@ -198,7 +201,7 @@ private void writeSortingField(FieldWriter<?> fieldData, Sorter.DocMap sortMap)
throws IOException {
final int[] docIdOffsets = new int[sortMap.size()];
int offset = 1; // 0 means no vector for this (field, document)
DocIdSetIterator iterator = fieldData.docsWithField.iterator();
DocIdSetIterator iterator = fieldData.getDocsWithFieldSet().iterator();
for (int docID = iterator.nextDoc();
docID != DocIdSetIterator.NO_MORE_DOCS;
docID = iterator.nextDoc()) {
Expand Down Expand Up @@ -230,7 +233,7 @@ private void writeSortingField(FieldWriter<?> fieldData, Sorter.DocMap sortMap)
fieldData.fieldInfo,
vectorIndexOffset,
vectorIndexLength,
fieldData.docsWithField.cardinality(),
fieldData.getDocsWithFieldSet().cardinality(),
mockGraph,
graphLevelNodeOffsets);
}
Expand Down Expand Up @@ -542,42 +545,65 @@ private static class FieldWriter<T> extends KnnFieldVectorsWriter<T> {
RamUsageEstimator.shallowSizeOfInstance(FieldWriter.class);

private final FieldInfo fieldInfo;
private final DocsWithFieldSet docsWithField;
private final List<T> vectors;
private final HnswGraphBuilder hnswGraphBuilder;
private int lastDocID = -1;
private int node = 0;
private final FlatFieldVectorsWriter<T> flatFieldVectorsWriter;

@SuppressWarnings("unchecked")
static FieldWriter<?> create(
FlatVectorsScorer scorer, FieldInfo fieldInfo, int M, int beamWidth, InfoStream infoStream)
FlatVectorsScorer scorer,
FlatFieldVectorsWriter<?> flatFieldVectorsWriter,
FieldInfo fieldInfo,
int M,
int beamWidth,
InfoStream infoStream)
throws IOException {
return switch (fieldInfo.getVectorEncoding()) {
case BYTE -> new FieldWriter<byte[]>(scorer, fieldInfo, M, beamWidth, infoStream);
case FLOAT32 -> new FieldWriter<float[]>(scorer, fieldInfo, M, beamWidth, infoStream);
case BYTE -> new FieldWriter<>(
scorer,
(FlatFieldVectorsWriter<byte[]>) flatFieldVectorsWriter,
fieldInfo,
M,
beamWidth,
infoStream);
case FLOAT32 -> new FieldWriter<>(
scorer,
(FlatFieldVectorsWriter<float[]>) flatFieldVectorsWriter,
fieldInfo,
M,
beamWidth,
infoStream);
};
}

@SuppressWarnings("unchecked")
FieldWriter(
FlatVectorsScorer scorer, FieldInfo fieldInfo, int M, int beamWidth, InfoStream infoStream)
FlatVectorsScorer scorer,
FlatFieldVectorsWriter<T> flatFieldVectorsWriter,
FieldInfo fieldInfo,
int M,
int beamWidth,
InfoStream infoStream)
throws IOException {
this.fieldInfo = fieldInfo;
this.docsWithField = new DocsWithFieldSet();
vectors = new ArrayList<>();
RandomVectorScorerSupplier scorerSupplier =
switch (fieldInfo.getVectorEncoding()) {
case BYTE -> scorer.getRandomVectorScorerSupplier(
fieldInfo.getVectorSimilarityFunction(),
RandomAccessVectorValues.fromBytes(
(List<byte[]>) vectors, fieldInfo.getVectorDimension()));
(List<byte[]>) flatFieldVectorsWriter.getVectors(),
fieldInfo.getVectorDimension()));
case FLOAT32 -> scorer.getRandomVectorScorerSupplier(
fieldInfo.getVectorSimilarityFunction(),
RandomAccessVectorValues.fromFloats(
(List<float[]>) vectors, fieldInfo.getVectorDimension()));
(List<float[]>) flatFieldVectorsWriter.getVectors(),
fieldInfo.getVectorDimension()));
};
hnswGraphBuilder =
HnswGraphBuilder.create(scorerSupplier, M, beamWidth, HnswGraphBuilder.randSeed);
hnswGraphBuilder.setInfoStream(infoStream);
this.flatFieldVectorsWriter = Objects.requireNonNull(flatFieldVectorsWriter);
}

@Override
Expand All @@ -588,20 +614,23 @@ public void addValue(int docID, T vectorValue) throws IOException {
+ fieldInfo.name
+ "\" appears more than once in this document (only one value is allowed per field)");
}
assert docID > lastDocID;
vectors.add(vectorValue);
docsWithField.add(docID);
flatFieldVectorsWriter.addValue(docID, vectorValue);
hnswGraphBuilder.addGraphNode(node);
node++;
lastDocID = docID;
}

public DocsWithFieldSet getDocsWithFieldSet() {
return flatFieldVectorsWriter.getDocsWithFieldSet();
}

@Override
public T copyValue(T vectorValue) {
throw new UnsupportedOperationException();
}

OnHeapHnswGraph getGraph() {
assert flatFieldVectorsWriter.isFinished();
if (node > 0) {
return hnswGraphBuilder.getGraph();
} else {
Expand All @@ -612,9 +641,7 @@ OnHeapHnswGraph getGraph() {
@Override
public long ramBytesUsed() {
return SHALLOW_SIZE
+ docsWithField.ramBytesUsed()
+ (long) vectors.size()
* (RamUsageEstimator.NUM_BYTES_OBJECT_REF + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER)
+ flatFieldVectorsWriter.ramBytesUsed()
+ hnswGraphBuilder.getGraph().ramBytesUsed();
}
}
Expand Down
Loading